ggerganov commited on
Commit
d51c0d3
·
1 Parent(s): a245fbf

metal : optimize MoE for large batches (llama/13388)

Browse files
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -299,21 +299,42 @@ typedef struct {
299
  } ggml_metal_kargs_mul_mv_ext;
300
 
301
  typedef struct {
302
- int32_t nei0;
303
- int32_t nei1;
304
- uint64_t nbi1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  int32_t ne00;
306
  int32_t ne02;
307
  uint64_t nb01;
308
  uint64_t nb02;
309
- int32_t ne11;
310
- int32_t ne12;
311
- int32_t ne13;
312
- uint64_t nb10;
313
- uint64_t nb11;
314
- uint64_t nb12;
315
- int32_t ne0;
316
- int32_t ne1;
 
 
317
  } ggml_metal_kargs_mul_mm_id;
318
 
319
  typedef struct {
 
299
  } ggml_metal_kargs_mul_mv_ext;
300
 
301
  typedef struct {
302
+ int32_t ne10;
303
+ int32_t ne11; // n_expert_used (bcast)
304
+ uint64_t nb11;
305
+ uint64_t nb12;
306
+ int32_t neh11; // n_tokens
307
+ uint64_t nbh11;
308
+ int32_t ne20; // n_expert_used
309
+ uint64_t nb21;
310
+ } ggml_metal_kargs_mul_mm_id_map0;
311
+
312
+ typedef struct {
313
+ int32_t ne20; // n_expert_used
314
+ int32_t neh0;
315
+ int32_t neh1;
316
+ uint64_t nbh1;
317
+ uint64_t nbh2;
318
+ int32_t ne0;
319
+ uint64_t nb1;
320
+ uint64_t nb2;
321
+ } ggml_metal_kargs_mul_mm_id_map1;
322
+
323
+ typedef struct {
324
  int32_t ne00;
325
  int32_t ne02;
326
  uint64_t nb01;
327
  uint64_t nb02;
328
+ uint64_t nb03;
329
+ int32_t neh12;
330
+ uint64_t nbh10;
331
+ uint64_t nbh11;
332
+ uint64_t nbh12;
333
+ uint64_t nbh13;
334
+ int32_t neh0;
335
+ int32_t neh1;
336
+ int16_t r2;
337
+ int16_t r3;
338
  } ggml_metal_kargs_mul_mm_id;
339
 
340
  typedef struct {
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
44
  // note: assumes single GPU device - the default one
45
  // TODO: support multiple GPU devices
46
  static struct ggml_backend_metal_device_context {
47
- id<MTLDevice> mtl_device;
48
- int mtl_device_ref_count;
49
  id<MTLLibrary> mtl_library;
50
 
51
  bool has_simdgroup_reduction;
@@ -306,28 +306,30 @@ enum ggml_metal_kernel_type {
306
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
307
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
308
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
309
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
310
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
311
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
312
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
313
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
314
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
315
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
316
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
317
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
318
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
319
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
320
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
321
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
322
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
323
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
324
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
325
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
326
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
327
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
328
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
329
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
330
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
 
 
331
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
332
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
333
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
@@ -490,7 +492,264 @@ enum ggml_metal_kernel_type {
490
  GGML_METAL_KERNEL_TYPE_COUNT
491
  };
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  struct ggml_backend_metal_context {
 
494
  id<MTLCommandQueue> queue;
495
 
496
  dispatch_queue_t d_queue;
@@ -515,7 +774,7 @@ struct ggml_backend_metal_context {
515
  void (^encode_async)(size_t ith);
516
 
517
  // n_cb command buffers + 1 used by the main thread
518
- id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
519
 
520
  // abort ggml_metal_graph_compute if callback returns true
521
  ggml_abort_callback abort_callback;
@@ -705,9 +964,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
705
  struct ggml_backend_metal_device_context * ctx_dev = dev->context;
706
 
707
  id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
 
708
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
709
 
710
- ctx->queue = [device newCommandQueue];
 
711
  if (ctx->queue == nil) {
712
  GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
713
  return NULL;
@@ -768,7 +1029,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
768
  ctx->gf = nil;
769
  ctx->encode_async = nil;
770
  for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
771
- ctx->command_buffers[i] = nil;
 
 
 
772
  }
773
 
774
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -985,28 +1249,30 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
985
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
986
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
987
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
988
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
989
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
990
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
991
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
992
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
993
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
994
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
995
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
996
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
997
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
998
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
999
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
1000
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
1001
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
1002
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
1003
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
1004
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
1005
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
1006
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
1007
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
1008
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
1009
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
 
 
1010
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
1011
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1012
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
@@ -1181,6 +1447,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
1181
 
1182
  [ctx->queue release];
1183
 
 
 
 
 
 
 
1184
  dispatch_release(ctx->d_queue);
1185
 
1186
  free(ctx);
@@ -1486,10 +1758,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1486
  }
1487
  }
1488
 
1489
- static void ggml_metal_encode_node(
1490
  ggml_backend_t backend,
1491
  int idx,
1492
- id<MTLComputeCommandEncoder> encoder) {
 
1493
  struct ggml_backend_metal_context * ctx = backend->context;
1494
  struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
1495
 
@@ -1505,7 +1778,7 @@ static void ggml_metal_encode_node(
1505
  struct ggml_tensor * dst = node;
1506
 
1507
  if (ggml_is_empty(dst)) {
1508
- return;
1509
  }
1510
 
1511
  switch (dst->op) {
@@ -1516,7 +1789,7 @@ static void ggml_metal_encode_node(
1516
  case GGML_OP_PERMUTE:
1517
  {
1518
  // noop -> next node
1519
- } return;
1520
  default:
1521
  {
1522
  } break;
@@ -1527,6 +1800,8 @@ static void ggml_metal_encode_node(
1527
  GGML_ABORT("unsupported op");
1528
  }
1529
 
 
 
1530
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
1531
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
1532
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -2173,26 +2448,76 @@ static void ggml_metal_encode_node(
2173
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2174
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2175
 
2176
- ggml_metal_kargs_soft_max args = {
 
 
 
 
 
 
 
 
 
 
 
 
2177
  /*.ne00 =*/ ne00,
2178
  /*.ne01 =*/ ne01,
2179
  /*.ne02 =*/ ne02,
2180
- /*.scale =*/ scale,
2181
- /*.max_bias =*/ max_bias,
2182
- /*.m0 =*/ m0,
2183
- /*.m1 =*/ m1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2184
  /*.n_head_log2 =*/ n_head_log2,
2185
  };
2186
 
2187
  [encoder setComputePipelineState:pipeline];
2188
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2189
  if (id_src1) {
2190
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2191
  } else {
2192
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2193
  }
2194
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2195
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
2196
 
2197
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2198
 
@@ -2683,7 +3008,7 @@ static void ggml_metal_encode_node(
2683
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2684
 
2685
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2686
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2687
  } else {
2688
  id<MTLComputePipelineState> pipeline = nil;
2689
 
@@ -2903,8 +3228,6 @@ static void ggml_metal_encode_node(
2903
  } break;
2904
  case GGML_OP_MUL_MAT_ID:
2905
  {
2906
- const int n_as = src0->ne[2];
2907
-
2908
  // src2 = ids
2909
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
2910
 
@@ -2918,24 +3241,21 @@ static void ggml_metal_encode_node(
2918
  GGML_ASSERT(ne03 == 1);
2919
  GGML_ASSERT(ne13 == 1);
2920
 
 
 
 
2921
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2922
  // to the matrix-vector kernel
2923
  // ne20 = n_used_experts
2924
- // ne21 = n_rows
2925
- const int dst_rows = ne20*ne21;
2926
- const int dst_rows_min = n_as;
2927
- const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
2928
-
2929
- // max size of the rowids array in the kernel shared buffer
2930
- //GGML_ASSERT(dst_rows <= dst_rows_max);
2931
 
2932
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
2933
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
2934
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
2935
  ne00 % 32 == 0 && ne00 >= 64 &&
2936
- //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
2937
- dst_rows > dst_rows_min &&
2938
- dst_rows <= dst_rows_max) {
2939
 
2940
  // some Metal matrix data types require aligned pointers
2941
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -2946,62 +3266,169 @@ static void ggml_metal_encode_node(
2946
  default: break;
2947
  }
2948
 
2949
- id<MTLComputePipelineState> pipeline = nil;
 
 
2950
 
2951
- switch (src0->type) {
2952
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
2953
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
2954
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
2955
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
2956
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
2957
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
2958
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
2959
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
2960
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
2961
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
2962
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
2963
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
2964
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
2965
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
2966
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
2967
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
2968
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
2969
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
2970
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
2971
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
2972
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
2973
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
2974
- default: GGML_ABORT("MUL_MAT_ID not implemented");
2975
  }
2976
 
2977
- ggml_metal_kargs_mul_mm_id args = {
2978
- /*.nei0 =*/ ne20,
2979
- /*.nei1 =*/ ne21,
2980
- /*.nbi1 =*/ nb21,
2981
- /*.ne00 =*/ ne00,
2982
- /*.ne02 =*/ ne02,
2983
- /*.nb01 =*/ nb01,
2984
- /*.nb02 =*/ nb02,
2985
- /*.ne11 =*/ ne11,
2986
- /*.ne12 =*/ ne12,
2987
- /*.ne13 =*/ ne13,
2988
- /*.nb10 =*/ nb10,
2989
- /*.nb11 =*/ nb11,
2990
- /*.nb12 =*/ nb12,
2991
- /*.ne0 =*/ ne0,
2992
- /*.ne1 =*/ ne1,
2993
- };
2994
 
2995
- [encoder setComputePipelineState:pipeline];
2996
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
2997
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2998
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2999
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3000
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
 
 
 
 
 
3001
 
3002
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003
 
3004
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3005
  } else {
3006
  id<MTLComputePipelineState> pipeline = nil;
3007
 
@@ -3195,7 +3622,7 @@ static void ggml_metal_encode_node(
3195
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3196
 
3197
  const int64_t _ne1 = 1;
3198
- const int64_t ne123 = dst_rows;
3199
 
3200
  if (smem > 0) {
3201
  [encoder setThreadgroupMemoryLength:smem atIndex:0];
@@ -4601,6 +5028,8 @@ static void ggml_metal_encode_node(
4601
  GGML_ABORT("fatal error");
4602
  }
4603
  }
 
 
4604
  }
4605
 
4606
  static enum ggml_status ggml_metal_graph_compute(
@@ -4654,25 +5083,25 @@ static enum ggml_status ggml_metal_graph_compute(
4654
  }
4655
 
4656
  // the main thread commits the first few commands immediately
4657
- // command_buffer[n_cb]
4658
  {
4659
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4660
- ctx->command_buffers[n_cb] = command_buffer;
4661
 
4662
- [command_buffer enqueue];
4663
  ctx->encode_async(n_cb);
4664
  }
4665
 
4666
  // prepare the rest of the command buffers asynchronously
4667
- // command_buffer[0.. n_cb)
4668
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
4669
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4670
- ctx->command_buffers[cb_idx] = command_buffer;
4671
 
4672
  // always enqueue the first two command buffers
4673
  // enqueue all of the command buffers if we don't need to abort
4674
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
4675
- [command_buffer enqueue];
4676
  }
4677
  }
4678
 
@@ -4681,14 +5110,14 @@ static enum ggml_status ggml_metal_graph_compute(
4681
  // wait for completion and check status of each command buffer
4682
  // needed to detect if the device ran out-of-memory for example (#1881)
4683
  {
4684
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
4685
- [command_buffer waitUntilCompleted];
4686
 
4687
- MTLCommandBufferStatus status = [command_buffer status];
4688
  if (status != MTLCommandBufferStatusCompleted) {
4689
  GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
4690
  if (status == MTLCommandBufferStatusError) {
4691
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
4692
  }
4693
 
4694
  return GGML_STATUS_FAILED;
@@ -4696,20 +5125,20 @@ static enum ggml_status ggml_metal_graph_compute(
4696
  }
4697
 
4698
  for (int i = 0; i < n_cb; ++i) {
4699
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
4700
- [command_buffer waitUntilCompleted];
4701
 
4702
- MTLCommandBufferStatus status = [command_buffer status];
4703
  if (status != MTLCommandBufferStatusCompleted) {
4704
  GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
4705
  if (status == MTLCommandBufferStatusError) {
4706
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
4707
  }
4708
 
4709
  return GGML_STATUS_FAILED;
4710
  }
4711
 
4712
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
4713
  if (!next_buffer) {
4714
  continue;
4715
  }
@@ -5092,8 +5521,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5092
 
5093
  const int n_nodes_per_cb = ctx->n_nodes_per_cb;
5094
 
5095
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
5096
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
 
5097
 
5098
  int node_start = 0;
5099
  int node_end = n_nodes_0;
@@ -5105,22 +5535,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5105
 
5106
  const bool should_capture = ctx->capture_next_compute;
5107
 
 
 
 
5108
  for (int idx = node_start; idx < node_end; ++idx) {
5109
  if (should_capture) {
5110
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5111
  }
5112
 
5113
- ggml_metal_encode_node(backend, idx, encoder);
5114
 
5115
  if (should_capture) {
5116
  [encoder popDebugGroup];
5117
  }
 
 
 
 
5118
  }
5119
 
5120
  [encoder endEncoding];
5121
 
5122
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
5123
- [command_buffer commit];
5124
  }
5125
  });
5126
  }
 
44
  // note: assumes single GPU device - the default one
45
  // TODO: support multiple GPU devices
46
  static struct ggml_backend_metal_device_context {
47
+ id<MTLDevice> mtl_device;
48
+ int mtl_device_ref_count;
49
  id<MTLLibrary> mtl_library;
50
 
51
  bool has_simdgroup_reduction;
 
306
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
307
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
308
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
309
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
310
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
311
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
312
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
313
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
314
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
315
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
316
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
317
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
318
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
319
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
320
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
321
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
322
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
323
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
324
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
325
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
326
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
327
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
328
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
329
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
330
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
331
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
332
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
333
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
334
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
335
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
 
492
  GGML_METAL_KERNEL_TYPE_COUNT
493
  };
494
 
495
+ //
496
+ // ggml_metal_heap
497
+ //
498
+
499
+ struct ggml_metal_heap {
500
+ // number of times the heap was unused
501
+ int n_unused;
502
+
503
+ // total number of buffer allocations in this heap across all computes
504
+ int64_t n_alloc;
505
+
506
+ // current offset in the heap - we reset this after each node in order to reuse the memory
507
+ size_t offs;
508
+
509
+ // the currently allocated MTLBuffer objects in this heap
510
+ id<MTLHeap> obj;
511
+
512
+ NSMutableArray * bufs;
513
+ };
514
+
515
+ static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
516
+ struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
517
+
518
+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
519
+ desc.storageMode = MTLStorageModePrivate;
520
+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
521
+ desc.type = MTLHeapTypePlacement;
522
+ desc.size = size;
523
+
524
+ heap->n_unused = 0;
525
+ heap->n_alloc = 0;
526
+
527
+ heap->obj = [device newHeapWithDescriptor:desc];
528
+ if (!heap->obj) {
529
+ GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
530
+
531
+ free(heap);
532
+
533
+ return false;
534
+ }
535
+
536
+ [desc release];
537
+
538
+ heap->bufs = [[NSMutableArray alloc] init];
539
+
540
+ return heap;
541
+ }
542
+
543
+ static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
544
+ heap->offs = 0;
545
+
546
+ // count how many graph computes the heap ended up being unused
547
+ if ([heap->bufs count] > 0) {
548
+ heap->n_unused = 0;
549
+ } else {
550
+ heap->n_unused++;
551
+ }
552
+
553
+ for (id<MTLBuffer> buf in heap->bufs) {
554
+ [buf release];
555
+ }
556
+ [heap->bufs removeAllObjects];
557
+
558
+ // tell the OS that it can reuse this memory if needed
559
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
560
+ [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
561
+ }
562
+
563
+ static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
564
+ if (heap == nil) {
565
+ return;
566
+ }
567
+
568
+ ggml_metal_heap_reset(heap);
569
+
570
+ [heap->obj release];
571
+ [heap->bufs release];
572
+
573
+ free(heap);
574
+ }
575
+
576
+ @interface ggml_metal_heap_ptr : NSObject
577
+
578
+ @property (nonatomic, assign) struct ggml_metal_heap * data;
579
+
580
+ @end
581
+
582
+ @implementation ggml_metal_heap_ptr
583
+ @end
584
+
585
+ //
586
+ // ggml_metal_mem_pool
587
+ //
588
+
589
+ struct ggml_metal_mem_pool {
590
+ id<MTLDevice> device;
591
+
592
+ int n_heaps; // total number of heaps ever created (including those that were removed)
593
+
594
+ NSMutableArray * heaps;
595
+ NSMutableArray * heaps_to_remove;
596
+ };
597
+
598
+ static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
599
+ struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
600
+
601
+ mem_pool->n_heaps = 0;
602
+
603
+ mem_pool->heaps = [[NSMutableArray alloc] init];
604
+ mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
605
+
606
+ return mem_pool;
607
+ }
608
+
609
+ static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
610
+ GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
611
+
612
+ size_t size_all = 0;
613
+ size_t size_cur = 0;
614
+
615
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
616
+ GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
617
+ GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
618
+ GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
619
+ GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
620
+ GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
621
+
622
+ if ([ptr.data->bufs count] > 0) {
623
+ size_cur += [ptr.data->obj size];
624
+ }
625
+ size_all += [ptr.data->obj size];
626
+
627
+ ggml_metal_heap_free(ptr.data);
628
+ [ptr release];
629
+ }
630
+ [mem_pool->heaps release];
631
+ [mem_pool->heaps_to_remove release];
632
+
633
+ if (size_all > 0) {
634
+ GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
635
+ GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
636
+ }
637
+
638
+ free(mem_pool);
639
+ }
640
+
641
+ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
642
+ for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
643
+ ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
644
+
645
+ struct ggml_metal_heap * heap = ptr.data;
646
+ ggml_metal_heap_reset(heap);
647
+
648
+ // if the heap hasn't been used for a while, remove it
649
+ if (heap->n_unused >= 128) {
650
+ [mem_pool->heaps_to_remove addObject:@(i)];
651
+ }
652
+ }
653
+
654
+ if (mem_pool->heaps_to_remove.count > 0) {
655
+ // remove in reverse order
656
+ for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
657
+ NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
658
+ ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
659
+
660
+ struct ggml_metal_heap * heap = ptr.data;
661
+ ggml_metal_heap_free(heap);
662
+
663
+ [mem_pool->heaps removeObjectAtIndex:index];
664
+ [ptr release];
665
+
666
+ if (i == 0) {
667
+ break;
668
+ }
669
+ }
670
+
671
+ [mem_pool->heaps_to_remove removeAllObjects];
672
+ }
673
+ }
674
+
675
+ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
676
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
677
+ ptr.data->offs = 0;
678
+ }
679
+ }
680
+
681
+ static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
682
+ const size_t alignment = 256;
683
+
684
+ const size_t size_aligned = GGML_PAD(size, alignment);
685
+
686
+ // try one of the existing heaps
687
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
688
+ struct ggml_metal_heap * heap = ptr.data;
689
+ if (heap->offs + size_aligned <= [heap->obj size]) {
690
+ // if this is the first buffer in the heap for the current command buffer, tell the OS that
691
+ // it cannot free the memory used by the heap
692
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
693
+ if ([heap->bufs count] == 0) {
694
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
695
+ }
696
+
697
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
698
+ if (buf == nil) {
699
+ GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
700
+ return nil;
701
+ }
702
+
703
+ heap->n_alloc++;
704
+ heap->offs += size_aligned;
705
+
706
+ [heap->bufs addObject:buf];
707
+
708
+ return buf;
709
+ }
710
+ }
711
+
712
+ // create a new heap that can fit this buffer
713
+ ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
714
+
715
+ struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
716
+ if (heap == NULL) {
717
+ GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
718
+ return NULL;
719
+ }
720
+
721
+ //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
722
+
723
+ heap_ptr.data = heap;
724
+ ggml_metal_heap_reset(heap);
725
+
726
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
727
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
728
+ if (buf == nil) {
729
+ GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
730
+ return NULL;
731
+ }
732
+
733
+ heap->n_alloc++;
734
+ heap->offs += size_aligned;
735
+
736
+ [heap->bufs addObject:buf];
737
+
738
+ [mem_pool->heaps addObject:heap_ptr];
739
+ mem_pool->n_heaps++;
740
+
741
+ return buf;
742
+ }
743
+
744
+ struct ggml_metal_command_buffer {
745
+ id<MTLCommandBuffer> obj;
746
+
747
+ // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
748
+ struct ggml_metal_mem_pool * mem_pool;
749
+ };
750
+
751
  struct ggml_backend_metal_context {
752
+ id<MTLDevice> device;
753
  id<MTLCommandQueue> queue;
754
 
755
  dispatch_queue_t d_queue;
 
774
  void (^encode_async)(size_t ith);
775
 
776
  // n_cb command buffers + 1 used by the main thread
777
+ struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
778
 
779
  // abort ggml_metal_graph_compute if callback returns true
780
  ggml_abort_callback abort_callback;
 
964
  struct ggml_backend_metal_device_context * ctx_dev = dev->context;
965
 
966
  id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
967
+
968
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
969
 
970
+ ctx->device = device;
971
+ ctx->queue = [device newCommandQueue];
972
  if (ctx->queue == nil) {
973
  GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
974
  return NULL;
 
1029
  ctx->gf = nil;
1030
  ctx->encode_async = nil;
1031
  for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1032
+ ctx->cmd_bufs[i].obj = nil;
1033
+
1034
+ ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
1035
+ ctx->cmd_bufs[i].mem_pool->device = device;
1036
  }
1037
 
1038
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
 
1249
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
1250
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
1251
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1252
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1253
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
1254
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
1255
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
1256
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
1257
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
1258
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
1259
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
1260
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
1261
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
1262
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
1263
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
1264
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
1265
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
1266
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
1267
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
1268
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
1269
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
1270
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
1271
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
1272
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
1273
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
1274
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
1275
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
1276
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
1277
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1278
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
 
1447
 
1448
  [ctx->queue release];
1449
 
1450
+ for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1451
+ // ctx->cmd_bufs[i].obj is auto released
1452
+
1453
+ ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
1454
+ }
1455
+
1456
  dispatch_release(ctx->d_queue);
1457
 
1458
  free(ctx);
 
1758
  }
1759
  }
1760
 
1761
+ static bool ggml_metal_encode_node(
1762
  ggml_backend_t backend,
1763
  int idx,
1764
+ id<MTLComputeCommandEncoder> encoder,
1765
+ struct ggml_metal_mem_pool * mem_pool) {
1766
  struct ggml_backend_metal_context * ctx = backend->context;
1767
  struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
1768
 
 
1778
  struct ggml_tensor * dst = node;
1779
 
1780
  if (ggml_is_empty(dst)) {
1781
+ return true;
1782
  }
1783
 
1784
  switch (dst->op) {
 
1789
  case GGML_OP_PERMUTE:
1790
  {
1791
  // noop -> next node
1792
+ } return true;
1793
  default:
1794
  {
1795
  } break;
 
1800
  GGML_ABORT("unsupported op");
1801
  }
1802
 
1803
+ ggml_metal_mem_pool_clear(mem_pool);
1804
+
1805
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
1806
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
1807
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
 
2448
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2449
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2450
 
2451
+ // use this branch to test the ggml_metal_mem_pool functionality
2452
+ #if 0
2453
+ // cpy to tmp buffer in MTLHeap
2454
+
2455
+ id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2456
+ if (!h_src0) {
2457
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2458
+ return false;
2459
+ }
2460
+
2461
+ offs_src0 = 0;
2462
+
2463
+ ggml_metal_kargs_cpy args_cpy = {
2464
  /*.ne00 =*/ ne00,
2465
  /*.ne01 =*/ ne01,
2466
  /*.ne02 =*/ ne02,
2467
+ /*.ne03 =*/ ne03,
2468
+ /*.nb00 =*/ nb00,
2469
+ /*.nb01 =*/ nb01,
2470
+ /*.nb02 =*/ nb02,
2471
+ /*.nb03 =*/ nb03,
2472
+ /*.ne0 =*/ ne00,
2473
+ /*.ne1 =*/ ne01,
2474
+ /*.ne2 =*/ ne02,
2475
+ /*.ne3 =*/ ne03,
2476
+ /*.nb0 =*/ nb00,
2477
+ /*.nb1 =*/ nb01,
2478
+ /*.nb2 =*/ nb02,
2479
+ /*.nb3 =*/ nb03,
2480
+ };
2481
+
2482
+ if (src0->type == GGML_TYPE_F16) {
2483
+ [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2484
+ } else {
2485
+ [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2486
+ }
2487
+ [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
2488
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2489
+ [encoder setBuffer:h_src0 offset:0 atIndex:2];
2490
+
2491
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2492
+ int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
2493
+
2494
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
2495
+
2496
+ #else
2497
+ id<MTLBuffer> h_src0 = id_src0;
2498
+ #endif
2499
+ // softmax
2500
+
2501
+ ggml_metal_kargs_soft_max args = {
2502
+ /*.ne00 =*/ ne00,
2503
+ /*.ne01 =*/ ne01,
2504
+ /*.ne02 =*/ ne02,
2505
+ /*.scale =*/ scale,
2506
+ /*.max_bias =*/ max_bias,
2507
+ /*.m0 =*/ m0,
2508
+ /*.m1 =*/ m1,
2509
  /*.n_head_log2 =*/ n_head_log2,
2510
  };
2511
 
2512
  [encoder setComputePipelineState:pipeline];
2513
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
2514
  if (id_src1) {
2515
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2516
  } else {
2517
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
2518
  }
2519
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2520
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2521
 
2522
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2523
 
 
3008
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3009
 
3010
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
3011
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3012
  } else {
3013
  id<MTLComputePipelineState> pipeline = nil;
3014
 
 
3228
  } break;
3229
  case GGML_OP_MUL_MAT_ID:
3230
  {
 
 
3231
  // src2 = ids
3232
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
3233
 
 
3241
  GGML_ASSERT(ne03 == 1);
3242
  GGML_ASSERT(ne13 == 1);
3243
 
3244
+ const uint32_t r2 = 1;
3245
+ const uint32_t r3 = 1;
3246
+
3247
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
3248
  // to the matrix-vector kernel
3249
  // ne20 = n_used_experts
3250
+ // ne21 = n_rows (batch size)
3251
+ const int ne21_mm_id_min = 32;
 
 
 
 
 
3252
 
3253
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
3254
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
3255
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
3256
  ne00 % 32 == 0 && ne00 >= 64 &&
3257
+ (ne21 >= ne21_mm_id_min)) {
3258
+ GGML_ASSERT(ne00 % 4 == 0);
 
3259
 
3260
  // some Metal matrix data types require aligned pointers
3261
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
 
3266
  default: break;
3267
  }
3268
 
3269
+ const int64_t neh10 = ne10; // n_embd
3270
+ const int64_t neh11 = ne21; // n_tokens
3271
+ const int64_t neh12 = ne02; // n_expert
3272
 
3273
+ const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
3274
+ const uint64_t nbh11 = nbh10*neh10;
3275
+ const uint64_t nbh12 = nbh11*neh11;
3276
+ const uint64_t nbh13 = nbh12*neh12;
3277
+
3278
+ const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
3279
+ id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3280
+ if (!h_src1) {
3281
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3282
+ return false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3283
  }
3284
 
3285
+ const int64_t neh0 = ne0;
3286
+ const int64_t neh1 = ne21;
3287
+ const int64_t neh2 = ne02;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3288
 
3289
+ const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
3290
+ const uint64_t nbh1 = nbh0*neh0;
3291
+ const uint64_t nbh2 = nbh1*neh1;
3292
+ //const uint64_t nbh3 = nbh2*neh2;
3293
+
3294
+ const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
3295
+ id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3296
+ if (!h_dst) {
3297
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3298
+ return false;
3299
+ }
3300
 
3301
+ // tokens per expert
3302
+ const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
3303
+ id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3304
+ if (!h_tpe) {
3305
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3306
+ return false;
3307
+ }
3308
+
3309
+ // id map
3310
+ // [n_expert_used, n_tokens]
3311
+ const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
3312
+ id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3313
+ if (!h_ids) {
3314
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3315
+ return false;
3316
+ }
3317
+
3318
+ {
3319
+ const int nth = MIN(1024, ne10/4);
3320
+
3321
+ ggml_metal_kargs_mul_mm_id_map0 args = {
3322
+ ne10,
3323
+ ne11, // n_expert_used (bcast)
3324
+ nb11,
3325
+ nb12,
3326
+ neh11, // n_tokens
3327
+ nbh11,
3328
+ ne20, // n_expert_used
3329
+ nb21,
3330
+ };
3331
+
3332
+ id<MTLComputePipelineState> pipeline = nil;
3333
+
3334
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
3335
+
3336
+ [encoder setComputePipelineState:pipeline];
3337
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3338
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3339
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3340
+ [encoder setBuffer: h_src1 offset:0 atIndex:3];
3341
+ [encoder setBuffer: h_tpe offset:0 atIndex:4];
3342
+ [encoder setBuffer: h_ids offset:0 atIndex:5];
3343
+
3344
+ [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3345
+ }
3346
+
3347
+ {
3348
+ id<MTLComputePipelineState> pipeline = nil;
3349
+
3350
+ switch (src0->type) {
3351
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
3352
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
3353
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
3354
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
3355
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
3356
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
3357
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
3358
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
3359
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
3360
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
3361
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
3362
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
3363
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
3364
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
3365
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
3366
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
3367
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
3368
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
3369
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
3370
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
3371
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
3372
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
3373
+ default: GGML_ABORT("MUL_MAT_ID not implemented");
3374
+ }
3375
+
3376
+ ggml_metal_kargs_mul_mm_id args = {
3377
+ /*.ne00 =*/ ne00,
3378
+ /*.ne02 =*/ ne02,
3379
+ /*.nb01 =*/ nb01,
3380
+ /*.nb02 =*/ nb02,
3381
+ /*.nb03 =*/ nb03,
3382
+ /*.neh12 =*/ neh12,
3383
+ /*.nbh10 =*/ nbh10,
3384
+ /*.nbh11 =*/ nbh11,
3385
+ /*.nbh12 =*/ nbh12,
3386
+ /*.nbh13 =*/ nbh13,
3387
+ /*.neh0 =*/ neh0,
3388
+ /*.neh1 =*/ neh1,
3389
+ /*.r2 =*/ r2,
3390
+ /*.r3 =*/ r3,
3391
+ };
3392
+
3393
+ [encoder setComputePipelineState:pipeline];
3394
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3395
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3396
+ [encoder setBuffer: h_src1 offset:0 atIndex:2];
3397
+ [encoder setBuffer: h_tpe offset:0 atIndex:3];
3398
+ [encoder setBuffer: h_dst offset:0 atIndex:4];
3399
+
3400
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
3401
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3402
+ }
3403
+
3404
+ {
3405
+ GGML_ASSERT(ne0 % 4 == 0);
3406
+
3407
+ const int nth = MIN(1024, ne0/4);
3408
 
3409
+ ggml_metal_kargs_mul_mm_id_map1 args = {
3410
+ ne20, // n_expert_used
3411
+ neh0,
3412
+ neh1,
3413
+ nbh1,
3414
+ nbh2,
3415
+ ne0,
3416
+ nb1,
3417
+ nb2,
3418
+ };
3419
+
3420
+ id<MTLComputePipelineState> pipeline = nil;
3421
+
3422
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
3423
+
3424
+ [encoder setComputePipelineState:pipeline];
3425
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3426
+ [encoder setBuffer: h_dst offset:0 atIndex:1];
3427
+ [encoder setBuffer: h_ids offset:0 atIndex:2];
3428
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3429
+
3430
+ [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3431
+ }
3432
  } else {
3433
  id<MTLComputePipelineState> pipeline = nil;
3434
 
 
3622
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3623
 
3624
  const int64_t _ne1 = 1;
3625
+ const int64_t ne123 = ne20*ne21;
3626
 
3627
  if (smem > 0) {
3628
  [encoder setThreadgroupMemoryLength:smem atIndex:0];
 
5028
  GGML_ABORT("fatal error");
5029
  }
5030
  }
5031
+
5032
+ return true;
5033
  }
5034
 
5035
  static enum ggml_status ggml_metal_graph_compute(
 
5083
  }
5084
 
5085
  // the main thread commits the first few commands immediately
5086
+ // cmd_buf[n_cb]
5087
  {
5088
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5089
+ ctx->cmd_bufs[n_cb].obj = cmd_buf;
5090
 
5091
+ [cmd_buf enqueue];
5092
  ctx->encode_async(n_cb);
5093
  }
5094
 
5095
  // prepare the rest of the command buffers asynchronously
5096
+ // cmd_buf[0.. n_cb)
5097
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
5098
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5099
+ ctx->cmd_bufs[cb_idx].obj = cmd_buf;
5100
 
5101
  // always enqueue the first two command buffers
5102
  // enqueue all of the command buffers if we don't need to abort
5103
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
5104
+ [cmd_buf enqueue];
5105
  }
5106
  }
5107
 
 
5110
  // wait for completion and check status of each command buffer
5111
  // needed to detect if the device ran out-of-memory for example (#1881)
5112
  {
5113
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
5114
+ [cmd_buf waitUntilCompleted];
5115
 
5116
+ MTLCommandBufferStatus status = [cmd_buf status];
5117
  if (status != MTLCommandBufferStatusCompleted) {
5118
  GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
5119
  if (status == MTLCommandBufferStatusError) {
5120
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
5121
  }
5122
 
5123
  return GGML_STATUS_FAILED;
 
5125
  }
5126
 
5127
  for (int i = 0; i < n_cb; ++i) {
5128
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
5129
+ [cmd_buf waitUntilCompleted];
5130
 
5131
+ MTLCommandBufferStatus status = [cmd_buf status];
5132
  if (status != MTLCommandBufferStatusCompleted) {
5133
  GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
5134
  if (status == MTLCommandBufferStatusError) {
5135
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
5136
  }
5137
 
5138
  return GGML_STATUS_FAILED;
5139
  }
5140
 
5141
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
5142
  if (!next_buffer) {
5143
  continue;
5144
  }
 
5521
 
5522
  const int n_nodes_per_cb = ctx->n_nodes_per_cb;
5523
 
5524
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
5525
+
5526
+ id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
5527
 
5528
  int node_start = 0;
5529
  int node_end = n_nodes_0;
 
5535
 
5536
  const bool should_capture = ctx->capture_next_compute;
5537
 
5538
+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5539
+ ggml_metal_mem_pool_reset(mem_pool);
5540
+
5541
  for (int idx = node_start; idx < node_end; ++idx) {
5542
  if (should_capture) {
5543
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5544
  }
5545
 
5546
+ const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
5547
 
5548
  if (should_capture) {
5549
  [encoder popDebugGroup];
5550
  }
5551
+
5552
+ if (!res) {
5553
+ break;
5554
+ }
5555
  }
5556
 
5557
  [encoder endEncoding];
5558
 
5559
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
5560
+ [cmd_buf commit];
5561
  }
5562
  });
5563
  }
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm(
6336
  }
6337
  }
6338
 
6339
- // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
6340
- // TODO: this kernel needs to be reimplemented from scratch for better performance
6341
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6342
- void kernel_mul_mm_id_impl(
6343
- int32_t ne00,
6344
- int32_t ne02,
6345
- uint64_t nb01,
6346
- uint64_t nb02,
6347
- int32_t ne11,
6348
- int32_t ne12,
6349
- uint64_t nb10,
6350
- uint64_t nb11,
6351
- uint64_t nb12,
6352
- int32_t ne0,
6353
- int32_t ne1,
6354
- int64_t ne0ne1,
6355
- device const char * src0,
6356
- device const char * src1,
6357
- threadgroup ushort2 * rowids,
6358
- device char * dst,
6359
- threadgroup char * shmem,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6360
  uint3 tgpig[[threadgroup_position_in_grid]],
6361
  ushort tiitg[[thread_index_in_threadgroup]],
6362
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6363
 
6364
- threadgroup half * sa = (threadgroup half *)(shmem);
6365
- threadgroup float * sb = (threadgroup float *)(shmem + 4096);
6366
 
6367
  const int r0 = tgpig.y;
6368
  const int r1 = tgpig.x;
 
 
 
 
 
6369
 
6370
- if (r1*BLOCK_SIZE_N >= ne1) return;
 
 
6371
 
6372
  // if this block is of 64x32 shape or smaller
6373
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
6374
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
6375
 
6376
  // a thread shouldn't load data outside of the matrix
6377
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
6378
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
6379
 
6380
- simdgroup_half8x8 ma[4];
6381
- simdgroup_float8x8 mb[2];
6382
  simdgroup_float8x8 mc[8];
6383
- for (int i = 0; i < 8; i++){
 
6384
  mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
6385
  }
 
6386
  short il = (tiitg % THREAD_PER_ROW);
6387
 
6388
- ushort offset1 = il/nl;
 
6389
 
6390
- threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
 
6391
 
6392
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
6393
- device const float * y = (device const float *)(src1
6394
- + nb12 * id[1]
6395
- + nb11 * (id[0] % ne11)
6396
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
6397
 
6398
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
 
 
 
 
 
 
6399
  // load data and store to threadgroup memory
6400
- half4x4 temp_a;
6401
  dequantize_func(x, il, temp_a);
 
6402
  threadgroup_barrier(mem_flags::mem_threadgroup);
6403
 
6404
- for (int i = 0; i < 16; i++) {
6405
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
6406
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
6407
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
 
6408
  }
6409
 
6410
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6411
 
6412
  il = (il + 2 < nl) ? il + 2 : il % 2;
6413
- x = (il < 2) ? x + (2+nl-1)/nl : x;
6414
  y += BLOCK_SIZE_K;
6415
 
6416
  threadgroup_barrier(mem_flags::mem_threadgroup);
6417
 
6418
  // load matrices from threadgroup memory and conduct outer products
6419
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
6420
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
6421
 
6422
- #pragma unroll(BLOCK_SIZE_K/8)
6423
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
6424
  #pragma unroll(4)
6425
- for (int i = 0; i < 4; i++) {
6426
  simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
6427
  }
 
6428
  simdgroup_barrier(mem_flags::mem_none);
 
6429
  #pragma unroll(2)
6430
- for (int i = 0; i < 2; i++) {
6431
  simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
6432
  }
6433
 
6434
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
6435
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
6436
-
6437
  #pragma unroll(8)
6438
- for (int i = 0; i < 8; i++){
6439
  simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
6440
  }
 
 
 
6441
  }
6442
  }
6443
 
6444
- {
 
 
 
 
 
 
 
 
 
6445
  threadgroup_barrier(mem_flags::mem_threadgroup);
6446
  threadgroup float * temp_str = ((threadgroup float *) shmem) \
6447
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
6448
- for (int i = 0; i < 8; i++) {
6449
- simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6450
  }
6451
 
6452
  threadgroup_barrier(mem_flags::mem_threadgroup);
6453
 
6454
  if (sgitg == 0) {
6455
  for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6456
- threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
6457
- int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
6458
-
6459
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
6460
  device float4 * D4 = (device float4 *) D;
6461
 
6462
  threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
@@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl(
6476
  }
6477
  }
6478
 
6479
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6480
- kernel void kernel_mul_mm_id(
6481
- constant ggml_metal_kargs_mul_mm_id & args,
6482
- device const char * src0s,
6483
- device const char * src1,
6484
- device char * dst,
6485
- device const char * ids,
6486
- threadgroup char * shmem [[threadgroup(0)]],
6487
- uint3 tgpig[[threadgroup_position_in_grid]],
6488
- ushort tiitg[[thread_index_in_threadgroup]],
6489
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6490
-
6491
- const int32_t i02 = tgpig.z;
6492
-
6493
- tgpig.z = 0;
6494
-
6495
- device const char * src0 = src0s + i02*args.nb02;
6496
-
6497
- // row indices
6498
- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
6499
-
6500
- // TODO: parallelize this loop
6501
- int32_t _ne1 = 0;
6502
- for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
6503
- for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
6504
- int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
6505
- if (id == i02) {
6506
- if (tiitg == 0) {
6507
- rowids[_ne1] = ushort2(ii0, ii1);
6508
- }
6509
- _ne1++;
6510
- }
6511
- }
6512
- }
6513
-
6514
- threadgroup_barrier(mem_flags::mem_threadgroup);
6515
-
6516
- kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
6517
- args.ne00,
6518
- args.ne02,
6519
- args.nb01,
6520
- args.nb02,
6521
- args.ne11,
6522
- args.ne12,
6523
- args.nb10,
6524
- args.nb11,
6525
- args.nb12,
6526
- args.ne0,
6527
- _ne1,
6528
- (int64_t)args.ne0*args.ne1,
6529
- src0,
6530
- src1,
6531
- rowids,
6532
- dst,
6533
- shmem,
6534
- tgpig,
6535
- tiitg,
6536
- sgitg);
6537
- }
6538
-
6539
  #define QK_NL 16
6540
 
6541
  //
@@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
6576
  // matrix-matrix multiplication
6577
  //
6578
 
6579
- typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
6580
 
6581
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6582
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6583
  #if defined(GGML_METAL_USE_BF16)
6584
- template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6585
  #endif
6586
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6587
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6588
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
6589
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
6590
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
6591
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
6592
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
6593
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
6594
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
6595
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
6596
- template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6597
- template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6598
- template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6599
- template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
6600
- template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
6601
- template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
6602
- template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
6603
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
6604
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6605
 
6606
  //
6607
  // indirect matrix-matrix multiplication
6608
  //
6609
 
6610
- typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
6611
 
6612
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
6613
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
6614
  #if defined(GGML_METAL_USE_BF16)
6615
- template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
6616
  #endif
6617
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
6618
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
6619
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
6620
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
6621
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
6622
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
6623
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
6624
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
6625
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
6626
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
6627
- template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6628
- template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6629
- template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6630
- template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
6631
- template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6632
- template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6633
- template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6634
- template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6635
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
 
6636
 
6637
  //
6638
  // matrix-vector multiplication
 
6336
  }
6337
  }
6338
 
6339
+ template<typename T4>
6340
+ kernel void kernel_mul_mm_id_map0(
6341
+ constant ggml_metal_kargs_mul_mm_id_map0 & args,
6342
+ device const char * src1,
6343
+ device const char * src2,
6344
+ device char * hsrc1,
6345
+ device char * htpe,
6346
+ device char * hids,
6347
+ uint3 tgpig[[threadgroup_position_in_grid]],
6348
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6349
+ ushort3 ntg[[threads_per_threadgroup]]) {
6350
+ const int ide = tgpig[0]; // expert id
6351
+
6352
+ int n_all = 0;
6353
+
6354
+ device int32_t * ids_i32 = (device int32_t *) (hids);
6355
+
6356
+ for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
6357
+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
6358
+
6359
+ for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
6360
+ if (src2_i32[i20] != ide) {
6361
+ continue;
6362
+ }
6363
+
6364
+ device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
6365
+ device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
6366
+
6367
+ for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
6368
+ hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
6369
+ }
6370
+
6371
+ if (tpitg.x == 0) {
6372
+ ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
6373
+ }
6374
+
6375
+ ++n_all;
6376
+ }
6377
+ }
6378
+
6379
+ if (tpitg.x == 0) {
6380
+ device int32_t * tpe_i32 = (device int32_t *) (htpe);
6381
+ tpe_i32[ide] = n_all;
6382
+ }
6383
+ }
6384
+
6385
+ typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
6386
+
6387
+ template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
6388
+
6389
+ template<typename T>
6390
+ kernel void kernel_mul_mm_id_map1(
6391
+ constant ggml_metal_kargs_mul_mm_id_map1 & args,
6392
+ device const char * hdst,
6393
+ device const char * hids,
6394
+ device char * dst,
6395
+ uint3 tgpig[[threadgroup_position_in_grid]],
6396
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6397
+ ushort3 ntg[[threads_per_threadgroup]]) {
6398
+ const int i20 = tgpig[0]; // used expert
6399
+ const int i21 = tgpig[1]; // token
6400
+
6401
+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
6402
+ device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
6403
+
6404
+ const int id = ids_i32[i21*args.ne20 + i20];
6405
+
6406
+ const int ide = id / args.neh1;
6407
+ const int idt = id % args.neh1;
6408
+
6409
+ device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
6410
+
6411
+ for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
6412
+ dst_f32x4[i0] = hdst_f32x4[i0];
6413
+ }
6414
+ }
6415
+
6416
+ typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
6417
+
6418
+ template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
6419
+
6420
+ template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
6421
+ kernel void kernel_mul_mm_id(
6422
+ constant ggml_metal_kargs_mul_mm_id & args,
6423
+ device const char * src0,
6424
+ device const char * src1,
6425
+ device const char * tpe,
6426
+ device char * dst,
6427
+ threadgroup char * shmem [[threadgroup(0)]],
6428
  uint3 tgpig[[threadgroup_position_in_grid]],
6429
  ushort tiitg[[thread_index_in_threadgroup]],
6430
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6431
 
6432
+ threadgroup T * sa = (threadgroup T *)(shmem);
6433
+ threadgroup half * sb = (threadgroup half *)(shmem + 4096);
6434
 
6435
  const int r0 = tgpig.y;
6436
  const int r1 = tgpig.x;
6437
+ const int im = tgpig.z;
6438
+
6439
+ device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
6440
+
6441
+ const int neh1 = tpe_i32[im];
6442
 
6443
+ if (r1*BLOCK_SIZE_N >= neh1) {
6444
+ return;
6445
+ }
6446
 
6447
  // if this block is of 64x32 shape or smaller
6448
+ const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
6449
+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
6450
 
6451
  // a thread shouldn't load data outside of the matrix
6452
+ const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
6453
+ const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
6454
 
6455
+ simdgroup_T8x8 ma[4];
6456
+ simdgroup_half8x8 mb[2];
6457
  simdgroup_float8x8 mc[8];
6458
+
6459
+ for (short i = 0; i < 8; i++){
6460
  mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
6461
  }
6462
+
6463
  short il = (tiitg % THREAD_PER_ROW);
6464
 
6465
+ const int i12 = im%args.neh12;
6466
+ const int i13 = im/args.neh12;
6467
 
6468
+ const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6469
+ const short offset1 = il/nl;
6470
 
6471
+ device const block_q * x = (device const block_q *)(src0
6472
+ + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
 
 
 
6473
 
6474
+ device const half * y = (device const half *)(src1
6475
+ + args.nbh13*i13
6476
+ + args.nbh12*i12
6477
+ + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
6478
+ + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
6479
+
6480
+ for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
6481
  // load data and store to threadgroup memory
6482
+ T4x4 temp_a;
6483
  dequantize_func(x, il, temp_a);
6484
+
6485
  threadgroup_barrier(mem_flags::mem_threadgroup);
6486
 
6487
+ #pragma unroll(16)
6488
+ for (short i = 0; i < 16; i++) {
6489
+ *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
6490
+ + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
6491
+ + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
6492
  }
6493
 
6494
+ *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
6495
 
6496
  il = (il + 2 < nl) ? il + 2 : il % 2;
6497
+ x = (il < 2) ? x + (2 + nl - 1)/nl : x;
6498
  y += BLOCK_SIZE_K;
6499
 
6500
  threadgroup_barrier(mem_flags::mem_threadgroup);
6501
 
6502
  // load matrices from threadgroup memory and conduct outer products
6503
+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
6504
+ threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
6505
 
6506
+ #pragma unroll(4)
6507
+ for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
6508
  #pragma unroll(4)
6509
+ for (short i = 0; i < 4; i++) {
6510
  simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
6511
  }
6512
+
6513
  simdgroup_barrier(mem_flags::mem_none);
6514
+
6515
  #pragma unroll(2)
6516
+ for (short i = 0; i < 2; i++) {
6517
  simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
6518
  }
6519
 
 
 
 
6520
  #pragma unroll(8)
6521
+ for (short i = 0; i < 8; i++){
6522
  simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
6523
  }
6524
+
6525
+ lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
6526
+ lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
6527
  }
6528
  }
6529
 
6530
+ if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
6531
+ device float * C = (device float *) dst +
6532
+ (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
6533
+ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
6534
+
6535
+ for (short i = 0; i < 8; i++) {
6536
+ simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
6537
+ }
6538
+ } else {
6539
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
6540
  threadgroup_barrier(mem_flags::mem_threadgroup);
6541
  threadgroup float * temp_str = ((threadgroup float *) shmem) \
6542
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
6543
+ for (short i = 0; i < 8; i++) {
6544
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
6545
  }
6546
 
6547
  threadgroup_barrier(mem_flags::mem_threadgroup);
6548
 
6549
  if (sgitg == 0) {
6550
  for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6551
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
 
 
 
6552
  device float4 * D4 = (device float4 *) D;
6553
 
6554
  threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
 
6568
  }
6569
  }
6570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6571
  #define QK_NL 16
6572
 
6573
  //
 
6608
  // matrix-matrix multiplication
6609
  //
6610
 
6611
+ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
6612
 
6613
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6614
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6615
  #if defined(GGML_METAL_USE_BF16)
6616
+ template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6617
  #endif
6618
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6619
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6620
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
6621
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
6622
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
6623
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
6624
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
6625
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
6626
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
6627
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
6628
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6629
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6630
+ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6631
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
6632
+ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
6633
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
6634
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
6635
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
6636
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6637
 
6638
  //
6639
  // indirect matrix-matrix multiplication
6640
  //
6641
 
6642
+ typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
6643
 
6644
+ template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6645
+ template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6646
  #if defined(GGML_METAL_USE_BF16)
6647
+ template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6648
  #endif
6649
+ template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6650
+ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6651
+ template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
6652
+ template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
6653
+ template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
6654
+ template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
6655
+ template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
6656
+ template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
6657
+ template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
6658
+ template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
6659
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6660
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6661
+ template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6662
+ template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
6663
+ template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
6664
+ template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
6665
+ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
6666
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
6667
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6668
+
6669
 
6670
  //
6671
  // matrix-vector multiplication
ggml/src/ggml.c CHANGED
@@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec(
2732
  c = ggml_mul_mat_id(ctx, as, b, ids);
2733
 
2734
  as -> [cols, rows, n_expert]
2735
- ids -> [n_experts_used, n_tokens] (i32)
2736
  b -> [cols, n_expert_used, n_tokens]
 
2737
  c -> [rows, n_expert_used, n_tokens]
2738
 
2739
- in b, n_experts_used can be broadcasted to match the n_expert_used of ids
2740
 
2741
  c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
2742
  */
 
2732
  c = ggml_mul_mat_id(ctx, as, b, ids);
2733
 
2734
  as -> [cols, rows, n_expert]
 
2735
  b -> [cols, n_expert_used, n_tokens]
2736
+ ids -> [n_expert_used, n_tokens] (i32)
2737
  c -> [rows, n_expert_used, n_tokens]
2738
 
2739
+ in b, n_expert_used can be broadcasted to match the n_expert_used of ids
2740
 
2741
  c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
2742
  */