Spaces:
Running
Running
Akarshan Biswas
commited on
Commit
·
5de15cd
1
Parent(s):
4da3fb6
SYCL: Add non contiguous support in RMS_NORM and NORM kernels (llama/13611)
Browse files* SYCL: Add non contiguous input support to norm kernel
* refactor and add RMS_NORM non contiguous input support
ggml-ci
* restore subgroup reduction for multi-subgroup thread blocks in norm kernels
* Swap grid dims of nsamples and nrows
ggml-ci
* Revert "Swap grid dims of nsamples and nrows"
This reverts commit 43be2d657fec7f7fba54e2cd154106bc0fc45adf.
* restore not required changes
ggml-ci
* address review comments: change it to more like SYCL
* Use a common function to calculate offset
* remove wrap around logic for handling broadcasts
* remove static from calculate_offset fn and use ceil_div
- ggml/src/ggml-sycl/common.hpp +14 -0
- ggml/src/ggml-sycl/ggml-sycl.cpp +1 -0
- ggml/src/ggml-sycl/norm.cpp +95 -68
ggml/src/ggml-sycl/common.hpp
CHANGED
|
@@ -13,6 +13,7 @@
|
|
| 13 |
#ifndef GGML_SYCL_COMMON_HPP
|
| 14 |
#define GGML_SYCL_COMMON_HPP
|
| 15 |
|
|
|
|
| 16 |
#include <fstream>
|
| 17 |
#include <iostream>
|
| 18 |
#include <string>
|
|
@@ -481,6 +482,19 @@ static __dpct_inline__ float warp_reduce_max(float x,
|
|
| 481 |
return x;
|
| 482 |
}
|
| 483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
// Helper for vec loading aligned data
|
| 485 |
template <typename Tp, int n>
|
| 486 |
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
|
|
|
|
| 13 |
#ifndef GGML_SYCL_COMMON_HPP
|
| 14 |
#define GGML_SYCL_COMMON_HPP
|
| 15 |
|
| 16 |
+
#include <cstddef>
|
| 17 |
#include <fstream>
|
| 18 |
#include <iostream>
|
| 19 |
#include <string>
|
|
|
|
| 482 |
return x;
|
| 483 |
}
|
| 484 |
|
| 485 |
+
/* Helper for Computing the linear offset of a ggml_tensor given
|
| 486 |
+
per-dimension sizes, strides, and indices */
|
| 487 |
+
template<int N>
|
| 488 |
+
__dpct_inline__ size_t calculate_offset(const std::array<int, N> & strides, const std::array<int, N> & indices) {
|
| 489 |
+
size_t offset = 0;
|
| 490 |
+
#pragma unroll
|
| 491 |
+
for (int i = 0; i < N; i++) {
|
| 492 |
+
auto index_i = indices[i];
|
| 493 |
+
offset += strides[i] * index_i;
|
| 494 |
+
}
|
| 495 |
+
return offset;
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
// Helper for vec loading aligned data
|
| 499 |
template <typename Tp, int n>
|
| 500 |
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -4241,6 +4241,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4241 |
#endif
|
| 4242 |
case GGML_OP_NORM:
|
| 4243 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 4244 |
case GGML_OP_L2_NORM:
|
| 4245 |
case GGML_OP_GROUP_NORM:
|
| 4246 |
return ggml_is_contiguous(op->src[0]);
|
|
|
|
| 4241 |
#endif
|
| 4242 |
case GGML_OP_NORM:
|
| 4243 |
case GGML_OP_RMS_NORM:
|
| 4244 |
+
return true;
|
| 4245 |
case GGML_OP_L2_NORM:
|
| 4246 |
case GGML_OP_GROUP_NORM:
|
| 4247 |
return ggml_is_contiguous(op->src[0]);
|
ggml/src/ggml-sycl/norm.cpp
CHANGED
|
@@ -1,40 +1,50 @@
|
|
| 1 |
#include "norm.hpp"
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
static void norm_f32(const float* x, float* dst, const int ncols, const
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
const int
|
| 8 |
|
| 9 |
const int nthreads = item_ct1.get_local_range(2);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
const int nwarps = nthreads / WARP_SIZE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
| 12 |
|
| 13 |
for (int col = tid; col < ncols; col += block_size) {
|
| 14 |
-
const float xi = x[
|
| 15 |
mean_var.x() += xi;
|
| 16 |
mean_var.y() += xi * xi;
|
| 17 |
}
|
| 18 |
|
| 19 |
// sum up partial sums
|
| 20 |
mean_var = warp_reduce_sum(mean_var, item_ct1);
|
| 21 |
-
if
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
if (
|
| 26 |
-
s_sum[
|
| 27 |
}
|
| 28 |
-
/*
|
| 29 |
-
DPCT1118:0: SYCL group functions and algorithms must be encountered in
|
| 30 |
-
converged control flow. You may need to adjust the code.
|
| 31 |
-
*/
|
| 32 |
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 33 |
mean_var = 0.f;
|
| 34 |
-
size_t nreduce = nwarps
|
| 35 |
for (size_t i = 0; i < nreduce; i += 1)
|
| 36 |
{
|
| 37 |
-
mean_var += s_sum[
|
| 38 |
}
|
| 39 |
mean_var = warp_reduce_sum(mean_var, item_ct1);
|
| 40 |
}
|
|
@@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
|
|
| 44 |
const float inv_std = sycl::rsqrt(var + eps);
|
| 45 |
|
| 46 |
for (int col = tid; col < ncols; col += block_size) {
|
| 47 |
-
dst[
|
| 48 |
}
|
| 49 |
}
|
| 50 |
|
|
@@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
|
| 135 |
}
|
| 136 |
}
|
| 137 |
|
| 138 |
-
static void rms_norm_f32(const float* x, float* dst, const int ncols, const
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
const int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
const int nthreads = item_ct1.get_local_range(2);
|
|
|
|
|
|
|
| 144 |
const int nwarps = nthreads / WARP_SIZE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
float tmp = 0.0f; // partial sum for thread in warp
|
| 146 |
|
| 147 |
for (int col = tid; col < ncols; col += block_size) {
|
| 148 |
-
const float xi = x[
|
| 149 |
tmp += xi * xi;
|
| 150 |
}
|
| 151 |
|
| 152 |
// sum up partial sums
|
| 153 |
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 154 |
if (block_size > WARP_SIZE) {
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
if (
|
| 159 |
-
s_sum[
|
| 160 |
}
|
| 161 |
-
|
| 162 |
-
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
| 163 |
-
converged control flow. You may need to adjust the code.
|
| 164 |
-
*/
|
| 165 |
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 166 |
-
size_t nreduce = nwarps
|
| 167 |
tmp = 0.f;
|
| 168 |
for (size_t i = 0; i < nreduce; i += 1)
|
| 169 |
{
|
| 170 |
-
tmp += s_sum[
|
| 171 |
}
|
| 172 |
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 173 |
}
|
|
@@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|
| 176 |
const float scale = sycl::rsqrt(mean + eps);
|
| 177 |
|
| 178 |
for (int col = tid; col < ncols; col += block_size) {
|
| 179 |
-
dst[
|
| 180 |
}
|
| 181 |
}
|
| 182 |
|
|
@@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float
|
|
| 224 |
}
|
| 225 |
}
|
| 226 |
|
| 227 |
-
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 231 |
if (ncols < 1024) {
|
| 232 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
| 233 |
stream->submit([&](sycl::handler& cgh) {
|
| 234 |
cgh.parallel_for(
|
| 235 |
-
sycl::nd_range<3>(
|
| 236 |
-
block_dims),
|
| 237 |
[=](sycl::nd_item<3> item_ct1)
|
| 238 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 239 |
-
norm_f32(x, dst, ncols, eps, item_ct1,
|
| 240 |
-
nullptr, WARP_SIZE);
|
| 241 |
});
|
| 242 |
});
|
| 243 |
}
|
|
@@ -252,15 +274,12 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
| 252 |
*/
|
| 253 |
stream->submit([&](sycl::handler& cgh) {
|
| 254 |
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
| 255 |
-
|
| 256 |
-
|
| 257 |
cgh.parallel_for(
|
| 258 |
-
sycl::nd_range<3>(
|
| 259 |
-
block_dims),
|
| 260 |
[=](sycl::nd_item<3> item_ct1)
|
| 261 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 262 |
-
norm_f32(x, dst, ncols, eps, item_ct1,
|
| 263 |
-
get_pointer(s_sum_acc_ct1), work_group_size);
|
| 264 |
});
|
| 265 |
});
|
| 266 |
}
|
|
@@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
| 313 |
}
|
| 314 |
}
|
| 315 |
|
| 316 |
-
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
| 317 |
-
|
| 318 |
-
queue_ptr stream, int device) {
|
| 319 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 320 |
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
|
|
|
|
|
|
| 321 |
if (ncols < 1024) {
|
| 322 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
| 323 |
stream->submit([&](sycl::handler& cgh) {
|
| 324 |
cgh.parallel_for(
|
| 325 |
-
sycl::nd_range<3>(
|
| 326 |
-
block_dims),
|
| 327 |
[=](sycl::nd_item<3> item_ct1)
|
| 328 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 329 |
-
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
| 330 |
-
nullptr, WARP_SIZE);
|
| 331 |
});
|
| 332 |
});
|
| 333 |
}
|
|
@@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
| 344 |
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
| 345 |
cgh);
|
| 346 |
cgh.parallel_for(
|
| 347 |
-
sycl::nd_range<3>(
|
| 348 |
-
block_dims),
|
| 349 |
[=](sycl::nd_item<3> item_ct1)
|
| 350 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 351 |
-
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
| 352 |
-
get_pointer(s_sum_acc_ct1), work_group_size);
|
| 353 |
});
|
| 354 |
});
|
| 355 |
}
|
|
@@ -398,12 +414,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
| 398 |
}
|
| 399 |
|
| 400 |
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
|
| 401 |
|
| 402 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 403 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 404 |
|
| 405 |
-
|
| 406 |
-
const int64_t nrows = ggml_nrows(dst->src[0]);
|
| 407 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 408 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 409 |
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
@@ -411,8 +427,14 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
| 411 |
|
| 412 |
float eps;
|
| 413 |
memcpy(&eps, dst->op_params, sizeof(float));
|
| 414 |
-
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
}
|
| 417 |
|
| 418 |
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
@@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
| 436 |
|
| 437 |
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 438 |
|
|
|
|
| 439 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 440 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 441 |
|
| 442 |
-
const int64_t ne00 = dst->src[0]->ne[0];
|
| 443 |
-
const int64_t nrows = ggml_nrows(dst->src[0]);
|
| 444 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 445 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 446 |
|
|
@@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
| 450 |
float eps;
|
| 451 |
memcpy(&eps, dst->op_params, sizeof(float));
|
| 452 |
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
}
|
| 455 |
|
| 456 |
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
|
| 1 |
#include "norm.hpp"
|
| 2 |
+
#include "ggml-sycl/common.hpp"
|
| 3 |
+
#include "ggml-sycl/presets.hpp"
|
| 4 |
|
| 5 |
+
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
| 6 |
+
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
|
| 7 |
+
|
| 8 |
+
const int nrows = item_ct1.get_group_range(2);
|
| 9 |
+
const int nchannels = item_ct1.get_group_range(1);
|
| 10 |
|
| 11 |
const int nthreads = item_ct1.get_local_range(2);
|
| 12 |
+
const int sample = item_ct1.get_group(0);
|
| 13 |
+
const int channel = item_ct1.get_group(1);
|
| 14 |
+
const int row = item_ct1.get_group(2);
|
| 15 |
+
|
| 16 |
+
const int tid = item_ct1.get_local_id(2);
|
| 17 |
const int nwarps = nthreads / WARP_SIZE;
|
| 18 |
+
|
| 19 |
+
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
|
| 20 |
+
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
|
| 21 |
+
|
| 22 |
+
x += strided_offset;
|
| 23 |
+
dst += packed_offset;
|
| 24 |
+
|
| 25 |
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
| 26 |
|
| 27 |
for (int col = tid; col < ncols; col += block_size) {
|
| 28 |
+
const float xi = x[col];
|
| 29 |
mean_var.x() += xi;
|
| 30 |
mean_var.y() += xi * xi;
|
| 31 |
}
|
| 32 |
|
| 33 |
// sum up partial sums
|
| 34 |
mean_var = warp_reduce_sum(mean_var, item_ct1);
|
| 35 |
+
if (block_size > WARP_SIZE) {
|
| 36 |
+
const auto sub_group = item_ct1.get_sub_group();
|
| 37 |
+
const auto sg_id = sub_group.get_group_linear_id();
|
| 38 |
+
const auto wi_in_sg = sub_group.get_local_linear_id();
|
| 39 |
+
if (wi_in_sg == 0) {
|
| 40 |
+
s_sum[sg_id] = mean_var;
|
| 41 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 43 |
mean_var = 0.f;
|
| 44 |
+
const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
|
| 45 |
for (size_t i = 0; i < nreduce; i += 1)
|
| 46 |
{
|
| 47 |
+
mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
|
| 48 |
}
|
| 49 |
mean_var = warp_reduce_sum(mean_var, item_ct1);
|
| 50 |
}
|
|
|
|
| 54 |
const float inv_std = sycl::rsqrt(var + eps);
|
| 55 |
|
| 56 |
for (int col = tid; col < ncols; col += block_size) {
|
| 57 |
+
dst[col] = (x[col] - mean) * inv_std;
|
| 58 |
}
|
| 59 |
}
|
| 60 |
|
|
|
|
| 145 |
}
|
| 146 |
}
|
| 147 |
|
| 148 |
+
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
| 149 |
+
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
| 150 |
+
|
| 151 |
+
const int nrows = item_ct1.get_group_range(2);
|
| 152 |
+
const int nchannels = item_ct1.get_group_range(1);
|
| 153 |
+
|
| 154 |
+
const int sample = item_ct1.get_group(0);
|
| 155 |
+
const int channel = item_ct1.get_group(1);
|
| 156 |
+
const int row = item_ct1.get_group(2);
|
| 157 |
+
|
| 158 |
const int nthreads = item_ct1.get_local_range(2);
|
| 159 |
+
|
| 160 |
+
const int tid = item_ct1.get_local_id(2);
|
| 161 |
const int nwarps = nthreads / WARP_SIZE;
|
| 162 |
+
|
| 163 |
+
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
|
| 164 |
+
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
|
| 165 |
+
|
| 166 |
+
x += strided_offset;
|
| 167 |
+
dst += packed_offset;
|
| 168 |
+
|
| 169 |
+
|
| 170 |
float tmp = 0.0f; // partial sum for thread in warp
|
| 171 |
|
| 172 |
for (int col = tid; col < ncols; col += block_size) {
|
| 173 |
+
const float xi = x[col];
|
| 174 |
tmp += xi * xi;
|
| 175 |
}
|
| 176 |
|
| 177 |
// sum up partial sums
|
| 178 |
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 179 |
if (block_size > WARP_SIZE) {
|
| 180 |
+
const auto sub_group = item_ct1.get_sub_group();
|
| 181 |
+
const auto sg_id = sub_group.get_group_linear_id();
|
| 182 |
+
const auto wi_in_sg = sub_group.get_local_linear_id();
|
| 183 |
+
if (wi_in_sg == 0) {
|
| 184 |
+
s_sum[sg_id] = tmp;
|
| 185 |
}
|
| 186 |
+
|
|
|
|
|
|
|
|
|
|
| 187 |
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 188 |
+
const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
|
| 189 |
tmp = 0.f;
|
| 190 |
for (size_t i = 0; i < nreduce; i += 1)
|
| 191 |
{
|
| 192 |
+
tmp += s_sum[wi_in_sg + i * WARP_SIZE];
|
| 193 |
}
|
| 194 |
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 195 |
}
|
|
|
|
| 198 |
const float scale = sycl::rsqrt(mean + eps);
|
| 199 |
|
| 200 |
for (int col = tid; col < ncols; col += block_size) {
|
| 201 |
+
dst[col] = scale * x[col];
|
| 202 |
}
|
| 203 |
}
|
| 204 |
|
|
|
|
| 246 |
}
|
| 247 |
}
|
| 248 |
|
| 249 |
+
static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
| 250 |
+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
|
| 251 |
+
const float eps, queue_ptr stream, int device) {
|
| 252 |
+
|
| 253 |
+
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
| 254 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 255 |
if (ncols < 1024) {
|
| 256 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
| 257 |
stream->submit([&](sycl::handler& cgh) {
|
| 258 |
cgh.parallel_for(
|
| 259 |
+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
|
|
| 260 |
[=](sycl::nd_item<3> item_ct1)
|
| 261 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 262 |
+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
|
|
|
| 263 |
});
|
| 264 |
});
|
| 265 |
}
|
|
|
|
| 274 |
*/
|
| 275 |
stream->submit([&](sycl::handler& cgh) {
|
| 276 |
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
| 277 |
+
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
|
|
|
|
| 278 |
cgh.parallel_for(
|
| 279 |
+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
|
|
| 280 |
[=](sycl::nd_item<3> item_ct1)
|
| 281 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 282 |
+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
|
|
|
| 283 |
});
|
| 284 |
});
|
| 285 |
}
|
|
|
|
| 332 |
}
|
| 333 |
}
|
| 334 |
|
| 335 |
+
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
| 336 |
+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
|
|
|
|
| 337 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 338 |
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
| 339 |
+
|
| 340 |
+
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
| 341 |
if (ncols < 1024) {
|
| 342 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
| 343 |
stream->submit([&](sycl::handler& cgh) {
|
| 344 |
cgh.parallel_for(
|
| 345 |
+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
|
|
| 346 |
[=](sycl::nd_item<3> item_ct1)
|
| 347 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 348 |
+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
|
|
|
| 349 |
});
|
| 350 |
});
|
| 351 |
}
|
|
|
|
| 362 |
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
| 363 |
cgh);
|
| 364 |
cgh.parallel_for(
|
| 365 |
+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
|
|
| 366 |
[=](sycl::nd_item<3> item_ct1)
|
| 367 |
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 368 |
+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
|
|
|
| 369 |
});
|
| 370 |
});
|
| 371 |
}
|
|
|
|
| 414 |
}
|
| 415 |
|
| 416 |
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
| 417 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 418 |
|
| 419 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 420 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 421 |
|
| 422 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
|
| 423 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 424 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 425 |
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
|
|
| 427 |
|
| 428 |
float eps;
|
| 429 |
memcpy(&eps, dst->op_params, sizeof(float));
|
| 430 |
+
GGML_ASSERT(eps >= 0.0f);
|
| 431 |
+
const size_t ts0 = ggml_type_size(src0->type);
|
| 432 |
+
GGML_ASSERT(nb00 == ts0);
|
| 433 |
+
const int64_t s01 = nb01 / ts0;
|
| 434 |
+
const int64_t s02 = nb02 / ts0;
|
| 435 |
+
const int64_t s03 = nb03 / ts0;
|
| 436 |
+
|
| 437 |
+
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
| 438 |
}
|
| 439 |
|
| 440 |
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
|
| 458 |
|
| 459 |
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 460 |
|
| 461 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 462 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 463 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 464 |
|
|
|
|
|
|
|
| 465 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 466 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 467 |
|
|
|
|
| 471 |
float eps;
|
| 472 |
memcpy(&eps, dst->op_params, sizeof(float));
|
| 473 |
|
| 474 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
| 475 |
+
const size_t ts0 = ggml_type_size(src0->type);
|
| 476 |
+
GGML_ASSERT(nb00 == ts0);
|
| 477 |
+
const int64_t s01 = nb01 / ts0;
|
| 478 |
+
const int64_t s02 = nb02 / ts0;
|
| 479 |
+
const int64_t s03 = nb03 / ts0;
|
| 480 |
+
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
| 481 |
}
|
| 482 |
|
| 483 |
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|