Akarshan Biswas commited on
Commit
5de15cd
·
1 Parent(s): 4da3fb6

SYCL: Add non contiguous support in RMS_NORM and NORM kernels (llama/13611)

Browse files

* SYCL: Add non contiguous input support to norm kernel

* refactor and add RMS_NORM non contiguous input support

ggml-ci

* restore subgroup reduction for multi-subgroup thread blocks in norm kernels

* Swap grid dims of nsamples and nrows

ggml-ci

* Revert "Swap grid dims of nsamples and nrows"

This reverts commit 43be2d657fec7f7fba54e2cd154106bc0fc45adf.

* restore not required changes
ggml-ci

* address review comments: change it to more like SYCL

* Use a common function to calculate offset

* remove wrap around logic for handling broadcasts

* remove static from calculate_offset fn and use ceil_div

ggml/src/ggml-sycl/common.hpp CHANGED
@@ -13,6 +13,7 @@
13
  #ifndef GGML_SYCL_COMMON_HPP
14
  #define GGML_SYCL_COMMON_HPP
15
 
 
16
  #include <fstream>
17
  #include <iostream>
18
  #include <string>
@@ -481,6 +482,19 @@ static __dpct_inline__ float warp_reduce_max(float x,
481
  return x;
482
  }
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  // Helper for vec loading aligned data
485
  template <typename Tp, int n>
486
  inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
 
13
  #ifndef GGML_SYCL_COMMON_HPP
14
  #define GGML_SYCL_COMMON_HPP
15
 
16
+ #include <cstddef>
17
  #include <fstream>
18
  #include <iostream>
19
  #include <string>
 
482
  return x;
483
  }
484
 
485
+ /* Helper for Computing the linear offset of a ggml_tensor given
486
+ per-dimension sizes, strides, and indices */
487
+ template<int N>
488
+ __dpct_inline__ size_t calculate_offset(const std::array<int, N> & strides, const std::array<int, N> & indices) {
489
+ size_t offset = 0;
490
+ #pragma unroll
491
+ for (int i = 0; i < N; i++) {
492
+ auto index_i = indices[i];
493
+ offset += strides[i] * index_i;
494
+ }
495
+ return offset;
496
+ }
497
+
498
  // Helper for vec loading aligned data
499
  template <typename Tp, int n>
500
  inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -4241,6 +4241,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4241
  #endif
4242
  case GGML_OP_NORM:
4243
  case GGML_OP_RMS_NORM:
 
4244
  case GGML_OP_L2_NORM:
4245
  case GGML_OP_GROUP_NORM:
4246
  return ggml_is_contiguous(op->src[0]);
 
4241
  #endif
4242
  case GGML_OP_NORM:
4243
  case GGML_OP_RMS_NORM:
4244
+ return true;
4245
  case GGML_OP_L2_NORM:
4246
  case GGML_OP_GROUP_NORM:
4247
  return ggml_is_contiguous(op->src[0]);
ggml/src/ggml-sycl/norm.cpp CHANGED
@@ -1,40 +1,50 @@
1
  #include "norm.hpp"
 
 
2
 
3
- static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
4
- const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
5
- const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
6
- item_ct1.get_local_id(1);
7
- const int tid = item_ct1.get_local_id(2);
8
 
9
  const int nthreads = item_ct1.get_local_range(2);
 
 
 
 
 
10
  const int nwarps = nthreads / WARP_SIZE;
 
 
 
 
 
 
 
11
  sycl::float2 mean_var = sycl::float2(0.f, 0.f);
12
 
13
  for (int col = tid; col < ncols; col += block_size) {
14
- const float xi = x[row * ncols + col];
15
  mean_var.x() += xi;
16
  mean_var.y() += xi * xi;
17
  }
18
 
19
  // sum up partial sums
20
  mean_var = warp_reduce_sum(mean_var, item_ct1);
21
- if (block_size > WARP_SIZE) {
22
-
23
- int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
24
- int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
25
- if (lane_id == 0) {
26
- s_sum[warp_id] = mean_var;
27
  }
28
- /*
29
- DPCT1118:0: SYCL group functions and algorithms must be encountered in
30
- converged control flow. You may need to adjust the code.
31
- */
32
  item_ct1.barrier(sycl::access::fence_space::local_space);
33
  mean_var = 0.f;
34
- size_t nreduce = nwarps / WARP_SIZE;
35
  for (size_t i = 0; i < nreduce; i += 1)
36
  {
37
- mean_var += s_sum[lane_id + i * WARP_SIZE];
38
  }
39
  mean_var = warp_reduce_sum(mean_var, item_ct1);
40
  }
@@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
44
  const float inv_std = sycl::rsqrt(var + eps);
45
 
46
  for (int col = tid; col < ncols; col += block_size) {
47
- dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
48
  }
49
  }
50
 
@@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
135
  }
136
  }
137
 
138
- static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
139
- const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
140
- const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
141
- item_ct1.get_local_id(1);
142
- const int tid = item_ct1.get_local_id(2);
 
 
 
 
 
143
  const int nthreads = item_ct1.get_local_range(2);
 
 
144
  const int nwarps = nthreads / WARP_SIZE;
 
 
 
 
 
 
 
 
145
  float tmp = 0.0f; // partial sum for thread in warp
146
 
147
  for (int col = tid; col < ncols; col += block_size) {
148
- const float xi = x[row * ncols + col];
149
  tmp += xi * xi;
150
  }
151
 
152
  // sum up partial sums
153
  tmp = warp_reduce_sum(tmp, item_ct1);
154
  if (block_size > WARP_SIZE) {
155
-
156
- int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
157
- int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
158
- if (lane_id == 0) {
159
- s_sum[warp_id] = tmp;
160
  }
161
- /*
162
- DPCT1118:3: SYCL group functions and algorithms must be encountered in
163
- converged control flow. You may need to adjust the code.
164
- */
165
  item_ct1.barrier(sycl::access::fence_space::local_space);
166
- size_t nreduce = nwarps / WARP_SIZE;
167
  tmp = 0.f;
168
  for (size_t i = 0; i < nreduce; i += 1)
169
  {
170
- tmp += s_sum[lane_id + i * WARP_SIZE];
171
  }
172
  tmp = warp_reduce_sum(tmp, item_ct1);
173
  }
@@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
176
  const float scale = sycl::rsqrt(mean + eps);
177
 
178
  for (int col = tid; col < ncols; col += block_size) {
179
- dst[row * ncols + col] = scale * x[row * ncols + col];
180
  }
181
  }
182
 
@@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float
224
  }
225
  }
226
 
227
- static void norm_f32_sycl(const float* x, float* dst, const int ncols,
228
- const int nrows, const float eps,
229
- queue_ptr stream, int device) {
 
 
230
  GGML_ASSERT(ncols % WARP_SIZE == 0);
231
  if (ncols < 1024) {
232
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
233
  stream->submit([&](sycl::handler& cgh) {
234
  cgh.parallel_for(
235
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
236
- block_dims),
237
  [=](sycl::nd_item<3> item_ct1)
238
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
239
- norm_f32(x, dst, ncols, eps, item_ct1,
240
- nullptr, WARP_SIZE);
241
  });
242
  });
243
  }
@@ -252,15 +274,12 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
252
  */
253
  stream->submit([&](sycl::handler& cgh) {
254
  sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
255
- sycl::range<1>(work_group_size / WARP_SIZE), cgh);
256
-
257
  cgh.parallel_for(
258
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
259
- block_dims),
260
  [=](sycl::nd_item<3> item_ct1)
261
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
262
- norm_f32(x, dst, ncols, eps, item_ct1,
263
- get_pointer(s_sum_acc_ct1), work_group_size);
264
  });
265
  });
266
  }
@@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
313
  }
314
  }
315
 
316
- static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
317
- const int nrows, const float eps,
318
- queue_ptr stream, int device) {
319
  GGML_ASSERT(ncols % WARP_SIZE == 0);
320
  // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
 
 
321
  if (ncols < 1024) {
322
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
323
  stream->submit([&](sycl::handler& cgh) {
324
  cgh.parallel_for(
325
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
326
- block_dims),
327
  [=](sycl::nd_item<3> item_ct1)
328
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
329
- rms_norm_f32(x, dst, ncols, eps, item_ct1,
330
- nullptr, WARP_SIZE);
331
  });
332
  });
333
  }
@@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
344
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
345
  cgh);
346
  cgh.parallel_for(
347
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
348
- block_dims),
349
  [=](sycl::nd_item<3> item_ct1)
350
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
351
- rms_norm_f32(x, dst, ncols, eps, item_ct1,
352
- get_pointer(s_sum_acc_ct1), work_group_size);
353
  });
354
  });
355
  }
@@ -398,12 +414,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
398
  }
399
 
400
  void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
401
 
402
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
403
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
404
 
405
- const int64_t ne00 = dst->src[0]->ne[0];
406
- const int64_t nrows = ggml_nrows(dst->src[0]);
407
  dpct::queue_ptr main_stream = ctx.stream();
408
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
409
  const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
@@ -411,8 +427,14 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
411
 
412
  float eps;
413
  memcpy(&eps, dst->op_params, sizeof(float));
414
-
415
- norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
 
 
 
 
 
 
416
  }
417
 
418
  void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
436
 
437
  void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
438
 
 
439
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
440
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
441
 
442
- const int64_t ne00 = dst->src[0]->ne[0];
443
- const int64_t nrows = ggml_nrows(dst->src[0]);
444
  dpct::queue_ptr main_stream = ctx.stream();
445
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
446
 
@@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
450
  float eps;
451
  memcpy(&eps, dst->op_params, sizeof(float));
452
 
453
- rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
 
 
 
 
 
 
454
  }
455
 
456
  void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
1
  #include "norm.hpp"
2
+ #include "ggml-sycl/common.hpp"
3
+ #include "ggml-sycl/presets.hpp"
4
 
5
+ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
6
+ const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
7
+
8
+ const int nrows = item_ct1.get_group_range(2);
9
+ const int nchannels = item_ct1.get_group_range(1);
10
 
11
  const int nthreads = item_ct1.get_local_range(2);
12
+ const int sample = item_ct1.get_group(0);
13
+ const int channel = item_ct1.get_group(1);
14
+ const int row = item_ct1.get_group(2);
15
+
16
+ const int tid = item_ct1.get_local_id(2);
17
  const int nwarps = nthreads / WARP_SIZE;
18
+
19
+ const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
20
+ const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21
+
22
+ x += strided_offset;
23
+ dst += packed_offset;
24
+
25
  sycl::float2 mean_var = sycl::float2(0.f, 0.f);
26
 
27
  for (int col = tid; col < ncols; col += block_size) {
28
+ const float xi = x[col];
29
  mean_var.x() += xi;
30
  mean_var.y() += xi * xi;
31
  }
32
 
33
  // sum up partial sums
34
  mean_var = warp_reduce_sum(mean_var, item_ct1);
35
+ if (block_size > WARP_SIZE) {
36
+ const auto sub_group = item_ct1.get_sub_group();
37
+ const auto sg_id = sub_group.get_group_linear_id();
38
+ const auto wi_in_sg = sub_group.get_local_linear_id();
39
+ if (wi_in_sg == 0) {
40
+ s_sum[sg_id] = mean_var;
41
  }
 
 
 
 
42
  item_ct1.barrier(sycl::access::fence_space::local_space);
43
  mean_var = 0.f;
44
+ const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
45
  for (size_t i = 0; i < nreduce; i += 1)
46
  {
47
+ mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
48
  }
49
  mean_var = warp_reduce_sum(mean_var, item_ct1);
50
  }
 
54
  const float inv_std = sycl::rsqrt(var + eps);
55
 
56
  for (int col = tid; col < ncols; col += block_size) {
57
+ dst[col] = (x[col] - mean) * inv_std;
58
  }
59
  }
60
 
 
145
  }
146
  }
147
 
148
+ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
149
+ const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
150
+
151
+ const int nrows = item_ct1.get_group_range(2);
152
+ const int nchannels = item_ct1.get_group_range(1);
153
+
154
+ const int sample = item_ct1.get_group(0);
155
+ const int channel = item_ct1.get_group(1);
156
+ const int row = item_ct1.get_group(2);
157
+
158
  const int nthreads = item_ct1.get_local_range(2);
159
+
160
+ const int tid = item_ct1.get_local_id(2);
161
  const int nwarps = nthreads / WARP_SIZE;
162
+
163
+ const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
164
+ const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
165
+
166
+ x += strided_offset;
167
+ dst += packed_offset;
168
+
169
+
170
  float tmp = 0.0f; // partial sum for thread in warp
171
 
172
  for (int col = tid; col < ncols; col += block_size) {
173
+ const float xi = x[col];
174
  tmp += xi * xi;
175
  }
176
 
177
  // sum up partial sums
178
  tmp = warp_reduce_sum(tmp, item_ct1);
179
  if (block_size > WARP_SIZE) {
180
+ const auto sub_group = item_ct1.get_sub_group();
181
+ const auto sg_id = sub_group.get_group_linear_id();
182
+ const auto wi_in_sg = sub_group.get_local_linear_id();
183
+ if (wi_in_sg == 0) {
184
+ s_sum[sg_id] = tmp;
185
  }
186
+
 
 
 
187
  item_ct1.barrier(sycl::access::fence_space::local_space);
188
+ const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
189
  tmp = 0.f;
190
  for (size_t i = 0; i < nreduce; i += 1)
191
  {
192
+ tmp += s_sum[wi_in_sg + i * WARP_SIZE];
193
  }
194
  tmp = warp_reduce_sum(tmp, item_ct1);
195
  }
 
198
  const float scale = sycl::rsqrt(mean + eps);
199
 
200
  for (int col = tid; col < ncols; col += block_size) {
201
+ dst[col] = scale * x[col];
202
  }
203
  }
204
 
 
246
  }
247
  }
248
 
249
+ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
250
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
251
+ const float eps, queue_ptr stream, int device) {
252
+
253
+ const sycl::range<3> global_dims(nsamples, nchannels, nrows);
254
  GGML_ASSERT(ncols % WARP_SIZE == 0);
255
  if (ncols < 1024) {
256
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
257
  stream->submit([&](sycl::handler& cgh) {
258
  cgh.parallel_for(
259
+ sycl::nd_range<3>(global_dims * block_dims, block_dims),
 
260
  [=](sycl::nd_item<3> item_ct1)
261
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
262
+ norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
 
263
  });
264
  });
265
  }
 
274
  */
275
  stream->submit([&](sycl::handler& cgh) {
276
  sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
277
+ sycl::range<1>(work_group_size / WARP_SIZE), cgh);
 
278
  cgh.parallel_for(
279
+ sycl::nd_range<3>(global_dims * block_dims, block_dims),
 
280
  [=](sycl::nd_item<3> item_ct1)
281
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
282
+ norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
 
283
  });
284
  });
285
  }
 
332
  }
333
  }
334
 
335
+ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
336
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
 
337
  GGML_ASSERT(ncols % WARP_SIZE == 0);
338
  // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
339
+
340
+ const sycl::range<3> global_dims(nsamples, nchannels, nrows);
341
  if (ncols < 1024) {
342
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
343
  stream->submit([&](sycl::handler& cgh) {
344
  cgh.parallel_for(
345
+ sycl::nd_range<3>(global_dims * block_dims, block_dims),
 
346
  [=](sycl::nd_item<3> item_ct1)
347
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
348
+ rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
 
349
  });
350
  });
351
  }
 
362
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
363
  cgh);
364
  cgh.parallel_for(
365
+ sycl::nd_range<3>(global_dims * block_dims, block_dims),
 
366
  [=](sycl::nd_item<3> item_ct1)
367
  [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
368
+ rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
 
369
  });
370
  });
371
  }
 
414
  }
415
 
416
  void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
417
+ const ggml_tensor * src0 = dst->src[0];
418
 
419
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
420
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
421
 
422
+ GGML_TENSOR_UNARY_OP_LOCALS
 
423
  dpct::queue_ptr main_stream = ctx.stream();
424
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
425
  const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
 
427
 
428
  float eps;
429
  memcpy(&eps, dst->op_params, sizeof(float));
430
+ GGML_ASSERT(eps >= 0.0f);
431
+ const size_t ts0 = ggml_type_size(src0->type);
432
+ GGML_ASSERT(nb00 == ts0);
433
+ const int64_t s01 = nb01 / ts0;
434
+ const int64_t s02 = nb02 / ts0;
435
+ const int64_t s03 = nb03 / ts0;
436
+
437
+ norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
438
  }
439
 
440
  void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
458
 
459
  void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
460
 
461
+ const ggml_tensor * src0 = dst->src[0];
462
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
463
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
464
 
 
 
465
  dpct::queue_ptr main_stream = ctx.stream();
466
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
467
 
 
471
  float eps;
472
  memcpy(&eps, dst->op_params, sizeof(float));
473
 
474
+ GGML_TENSOR_UNARY_OP_LOCALS
475
+ const size_t ts0 = ggml_type_size(src0->type);
476
+ GGML_ASSERT(nb00 == ts0);
477
+ const int64_t s01 = nb01 / ts0;
478
+ const int64_t s02 = nb02 / ts0;
479
+ const int64_t s03 = nb03 / ts0;
480
+ rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
481
  }
482
 
483
  void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {