Spaces:
Running
ggml : add ggml_set_rows (llama/14274)
Browse files* ggml : add ggml_set_rows
Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.
ref: #8366
* use I64 for indices
* ggml : add repeat impl for i64
* ggml : add ggml_is_contiguous_rows
* ggml : ggml_set_rows support broadcast
* ggml : ggml_set_rows support quantized dst
ggml-ci
* ggml : support GGML_TYPE_F32 ".from_float" trait
* ggml : ggml_set_rows update comment + better index name
* tests : add ggml_set_rows
* metal : add ggml_set_rows implementation
ggml-ci
* ggml : simplify forward_dup_f32
* ggml : fix supports_op
* tests : add comment to set_rows
* ggml : leave the repeat_i64 for a separate PR
ggml-ci
* ggml : set_rows use std::min instead of MIN
* ggml : better error message for set_rows unsupported type
* metal : perform op->type check only once
* tests : more consistent implementation + more tests
ggml-ci
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml/include/ggml-cpu.h +1 -0
- ggml/include/ggml.h +21 -0
- ggml/src/ggml-cpu/ggml-cpu.c +10 -0
- ggml/src/ggml-cpu/ggml-cpu.cpp +1 -0
- ggml/src/ggml-cpu/ops.cpp +77 -19
- ggml/src/ggml-cpu/ops.h +1 -0
- ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- ggml/src/ggml-metal/ggml-metal.m +108 -4
- ggml/src/ggml-metal/ggml-metal.metal +290 -179
- ggml/src/ggml.c +39 -2
|
@@ -134,6 +134,7 @@ extern "C" {
|
|
| 134 |
|
| 135 |
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
| 136 |
|
|
|
|
| 137 |
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
|
| 138 |
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
|
| 139 |
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
|
|
|
|
| 134 |
|
| 135 |
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
| 136 |
|
| 137 |
+
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
|
| 138 |
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
|
| 139 |
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
|
| 140 |
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
|
|
@@ -470,6 +470,7 @@ extern "C" {
|
|
| 470 |
GGML_OP_TRANSPOSE,
|
| 471 |
GGML_OP_GET_ROWS,
|
| 472 |
GGML_OP_GET_ROWS_BACK,
|
|
|
|
| 473 |
GGML_OP_DIAG,
|
| 474 |
GGML_OP_DIAG_MASK_INF,
|
| 475 |
GGML_OP_DIAG_MASK_ZERO,
|
|
@@ -687,6 +688,9 @@ extern "C" {
|
|
| 687 |
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
|
| 688 |
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
|
| 689 |
|
|
|
|
|
|
|
|
|
|
| 690 |
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
| 691 |
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
| 692 |
|
|
@@ -1375,6 +1379,23 @@ extern "C" {
|
|
| 1375 |
struct ggml_tensor * b, // row indices
|
| 1376 |
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
|
| 1377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1378 |
GGML_API struct ggml_tensor * ggml_diag(
|
| 1379 |
struct ggml_context * ctx,
|
| 1380 |
struct ggml_tensor * a);
|
|
|
|
| 470 |
GGML_OP_TRANSPOSE,
|
| 471 |
GGML_OP_GET_ROWS,
|
| 472 |
GGML_OP_GET_ROWS_BACK,
|
| 473 |
+
GGML_OP_SET_ROWS,
|
| 474 |
GGML_OP_DIAG,
|
| 475 |
GGML_OP_DIAG_MASK_INF,
|
| 476 |
GGML_OP_DIAG_MASK_ZERO,
|
|
|
|
| 688 |
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
|
| 689 |
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
|
| 690 |
|
| 691 |
+
// true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
|
| 692 |
+
GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
|
| 693 |
+
|
| 694 |
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
| 695 |
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
| 696 |
|
|
|
|
| 1379 |
struct ggml_tensor * b, // row indices
|
| 1380 |
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
|
| 1381 |
|
| 1382 |
+
// a TD [n_embd, ne1, ne2, ne3]
|
| 1383 |
+
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
|
| 1384 |
+
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
|
| 1385 |
+
//
|
| 1386 |
+
// undefined behavior if destination rows overlap
|
| 1387 |
+
//
|
| 1388 |
+
// broadcast:
|
| 1389 |
+
// ne2 % ne11 == 0
|
| 1390 |
+
// ne3 % ne12 == 0
|
| 1391 |
+
//
|
| 1392 |
+
// return view(a)
|
| 1393 |
+
GGML_API struct ggml_tensor * ggml_set_rows(
|
| 1394 |
+
struct ggml_context * ctx,
|
| 1395 |
+
struct ggml_tensor * a, // destination
|
| 1396 |
+
struct ggml_tensor * b, // source
|
| 1397 |
+
struct ggml_tensor * c); // row indices
|
| 1398 |
+
|
| 1399 |
GGML_API struct ggml_tensor * ggml_diag(
|
| 1400 |
struct ggml_context * ctx,
|
| 1401 |
struct ggml_tensor * a);
|
|
@@ -195,6 +195,7 @@ typedef pthread_t ggml_thread_t;
|
|
| 195 |
|
| 196 |
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
| 197 |
[GGML_TYPE_F32] = {
|
|
|
|
| 198 |
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
|
| 199 |
.vec_dot_type = GGML_TYPE_F32,
|
| 200 |
.nrows = 1,
|
|
@@ -1817,6 +1818,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 1817 |
{
|
| 1818 |
ggml_compute_forward_get_rows_back(params, tensor);
|
| 1819 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1820 |
case GGML_OP_DIAG:
|
| 1821 |
{
|
| 1822 |
ggml_compute_forward_diag(params, tensor);
|
|
@@ -2170,6 +2175,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 2170 |
n_tasks = n_threads;
|
| 2171 |
} break;
|
| 2172 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 2173 |
{
|
| 2174 |
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
|
| 2175 |
// decreases performance with GPU offloading
|
|
@@ -3124,6 +3130,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
|
|
| 3124 |
return ggml_graph_compute(cgraph, &cplan);
|
| 3125 |
}
|
| 3126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3127 |
void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
|
| 3128 |
int64_t i = 0;
|
| 3129 |
#if defined(__F16C__)
|
|
|
|
| 195 |
|
| 196 |
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
| 197 |
[GGML_TYPE_F32] = {
|
| 198 |
+
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
|
| 199 |
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
|
| 200 |
.vec_dot_type = GGML_TYPE_F32,
|
| 201 |
.nrows = 1,
|
|
|
|
| 1818 |
{
|
| 1819 |
ggml_compute_forward_get_rows_back(params, tensor);
|
| 1820 |
} break;
|
| 1821 |
+
case GGML_OP_SET_ROWS:
|
| 1822 |
+
{
|
| 1823 |
+
ggml_compute_forward_set_rows(params, tensor);
|
| 1824 |
+
} break;
|
| 1825 |
case GGML_OP_DIAG:
|
| 1826 |
{
|
| 1827 |
ggml_compute_forward_diag(params, tensor);
|
|
|
|
| 2175 |
n_tasks = n_threads;
|
| 2176 |
} break;
|
| 2177 |
case GGML_OP_GET_ROWS:
|
| 2178 |
+
case GGML_OP_SET_ROWS:
|
| 2179 |
{
|
| 2180 |
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
|
| 2181 |
// decreases performance with GPU offloading
|
|
|
|
| 3130 |
return ggml_graph_compute(cgraph, &cplan);
|
| 3131 |
}
|
| 3132 |
|
| 3133 |
+
void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
|
| 3134 |
+
memcpy(y, x, n * sizeof(float));
|
| 3135 |
+
}
|
| 3136 |
+
|
| 3137 |
void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
|
| 3138 |
int64_t i = 0;
|
| 3139 |
#if defined(__F16C__)
|
|
@@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
|
| 416 |
|
| 417 |
switch (op->op) {
|
| 418 |
case GGML_OP_CPY:
|
|
|
|
| 419 |
return
|
| 420 |
op->type != GGML_TYPE_IQ3_XXS &&
|
| 421 |
op->type != GGML_TYPE_IQ3_S &&
|
|
|
|
| 416 |
|
| 417 |
switch (op->op) {
|
| 418 |
case GGML_OP_CPY:
|
| 419 |
+
case GGML_OP_SET_ROWS:
|
| 420 |
return
|
| 421 |
op->type != GGML_TYPE_IQ3_XXS &&
|
| 422 |
op->type != GGML_TYPE_IQ3_S &&
|
|
@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
|
|
| 696 |
if (ggml_is_contiguous(dst)) {
|
| 697 |
// TODO: simplify
|
| 698 |
if (nb00 == sizeof(float)) {
|
| 699 |
-
if (dst->type
|
| 700 |
-
|
| 701 |
-
const size_t rs = ne00 * nb00;
|
| 702 |
-
char * dst_ptr = (char *) dst->data;
|
| 703 |
-
|
| 704 |
-
for (int i03 = 0; i03 < ne03; i03++) {
|
| 705 |
-
for (int i02 = 0; i02 < ne02; i02++) {
|
| 706 |
-
id += rs * ir0;
|
| 707 |
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
| 708 |
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
| 709 |
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
| 710 |
-
id += rs;
|
| 711 |
-
}
|
| 712 |
-
id += rs * (ne01 - ir1);
|
| 713 |
-
}
|
| 714 |
-
}
|
| 715 |
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
| 716 |
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
| 717 |
|
| 718 |
size_t id = 0;
|
| 719 |
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
|
|
| 724 |
id += rs * ir0;
|
| 725 |
for (int i01 = ir0; i01 < ir1; i01++) {
|
| 726 |
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
| 727 |
-
|
| 728 |
id += rs;
|
| 729 |
}
|
| 730 |
id += rs * (ne01 - ir1);
|
|
@@ -2300,6 +2284,12 @@ void ggml_compute_forward_repeat(
|
|
| 2300 |
{
|
| 2301 |
ggml_compute_forward_repeat_f32(params, dst);
|
| 2302 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2303 |
default:
|
| 2304 |
{
|
| 2305 |
GGML_ABORT("fatal error");
|
|
@@ -4470,6 +4460,74 @@ void ggml_compute_forward_get_rows(
|
|
| 4470 |
//}
|
| 4471 |
}
|
| 4472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4473 |
// ggml_compute_forward_get_rows_back
|
| 4474 |
|
| 4475 |
static void ggml_compute_forward_get_rows_back_f32_f16(
|
|
|
|
| 696 |
if (ggml_is_contiguous(dst)) {
|
| 697 |
// TODO: simplify
|
| 698 |
if (nb00 == sizeof(float)) {
|
| 699 |
+
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
| 700 |
+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
|
| 702 |
size_t id = 0;
|
| 703 |
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
|
|
| 708 |
id += rs * ir0;
|
| 709 |
for (int i01 = ir0; i01 < ir1; i01++) {
|
| 710 |
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
| 711 |
+
from_float(src0_ptr, dst_ptr + id, ne00);
|
| 712 |
id += rs;
|
| 713 |
}
|
| 714 |
id += rs * (ne01 - ir1);
|
|
|
|
| 2284 |
{
|
| 2285 |
ggml_compute_forward_repeat_f32(params, dst);
|
| 2286 |
} break;
|
| 2287 |
+
// TODO: templateify the implemenation and support for I64
|
| 2288 |
+
// ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
|
| 2289 |
+
//case GGML_TYPE_I64:
|
| 2290 |
+
// {
|
| 2291 |
+
// ggml_compute_forward_repeat_i64(params, dst);
|
| 2292 |
+
// } break;
|
| 2293 |
default:
|
| 2294 |
{
|
| 2295 |
GGML_ABORT("fatal error");
|
|
|
|
| 4460 |
//}
|
| 4461 |
}
|
| 4462 |
|
| 4463 |
+
static void ggml_compute_forward_set_rows_f32(
|
| 4464 |
+
const ggml_compute_params * params,
|
| 4465 |
+
ggml_tensor * dst) {
|
| 4466 |
+
|
| 4467 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 4468 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 4469 |
+
|
| 4470 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 4471 |
+
|
| 4472 |
+
const int64_t nc = ne00;
|
| 4473 |
+
const int64_t nr = ne01;
|
| 4474 |
+
|
| 4475 |
+
assert(ne0 == nc);
|
| 4476 |
+
assert(ne2 == ne02);
|
| 4477 |
+
assert(ne3 == ne03);
|
| 4478 |
+
assert(src0->type == GGML_TYPE_F32);
|
| 4479 |
+
assert(ne02 % ne11 == 0);
|
| 4480 |
+
assert(ne03 % ne12 == 0);
|
| 4481 |
+
|
| 4482 |
+
const int ith = params->ith;
|
| 4483 |
+
const int nth = params->nth;
|
| 4484 |
+
|
| 4485 |
+
// rows per thread
|
| 4486 |
+
const int64_t dr = (nr + nth - 1)/nth;
|
| 4487 |
+
|
| 4488 |
+
// row range for this thread
|
| 4489 |
+
const int64_t ir0 = dr*ith;
|
| 4490 |
+
const int64_t ir1 = std::min(ir0 + dr, nr);
|
| 4491 |
+
|
| 4492 |
+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
| 4493 |
+
|
| 4494 |
+
for (int64_t i03 = 0; i03 < ne03; ++i03) {
|
| 4495 |
+
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
| 4496 |
+
for (int64_t i = ir0; i < ir1; ++i) {
|
| 4497 |
+
const int64_t i12 = i03%ne12;
|
| 4498 |
+
const int64_t i11 = i02%ne11;
|
| 4499 |
+
const int64_t i10 = i;
|
| 4500 |
+
|
| 4501 |
+
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
| 4502 |
+
|
| 4503 |
+
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
| 4504 |
+
|
| 4505 |
+
from_float(
|
| 4506 |
+
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
|
| 4507 |
+
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
|
| 4508 |
+
}
|
| 4509 |
+
}
|
| 4510 |
+
}
|
| 4511 |
+
}
|
| 4512 |
+
|
| 4513 |
+
void ggml_compute_forward_set_rows(
|
| 4514 |
+
const ggml_compute_params * params,
|
| 4515 |
+
ggml_tensor * dst) {
|
| 4516 |
+
|
| 4517 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 4518 |
+
|
| 4519 |
+
switch (src0->type) {
|
| 4520 |
+
case GGML_TYPE_F32:
|
| 4521 |
+
{
|
| 4522 |
+
ggml_compute_forward_set_rows_f32(params, dst);
|
| 4523 |
+
} break;
|
| 4524 |
+
default:
|
| 4525 |
+
{
|
| 4526 |
+
GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
|
| 4527 |
+
}
|
| 4528 |
+
}
|
| 4529 |
+
}
|
| 4530 |
+
|
| 4531 |
// ggml_compute_forward_get_rows_back
|
| 4532 |
|
| 4533 |
static void ggml_compute_forward_get_rows_back_f32_f16(
|
|
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
|
|
| 53 |
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 54 |
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 55 |
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
|
|
| 56 |
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 57 |
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 58 |
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
|
|
| 53 |
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 54 |
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 55 |
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 56 |
+
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 57 |
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 58 |
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 59 |
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
@@ -521,6 +521,22 @@ typedef struct {
|
|
| 521 |
uint64_t nb2;
|
| 522 |
} ggml_metal_kargs_get_rows;
|
| 523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
typedef struct {
|
| 525 |
int64_t ne00;
|
| 526 |
int64_t ne01;
|
|
|
|
| 521 |
uint64_t nb2;
|
| 522 |
} ggml_metal_kargs_get_rows;
|
| 523 |
|
| 524 |
+
typedef struct {
|
| 525 |
+
int32_t nk0;
|
| 526 |
+
int32_t ne01;
|
| 527 |
+
uint64_t nb01;
|
| 528 |
+
uint64_t nb02;
|
| 529 |
+
uint64_t nb03;
|
| 530 |
+
int32_t ne11;
|
| 531 |
+
int32_t ne12;
|
| 532 |
+
uint64_t nb10;
|
| 533 |
+
uint64_t nb11;
|
| 534 |
+
uint64_t nb12;
|
| 535 |
+
uint64_t nb1;
|
| 536 |
+
uint64_t nb2;
|
| 537 |
+
uint64_t nb3;
|
| 538 |
+
} ggml_metal_kargs_set_rows;
|
| 539 |
+
|
| 540 |
typedef struct {
|
| 541 |
int64_t ne00;
|
| 542 |
int64_t ne01;
|
|
@@ -202,6 +202,15 @@ enum ggml_metal_kernel_type {
|
|
| 202 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
| 203 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
| 204 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 206 |
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
| 207 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -1169,6 +1178,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1169 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 1170 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 1171 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1172 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
| 1173 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
| 1174 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
@@ -1635,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1635 |
const bool use_bfloat = ctx_dev->use_bfloat;
|
| 1636 |
|
| 1637 |
if (!use_bfloat) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1638 |
for (size_t i = 0, n = 3; i < n; ++i) {
|
| 1639 |
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
| 1640 |
return false;
|
|
@@ -1804,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1804 |
{
|
| 1805 |
return op->ne[3] == 1;
|
| 1806 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1807 |
default:
|
| 1808 |
return false;
|
| 1809 |
}
|
|
@@ -3777,13 +3820,74 @@ static bool ggml_metal_encode_node(
|
|
| 3777 |
};
|
| 3778 |
|
| 3779 |
[encoder setComputePipelineState:pipeline];
|
| 3780 |
-
[encoder
|
| 3781 |
-
[encoder setBuffer:
|
| 3782 |
-
[encoder setBuffer:
|
| 3783 |
-
[encoder
|
| 3784 |
|
| 3785 |
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
| 3786 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3787 |
case GGML_OP_RMS_NORM:
|
| 3788 |
{
|
| 3789 |
GGML_ASSERT(ne00 % 4 == 0);
|
|
|
|
| 202 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
| 203 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
| 204 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
| 205 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
|
| 206 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
|
| 207 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
|
| 208 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
|
| 209 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
|
| 210 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
|
| 211 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
| 212 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
| 213 |
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
| 214 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 215 |
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
| 216 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
|
|
| 1178 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 1179 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 1180 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 1181 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
|
| 1182 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
|
| 1183 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
|
| 1184 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
|
| 1185 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
|
| 1186 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
|
| 1187 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
| 1188 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
| 1189 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
| 1190 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
| 1191 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
| 1192 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
|
|
| 1653 |
const bool use_bfloat = ctx_dev->use_bfloat;
|
| 1654 |
|
| 1655 |
if (!use_bfloat) {
|
| 1656 |
+
if (op->type == GGML_TYPE_BF16) {
|
| 1657 |
+
return false;
|
| 1658 |
+
}
|
| 1659 |
+
|
| 1660 |
for (size_t i = 0, n = 3; i < n; ++i) {
|
| 1661 |
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
| 1662 |
return false;
|
|
|
|
| 1826 |
{
|
| 1827 |
return op->ne[3] == 1;
|
| 1828 |
}
|
| 1829 |
+
case GGML_OP_SET_ROWS:
|
| 1830 |
+
{
|
| 1831 |
+
if (op->src[0]->type != GGML_TYPE_F32) {
|
| 1832 |
+
return false;
|
| 1833 |
+
}
|
| 1834 |
+
|
| 1835 |
+
switch (op->type) {
|
| 1836 |
+
case GGML_TYPE_F32:
|
| 1837 |
+
case GGML_TYPE_F16:
|
| 1838 |
+
case GGML_TYPE_BF16:
|
| 1839 |
+
case GGML_TYPE_Q8_0:
|
| 1840 |
+
case GGML_TYPE_Q4_0:
|
| 1841 |
+
case GGML_TYPE_Q4_1:
|
| 1842 |
+
case GGML_TYPE_Q5_0:
|
| 1843 |
+
case GGML_TYPE_Q5_1:
|
| 1844 |
+
case GGML_TYPE_IQ4_NL:
|
| 1845 |
+
return true;
|
| 1846 |
+
default:
|
| 1847 |
+
return false;
|
| 1848 |
+
};
|
| 1849 |
+
}
|
| 1850 |
default:
|
| 1851 |
return false;
|
| 1852 |
}
|
|
|
|
| 3820 |
};
|
| 3821 |
|
| 3822 |
[encoder setComputePipelineState:pipeline];
|
| 3823 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 3824 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 3825 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 3826 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 3827 |
|
| 3828 |
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
| 3829 |
} break;
|
| 3830 |
+
case GGML_OP_SET_ROWS:
|
| 3831 |
+
{
|
| 3832 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 3833 |
+
|
| 3834 |
+
switch (dst->type) {
|
| 3835 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
|
| 3836 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
|
| 3837 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
|
| 3838 |
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
|
| 3839 |
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
|
| 3840 |
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
|
| 3841 |
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
|
| 3842 |
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
|
| 3843 |
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
|
| 3844 |
+
default: GGML_ABORT("not implemented");
|
| 3845 |
+
}
|
| 3846 |
+
|
| 3847 |
+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
|
| 3848 |
+
|
| 3849 |
+
int nth = 32; // SIMD width
|
| 3850 |
+
|
| 3851 |
+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
| 3852 |
+
nth *= 2;
|
| 3853 |
+
}
|
| 3854 |
+
|
| 3855 |
+
int nrptg = 1;
|
| 3856 |
+
if (nth > nk0) {
|
| 3857 |
+
nrptg = (nth + nk0 - 1)/nk0;
|
| 3858 |
+
nth = nk0;
|
| 3859 |
+
|
| 3860 |
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
| 3861 |
+
nrptg--;
|
| 3862 |
+
}
|
| 3863 |
+
}
|
| 3864 |
+
|
| 3865 |
+
nth = MIN(nth, nk0);
|
| 3866 |
+
|
| 3867 |
+
ggml_metal_kargs_set_rows args = {
|
| 3868 |
+
/*.nk0 =*/ nk0,
|
| 3869 |
+
/*.ne01 =*/ ne01,
|
| 3870 |
+
/*.nb01 =*/ nb01,
|
| 3871 |
+
/*.nb02 =*/ nb02,
|
| 3872 |
+
/*.nb03 =*/ nb03,
|
| 3873 |
+
/*.ne11 =*/ ne11,
|
| 3874 |
+
/*.ne12 =*/ ne12,
|
| 3875 |
+
/*.nb10 =*/ nb10,
|
| 3876 |
+
/*.nb11 =*/ nb11,
|
| 3877 |
+
/*.nb12 =*/ nb12,
|
| 3878 |
+
/*.nb1 =*/ nb1,
|
| 3879 |
+
/*.nb2 =*/ nb2,
|
| 3880 |
+
/*.nb3 =*/ nb3,
|
| 3881 |
+
};
|
| 3882 |
+
|
| 3883 |
+
[encoder setComputePipelineState:pipeline];
|
| 3884 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 3885 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 3886 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 3887 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 3888 |
+
|
| 3889 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
| 3890 |
+
} break;
|
| 3891 |
case GGML_OP_RMS_NORM:
|
| 3892 |
{
|
| 3893 |
GGML_ASSERT(ne00 % 4 == 0);
|
|
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
| 35 |
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
| 36 |
};
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
| 39 |
template <typename type4x4>
|
| 40 |
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
| 97 |
}
|
| 98 |
}
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
template <typename type4x4>
|
| 101 |
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
| 102 |
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
| 279 |
}
|
| 280 |
}
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
template <typename type4x4>
|
| 283 |
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
| 284 |
const float d = xb->d;
|
|
@@ -4410,6 +4608,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
|
|
| 4410 |
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
| 4411 |
#endif
|
| 4412 |
|
|
|
|
| 4413 |
kernel void kernel_cpy_f32_q8_0(
|
| 4414 |
constant ggml_metal_kargs_cpy & args,
|
| 4415 |
device const char * src0,
|
|
@@ -4433,23 +4632,7 @@ kernel void kernel_cpy_f32_q8_0(
|
|
| 4433 |
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
| 4434 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4435 |
|
| 4436 |
-
|
| 4437 |
-
|
| 4438 |
-
for (int j = 0; j < QK8_0; j++) {
|
| 4439 |
-
const float v = src[j];
|
| 4440 |
-
amax = MAX(amax, fabs(v));
|
| 4441 |
-
}
|
| 4442 |
-
|
| 4443 |
-
const float d = amax / ((1 << 7) - 1);
|
| 4444 |
-
const float id = d ? 1.0f/d : 0.0f;
|
| 4445 |
-
|
| 4446 |
-
dst_data[i00/QK8_0].d = d;
|
| 4447 |
-
|
| 4448 |
-
for (int j = 0; j < QK8_0; ++j) {
|
| 4449 |
-
const float x0 = src[j]*id;
|
| 4450 |
-
|
| 4451 |
-
dst_data[i00/QK8_0].qs[j] = round(x0);
|
| 4452 |
-
}
|
| 4453 |
}
|
| 4454 |
}
|
| 4455 |
|
|
@@ -4476,32 +4659,7 @@ kernel void kernel_cpy_f32_q4_0(
|
|
| 4476 |
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
| 4477 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4478 |
|
| 4479 |
-
|
| 4480 |
-
float max = 0.0f;
|
| 4481 |
-
|
| 4482 |
-
for (int j = 0; j < QK4_0; j++) {
|
| 4483 |
-
const float v = src[j];
|
| 4484 |
-
if (amax < fabs(v)) {
|
| 4485 |
-
amax = fabs(v);
|
| 4486 |
-
max = v;
|
| 4487 |
-
}
|
| 4488 |
-
}
|
| 4489 |
-
|
| 4490 |
-
const float d = max / -8;
|
| 4491 |
-
const float id = d ? 1.0f/d : 0.0f;
|
| 4492 |
-
|
| 4493 |
-
dst_data[i00/QK4_0].d = d;
|
| 4494 |
-
|
| 4495 |
-
for (int j = 0; j < QK4_0/2; ++j) {
|
| 4496 |
-
const float x0 = src[0 + j]*id;
|
| 4497 |
-
const float x1 = src[QK4_0/2 + j]*id;
|
| 4498 |
-
|
| 4499 |
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
| 4500 |
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
| 4501 |
-
|
| 4502 |
-
dst_data[i00/QK4_0].qs[j] = xi0;
|
| 4503 |
-
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
| 4504 |
-
}
|
| 4505 |
}
|
| 4506 |
}
|
| 4507 |
|
|
@@ -4528,31 +4686,7 @@ kernel void kernel_cpy_f32_q4_1(
|
|
| 4528 |
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
| 4529 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4530 |
|
| 4531 |
-
|
| 4532 |
-
float max = -FLT_MAX;
|
| 4533 |
-
|
| 4534 |
-
for (int j = 0; j < QK4_1; j++) {
|
| 4535 |
-
const float v = src[j];
|
| 4536 |
-
if (min > v) min = v;
|
| 4537 |
-
if (max < v) max = v;
|
| 4538 |
-
}
|
| 4539 |
-
|
| 4540 |
-
const float d = (max - min) / ((1 << 4) - 1);
|
| 4541 |
-
const float id = d ? 1.0f/d : 0.0f;
|
| 4542 |
-
|
| 4543 |
-
dst_data[i00/QK4_1].d = d;
|
| 4544 |
-
dst_data[i00/QK4_1].m = min;
|
| 4545 |
-
|
| 4546 |
-
for (int j = 0; j < QK4_1/2; ++j) {
|
| 4547 |
-
const float x0 = (src[0 + j] - min)*id;
|
| 4548 |
-
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
| 4549 |
-
|
| 4550 |
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
| 4551 |
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
| 4552 |
-
|
| 4553 |
-
dst_data[i00/QK4_1].qs[j] = xi0;
|
| 4554 |
-
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
| 4555 |
-
}
|
| 4556 |
}
|
| 4557 |
}
|
| 4558 |
|
|
@@ -4579,38 +4713,7 @@ kernel void kernel_cpy_f32_q5_0(
|
|
| 4579 |
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
| 4580 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4581 |
|
| 4582 |
-
|
| 4583 |
-
float max = 0.0f;
|
| 4584 |
-
|
| 4585 |
-
for (int j = 0; j < QK5_0; j++) {
|
| 4586 |
-
const float v = src[j];
|
| 4587 |
-
if (amax < fabs(v)) {
|
| 4588 |
-
amax = fabs(v);
|
| 4589 |
-
max = v;
|
| 4590 |
-
}
|
| 4591 |
-
}
|
| 4592 |
-
|
| 4593 |
-
const float d = max / -16;
|
| 4594 |
-
const float id = d ? 1.0f/d : 0.0f;
|
| 4595 |
-
|
| 4596 |
-
dst_data[i00/QK5_0].d = d;
|
| 4597 |
-
|
| 4598 |
-
uint32_t qh = 0;
|
| 4599 |
-
for (int j = 0; j < QK5_0/2; ++j) {
|
| 4600 |
-
const float x0 = src[0 + j]*id;
|
| 4601 |
-
const float x1 = src[QK5_0/2 + j]*id;
|
| 4602 |
-
|
| 4603 |
-
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
| 4604 |
-
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
| 4605 |
-
|
| 4606 |
-
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
| 4607 |
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
| 4608 |
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
| 4609 |
-
}
|
| 4610 |
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
| 4611 |
-
for (int j = 0; j < 4; ++j) {
|
| 4612 |
-
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
| 4613 |
-
}
|
| 4614 |
}
|
| 4615 |
}
|
| 4616 |
|
|
@@ -4637,49 +4740,8 @@ kernel void kernel_cpy_f32_q5_1(
|
|
| 4637 |
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
| 4638 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4639 |
|
| 4640 |
-
|
| 4641 |
-
float min = src[0];
|
| 4642 |
-
|
| 4643 |
-
for (int j = 1; j < QK5_1; j++) {
|
| 4644 |
-
const float v = src[j];
|
| 4645 |
-
min = v < min ? v : min;
|
| 4646 |
-
max = v > max ? v : max;
|
| 4647 |
-
}
|
| 4648 |
-
|
| 4649 |
-
const float d = (max - min) / 31;
|
| 4650 |
-
const float id = d ? 1.0f/d : 0.0f;
|
| 4651 |
-
|
| 4652 |
-
dst_data[i00/QK5_1].d = d;
|
| 4653 |
-
dst_data[i00/QK5_1].m = min;
|
| 4654 |
-
|
| 4655 |
-
uint32_t qh = 0;
|
| 4656 |
-
for (int j = 0; j < QK5_1/2; ++j) {
|
| 4657 |
-
const float x0 = (src[0 + j] - min)*id;
|
| 4658 |
-
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
| 4659 |
-
|
| 4660 |
-
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
| 4661 |
-
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
| 4662 |
-
|
| 4663 |
-
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
| 4664 |
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
| 4665 |
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
| 4666 |
-
}
|
| 4667 |
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
| 4668 |
-
for (int j = 0; j < 4; ++j) {
|
| 4669 |
-
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
| 4670 |
-
}
|
| 4671 |
-
}
|
| 4672 |
-
}
|
| 4673 |
-
|
| 4674 |
-
static inline int best_index_int8(int n, constant float * val, float x) {
|
| 4675 |
-
if (x <= val[0]) return 0;
|
| 4676 |
-
if (x >= val[n-1]) return n-1;
|
| 4677 |
-
int ml = 0, mu = n-1;
|
| 4678 |
-
while (mu-ml > 1) {
|
| 4679 |
-
int mav = (ml+mu)/2;
|
| 4680 |
-
if (x < val[mav]) mu = mav; else ml = mav;
|
| 4681 |
}
|
| 4682 |
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
| 4683 |
}
|
| 4684 |
|
| 4685 |
kernel void kernel_cpy_f32_iq4_nl(
|
|
@@ -4705,40 +4767,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
| 4705 |
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
| 4706 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4707 |
|
| 4708 |
-
|
| 4709 |
-
float max = 0.0f;
|
| 4710 |
-
|
| 4711 |
-
for (int j = 0; j < QK4_NL; j++) {
|
| 4712 |
-
const float v = src[j];
|
| 4713 |
-
if (amax < fabs(v)) {
|
| 4714 |
-
amax = fabs(v);
|
| 4715 |
-
max = v;
|
| 4716 |
-
}
|
| 4717 |
-
}
|
| 4718 |
-
|
| 4719 |
-
const float d = max / kvalues_iq4nl_f[0];
|
| 4720 |
-
const float id = d ? 1.0f/d : 0.0f;
|
| 4721 |
-
|
| 4722 |
-
float sumqx = 0, sumq2 = 0;
|
| 4723 |
-
for (int j = 0; j < QK4_NL/2; ++j) {
|
| 4724 |
-
const float x0 = src[0 + j]*id;
|
| 4725 |
-
const float x1 = src[QK4_NL/2 + j]*id;
|
| 4726 |
-
|
| 4727 |
-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
| 4728 |
-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
| 4729 |
-
|
| 4730 |
-
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
| 4731 |
-
|
| 4732 |
-
const float v0 = kvalues_iq4nl_f[xi0];
|
| 4733 |
-
const float v1 = kvalues_iq4nl_f[xi1];
|
| 4734 |
-
const float w0 = src[0 + j]*src[0 + j];
|
| 4735 |
-
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
| 4736 |
-
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
| 4737 |
-
sumq2 += w0*v0*v0 + w1*v1*v1;
|
| 4738 |
-
|
| 4739 |
-
}
|
| 4740 |
-
|
| 4741 |
-
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
| 4742 |
}
|
| 4743 |
}
|
| 4744 |
|
|
@@ -6419,10 +6448,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
| 6419 |
|
| 6420 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 6421 |
kernel void kernel_get_rows_q(
|
|
|
|
| 6422 |
device const void * src0,
|
| 6423 |
device const void * src1,
|
| 6424 |
device float * dst,
|
| 6425 |
-
constant ggml_metal_kargs_get_rows & args,
|
| 6426 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6427 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6428 |
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6442,10 +6471,10 @@ kernel void kernel_get_rows_q(
|
|
| 6442 |
|
| 6443 |
template<typename T>
|
| 6444 |
kernel void kernel_get_rows_f(
|
|
|
|
| 6445 |
device const void * src0,
|
| 6446 |
device const void * src1,
|
| 6447 |
device float * dst,
|
| 6448 |
-
constant ggml_metal_kargs_get_rows & args,
|
| 6449 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6450 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6451 |
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6463,10 +6492,10 @@ kernel void kernel_get_rows_f(
|
|
| 6463 |
}
|
| 6464 |
|
| 6465 |
kernel void kernel_get_rows_i32(
|
|
|
|
| 6466 |
device const void * src0,
|
| 6467 |
device const void * src1,
|
| 6468 |
device int32_t * dst,
|
| 6469 |
-
constant ggml_metal_kargs_get_rows & args,
|
| 6470 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6471 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6472 |
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6483,6 +6512,67 @@ kernel void kernel_get_rows_i32(
|
|
| 6483 |
}
|
| 6484 |
}
|
| 6485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6486 |
|
| 6487 |
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
| 6488 |
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
@@ -6906,6 +6996,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
|
|
| 6906 |
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 6907 |
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 6908 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6909 |
//
|
| 6910 |
// matrix-matrix multiplication
|
| 6911 |
//
|
|
|
|
| 35 |
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
| 36 |
};
|
| 37 |
|
| 38 |
+
static inline int best_index_int8(int n, constant float * val, float x) {
|
| 39 |
+
if (x <= val[0]) return 0;
|
| 40 |
+
if (x >= val[n-1]) return n-1;
|
| 41 |
+
int ml = 0, mu = n-1;
|
| 42 |
+
while (mu-ml > 1) {
|
| 43 |
+
int mav = (ml+mu)/2;
|
| 44 |
+
if (x < val[mav]) mu = mav; else ml = mav;
|
| 45 |
+
}
|
| 46 |
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
| 50 |
template <typename type4x4>
|
| 51 |
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
|
|
| 108 |
}
|
| 109 |
}
|
| 110 |
|
| 111 |
+
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
| 112 |
+
float amax = 0.0f; // absolute max
|
| 113 |
+
float max = 0.0f;
|
| 114 |
+
|
| 115 |
+
for (int j = 0; j < QK4_0; j++) {
|
| 116 |
+
const float v = src[j];
|
| 117 |
+
if (amax < fabs(v)) {
|
| 118 |
+
amax = fabs(v);
|
| 119 |
+
max = v;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
const float d = max / -8;
|
| 124 |
+
const float id = d ? 1.0f/d : 0.0f;
|
| 125 |
+
|
| 126 |
+
dst.d = d;
|
| 127 |
+
|
| 128 |
+
for (int j = 0; j < QK4_0/2; ++j) {
|
| 129 |
+
const float x0 = src[0 + j]*id;
|
| 130 |
+
const float x1 = src[QK4_0/2 + j]*id;
|
| 131 |
+
|
| 132 |
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
| 133 |
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
| 134 |
+
|
| 135 |
+
dst.qs[j] = xi0;
|
| 136 |
+
dst.qs[j] |= xi1 << 4;
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
| 141 |
+
float min = FLT_MAX;
|
| 142 |
+
float max = -FLT_MAX;
|
| 143 |
+
|
| 144 |
+
for (int j = 0; j < QK4_1; j++) {
|
| 145 |
+
const float v = src[j];
|
| 146 |
+
if (min > v) min = v;
|
| 147 |
+
if (max < v) max = v;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
const float d = (max - min) / ((1 << 4) - 1);
|
| 151 |
+
const float id = d ? 1.0f/d : 0.0f;
|
| 152 |
+
|
| 153 |
+
dst.d = d;
|
| 154 |
+
dst.m = min;
|
| 155 |
+
|
| 156 |
+
for (int j = 0; j < QK4_1/2; ++j) {
|
| 157 |
+
const float x0 = (src[0 + j] - min)*id;
|
| 158 |
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
| 159 |
+
|
| 160 |
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
| 161 |
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
| 162 |
+
|
| 163 |
+
dst.qs[j] = xi0;
|
| 164 |
+
dst.qs[j] |= xi1 << 4;
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
| 169 |
+
float amax = 0.0f; // absolute max
|
| 170 |
+
float max = 0.0f;
|
| 171 |
+
|
| 172 |
+
for (int j = 0; j < QK5_0; j++) {
|
| 173 |
+
const float v = src[j];
|
| 174 |
+
if (amax < fabs(v)) {
|
| 175 |
+
amax = fabs(v);
|
| 176 |
+
max = v;
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
const float d = max / -16;
|
| 181 |
+
const float id = d ? 1.0f/d : 0.0f;
|
| 182 |
+
|
| 183 |
+
dst.d = d;
|
| 184 |
+
|
| 185 |
+
uint32_t qh = 0;
|
| 186 |
+
for (int j = 0; j < QK5_0/2; ++j) {
|
| 187 |
+
const float x0 = src[0 + j]*id;
|
| 188 |
+
const float x1 = src[QK5_0/2 + j]*id;
|
| 189 |
+
|
| 190 |
+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
| 191 |
+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
| 192 |
+
|
| 193 |
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
| 194 |
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
| 195 |
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
| 199 |
+
|
| 200 |
+
for (int j = 0; j < 4; ++j) {
|
| 201 |
+
dst.qh[j] = qh8[j];
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
| 206 |
+
float max = src[0];
|
| 207 |
+
float min = src[0];
|
| 208 |
+
|
| 209 |
+
for (int j = 1; j < QK5_1; j++) {
|
| 210 |
+
const float v = src[j];
|
| 211 |
+
min = v < min ? v : min;
|
| 212 |
+
max = v > max ? v : max;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
const float d = (max - min) / 31;
|
| 216 |
+
const float id = d ? 1.0f/d : 0.0f;
|
| 217 |
+
|
| 218 |
+
dst.d = d;
|
| 219 |
+
dst.m = min;
|
| 220 |
+
|
| 221 |
+
uint32_t qh = 0;
|
| 222 |
+
for (int j = 0; j < QK5_1/2; ++j) {
|
| 223 |
+
const float x0 = (src[0 + j] - min)*id;
|
| 224 |
+
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
| 225 |
+
|
| 226 |
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
| 227 |
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
| 228 |
+
|
| 229 |
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
| 230 |
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
| 231 |
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
| 235 |
+
|
| 236 |
+
for (int j = 0; j < 4; ++j) {
|
| 237 |
+
dst.qh[j] = qh8[j];
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
| 242 |
+
float amax = 0.0f; // absolute max
|
| 243 |
+
float max = 0.0f;
|
| 244 |
+
|
| 245 |
+
for (int j = 0; j < QK4_NL; j++) {
|
| 246 |
+
const float v = src[j];
|
| 247 |
+
if (amax < fabs(v)) {
|
| 248 |
+
amax = fabs(v);
|
| 249 |
+
max = v;
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
const float d = max / kvalues_iq4nl_f[0];
|
| 254 |
+
const float id = d ? 1.0f/d : 0.0f;
|
| 255 |
+
|
| 256 |
+
float sumqx = 0, sumq2 = 0;
|
| 257 |
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
| 258 |
+
const float x0 = src[0 + j]*id;
|
| 259 |
+
const float x1 = src[QK4_NL/2 + j]*id;
|
| 260 |
+
|
| 261 |
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
| 262 |
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
| 263 |
+
|
| 264 |
+
dst.qs[j] = xi0 | (xi1 << 4);
|
| 265 |
+
|
| 266 |
+
const float v0 = kvalues_iq4nl_f[xi0];
|
| 267 |
+
const float v1 = kvalues_iq4nl_f[xi1];
|
| 268 |
+
const float w0 = src[0 + j]*src[0 + j];
|
| 269 |
+
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
| 270 |
+
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
| 271 |
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
| 272 |
+
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
template <typename type4x4>
|
| 279 |
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
| 280 |
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
|
|
| 457 |
}
|
| 458 |
}
|
| 459 |
|
| 460 |
+
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
| 461 |
+
float amax = 0.0f; // absolute max
|
| 462 |
+
|
| 463 |
+
for (int j = 0; j < QK8_0; j++) {
|
| 464 |
+
const float v = src[j];
|
| 465 |
+
amax = MAX(amax, fabs(v));
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
const float d = amax / ((1 << 7) - 1);
|
| 469 |
+
const float id = d ? 1.0f/d : 0.0f;
|
| 470 |
+
|
| 471 |
+
dst.d = d;
|
| 472 |
+
|
| 473 |
+
for (int j = 0; j < QK8_0; ++j) {
|
| 474 |
+
const float x0 = src[j]*id;
|
| 475 |
+
|
| 476 |
+
dst.qs[j] = round(x0);
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
template <typename type4x4>
|
| 481 |
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
| 482 |
const float d = xb->d;
|
|
|
|
| 4608 |
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
| 4609 |
#endif
|
| 4610 |
|
| 4611 |
+
// TODO: templetify these kernels
|
| 4612 |
kernel void kernel_cpy_f32_q8_0(
|
| 4613 |
constant ggml_metal_kargs_cpy & args,
|
| 4614 |
device const char * src0,
|
|
|
|
| 4632 |
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
| 4633 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4634 |
|
| 4635 |
+
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4636 |
}
|
| 4637 |
}
|
| 4638 |
|
|
|
|
| 4659 |
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
| 4660 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4661 |
|
| 4662 |
+
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4663 |
}
|
| 4664 |
}
|
| 4665 |
|
|
|
|
| 4686 |
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
| 4687 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4688 |
|
| 4689 |
+
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4690 |
}
|
| 4691 |
}
|
| 4692 |
|
|
|
|
| 4713 |
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
| 4714 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4715 |
|
| 4716 |
+
quantize_q5_0(src, dst_data[i00/QK5_0]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4717 |
}
|
| 4718 |
}
|
| 4719 |
|
|
|
|
| 4740 |
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
| 4741 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4742 |
|
| 4743 |
+
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4744 |
}
|
|
|
|
| 4745 |
}
|
| 4746 |
|
| 4747 |
kernel void kernel_cpy_f32_iq4_nl(
|
|
|
|
| 4767 |
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
| 4768 |
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 4769 |
|
| 4770 |
+
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4771 |
}
|
| 4772 |
}
|
| 4773 |
|
|
|
|
| 6448 |
|
| 6449 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 6450 |
kernel void kernel_get_rows_q(
|
| 6451 |
+
constant ggml_metal_kargs_get_rows & args,
|
| 6452 |
device const void * src0,
|
| 6453 |
device const void * src1,
|
| 6454 |
device float * dst,
|
|
|
|
| 6455 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6456 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6457 |
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
| 6471 |
|
| 6472 |
template<typename T>
|
| 6473 |
kernel void kernel_get_rows_f(
|
| 6474 |
+
constant ggml_metal_kargs_get_rows & args,
|
| 6475 |
device const void * src0,
|
| 6476 |
device const void * src1,
|
| 6477 |
device float * dst,
|
|
|
|
| 6478 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6479 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6480 |
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
| 6492 |
}
|
| 6493 |
|
| 6494 |
kernel void kernel_get_rows_i32(
|
| 6495 |
+
constant ggml_metal_kargs_get_rows & args,
|
| 6496 |
device const void * src0,
|
| 6497 |
device const void * src1,
|
| 6498 |
device int32_t * dst,
|
|
|
|
| 6499 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6500 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6501 |
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
| 6512 |
}
|
| 6513 |
}
|
| 6514 |
|
| 6515 |
+
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
| 6516 |
+
kernel void kernel_set_rows_q32(
|
| 6517 |
+
constant ggml_metal_kargs_set_rows & args,
|
| 6518 |
+
device const void * src0,
|
| 6519 |
+
device const void * src1,
|
| 6520 |
+
device float * dst,
|
| 6521 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6522 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 6523 |
+
uint3 tptg [[threads_per_threadgroup]]) {
|
| 6524 |
+
const int32_t i03 = tgpig.z;
|
| 6525 |
+
const int32_t i02 = tgpig.y;
|
| 6526 |
+
|
| 6527 |
+
const int32_t i12 = i03%args.ne12;
|
| 6528 |
+
const int32_t i11 = i02%args.ne11;
|
| 6529 |
+
|
| 6530 |
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
| 6531 |
+
if (i01 >= args.ne01) {
|
| 6532 |
+
return;
|
| 6533 |
+
}
|
| 6534 |
+
|
| 6535 |
+
const int32_t i10 = i01;
|
| 6536 |
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
| 6537 |
+
|
| 6538 |
+
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
| 6539 |
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
| 6540 |
+
|
| 6541 |
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
| 6542 |
+
quantize_func(src_row + 32*ind, dst_row[ind]);
|
| 6543 |
+
}
|
| 6544 |
+
}
|
| 6545 |
+
|
| 6546 |
+
template<typename T>
|
| 6547 |
+
kernel void kernel_set_rows_f(
|
| 6548 |
+
constant ggml_metal_kargs_set_rows & args,
|
| 6549 |
+
device const void * src0,
|
| 6550 |
+
device const void * src1,
|
| 6551 |
+
device float * dst,
|
| 6552 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6553 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 6554 |
+
uint3 tptg [[threads_per_threadgroup]]) {
|
| 6555 |
+
const int32_t i03 = tgpig.z;
|
| 6556 |
+
const int32_t i02 = tgpig.y;
|
| 6557 |
+
|
| 6558 |
+
const int32_t i12 = i03%args.ne12;
|
| 6559 |
+
const int32_t i11 = i02%args.ne11;
|
| 6560 |
+
|
| 6561 |
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
| 6562 |
+
if (i01 >= args.ne01) {
|
| 6563 |
+
return;
|
| 6564 |
+
}
|
| 6565 |
+
|
| 6566 |
+
const int32_t i10 = i01;
|
| 6567 |
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
| 6568 |
+
|
| 6569 |
+
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
| 6570 |
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
| 6571 |
+
|
| 6572 |
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
| 6573 |
+
dst_row[ind] = (T) src_row[ind];
|
| 6574 |
+
}
|
| 6575 |
+
}
|
| 6576 |
|
| 6577 |
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
| 6578 |
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
|
|
| 6996 |
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 6997 |
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 6998 |
|
| 6999 |
+
//
|
| 7000 |
+
// set rows
|
| 7001 |
+
//
|
| 7002 |
+
|
| 7003 |
+
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
|
| 7004 |
+
|
| 7005 |
+
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
|
| 7006 |
+
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
|
| 7007 |
+
#if defined(GGML_METAL_USE_BF16)
|
| 7008 |
+
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
|
| 7009 |
+
#endif
|
| 7010 |
+
|
| 7011 |
+
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
| 7012 |
+
|
| 7013 |
+
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
|
| 7014 |
+
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
|
| 7015 |
+
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
|
| 7016 |
+
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
|
| 7017 |
+
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
|
| 7018 |
+
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
|
| 7019 |
+
|
| 7020 |
//
|
| 7021 |
// matrix-matrix multiplication
|
| 7022 |
//
|
|
@@ -933,6 +933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 933 |
"TRANSPOSE",
|
| 934 |
"GET_ROWS",
|
| 935 |
"GET_ROWS_BACK",
|
|
|
|
| 936 |
"DIAG",
|
| 937 |
"DIAG_MASK_INF",
|
| 938 |
"DIAG_MASK_ZERO",
|
|
@@ -983,7 +984,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 983 |
"OPT_STEP_ADAMW",
|
| 984 |
};
|
| 985 |
|
| 986 |
-
static_assert(GGML_OP_COUNT ==
|
| 987 |
|
| 988 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 989 |
"none",
|
|
@@ -1029,6 +1030,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1029 |
"transpose(x)",
|
| 1030 |
"get_rows(x)",
|
| 1031 |
"get_rows_back(x)",
|
|
|
|
| 1032 |
"diag(x)",
|
| 1033 |
"diag_mask_inf(x)",
|
| 1034 |
"diag_mask_zero(x)",
|
|
@@ -1079,7 +1081,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1079 |
"adamw(x)",
|
| 1080 |
};
|
| 1081 |
|
| 1082 |
-
static_assert(GGML_OP_COUNT ==
|
| 1083 |
|
| 1084 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1085 |
|
|
@@ -1348,6 +1350,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
|
|
| 1348 |
tensor->nb[2] == ggml_type_size(tensor->type);
|
| 1349 |
}
|
| 1350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1351 |
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
| 1352 |
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
| 1353 |
|
|
@@ -3384,6 +3392,35 @@ struct ggml_tensor * ggml_get_rows_back(
|
|
| 3384 |
return result;
|
| 3385 |
}
|
| 3386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3387 |
// ggml_diag
|
| 3388 |
|
| 3389 |
struct ggml_tensor * ggml_diag(
|
|
|
|
| 933 |
"TRANSPOSE",
|
| 934 |
"GET_ROWS",
|
| 935 |
"GET_ROWS_BACK",
|
| 936 |
+
"SET_ROWS",
|
| 937 |
"DIAG",
|
| 938 |
"DIAG_MASK_INF",
|
| 939 |
"DIAG_MASK_ZERO",
|
|
|
|
| 984 |
"OPT_STEP_ADAMW",
|
| 985 |
};
|
| 986 |
|
| 987 |
+
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
| 988 |
|
| 989 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 990 |
"none",
|
|
|
|
| 1030 |
"transpose(x)",
|
| 1031 |
"get_rows(x)",
|
| 1032 |
"get_rows_back(x)",
|
| 1033 |
+
"set_rows(x)",
|
| 1034 |
"diag(x)",
|
| 1035 |
"diag_mask_inf(x)",
|
| 1036 |
"diag_mask_zero(x)",
|
|
|
|
| 1081 |
"adamw(x)",
|
| 1082 |
};
|
| 1083 |
|
| 1084 |
+
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
| 1085 |
|
| 1086 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1087 |
|
|
|
|
| 1350 |
tensor->nb[2] == ggml_type_size(tensor->type);
|
| 1351 |
}
|
| 1352 |
|
| 1353 |
+
bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
|
| 1354 |
+
return
|
| 1355 |
+
tensor->ne[0] == ggml_blck_size(tensor->type) ||
|
| 1356 |
+
tensor->nb[0] == ggml_type_size(tensor->type);
|
| 1357 |
+
}
|
| 1358 |
+
|
| 1359 |
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
| 1360 |
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
| 1361 |
|
|
|
|
| 3392 |
return result;
|
| 3393 |
}
|
| 3394 |
|
| 3395 |
+
// ggml_set_rows
|
| 3396 |
+
|
| 3397 |
+
struct ggml_tensor * ggml_set_rows(
|
| 3398 |
+
struct ggml_context * ctx,
|
| 3399 |
+
struct ggml_tensor * a,
|
| 3400 |
+
struct ggml_tensor * b,
|
| 3401 |
+
struct ggml_tensor * c) {
|
| 3402 |
+
GGML_ASSERT(a->ne[0] == b->ne[0]);
|
| 3403 |
+
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
| 3404 |
+
GGML_ASSERT(a->ne[3] == b->ne[3]);
|
| 3405 |
+
GGML_ASSERT(b->ne[1] == c->ne[0]);
|
| 3406 |
+
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
|
| 3407 |
+
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
|
| 3408 |
+
GGML_ASSERT(c->ne[3] == 1);
|
| 3409 |
+
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
| 3410 |
+
GGML_ASSERT(c->type == GGML_TYPE_I64);
|
| 3411 |
+
|
| 3412 |
+
GGML_ASSERT(ggml_is_contiguous_rows(a));
|
| 3413 |
+
GGML_ASSERT(ggml_is_contiguous_rows(b));
|
| 3414 |
+
|
| 3415 |
+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
| 3416 |
+
|
| 3417 |
+
result->op = GGML_OP_SET_ROWS;
|
| 3418 |
+
result->src[0] = b;
|
| 3419 |
+
result->src[1] = c;
|
| 3420 |
+
|
| 3421 |
+
return result;
|
| 3422 |
+
}
|
| 3423 |
+
|
| 3424 |
// ggml_diag
|
| 3425 |
|
| 3426 |
struct ggml_tensor * ggml_diag(
|