ggerganov commited on
Commit
aaf2d96
·
1 Parent(s): 68ded09

cuda : fix rope with partial rotation and non-cont src (llama/14580)

Browse files

* cuda : fix rope non-cont

ggml-ci

* cont : fix multi-rope + add test

ggml-ci

* sycl : try fix

ggml-ci

* cont : fix sycl + clean-up cuda

ggml-ci

ggml/src/ggml-cuda/rope.cu CHANGED
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
50
 
51
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
52
 
53
- if (i0 >= n_dims) {
54
- const int i = row_dst*ne0 + i0;
55
-
56
- dst[i + 0] = x[i + 0];
57
- dst[i + 1] = x[i + 1];
58
-
59
- return;
60
- }
61
-
62
  const int row_x = row_dst % ne1;
63
  const int channel_x = row_dst / ne1;
64
 
65
  const int idst = row_dst*ne0 + i0;
66
  const int ix = channel_x*s2 + row_x*s1 + i0;
67
 
 
 
 
 
 
 
 
68
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
69
 
70
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
94
 
95
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
96
 
97
- if (i0 >= n_dims) {
98
- const int i = row_dst*ne0 + i0;
99
-
100
- dst[i + 0] = x[i + 0];
101
- dst[i + 1] = x[i + 1];
102
-
103
- return;
104
- }
105
-
106
  const int row_x = row_dst % ne1;
107
  const int channel_x = row_dst / ne1;
108
 
109
  const int idst = row_dst*ne0 + i0/2;
110
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
111
 
 
 
 
 
 
 
 
112
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
113
 
114
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
138
 
139
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
140
 
141
- if (i0 >= n_dims) {
142
- const int i = row_dst*ne0 + i0;
143
-
144
- dst[i + 0] = x[i + 0];
145
- dst[i + 1] = x[i + 1];
146
-
147
- return;
148
- }
149
-
150
  const int row_x = row_dst % ne1;
151
  const int channel_x = row_dst / ne1;
152
 
153
  const int idst = row_dst*ne0 + i0/2;
154
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
155
 
 
 
 
 
 
 
 
156
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
157
  const int sec_w = sections.v[1] + sections.v[0];
158
  const int sector = (i0 / 2) % sect_dims;
 
50
 
51
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
52
 
 
 
 
 
 
 
 
 
 
53
  const int row_x = row_dst % ne1;
54
  const int channel_x = row_dst / ne1;
55
 
56
  const int idst = row_dst*ne0 + i0;
57
  const int ix = channel_x*s2 + row_x*s1 + i0;
58
 
59
+ if (i0 >= n_dims) {
60
+ dst[idst + 0] = x[ix + 0];
61
+ dst[idst + 1] = x[ix + 1];
62
+
63
+ return;
64
+ }
65
+
66
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
67
 
68
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
92
 
93
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
94
 
 
 
 
 
 
 
 
 
 
95
  const int row_x = row_dst % ne1;
96
  const int channel_x = row_dst / ne1;
97
 
98
  const int idst = row_dst*ne0 + i0/2;
99
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
100
 
101
+ if (i0 >= n_dims) {
102
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
104
+
105
+ return;
106
+ }
107
+
108
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
109
 
110
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
134
 
135
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
136
 
 
 
 
 
 
 
 
 
 
137
  const int row_x = row_dst % ne1;
138
  const int channel_x = row_dst / ne1;
139
 
140
  const int idst = row_dst*ne0 + i0/2;
141
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
142
 
143
+ if (i0 >= n_dims) {
144
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
145
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
146
+
147
+ return;
148
+ }
149
+
150
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
151
  const int sec_w = sections.v[1] + sections.v[0];
152
  const int sector = (i0 / 2) % sect_dims;
ggml/src/ggml-sycl/rope.cpp CHANGED
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
47
 
48
  const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
49
 
50
- if (i0 >= n_dims) {
51
- const int i = row * ne0 + i0;
52
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
53
- return;
54
- }
55
-
56
  const int row0 = row % ne1;
57
  const int channel0 = row / ne1;
58
 
59
  const int i = row * ne0 + i0;
60
  const int i2 = channel0 * s2 + row0 * s1 + i0;
61
 
 
 
 
 
 
62
  const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
63
 
64
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
88
 
89
  const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
90
 
91
- if (i0 >= n_dims) {
92
- const int i = row * ne0 + i0;
93
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
94
- return;
95
- }
96
-
97
  const int row0 = row % ne1;
98
  const int channel0 = row / ne1;
99
 
100
  const int i = row * ne0 + i0 / 2;
101
  const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
102
 
 
 
 
 
 
103
  const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
104
 
105
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
129
  }
130
  const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
131
 
132
- if (i0 >= n_dims) {
133
- const int i = row_dst*ne0 + i0;
134
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
135
- return;
136
- }
137
-
138
  const int row_x = row_dst % ne1;
139
  const int channel_x = row_dst / ne1;
140
  const int idst = (row_dst * ne0) + (i0 / 2);
141
  const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
142
 
 
 
 
 
 
143
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
144
  const int sec_w = sections.v[1] + sections.v[0];
145
  const int sector = (i0 / 2) % sect_dims;
 
47
 
48
  const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
49
 
 
 
 
 
 
 
50
  const int row0 = row % ne1;
51
  const int channel0 = row / ne1;
52
 
53
  const int i = row * ne0 + i0;
54
  const int i2 = channel0 * s2 + row0 * s1 + i0;
55
 
56
+ if (i0 >= n_dims) {
57
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
58
+ return;
59
+ }
60
+
61
  const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
62
 
63
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
87
 
88
  const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
89
 
 
 
 
 
 
 
90
  const int row0 = row % ne1;
91
  const int channel0 = row / ne1;
92
 
93
  const int i = row * ne0 + i0 / 2;
94
  const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
95
 
96
+ if (i0 >= n_dims) {
97
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
98
+ return;
99
+ }
100
+
101
  const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
102
 
103
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
127
  }
128
  const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
129
 
 
 
 
 
 
 
130
  const int row_x = row_dst % ne1;
131
  const int channel_x = row_dst / ne1;
132
  const int idst = (row_dst * ne0) + (i0 / 2);
133
  const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
134
 
135
+ if (i0 >= n_dims) {
136
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
137
+ return;
138
+ }
139
+
140
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
141
  const int sec_w = sections.v[1] + sections.v[0];
142
  const int sector = (i0 / 2) % sect_dims;