Spaces:
Running
Running
metal : copy kernels for quant to F32/F16 conversions (llama/12017)
Browse filesmetal: use dequantize_q templates
---------
Co-authored-by: Georgi Gerganov <[email protected]>
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -407,6 +407,16 @@ enum ggml_metal_kernel_type {
|
|
| 407 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
| 408 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
| 409 |
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
GGML_METAL_KERNEL_TYPE_CONCAT,
|
| 411 |
GGML_METAL_KERNEL_TYPE_SQR,
|
| 412 |
GGML_METAL_KERNEL_TYPE_SQRT,
|
|
@@ -1012,6 +1022,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1012 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
| 1013 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
| 1014 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
| 1016 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
| 1017 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
|
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1287 |
default:
|
| 1288 |
return false;
|
| 1289 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1290 |
default:
|
| 1291 |
return false;
|
| 1292 |
};
|
|
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
|
|
| 3899 |
case GGML_OP_CPY:
|
| 3900 |
case GGML_OP_CONT:
|
| 3901 |
{
|
| 3902 |
-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
| 3903 |
-
|
| 3904 |
-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
| 3905 |
-
|
| 3906 |
id<MTLComputePipelineState> pipeline = nil;
|
| 3907 |
|
| 3908 |
switch (src0t) {
|
|
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
|
|
| 3936 |
switch (dstt) {
|
| 3937 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
| 3938 |
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
| 3939 |
-
default:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3940 |
};
|
| 3941 |
} break;
|
| 3942 |
default: GGML_ABORT("not implemented");
|
|
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
|
|
| 3966 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 3967 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 3968 |
|
|
|
|
|
|
|
|
|
|
| 3969 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
| 3970 |
} break;
|
| 3971 |
case GGML_OP_SET:
|
| 3972 |
{
|
|
|
|
| 407 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
| 408 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
| 409 |
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
| 410 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
|
| 411 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
|
| 412 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
|
| 413 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
|
| 414 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
|
| 415 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
|
| 416 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
|
| 417 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
|
| 418 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
|
| 419 |
+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
|
| 420 |
GGML_METAL_KERNEL_TYPE_CONCAT,
|
| 421 |
GGML_METAL_KERNEL_TYPE_SQR,
|
| 422 |
GGML_METAL_KERNEL_TYPE_SQRT,
|
|
|
|
| 1022 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
| 1023 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
| 1024 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
| 1025 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
|
| 1026 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
|
| 1027 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
|
| 1028 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
|
| 1029 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
|
| 1030 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
|
| 1031 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
|
| 1032 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
|
| 1033 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
|
| 1034 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
|
| 1035 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
| 1036 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
| 1037 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
|
|
|
| 1307 |
default:
|
| 1308 |
return false;
|
| 1309 |
}
|
| 1310 |
+
case GGML_TYPE_Q4_0:
|
| 1311 |
+
case GGML_TYPE_Q4_1:
|
| 1312 |
+
case GGML_TYPE_Q5_0:
|
| 1313 |
+
case GGML_TYPE_Q5_1:
|
| 1314 |
+
case GGML_TYPE_Q8_0:
|
| 1315 |
+
switch (op->type) {
|
| 1316 |
+
case GGML_TYPE_F32:
|
| 1317 |
+
case GGML_TYPE_F16:
|
| 1318 |
+
return true;
|
| 1319 |
+
default:
|
| 1320 |
+
return false;
|
| 1321 |
+
}
|
| 1322 |
default:
|
| 1323 |
return false;
|
| 1324 |
};
|
|
|
|
| 3931 |
case GGML_OP_CPY:
|
| 3932 |
case GGML_OP_CONT:
|
| 3933 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3934 |
id<MTLComputePipelineState> pipeline = nil;
|
| 3935 |
|
| 3936 |
switch (src0t) {
|
|
|
|
| 3964 |
switch (dstt) {
|
| 3965 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
| 3966 |
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
| 3967 |
+
default: GGML_ABORT("not implemented");
|
| 3968 |
+
};
|
| 3969 |
+
} break;
|
| 3970 |
+
case GGML_TYPE_Q4_0:
|
| 3971 |
+
{
|
| 3972 |
+
switch (dstt) {
|
| 3973 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
|
| 3974 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
|
| 3975 |
+
default: GGML_ABORT("not implemented");
|
| 3976 |
+
};
|
| 3977 |
+
} break;
|
| 3978 |
+
case GGML_TYPE_Q4_1:
|
| 3979 |
+
{
|
| 3980 |
+
switch (dstt) {
|
| 3981 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
|
| 3982 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
|
| 3983 |
+
default: GGML_ABORT("not implemented");
|
| 3984 |
+
};
|
| 3985 |
+
} break;
|
| 3986 |
+
case GGML_TYPE_Q5_0:
|
| 3987 |
+
{
|
| 3988 |
+
switch (dstt) {
|
| 3989 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
|
| 3990 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
|
| 3991 |
+
default: GGML_ABORT("not implemented");
|
| 3992 |
+
};
|
| 3993 |
+
} break;
|
| 3994 |
+
case GGML_TYPE_Q5_1:
|
| 3995 |
+
{
|
| 3996 |
+
switch (dstt) {
|
| 3997 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
|
| 3998 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
|
| 3999 |
+
default: GGML_ABORT("not implemented");
|
| 4000 |
+
};
|
| 4001 |
+
} break;
|
| 4002 |
+
case GGML_TYPE_Q8_0:
|
| 4003 |
+
{
|
| 4004 |
+
switch (dstt) {
|
| 4005 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
|
| 4006 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
|
| 4007 |
+
default: GGML_ABORT("not implemented");
|
| 4008 |
};
|
| 4009 |
} break;
|
| 4010 |
default: GGML_ABORT("not implemented");
|
|
|
|
| 4034 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 4035 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 4036 |
|
| 4037 |
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
| 4038 |
+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
| 4039 |
+
|
| 4040 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 4041 |
+
|
| 4042 |
} break;
|
| 4043 |
case GGML_OP_SET:
|
| 4044 |
{
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
| 4341 |
}
|
| 4342 |
}
|
| 4343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4344 |
kernel void kernel_concat(
|
| 4345 |
constant ggml_metal_kargs_concat & args,
|
| 4346 |
device const char * src0,
|
|
|
|
| 4341 |
}
|
| 4342 |
}
|
| 4343 |
|
| 4344 |
+
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
| 4345 |
+
kernel void kernel_cpy_q_f32(
|
| 4346 |
+
constant ggml_metal_kargs_cpy & args,
|
| 4347 |
+
device const char * src0,
|
| 4348 |
+
device char * dst,
|
| 4349 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4350 |
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
| 4351 |
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
| 4352 |
+
const int i03 = tgpig[2];
|
| 4353 |
+
const int i02 = tgpig[1];
|
| 4354 |
+
const int i01 = tgpig[0];
|
| 4355 |
+
|
| 4356 |
+
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
| 4357 |
+
|
| 4358 |
+
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
| 4359 |
+
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
| 4360 |
+
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
| 4361 |
+
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
| 4362 |
+
|
| 4363 |
+
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
| 4364 |
+
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
| 4365 |
+
|
| 4366 |
+
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
|
| 4367 |
+
T4x4 temp;
|
| 4368 |
+
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
| 4369 |
+
dst_data[i00] = temp;
|
| 4370 |
+
}
|
| 4371 |
+
}
|
| 4372 |
+
|
| 4373 |
+
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
| 4374 |
+
|
| 4375 |
+
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
| 4376 |
+
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
| 4377 |
+
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
| 4378 |
+
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
| 4379 |
+
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
| 4380 |
+
|
| 4381 |
+
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
| 4382 |
+
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
| 4383 |
+
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
| 4384 |
+
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
|
| 4385 |
+
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
|
| 4386 |
+
|
| 4387 |
kernel void kernel_concat(
|
| 4388 |
constant ggml_metal_kargs_concat & args,
|
| 4389 |
device const char * src0,
|