Garf ggerganov commited on
Commit
6c8e7ec
·
1 Parent(s): 4532dc6

metal : copy kernels for quant to F32/F16 conversions (llama/12017)

Browse files

metal: use dequantize_q templates

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -407,6 +407,16 @@ enum ggml_metal_kernel_type {
407
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
 
 
 
 
 
 
 
 
 
 
410
  GGML_METAL_KERNEL_TYPE_CONCAT,
411
  GGML_METAL_KERNEL_TYPE_SQR,
412
  GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1012
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
1013
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
1014
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
 
 
 
 
 
 
 
 
 
 
1015
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
1016
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
1017
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1287
  default:
1288
  return false;
1289
  }
 
 
 
 
 
 
 
 
 
 
 
 
1290
  default:
1291
  return false;
1292
  };
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
3899
  case GGML_OP_CPY:
3900
  case GGML_OP_CONT:
3901
  {
3902
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3903
-
3904
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
3905
-
3906
  id<MTLComputePipelineState> pipeline = nil;
3907
 
3908
  switch (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
3936
  switch (dstt) {
3937
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3938
  case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3939
- default: GGML_ASSERT(false && "not implemented");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3940
  };
3941
  } break;
3942
  default: GGML_ABORT("not implemented");
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
3966
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3967
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3968
 
 
 
 
3969
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 
3970
  } break;
3971
  case GGML_OP_SET:
3972
  {
 
407
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
412
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
414
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
416
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
418
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
420
  GGML_METAL_KERNEL_TYPE_CONCAT,
421
  GGML_METAL_KERNEL_TYPE_SQR,
422
  GGML_METAL_KERNEL_TYPE_SQRT,
 
1022
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
1023
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
1024
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1025
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1026
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
1027
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1028
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
1029
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1030
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
1031
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1032
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
1033
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1034
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
1035
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
1036
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
1037
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
 
1307
  default:
1308
  return false;
1309
  }
1310
+ case GGML_TYPE_Q4_0:
1311
+ case GGML_TYPE_Q4_1:
1312
+ case GGML_TYPE_Q5_0:
1313
+ case GGML_TYPE_Q5_1:
1314
+ case GGML_TYPE_Q8_0:
1315
+ switch (op->type) {
1316
+ case GGML_TYPE_F32:
1317
+ case GGML_TYPE_F16:
1318
+ return true;
1319
+ default:
1320
+ return false;
1321
+ }
1322
  default:
1323
  return false;
1324
  };
 
3931
  case GGML_OP_CPY:
3932
  case GGML_OP_CONT:
3933
  {
 
 
 
 
3934
  id<MTLComputePipelineState> pipeline = nil;
3935
 
3936
  switch (src0t) {
 
3964
  switch (dstt) {
3965
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3966
  case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3967
+ default: GGML_ABORT("not implemented");
3968
+ };
3969
+ } break;
3970
+ case GGML_TYPE_Q4_0:
3971
+ {
3972
+ switch (dstt) {
3973
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3974
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
3975
+ default: GGML_ABORT("not implemented");
3976
+ };
3977
+ } break;
3978
+ case GGML_TYPE_Q4_1:
3979
+ {
3980
+ switch (dstt) {
3981
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3982
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
3983
+ default: GGML_ABORT("not implemented");
3984
+ };
3985
+ } break;
3986
+ case GGML_TYPE_Q5_0:
3987
+ {
3988
+ switch (dstt) {
3989
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3990
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
3991
+ default: GGML_ABORT("not implemented");
3992
+ };
3993
+ } break;
3994
+ case GGML_TYPE_Q5_1:
3995
+ {
3996
+ switch (dstt) {
3997
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3998
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
3999
+ default: GGML_ABORT("not implemented");
4000
+ };
4001
+ } break;
4002
+ case GGML_TYPE_Q8_0:
4003
+ {
4004
+ switch (dstt) {
4005
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
4006
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
4007
+ default: GGML_ABORT("not implemented");
4008
  };
4009
  } break;
4010
  default: GGML_ABORT("not implemented");
 
4034
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4035
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4036
 
4037
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4038
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4039
+
4040
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4041
+
4042
  } break;
4043
  case GGML_OP_SET:
4044
  {
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
4341
  }
4342
  }
4343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4344
  kernel void kernel_concat(
4345
  constant ggml_metal_kargs_concat & args,
4346
  device const char * src0,
 
4341
  }
4342
  }
4343
 
4344
+ template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
4345
+ kernel void kernel_cpy_q_f32(
4346
+ constant ggml_metal_kargs_cpy & args,
4347
+ device const char * src0,
4348
+ device char * dst,
4349
+ uint3 tgpig[[threadgroup_position_in_grid]],
4350
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4351
+ ushort3 ntg[[threads_per_threadgroup]]) {
4352
+ const int i03 = tgpig[2];
4353
+ const int i02 = tgpig[1];
4354
+ const int i01 = tgpig[0];
4355
+
4356
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4357
+
4358
+ const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
4359
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
4360
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
4361
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
4362
+
4363
+ device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4364
+ device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4365
+
4366
+ for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
4367
+ T4x4 temp;
4368
+ dequantize_func(src_data + i00/nl, i00%nl, temp);
4369
+ dst_data[i00] = temp;
4370
+ }
4371
+ }
4372
+
4373
+ typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
4374
+
4375
+ template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
4376
+ template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
4377
+ template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
4378
+ template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
4379
+ template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
4380
+
4381
+ template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
4382
+ template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
4383
+ template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
4384
+ template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
4385
+ template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
4386
+
4387
  kernel void kernel_concat(
4388
  constant ggml_metal_kargs_concat & args,
4389
  device const char * src0,