Spaces:
Sleeping
Sleeping
Anton Mitkov
commited on
Commit
·
2722bea
1
Parent(s):
1f97ff4
sycl: Batched mulmat rework for oneDNN dispatch (llama/14617)
Browse files- ggml/src/ggml-sycl/gemm.hpp +14 -26
- ggml/src/ggml-sycl/ggml-sycl.cpp +119 -46
ggml/src/ggml-sycl/gemm.hpp
CHANGED
|
@@ -32,39 +32,28 @@ public:
|
|
| 32 |
else static_assert(0);
|
| 33 |
}
|
| 34 |
|
| 35 |
-
// matrix A has m rows, k columns
|
| 36 |
-
// matrix B has k rows, n columns
|
| 37 |
-
// nra - number of elements to skip when moving into next row in A
|
| 38 |
-
// nrb - number of elements to skip when moving into next row in B
|
| 39 |
-
// nca - number of elements to skip when moving into next column in A
|
| 40 |
-
// ncb - number of elements to skip when moving into next column in B
|
| 41 |
-
// stride_a - number of elements to skip when moving to next A matrix
|
| 42 |
-
// stride_b - number of elements to skip when moving to next B matrix
|
| 43 |
-
// batches_a - number of A matrices
|
| 44 |
-
// batches_b - number of B matrices
|
| 45 |
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
| 46 |
-
const void * a, dt at, dnnl_dim_t
|
| 47 |
-
const void * b, dt bt, dnnl_dim_t
|
| 48 |
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
|
| 49 |
|
| 50 |
auto stream = ctx.stream_dnnl(q);
|
| 51 |
auto eng = ctx.engine_dnnl(q);
|
| 52 |
|
| 53 |
-
|
| 54 |
-
dnnl::memory::dims
|
| 55 |
-
dnnl::memory::dims b_dims = { batches_b, k, n };
|
| 56 |
-
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
|
| 57 |
-
|
| 58 |
-
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
|
| 59 |
-
dnnl::memory::dims a_strides = { stride_a, nra, nca };
|
| 60 |
-
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
|
| 61 |
-
|
| 62 |
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
|
|
|
|
|
|
|
|
|
|
| 63 |
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
| 64 |
-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
|
| 65 |
|
|
|
|
|
|
|
|
|
|
| 66 |
dnnl::primitive_attr primitive_attr;
|
| 67 |
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
|
|
|
| 68 |
#ifdef GGML_SYCL_F16
|
| 69 |
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
|
| 70 |
#endif
|
|
@@ -76,24 +65,23 @@ public:
|
|
| 76 |
|
| 77 |
auto scratchpad_md = matmul_pd.scratchpad_desc();
|
| 78 |
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
|
|
|
|
| 79 |
auto matmul_prim = dnnl::matmul(matmul_pd);
|
| 80 |
|
| 81 |
std::unordered_map<int, dnnl::memory> matmul_args;
|
| 82 |
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
| 83 |
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
|
|
|
| 84 |
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
| 85 |
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
|
| 86 |
|
| 87 |
matmul_prim.execute(stream, matmul_args);
|
| 88 |
}
|
| 89 |
|
| 90 |
-
// matrices A and B are column major, both having k rows
|
| 91 |
-
// matrix A has m column, matrix B has n columns
|
| 92 |
-
// output: column major matrix C = A transposed * B
|
| 93 |
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
| 94 |
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
| 95 |
|
| 96 |
-
gemm(ctx, m, n, k, a, at,
|
| 97 |
}
|
| 98 |
};
|
| 99 |
|
|
|
|
| 32 |
else static_assert(0);
|
| 33 |
}
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
| 36 |
+
const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
|
| 37 |
+
const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
|
| 38 |
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
|
| 39 |
|
| 40 |
auto stream = ctx.stream_dnnl(q);
|
| 41 |
auto eng = ctx.engine_dnnl(q);
|
| 42 |
|
| 43 |
+
dnnl::memory::dims a_dims = {batches_a, m, k };
|
| 44 |
+
dnnl::memory::dims a_strides = {stra2, stra1, stra0};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
|
| 46 |
+
|
| 47 |
+
dnnl::memory::dims b_dims = {batches_b, k, n };
|
| 48 |
+
dnnl::memory::dims b_strides = {strb2, strb0, strb1};
|
| 49 |
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
|
|
|
| 50 |
|
| 51 |
+
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
|
| 52 |
+
dnnl::memory::dims c_strides = {m*n, 1, m };
|
| 53 |
+
const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
|
| 54 |
dnnl::primitive_attr primitive_attr;
|
| 55 |
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
| 56 |
+
|
| 57 |
#ifdef GGML_SYCL_F16
|
| 58 |
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
|
| 59 |
#endif
|
|
|
|
| 65 |
|
| 66 |
auto scratchpad_md = matmul_pd.scratchpad_desc();
|
| 67 |
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
|
| 68 |
+
|
| 69 |
auto matmul_prim = dnnl::matmul(matmul_pd);
|
| 70 |
|
| 71 |
std::unordered_map<int, dnnl::memory> matmul_args;
|
| 72 |
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
| 73 |
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
| 74 |
+
|
| 75 |
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
| 76 |
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
|
| 77 |
|
| 78 |
matmul_prim.execute(stream, matmul_args);
|
| 79 |
}
|
| 80 |
|
|
|
|
|
|
|
|
|
|
| 81 |
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
| 82 |
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
| 83 |
|
| 84 |
+
gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
|
| 85 |
}
|
| 86 |
};
|
| 87 |
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -1546,7 +1546,7 @@ static void mul_mat_p021_f16_f32(
|
|
| 1546 |
|
| 1547 |
static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
| 1548 |
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
|
| 1549 |
-
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
|
| 1550 |
const sycl::nd_item<3> &item_ct1) {
|
| 1551 |
|
| 1552 |
const sycl::half *x = (const sycl::half *)vx;
|
|
@@ -1557,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
| 1557 |
item_ct1.get_local_id(0);
|
| 1558 |
const int channel_x = channel / channel_x_divisor;
|
| 1559 |
|
| 1560 |
-
const int nrows_y = ncols_x;
|
| 1561 |
const int nrows_dst = nrows_x;
|
| 1562 |
const int row_dst = row_x;
|
| 1563 |
|
|
@@ -1576,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
| 1576 |
const int row_y = col_x;
|
| 1577 |
|
| 1578 |
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
| 1579 |
-
const int iy = channel*
|
| 1580 |
|
| 1581 |
const float xi =
|
| 1582 |
sycl::vec<sycl::half, 1>(x[ix])
|
|
@@ -1823,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
|
| 1823 |
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
| 1824 |
const void *vx, const float *y, float *dst, const int ncols_x,
|
| 1825 |
const int nrows_x, const int row_stride_x, const int nchannels_x,
|
| 1826 |
-
const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
|
| 1827 |
|
| 1828 |
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
|
| 1829 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
@@ -1835,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
| 1835 |
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 1836 |
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 1837 |
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
| 1838 |
-
row_stride_x, channel_stride_x,
|
| 1839 |
nchannels_y / nchannels_x, item_ct1);
|
| 1840 |
});
|
| 1841 |
}
|
|
@@ -2124,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2124 |
|
| 2125 |
#if GGML_SYCL_DNNL
|
| 2126 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2127 |
-
|
| 2128 |
-
|
| 2129 |
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
| 2130 |
}
|
| 2131 |
else
|
|
@@ -2171,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2171 |
|
| 2172 |
#if GGML_SYCL_DNNL
|
| 2173 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2174 |
-
DnnlGemmWrapper::row_gemm(ctx,
|
| 2175 |
-
DnnlGemmWrapper::to_dt<float>(),
|
| 2176 |
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
| 2177 |
}
|
| 2178 |
else
|
|
@@ -2776,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
| 2776 |
const int64_t nb02 = src0->nb[2];
|
| 2777 |
|
| 2778 |
const int64_t ne12 = src1->ne[2];
|
|
|
|
| 2779 |
|
| 2780 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 2781 |
queue_ptr main_stream = ctx.stream();
|
|
@@ -2786,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
| 2786 |
|
| 2787 |
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
|
| 2788 |
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
|
|
|
|
| 2789 |
|
| 2790 |
-
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
| 2791 |
}
|
| 2792 |
catch (sycl::exception const &exc) {
|
| 2793 |
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
|
@@ -2841,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
| 2841 |
float * dst_ddf = static_cast<float *>(dst->data);
|
| 2842 |
|
| 2843 |
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
|
|
|
| 2844 |
const size_t type_size_src1 = ggml_type_size(src1->type);
|
| 2845 |
-
GGML_ASSERT(nb10 == type_size_src1);
|
| 2846 |
|
| 2847 |
// SRC1 strides
|
| 2848 |
int64_t s11 = nb11 / type_size_src1;
|
|
@@ -2854,11 +2855,32 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
| 2854 |
if (src1->type != GGML_TYPE_F16) {
|
| 2855 |
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
|
| 2856 |
" : converting src1 to fp16");
|
| 2857 |
-
|
| 2858 |
-
|
| 2859 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2860 |
src1_f16_alloc.alloc(ne_src1);
|
| 2861 |
-
|
|
|
|
|
|
|
|
|
|
| 2862 |
|
| 2863 |
src1_f16 = src1_f16_alloc.get();
|
| 2864 |
s11 = ne10;
|
|
@@ -2892,38 +2914,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
| 2892 |
|
| 2893 |
#if GGML_SYCL_DNNL
|
| 2894 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2895 |
-
|
| 2896 |
-
|
| 2897 |
-
|
| 2898 |
-
|
| 2899 |
-
|
| 2900 |
-
|
| 2901 |
-
|
| 2902 |
-
|
| 2903 |
-
|
| 2904 |
-
|
| 2905 |
-
|
| 2906 |
-
|
| 2907 |
-
|
| 2908 |
-
|
| 2909 |
-
|
| 2910 |
-
|
| 2911 |
-
|
| 2912 |
-
|
| 2913 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2914 |
}
|
| 2915 |
-
}
|
| 2916 |
-
|
| 2917 |
-
|
| 2918 |
-
|
| 2919 |
-
|
| 2920 |
-
|
| 2921 |
-
|
| 2922 |
-
|
| 2923 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2924 |
}
|
| 2925 |
}
|
| 2926 |
-
|
| 2927 |
}
|
| 2928 |
else
|
| 2929 |
#endif
|
|
@@ -3263,10 +3336,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
| 3263 |
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
| 3264 |
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
| 3265 |
}
|
| 3266 |
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) &&
|
| 3267 |
// KQV single-batch
|
| 3268 |
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
| 3269 |
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
| 3270 |
// KQ + KQV multi-batch
|
| 3271 |
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
| 3272 |
} else if (use_dequantize_mul_mat_vec) {
|
|
|
|
| 1546 |
|
| 1547 |
static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
| 1548 |
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
|
| 1549 |
+
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
|
| 1550 |
const sycl::nd_item<3> &item_ct1) {
|
| 1551 |
|
| 1552 |
const sycl::half *x = (const sycl::half *)vx;
|
|
|
|
| 1557 |
item_ct1.get_local_id(0);
|
| 1558 |
const int channel_x = channel / channel_x_divisor;
|
| 1559 |
|
|
|
|
| 1560 |
const int nrows_dst = nrows_x;
|
| 1561 |
const int row_dst = row_x;
|
| 1562 |
|
|
|
|
| 1575 |
const int row_y = col_x;
|
| 1576 |
|
| 1577 |
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
| 1578 |
+
const int iy = channel * channel_stride_y + row_y;
|
| 1579 |
|
| 1580 |
const float xi =
|
| 1581 |
sycl::vec<sycl::half, 1>(x[ix])
|
|
|
|
| 1822 |
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
| 1823 |
const void *vx, const float *y, float *dst, const int ncols_x,
|
| 1824 |
const int nrows_x, const int row_stride_x, const int nchannels_x,
|
| 1825 |
+
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
|
| 1826 |
|
| 1827 |
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
|
| 1828 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
|
|
| 1834 |
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 1835 |
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
| 1836 |
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
| 1837 |
+
row_stride_x, channel_stride_x, channel_stride_y,
|
| 1838 |
nchannels_y / nchannels_x, item_ct1);
|
| 1839 |
});
|
| 1840 |
}
|
|
|
|
| 2123 |
|
| 2124 |
#if GGML_SYCL_DNNL
|
| 2125 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2126 |
+
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
|
| 2127 |
+
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
| 2128 |
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
| 2129 |
}
|
| 2130 |
else
|
|
|
|
| 2170 |
|
| 2171 |
#if GGML_SYCL_DNNL
|
| 2172 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2173 |
+
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
|
| 2174 |
+
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
| 2175 |
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
| 2176 |
}
|
| 2177 |
else
|
|
|
|
| 2775 |
const int64_t nb02 = src0->nb[2];
|
| 2776 |
|
| 2777 |
const int64_t ne12 = src1->ne[2];
|
| 2778 |
+
const int64_t nb11 = src1->nb[1];
|
| 2779 |
|
| 2780 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 2781 |
queue_ptr main_stream = ctx.stream();
|
|
|
|
| 2786 |
|
| 2787 |
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
|
| 2788 |
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
|
| 2789 |
+
const int64_t channel_stride_y = nb11 / sizeof(float);
|
| 2790 |
|
| 2791 |
+
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
|
| 2792 |
}
|
| 2793 |
catch (sycl::exception const &exc) {
|
| 2794 |
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
|
|
|
| 2842 |
float * dst_ddf = static_cast<float *>(dst->data);
|
| 2843 |
|
| 2844 |
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
| 2845 |
+
const size_t type_size_src0 = ggml_type_size(src0->type);
|
| 2846 |
const size_t type_size_src1 = ggml_type_size(src1->type);
|
|
|
|
| 2847 |
|
| 2848 |
// SRC1 strides
|
| 2849 |
int64_t s11 = nb11 / type_size_src1;
|
|
|
|
| 2855 |
if (src1->type != GGML_TYPE_F16) {
|
| 2856 |
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
|
| 2857 |
" : converting src1 to fp16");
|
| 2858 |
+
|
| 2859 |
+
// iterate tensor dims and find the slowest moving dim and stride
|
| 2860 |
+
int64_t last_dim=0;
|
| 2861 |
+
int64_t last_str=0;
|
| 2862 |
+
int64_t largest_str=0;
|
| 2863 |
+
for(int i = 0; i< 4; i++){
|
| 2864 |
+
// last stride is always the largest
|
| 2865 |
+
if(src1->nb[i] == largest_str){
|
| 2866 |
+
if(src1->ne[last_dim] == 1){
|
| 2867 |
+
last_str = i;
|
| 2868 |
+
last_dim = i;
|
| 2869 |
+
}
|
| 2870 |
+
}
|
| 2871 |
+
if(src1->nb[i] > largest_str){
|
| 2872 |
+
largest_str = src1->nb[i];
|
| 2873 |
+
last_str = i;
|
| 2874 |
+
last_dim = i;
|
| 2875 |
+
}
|
| 2876 |
+
|
| 2877 |
+
}
|
| 2878 |
+
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
|
| 2879 |
src1_f16_alloc.alloc(ne_src1);
|
| 2880 |
+
|
| 2881 |
+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
| 2882 |
+
GGML_ASSERT(to_fp16_sycl != nullptr);
|
| 2883 |
+
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
|
| 2884 |
|
| 2885 |
src1_f16 = src1_f16_alloc.get();
|
| 2886 |
s11 = ne10;
|
|
|
|
| 2914 |
|
| 2915 |
#if GGML_SYCL_DNNL
|
| 2916 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2917 |
+
int64_t str_a0 = nb00 / type_size_src0;
|
| 2918 |
+
int64_t str_a1 = nb01 / type_size_src0;
|
| 2919 |
+
int64_t str_a2 = nb02 / type_size_src0;
|
| 2920 |
+
|
| 2921 |
+
int64_t str_b0 = nb10 / type_size_src1;
|
| 2922 |
+
int64_t str_b1 = nb11 / type_size_src1;
|
| 2923 |
+
int64_t str_b2 = nb12 / type_size_src1;
|
| 2924 |
+
|
| 2925 |
+
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
|
| 2926 |
+
const sycl::half *src1, float *dst,
|
| 2927 |
+
int64_t a0, int64_t a1, int64_t batcha,
|
| 2928 |
+
int64_t b0, int64_t b1, int64_t batchb,
|
| 2929 |
+
int64_t sa0, int64_t sa1, int64_t sa2,
|
| 2930 |
+
int64_t sb0, int64_t sb1, int64_t sb2,
|
| 2931 |
+
int64_t sd2) {
|
| 2932 |
+
bool supported_broadcast = batchb == batcha ? true
|
| 2933 |
+
: batchb == 1 || batcha == 1 ? true
|
| 2934 |
+
: false;
|
| 2935 |
+
if (supported_broadcast) {
|
| 2936 |
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
|
| 2937 |
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
|
| 2938 |
+
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
|
| 2939 |
+
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
|
| 2940 |
+
} else {
|
| 2941 |
+
// iterate over batches from smaller set of matrices (matrix 0)
|
| 2942 |
+
int64_t batches0 = batcha;
|
| 2943 |
+
int64_t batches1 = batchb;
|
| 2944 |
+
|
| 2945 |
+
if (batches0 > batches1) {
|
| 2946 |
+
int64_t num_mul_mats = batches1;
|
| 2947 |
+
int64_t sub_batch = batches0 / num_mul_mats;
|
| 2948 |
+
// src0 is batched and bigger, shift and multiply with src1
|
| 2949 |
+
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
|
| 2950 |
+
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
|
| 2951 |
+
const sycl::half *src1_shifted = src1 + (sb2 * i0);
|
| 2952 |
+
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
|
| 2953 |
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
| 2954 |
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
| 2955 |
+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
| 2956 |
+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
| 2957 |
+
queue, sub_batch, 1);
|
| 2958 |
+
}
|
| 2959 |
+
} else {
|
| 2960 |
+
int64_t num_mul_mats = batches0;
|
| 2961 |
+
int64_t sub_batch = batches1 / num_mul_mats;
|
| 2962 |
+
// src1 is batched and bigger, shift and multiply with src0
|
| 2963 |
+
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
|
| 2964 |
+
const sycl::half *src0_shifted = src0 + (sa2 * i1);
|
| 2965 |
+
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
|
| 2966 |
+
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
|
| 2967 |
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
| 2968 |
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
| 2969 |
+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
| 2970 |
+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
| 2971 |
+
queue, 1, sub_batch);
|
| 2972 |
+
}
|
| 2973 |
+
}
|
| 2974 |
}
|
| 2975 |
+
};
|
| 2976 |
+
|
| 2977 |
+
bool cont_batches_a = nb02 * ne02 == nb03;
|
| 2978 |
+
bool cont_batches_b = nb12 * ne12 == nb13;
|
| 2979 |
+
if (cont_batches_a && cont_batches_b) {
|
| 2980 |
+
int64_t batches0 = ne02 * ne03;
|
| 2981 |
+
int64_t batches1 = ne12 * ne13;
|
| 2982 |
+
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
| 2983 |
+
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
|
| 2984 |
+
str_b2, nb2 / sizeof(float));
|
| 2985 |
+
} else {
|
| 2986 |
+
for (int64_t b_a = 0; b_a < ne03; b_a++) {
|
| 2987 |
+
const sycl::half *src0_f16_shifted
|
| 2988 |
+
= src0_f16 + (nb03 * b_a / type_size_src0);
|
| 2989 |
+
const sycl::half *src1_f16_shifted
|
| 2990 |
+
= src1_f16 + (nb13 * b_a / type_size_src1);
|
| 2991 |
+
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
|
| 2992 |
+
int64_t batches0 = ne02;
|
| 2993 |
+
int64_t batches1 = ne12;
|
| 2994 |
+
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
|
| 2995 |
+
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
|
| 2996 |
+
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
|
| 2997 |
}
|
| 2998 |
}
|
| 2999 |
+
|
| 3000 |
}
|
| 3001 |
else
|
| 3002 |
#endif
|
|
|
|
| 3336 |
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
| 3337 |
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
| 3338 |
}
|
| 3339 |
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
| 3340 |
// KQV single-batch
|
| 3341 |
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
| 3342 |
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
|
| 3343 |
// KQ + KQV multi-batch
|
| 3344 |
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
| 3345 |
} else if (use_dequantize_mul_mat_vec) {
|