Anton Mitkov commited on
Commit
2722bea
·
1 Parent(s): 1f97ff4

sycl: Batched mulmat rework for oneDNN dispatch (llama/14617)

Browse files
ggml/src/ggml-sycl/gemm.hpp CHANGED
@@ -32,39 +32,28 @@ public:
32
  else static_assert(0);
33
  }
34
 
35
- // matrix A has m rows, k columns
36
- // matrix B has k rows, n columns
37
- // nra - number of elements to skip when moving into next row in A
38
- // nrb - number of elements to skip when moving into next row in B
39
- // nca - number of elements to skip when moving into next column in A
40
- // ncb - number of elements to skip when moving into next column in B
41
- // stride_a - number of elements to skip when moving to next A matrix
42
- // stride_b - number of elements to skip when moving to next B matrix
43
- // batches_a - number of A matrices
44
- // batches_b - number of B matrices
45
  static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
46
- const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
47
- const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
48
  void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
49
 
50
  auto stream = ctx.stream_dnnl(q);
51
  auto eng = ctx.engine_dnnl(q);
52
 
53
- // { # strides, # rows, # columns }
54
- dnnl::memory::dims a_dims = { batches_a, m, k };
55
- dnnl::memory::dims b_dims = { batches_b, k, n };
56
- dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
57
-
58
- // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
59
- dnnl::memory::dims a_strides = { stride_a, nra, nca };
60
- dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
61
-
62
  const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
 
 
 
63
  const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
64
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
65
 
 
 
 
66
  dnnl::primitive_attr primitive_attr;
67
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
 
68
  #ifdef GGML_SYCL_F16
69
  primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
70
  #endif
@@ -76,24 +65,23 @@ public:
76
 
77
  auto scratchpad_md = matmul_pd.scratchpad_desc();
78
  auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
 
79
  auto matmul_prim = dnnl::matmul(matmul_pd);
80
 
81
  std::unordered_map<int, dnnl::memory> matmul_args;
82
  matmul_args.insert({ DNNL_ARG_SRC, a_mem });
83
  matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
 
84
  matmul_args.insert({ DNNL_ARG_DST, c_mem });
85
  matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
86
 
87
  matmul_prim.execute(stream, matmul_args);
88
  }
89
 
90
- // matrices A and B are column major, both having k rows
91
- // matrix A has m column, matrix B has n columns
92
- // output: column major matrix C = A transposed * B
93
  static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
94
  const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
95
 
96
- gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
97
  }
98
  };
99
 
 
32
  else static_assert(0);
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
35
  static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
36
+ const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
37
+ const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
38
  void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
39
 
40
  auto stream = ctx.stream_dnnl(q);
41
  auto eng = ctx.engine_dnnl(q);
42
 
43
+ dnnl::memory::dims a_dims = {batches_a, m, k };
44
+ dnnl::memory::dims a_strides = {stra2, stra1, stra0};
 
 
 
 
 
 
 
45
  const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
46
+
47
+ dnnl::memory::dims b_dims = {batches_b, k, n };
48
+ dnnl::memory::dims b_strides = {strb2, strb0, strb1};
49
  const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
 
50
 
51
+ dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
52
+ dnnl::memory::dims c_strides = {m*n, 1, m };
53
+ const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
54
  dnnl::primitive_attr primitive_attr;
55
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
56
+
57
  #ifdef GGML_SYCL_F16
58
  primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
59
  #endif
 
65
 
66
  auto scratchpad_md = matmul_pd.scratchpad_desc();
67
  auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
68
+
69
  auto matmul_prim = dnnl::matmul(matmul_pd);
70
 
71
  std::unordered_map<int, dnnl::memory> matmul_args;
72
  matmul_args.insert({ DNNL_ARG_SRC, a_mem });
73
  matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
74
+
75
  matmul_args.insert({ DNNL_ARG_DST, c_mem });
76
  matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
77
 
78
  matmul_prim.execute(stream, matmul_args);
79
  }
80
 
 
 
 
81
  static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
82
  const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
83
 
84
+ gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
85
  }
86
  };
87
 
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -1546,7 +1546,7 @@ static void mul_mat_p021_f16_f32(
1546
 
1547
  static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1548
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1549
- const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
1550
  const sycl::nd_item<3> &item_ct1) {
1551
 
1552
  const sycl::half *x = (const sycl::half *)vx;
@@ -1557,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1557
  item_ct1.get_local_id(0);
1558
  const int channel_x = channel / channel_x_divisor;
1559
 
1560
- const int nrows_y = ncols_x;
1561
  const int nrows_dst = nrows_x;
1562
  const int row_dst = row_x;
1563
 
@@ -1576,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1576
  const int row_y = col_x;
1577
 
1578
  const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1579
- const int iy = channel*nrows_y + row_y;
1580
 
1581
  const float xi =
1582
  sycl::vec<sycl::half, 1>(x[ix])
@@ -1823,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1823
  static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1824
  const void *vx, const float *y, float *dst, const int ncols_x,
1825
  const int nrows_x, const int row_stride_x, const int nchannels_x,
1826
- const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
1827
 
1828
  const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1829
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -1835,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1835
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1836
  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1837
  mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1838
- row_stride_x, channel_stride_x,
1839
  nchannels_y / nchannels_x, item_ct1);
1840
  });
1841
  }
@@ -2124,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2124
 
2125
  #if GGML_SYCL_DNNL
2126
  if (!g_ggml_sycl_disable_dnn) {
2127
- DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2128
- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2129
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2130
  }
2131
  else
@@ -2171,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2171
 
2172
  #if GGML_SYCL_DNNL
2173
  if (!g_ggml_sycl_disable_dnn) {
2174
- DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2175
- DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2176
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2177
  }
2178
  else
@@ -2776,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2776
  const int64_t nb02 = src0->nb[2];
2777
 
2778
  const int64_t ne12 = src1->ne[2];
 
2779
 
2780
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2781
  queue_ptr main_stream = ctx.stream();
@@ -2786,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2786
 
2787
  const int64_t row_stride_x = nb01 / sizeof(sycl::half);
2788
  const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
 
2789
 
2790
- ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
2791
  }
2792
  catch (sycl::exception const &exc) {
2793
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -2841,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2841
  float * dst_ddf = static_cast<float *>(dst->data);
2842
 
2843
  const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
 
2844
  const size_t type_size_src1 = ggml_type_size(src1->type);
2845
- GGML_ASSERT(nb10 == type_size_src1);
2846
 
2847
  // SRC1 strides
2848
  int64_t s11 = nb11 / type_size_src1;
@@ -2854,11 +2855,32 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2854
  if (src1->type != GGML_TYPE_F16) {
2855
  scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2856
  " : converting src1 to fp16");
2857
- const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2858
- GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2859
- const int64_t ne_src1 = ggml_nelements(src1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2860
  src1_f16_alloc.alloc(ne_src1);
2861
- to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
 
 
 
2862
 
2863
  src1_f16 = src1_f16_alloc.get();
2864
  s11 = ne10;
@@ -2892,38 +2914,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2892
 
2893
  #if GGML_SYCL_DNNL
2894
  if (!g_ggml_sycl_disable_dnn) {
2895
- auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2896
- (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2897
-
2898
- DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2899
- src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2900
- src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2901
- dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2902
- };
2903
-
2904
- if (r2 == 1 && r3 == 1) {
2905
- if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2906
- dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2907
- }
2908
- else {
2909
- for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2910
- const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2911
- const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2912
- float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2913
- dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2914
  }
2915
- }
2916
- } else {
2917
- // iterate over batches from smaller set of matrices (matrix 0)
2918
- for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2919
- for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2920
- const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2921
- const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2922
- float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2923
- dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
 
 
 
 
 
 
 
 
 
 
 
 
 
2924
  }
2925
  }
2926
- }
2927
  }
2928
  else
2929
  #endif
@@ -3263,10 +3336,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3263
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3264
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3265
  }
3266
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3267
  // KQV single-batch
3268
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3269
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3270
  // KQ + KQV multi-batch
3271
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3272
  } else if (use_dequantize_mul_mat_vec) {
 
1546
 
1547
  static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1548
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1549
+ const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
1550
  const sycl::nd_item<3> &item_ct1) {
1551
 
1552
  const sycl::half *x = (const sycl::half *)vx;
 
1557
  item_ct1.get_local_id(0);
1558
  const int channel_x = channel / channel_x_divisor;
1559
 
 
1560
  const int nrows_dst = nrows_x;
1561
  const int row_dst = row_x;
1562
 
 
1575
  const int row_y = col_x;
1576
 
1577
  const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1578
+ const int iy = channel * channel_stride_y + row_y;
1579
 
1580
  const float xi =
1581
  sycl::vec<sycl::half, 1>(x[ix])
 
1822
  static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1823
  const void *vx, const float *y, float *dst, const int ncols_x,
1824
  const int nrows_x, const int row_stride_x, const int nchannels_x,
1825
+ const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
1826
 
1827
  const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1828
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
 
1834
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1835
  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1836
  mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1837
+ row_stride_x, channel_stride_x, channel_stride_y,
1838
  nchannels_y / nchannels_x, item_ct1);
1839
  });
1840
  }
 
2123
 
2124
  #if GGML_SYCL_DNNL
2125
  if (!g_ggml_sycl_disable_dnn) {
2126
+ DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
2127
+ DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2128
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2129
  }
2130
  else
 
2170
 
2171
  #if GGML_SYCL_DNNL
2172
  if (!g_ggml_sycl_disable_dnn) {
2173
+ DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
2174
+ DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2175
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2176
  }
2177
  else
 
2775
  const int64_t nb02 = src0->nb[2];
2776
 
2777
  const int64_t ne12 = src1->ne[2];
2778
+ const int64_t nb11 = src1->nb[1];
2779
 
2780
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2781
  queue_ptr main_stream = ctx.stream();
 
2786
 
2787
  const int64_t row_stride_x = nb01 / sizeof(sycl::half);
2788
  const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
2789
+ const int64_t channel_stride_y = nb11 / sizeof(float);
2790
 
2791
+ ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
2792
  }
2793
  catch (sycl::exception const &exc) {
2794
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 
2842
  float * dst_ddf = static_cast<float *>(dst->data);
2843
 
2844
  const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2845
+ const size_t type_size_src0 = ggml_type_size(src0->type);
2846
  const size_t type_size_src1 = ggml_type_size(src1->type);
 
2847
 
2848
  // SRC1 strides
2849
  int64_t s11 = nb11 / type_size_src1;
 
2855
  if (src1->type != GGML_TYPE_F16) {
2856
  scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2857
  " : converting src1 to fp16");
2858
+
2859
+ // iterate tensor dims and find the slowest moving dim and stride
2860
+ int64_t last_dim=0;
2861
+ int64_t last_str=0;
2862
+ int64_t largest_str=0;
2863
+ for(int i = 0; i< 4; i++){
2864
+ // last stride is always the largest
2865
+ if(src1->nb[i] == largest_str){
2866
+ if(src1->ne[last_dim] == 1){
2867
+ last_str = i;
2868
+ last_dim = i;
2869
+ }
2870
+ }
2871
+ if(src1->nb[i] > largest_str){
2872
+ largest_str = src1->nb[i];
2873
+ last_str = i;
2874
+ last_dim = i;
2875
+ }
2876
+
2877
+ }
2878
+ const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
2879
  src1_f16_alloc.alloc(ne_src1);
2880
+
2881
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2882
+ GGML_ASSERT(to_fp16_sycl != nullptr);
2883
+ to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
2884
 
2885
  src1_f16 = src1_f16_alloc.get();
2886
  s11 = ne10;
 
2914
 
2915
  #if GGML_SYCL_DNNL
2916
  if (!g_ggml_sycl_disable_dnn) {
2917
+ int64_t str_a0 = nb00 / type_size_src0;
2918
+ int64_t str_a1 = nb01 / type_size_src0;
2919
+ int64_t str_a2 = nb02 / type_size_src0;
2920
+
2921
+ int64_t str_b0 = nb10 / type_size_src1;
2922
+ int64_t str_b1 = nb11 / type_size_src1;
2923
+ int64_t str_b2 = nb12 / type_size_src1;
2924
+
2925
+ auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
2926
+ const sycl::half *src1, float *dst,
2927
+ int64_t a0, int64_t a1, int64_t batcha,
2928
+ int64_t b0, int64_t b1, int64_t batchb,
2929
+ int64_t sa0, int64_t sa1, int64_t sa2,
2930
+ int64_t sb0, int64_t sb1, int64_t sb2,
2931
+ int64_t sd2) {
2932
+ bool supported_broadcast = batchb == batcha ? true
2933
+ : batchb == 1 || batcha == 1 ? true
2934
+ : false;
2935
+ if (supported_broadcast) {
2936
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
2937
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
2938
+ DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
2939
+ DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
2940
+ } else {
2941
+ // iterate over batches from smaller set of matrices (matrix 0)
2942
+ int64_t batches0 = batcha;
2943
+ int64_t batches1 = batchb;
2944
+
2945
+ if (batches0 > batches1) {
2946
+ int64_t num_mul_mats = batches1;
2947
+ int64_t sub_batch = batches0 / num_mul_mats;
2948
+ // src0 is batched and bigger, shift and multiply with src1
2949
+ for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
2950
+ const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
2951
+ const sycl::half *src1_shifted = src1 + (sb2 * i0);
2952
+ float *dst_shifted = dst + (sd2 * i0 * sub_batch);
2953
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2954
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2955
+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2956
+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2957
+ queue, sub_batch, 1);
2958
+ }
2959
+ } else {
2960
+ int64_t num_mul_mats = batches0;
2961
+ int64_t sub_batch = batches1 / num_mul_mats;
2962
+ // src1 is batched and bigger, shift and multiply with src0
2963
+ for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
2964
+ const sycl::half *src0_shifted = src0 + (sa2 * i1);
2965
+ const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
2966
+ float *dst_shifted = dst + (sd2 * i1 * sub_batch);
2967
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2968
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2969
+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2970
+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2971
+ queue, 1, sub_batch);
2972
+ }
2973
+ }
2974
  }
2975
+ };
2976
+
2977
+ bool cont_batches_a = nb02 * ne02 == nb03;
2978
+ bool cont_batches_b = nb12 * ne12 == nb13;
2979
+ if (cont_batches_a && cont_batches_b) {
2980
+ int64_t batches0 = ne02 * ne03;
2981
+ int64_t batches1 = ne12 * ne13;
2982
+ launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2983
+ ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
2984
+ str_b2, nb2 / sizeof(float));
2985
+ } else {
2986
+ for (int64_t b_a = 0; b_a < ne03; b_a++) {
2987
+ const sycl::half *src0_f16_shifted
2988
+ = src0_f16 + (nb03 * b_a / type_size_src0);
2989
+ const sycl::half *src1_f16_shifted
2990
+ = src1_f16 + (nb13 * b_a / type_size_src1);
2991
+ float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
2992
+ int64_t batches0 = ne02;
2993
+ int64_t batches1 = ne12;
2994
+ launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
2995
+ ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
2996
+ str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
2997
  }
2998
  }
2999
+
3000
  }
3001
  else
3002
  #endif
 
3336
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3337
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3338
  }
3339
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3340
  // KQV single-batch
3341
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3342
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
3343
  // KQ + KQV multi-batch
3344
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3345
  } else if (use_dequantize_mul_mat_vec) {