Spaces:
Sleeping
Sleeping
cuda : fix rope with partial rotation and non-cont src (llama/14580)
Browse files* cuda : fix rope non-cont
ggml-ci
* cont : fix multi-rope + add test
ggml-ci
* sycl : try fix
ggml-ci
* cont : fix sycl + clean-up cuda
ggml-ci
- ggml/src/ggml-cuda/rope.cu +21 -27
- ggml/src/ggml-sycl/rope.cpp +15 -18
ggml/src/ggml-cuda/rope.cu
CHANGED
|
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
|
|
| 50 |
|
| 51 |
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 52 |
|
| 53 |
-
if (i0 >= n_dims) {
|
| 54 |
-
const int i = row_dst*ne0 + i0;
|
| 55 |
-
|
| 56 |
-
dst[i + 0] = x[i + 0];
|
| 57 |
-
dst[i + 1] = x[i + 1];
|
| 58 |
-
|
| 59 |
-
return;
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
const int row_x = row_dst % ne1;
|
| 63 |
const int channel_x = row_dst / ne1;
|
| 64 |
|
| 65 |
const int idst = row_dst*ne0 + i0;
|
| 66 |
const int ix = channel_x*s2 + row_x*s1 + i0;
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
| 69 |
|
| 70 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
|
|
| 94 |
|
| 95 |
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 96 |
|
| 97 |
-
if (i0 >= n_dims) {
|
| 98 |
-
const int i = row_dst*ne0 + i0;
|
| 99 |
-
|
| 100 |
-
dst[i + 0] = x[i + 0];
|
| 101 |
-
dst[i + 1] = x[i + 1];
|
| 102 |
-
|
| 103 |
-
return;
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
const int row_x = row_dst % ne1;
|
| 107 |
const int channel_x = row_dst / ne1;
|
| 108 |
|
| 109 |
const int idst = row_dst*ne0 + i0/2;
|
| 110 |
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
| 113 |
|
| 114 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
|
|
| 138 |
|
| 139 |
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 140 |
|
| 141 |
-
if (i0 >= n_dims) {
|
| 142 |
-
const int i = row_dst*ne0 + i0;
|
| 143 |
-
|
| 144 |
-
dst[i + 0] = x[i + 0];
|
| 145 |
-
dst[i + 1] = x[i + 1];
|
| 146 |
-
|
| 147 |
-
return;
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
const int row_x = row_dst % ne1;
|
| 151 |
const int channel_x = row_dst / ne1;
|
| 152 |
|
| 153 |
const int idst = row_dst*ne0 + i0/2;
|
| 154 |
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
| 157 |
const int sec_w = sections.v[1] + sections.v[0];
|
| 158 |
const int sector = (i0 / 2) % sect_dims;
|
|
|
|
| 50 |
|
| 51 |
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
const int row_x = row_dst % ne1;
|
| 54 |
const int channel_x = row_dst / ne1;
|
| 55 |
|
| 56 |
const int idst = row_dst*ne0 + i0;
|
| 57 |
const int ix = channel_x*s2 + row_x*s1 + i0;
|
| 58 |
|
| 59 |
+
if (i0 >= n_dims) {
|
| 60 |
+
dst[idst + 0] = x[ix + 0];
|
| 61 |
+
dst[idst + 1] = x[ix + 1];
|
| 62 |
+
|
| 63 |
+
return;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
| 67 |
|
| 68 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
|
|
| 92 |
|
| 93 |
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
const int row_x = row_dst % ne1;
|
| 96 |
const int channel_x = row_dst / ne1;
|
| 97 |
|
| 98 |
const int idst = row_dst*ne0 + i0/2;
|
| 99 |
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
| 100 |
|
| 101 |
+
if (i0 >= n_dims) {
|
| 102 |
+
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
| 103 |
+
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
| 104 |
+
|
| 105 |
+
return;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
| 109 |
|
| 110 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
|
|
| 134 |
|
| 135 |
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
const int row_x = row_dst % ne1;
|
| 138 |
const int channel_x = row_dst / ne1;
|
| 139 |
|
| 140 |
const int idst = row_dst*ne0 + i0/2;
|
| 141 |
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
| 142 |
|
| 143 |
+
if (i0 >= n_dims) {
|
| 144 |
+
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
| 145 |
+
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
| 146 |
+
|
| 147 |
+
return;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
| 151 |
const int sec_w = sections.v[1] + sections.v[0];
|
| 152 |
const int sector = (i0 / 2) % sect_dims;
|
ggml/src/ggml-sycl/rope.cpp
CHANGED
|
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
|
| 47 |
|
| 48 |
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
| 49 |
|
| 50 |
-
if (i0 >= n_dims) {
|
| 51 |
-
const int i = row * ne0 + i0;
|
| 52 |
-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
| 53 |
-
return;
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
const int row0 = row % ne1;
|
| 57 |
const int channel0 = row / ne1;
|
| 58 |
|
| 59 |
const int i = row * ne0 + i0;
|
| 60 |
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
| 63 |
|
| 64 |
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
| 88 |
|
| 89 |
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
| 90 |
|
| 91 |
-
if (i0 >= n_dims) {
|
| 92 |
-
const int i = row * ne0 + i0;
|
| 93 |
-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
| 94 |
-
return;
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
const int row0 = row % ne1;
|
| 98 |
const int channel0 = row / ne1;
|
| 99 |
|
| 100 |
const int i = row * ne0 + i0 / 2;
|
| 101 |
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
| 104 |
|
| 105 |
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
|
| 129 |
}
|
| 130 |
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
| 131 |
|
| 132 |
-
if (i0 >= n_dims) {
|
| 133 |
-
const int i = row_dst*ne0 + i0;
|
| 134 |
-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
| 135 |
-
return;
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
const int row_x = row_dst % ne1;
|
| 139 |
const int channel_x = row_dst / ne1;
|
| 140 |
const int idst = (row_dst * ne0) + (i0 / 2);
|
| 141 |
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
| 144 |
const int sec_w = sections.v[1] + sections.v[0];
|
| 145 |
const int sector = (i0 / 2) % sect_dims;
|
|
|
|
| 47 |
|
| 48 |
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
const int row0 = row % ne1;
|
| 51 |
const int channel0 = row / ne1;
|
| 52 |
|
| 53 |
const int i = row * ne0 + i0;
|
| 54 |
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
| 55 |
|
| 56 |
+
if (i0 >= n_dims) {
|
| 57 |
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
| 58 |
+
return;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
| 62 |
|
| 63 |
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
|
|
| 87 |
|
| 88 |
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
const int row0 = row % ne1;
|
| 91 |
const int channel0 = row / ne1;
|
| 92 |
|
| 93 |
const int i = row * ne0 + i0 / 2;
|
| 94 |
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
| 95 |
|
| 96 |
+
if (i0 >= n_dims) {
|
| 97 |
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
| 98 |
+
return;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
| 102 |
|
| 103 |
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
|
|
| 127 |
}
|
| 128 |
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
const int row_x = row_dst % ne1;
|
| 131 |
const int channel_x = row_dst / ne1;
|
| 132 |
const int idst = (row_dst * ne0) + (i0 / 2);
|
| 133 |
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
| 134 |
|
| 135 |
+
if (i0 >= n_dims) {
|
| 136 |
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
| 137 |
+
return;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
| 141 |
const int sec_w = sections.v[1] + sections.v[0];
|
| 142 |
const int sector = (i0 / 2) % sect_dims;
|