rgerganov ggerganov commited on
Commit
ac46a22
·
1 Parent(s): 7988638

ggml : add ggml_set_rows (llama/14274)

Browse files

* ggml : add ggml_set_rows

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.

ref: #8366

* use I64 for indices

* ggml : add repeat impl for i64

* ggml : add ggml_is_contiguous_rows

* ggml : ggml_set_rows support broadcast

* ggml : ggml_set_rows support quantized dst

ggml-ci

* ggml : support GGML_TYPE_F32 ".from_float" trait

* ggml : ggml_set_rows update comment + better index name

* tests : add ggml_set_rows

* metal : add ggml_set_rows implementation

ggml-ci

* ggml : simplify forward_dup_f32

* ggml : fix supports_op

* tests : add comment to set_rows

* ggml : leave the repeat_i64 for a separate PR

ggml-ci

* ggml : set_rows use std::min instead of MIN

* ggml : better error message for set_rows unsupported type

* metal : perform op->type check only once

* tests : more consistent implementation + more tests

ggml-ci

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ggml/include/ggml-cpu.h CHANGED
@@ -134,6 +134,7 @@ extern "C" {
134
 
135
  GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
136
 
 
137
  GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
138
  GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
139
  GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
 
134
 
135
  GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
136
 
137
+ GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
138
  GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
139
  GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
140
  GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
ggml/include/ggml.h CHANGED
@@ -470,6 +470,7 @@ extern "C" {
470
  GGML_OP_TRANSPOSE,
471
  GGML_OP_GET_ROWS,
472
  GGML_OP_GET_ROWS_BACK,
 
473
  GGML_OP_DIAG,
474
  GGML_OP_DIAG_MASK_INF,
475
  GGML_OP_DIAG_MASK_ZERO,
@@ -687,6 +688,9 @@ extern "C" {
687
  // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
688
  GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
689
 
 
 
 
690
  GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
691
  GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
692
 
@@ -1375,6 +1379,23 @@ extern "C" {
1375
  struct ggml_tensor * b, // row indices
1376
  struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
1377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1378
  GGML_API struct ggml_tensor * ggml_diag(
1379
  struct ggml_context * ctx,
1380
  struct ggml_tensor * a);
 
470
  GGML_OP_TRANSPOSE,
471
  GGML_OP_GET_ROWS,
472
  GGML_OP_GET_ROWS_BACK,
473
+ GGML_OP_SET_ROWS,
474
  GGML_OP_DIAG,
475
  GGML_OP_DIAG_MASK_INF,
476
  GGML_OP_DIAG_MASK_ZERO,
 
688
  // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
689
  GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
690
 
691
+ // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
692
+ GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
693
+
694
  GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
695
  GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
696
 
 
1379
  struct ggml_tensor * b, // row indices
1380
  struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
1381
 
1382
+ // a TD [n_embd, ne1, ne2, ne3]
1383
+ // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
1384
+ // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
1385
+ //
1386
+ // undefined behavior if destination rows overlap
1387
+ //
1388
+ // broadcast:
1389
+ // ne2 % ne11 == 0
1390
+ // ne3 % ne12 == 0
1391
+ //
1392
+ // return view(a)
1393
+ GGML_API struct ggml_tensor * ggml_set_rows(
1394
+ struct ggml_context * ctx,
1395
+ struct ggml_tensor * a, // destination
1396
+ struct ggml_tensor * b, // source
1397
+ struct ggml_tensor * c); // row indices
1398
+
1399
  GGML_API struct ggml_tensor * ggml_diag(
1400
  struct ggml_context * ctx,
1401
  struct ggml_tensor * a);
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -195,6 +195,7 @@ typedef pthread_t ggml_thread_t;
195
 
196
  static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
197
  [GGML_TYPE_F32] = {
 
198
  .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
199
  .vec_dot_type = GGML_TYPE_F32,
200
  .nrows = 1,
@@ -1817,6 +1818,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1817
  {
1818
  ggml_compute_forward_get_rows_back(params, tensor);
1819
  } break;
 
 
 
 
1820
  case GGML_OP_DIAG:
1821
  {
1822
  ggml_compute_forward_diag(params, tensor);
@@ -2170,6 +2175,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2170
  n_tasks = n_threads;
2171
  } break;
2172
  case GGML_OP_GET_ROWS:
 
2173
  {
2174
  // FIXME: get_rows can use additional threads, but the cost of launching additional threads
2175
  // decreases performance with GPU offloading
@@ -3124,6 +3130,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
3124
  return ggml_graph_compute(cgraph, &cplan);
3125
  }
3126
 
 
 
 
 
3127
  void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
3128
  int64_t i = 0;
3129
  #if defined(__F16C__)
 
195
 
196
  static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
197
  [GGML_TYPE_F32] = {
198
+ .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
199
  .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
200
  .vec_dot_type = GGML_TYPE_F32,
201
  .nrows = 1,
 
1818
  {
1819
  ggml_compute_forward_get_rows_back(params, tensor);
1820
  } break;
1821
+ case GGML_OP_SET_ROWS:
1822
+ {
1823
+ ggml_compute_forward_set_rows(params, tensor);
1824
+ } break;
1825
  case GGML_OP_DIAG:
1826
  {
1827
  ggml_compute_forward_diag(params, tensor);
 
2175
  n_tasks = n_threads;
2176
  } break;
2177
  case GGML_OP_GET_ROWS:
2178
+ case GGML_OP_SET_ROWS:
2179
  {
2180
  // FIXME: get_rows can use additional threads, but the cost of launching additional threads
2181
  // decreases performance with GPU offloading
 
3130
  return ggml_graph_compute(cgraph, &cplan);
3131
  }
3132
 
3133
+ void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
3134
+ memcpy(y, x, n * sizeof(float));
3135
+ }
3136
+
3137
  void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
3138
  int64_t i = 0;
3139
  #if defined(__F16C__)
ggml/src/ggml-cpu/ggml-cpu.cpp CHANGED
@@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
416
 
417
  switch (op->op) {
418
  case GGML_OP_CPY:
 
419
  return
420
  op->type != GGML_TYPE_IQ3_XXS &&
421
  op->type != GGML_TYPE_IQ3_S &&
 
416
 
417
  switch (op->op) {
418
  case GGML_OP_CPY:
419
+ case GGML_OP_SET_ROWS:
420
  return
421
  op->type != GGML_TYPE_IQ3_XXS &&
422
  op->type != GGML_TYPE_IQ3_S &&
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
696
  if (ggml_is_contiguous(dst)) {
697
  // TODO: simplify
698
  if (nb00 == sizeof(float)) {
699
- if (dst->type == GGML_TYPE_F32) {
700
- size_t id = 0;
701
- const size_t rs = ne00 * nb00;
702
- char * dst_ptr = (char *) dst->data;
703
-
704
- for (int i03 = 0; i03 < ne03; i03++) {
705
- for (int i02 = 0; i02 < ne02; i02++) {
706
- id += rs * ir0;
707
- for (int i01 = ir0; i01 < ir1; i01++) {
708
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
709
- memcpy(dst_ptr + id, src0_ptr, rs);
710
- id += rs;
711
- }
712
- id += rs * (ne01 - ir1);
713
- }
714
- }
715
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
716
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
717
 
718
  size_t id = 0;
719
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
724
  id += rs * ir0;
725
  for (int i01 = ir0; i01 < ir1; i01++) {
726
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
727
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
728
  id += rs;
729
  }
730
  id += rs * (ne01 - ir1);
@@ -2300,6 +2284,12 @@ void ggml_compute_forward_repeat(
2300
  {
2301
  ggml_compute_forward_repeat_f32(params, dst);
2302
  } break;
 
 
 
 
 
 
2303
  default:
2304
  {
2305
  GGML_ABORT("fatal error");
@@ -4470,6 +4460,74 @@ void ggml_compute_forward_get_rows(
4470
  //}
4471
  }
4472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4473
  // ggml_compute_forward_get_rows_back
4474
 
4475
  static void ggml_compute_forward_get_rows_back_f32_f16(
 
696
  if (ggml_is_contiguous(dst)) {
697
  // TODO: simplify
698
  if (nb00 == sizeof(float)) {
699
+ if (ggml_get_type_traits_cpu(dst->type)->from_float) {
700
+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
  size_t id = 0;
703
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
 
708
  id += rs * ir0;
709
  for (int i01 = ir0; i01 < ir1; i01++) {
710
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
711
+ from_float(src0_ptr, dst_ptr + id, ne00);
712
  id += rs;
713
  }
714
  id += rs * (ne01 - ir1);
 
2284
  {
2285
  ggml_compute_forward_repeat_f32(params, dst);
2286
  } break;
2287
+ // TODO: templateify the implemenation and support for I64
2288
+ // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
2289
+ //case GGML_TYPE_I64:
2290
+ // {
2291
+ // ggml_compute_forward_repeat_i64(params, dst);
2292
+ // } break;
2293
  default:
2294
  {
2295
  GGML_ABORT("fatal error");
 
4460
  //}
4461
  }
4462
 
4463
+ static void ggml_compute_forward_set_rows_f32(
4464
+ const ggml_compute_params * params,
4465
+ ggml_tensor * dst) {
4466
+
4467
+ const ggml_tensor * src0 = dst->src[0];
4468
+ const ggml_tensor * src1 = dst->src[1];
4469
+
4470
+ GGML_TENSOR_BINARY_OP_LOCALS
4471
+
4472
+ const int64_t nc = ne00;
4473
+ const int64_t nr = ne01;
4474
+
4475
+ assert(ne0 == nc);
4476
+ assert(ne2 == ne02);
4477
+ assert(ne3 == ne03);
4478
+ assert(src0->type == GGML_TYPE_F32);
4479
+ assert(ne02 % ne11 == 0);
4480
+ assert(ne03 % ne12 == 0);
4481
+
4482
+ const int ith = params->ith;
4483
+ const int nth = params->nth;
4484
+
4485
+ // rows per thread
4486
+ const int64_t dr = (nr + nth - 1)/nth;
4487
+
4488
+ // row range for this thread
4489
+ const int64_t ir0 = dr*ith;
4490
+ const int64_t ir1 = std::min(ir0 + dr, nr);
4491
+
4492
+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
4493
+
4494
+ for (int64_t i03 = 0; i03 < ne03; ++i03) {
4495
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
4496
+ for (int64_t i = ir0; i < ir1; ++i) {
4497
+ const int64_t i12 = i03%ne12;
4498
+ const int64_t i11 = i02%ne11;
4499
+ const int64_t i10 = i;
4500
+
4501
+ const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4502
+
4503
+ GGML_ASSERT(i1 >= 0 && i1 < ne1);
4504
+
4505
+ from_float(
4506
+ (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4507
+ ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
4508
+ }
4509
+ }
4510
+ }
4511
+ }
4512
+
4513
+ void ggml_compute_forward_set_rows(
4514
+ const ggml_compute_params * params,
4515
+ ggml_tensor * dst) {
4516
+
4517
+ const ggml_tensor * src0 = dst->src[0];
4518
+
4519
+ switch (src0->type) {
4520
+ case GGML_TYPE_F32:
4521
+ {
4522
+ ggml_compute_forward_set_rows_f32(params, dst);
4523
+ } break;
4524
+ default:
4525
+ {
4526
+ GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
4527
+ }
4528
+ }
4529
+ }
4530
+
4531
  // ggml_compute_forward_get_rows_back
4532
 
4533
  static void ggml_compute_forward_get_rows_back_f32_f16(
ggml/src/ggml-cpu/ops.h CHANGED
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
53
  void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
54
  void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
55
  void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
56
  void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
57
  void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
58
  void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
53
  void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
54
  void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
55
  void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
56
+ void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
57
  void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
58
  void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
59
  void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -521,6 +521,22 @@ typedef struct {
521
  uint64_t nb2;
522
  } ggml_metal_kargs_get_rows;
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  typedef struct {
525
  int64_t ne00;
526
  int64_t ne01;
 
521
  uint64_t nb2;
522
  } ggml_metal_kargs_get_rows;
523
 
524
+ typedef struct {
525
+ int32_t nk0;
526
+ int32_t ne01;
527
+ uint64_t nb01;
528
+ uint64_t nb02;
529
+ uint64_t nb03;
530
+ int32_t ne11;
531
+ int32_t ne12;
532
+ uint64_t nb10;
533
+ uint64_t nb11;
534
+ uint64_t nb12;
535
+ uint64_t nb1;
536
+ uint64_t nb2;
537
+ uint64_t nb3;
538
+ } ggml_metal_kargs_set_rows;
539
+
540
  typedef struct {
541
  int64_t ne00;
542
  int64_t ne01;
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -202,6 +202,15 @@ enum ggml_metal_kernel_type {
202
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
 
 
 
 
 
 
 
 
 
205
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
206
  GGML_METAL_KERNEL_TYPE_L2_NORM,
207
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1169,6 +1178,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1169
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1170
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1171
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
 
 
 
 
 
 
 
 
 
1172
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1173
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1174
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1635,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1635
  const bool use_bfloat = ctx_dev->use_bfloat;
1636
 
1637
  if (!use_bfloat) {
 
 
 
 
1638
  for (size_t i = 0, n = 3; i < n; ++i) {
1639
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1640
  return false;
@@ -1804,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1804
  {
1805
  return op->ne[3] == 1;
1806
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1807
  default:
1808
  return false;
1809
  }
@@ -3777,13 +3820,74 @@ static bool ggml_metal_encode_node(
3777
  };
3778
 
3779
  [encoder setComputePipelineState:pipeline];
3780
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3781
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3782
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3783
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
3784
 
3785
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3786
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3787
  case GGML_OP_RMS_NORM:
3788
  {
3789
  GGML_ASSERT(ne00 % 4 == 0);
 
202
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
214
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
215
  GGML_METAL_KERNEL_TYPE_L2_NORM,
216
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
 
1178
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1179
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1182
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1190
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1191
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1192
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
 
1653
  const bool use_bfloat = ctx_dev->use_bfloat;
1654
 
1655
  if (!use_bfloat) {
1656
+ if (op->type == GGML_TYPE_BF16) {
1657
+ return false;
1658
+ }
1659
+
1660
  for (size_t i = 0, n = 3; i < n; ++i) {
1661
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1662
  return false;
 
1826
  {
1827
  return op->ne[3] == 1;
1828
  }
1829
+ case GGML_OP_SET_ROWS:
1830
+ {
1831
+ if (op->src[0]->type != GGML_TYPE_F32) {
1832
+ return false;
1833
+ }
1834
+
1835
+ switch (op->type) {
1836
+ case GGML_TYPE_F32:
1837
+ case GGML_TYPE_F16:
1838
+ case GGML_TYPE_BF16:
1839
+ case GGML_TYPE_Q8_0:
1840
+ case GGML_TYPE_Q4_0:
1841
+ case GGML_TYPE_Q4_1:
1842
+ case GGML_TYPE_Q5_0:
1843
+ case GGML_TYPE_Q5_1:
1844
+ case GGML_TYPE_IQ4_NL:
1845
+ return true;
1846
+ default:
1847
+ return false;
1848
+ };
1849
+ }
1850
  default:
1851
  return false;
1852
  }
 
3820
  };
3821
 
3822
  [encoder setComputePipelineState:pipeline];
3823
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3824
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3825
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3826
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3827
 
3828
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3829
  } break;
3830
+ case GGML_OP_SET_ROWS:
3831
+ {
3832
+ id<MTLComputePipelineState> pipeline = nil;
3833
+
3834
+ switch (dst->type) {
3835
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3836
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3837
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3838
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3839
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3840
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3841
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3842
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3843
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3844
+ default: GGML_ABORT("not implemented");
3845
+ }
3846
+
3847
+ const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3848
+
3849
+ int nth = 32; // SIMD width
3850
+
3851
+ while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3852
+ nth *= 2;
3853
+ }
3854
+
3855
+ int nrptg = 1;
3856
+ if (nth > nk0) {
3857
+ nrptg = (nth + nk0 - 1)/nk0;
3858
+ nth = nk0;
3859
+
3860
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3861
+ nrptg--;
3862
+ }
3863
+ }
3864
+
3865
+ nth = MIN(nth, nk0);
3866
+
3867
+ ggml_metal_kargs_set_rows args = {
3868
+ /*.nk0 =*/ nk0,
3869
+ /*.ne01 =*/ ne01,
3870
+ /*.nb01 =*/ nb01,
3871
+ /*.nb02 =*/ nb02,
3872
+ /*.nb03 =*/ nb03,
3873
+ /*.ne11 =*/ ne11,
3874
+ /*.ne12 =*/ ne12,
3875
+ /*.nb10 =*/ nb10,
3876
+ /*.nb11 =*/ nb11,
3877
+ /*.nb12 =*/ nb12,
3878
+ /*.nb1 =*/ nb1,
3879
+ /*.nb2 =*/ nb2,
3880
+ /*.nb3 =*/ nb3,
3881
+ };
3882
+
3883
+ [encoder setComputePipelineState:pipeline];
3884
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3885
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3886
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3887
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3888
+
3889
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3890
+ } break;
3891
  case GGML_OP_RMS_NORM:
3892
  {
3893
  GGML_ASSERT(ne00 % 4 == 0);
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
  };
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  // NOTE: this is not dequantizing - we are simply fitting the template
39
  template <typename type4x4>
40
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
97
  }
98
  }
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  template <typename type4x4>
101
  void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
102
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
279
  }
280
  }
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  template <typename type4x4>
283
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
284
  const float d = xb->d;
@@ -4410,6 +4608,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
4410
  template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
4411
  #endif
4412
 
 
4413
  kernel void kernel_cpy_f32_q8_0(
4414
  constant ggml_metal_kargs_cpy & args,
4415
  device const char * src0,
@@ -4433,23 +4632,7 @@ kernel void kernel_cpy_f32_q8_0(
4433
  for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
4434
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4435
 
4436
- float amax = 0.0f; // absolute max
4437
-
4438
- for (int j = 0; j < QK8_0; j++) {
4439
- const float v = src[j];
4440
- amax = MAX(amax, fabs(v));
4441
- }
4442
-
4443
- const float d = amax / ((1 << 7) - 1);
4444
- const float id = d ? 1.0f/d : 0.0f;
4445
-
4446
- dst_data[i00/QK8_0].d = d;
4447
-
4448
- for (int j = 0; j < QK8_0; ++j) {
4449
- const float x0 = src[j]*id;
4450
-
4451
- dst_data[i00/QK8_0].qs[j] = round(x0);
4452
- }
4453
  }
4454
  }
4455
 
@@ -4476,32 +4659,7 @@ kernel void kernel_cpy_f32_q4_0(
4476
  for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
4477
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4478
 
4479
- float amax = 0.0f; // absolute max
4480
- float max = 0.0f;
4481
-
4482
- for (int j = 0; j < QK4_0; j++) {
4483
- const float v = src[j];
4484
- if (amax < fabs(v)) {
4485
- amax = fabs(v);
4486
- max = v;
4487
- }
4488
- }
4489
-
4490
- const float d = max / -8;
4491
- const float id = d ? 1.0f/d : 0.0f;
4492
-
4493
- dst_data[i00/QK4_0].d = d;
4494
-
4495
- for (int j = 0; j < QK4_0/2; ++j) {
4496
- const float x0 = src[0 + j]*id;
4497
- const float x1 = src[QK4_0/2 + j]*id;
4498
-
4499
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
4500
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
4501
-
4502
- dst_data[i00/QK4_0].qs[j] = xi0;
4503
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
4504
- }
4505
  }
4506
  }
4507
 
@@ -4528,31 +4686,7 @@ kernel void kernel_cpy_f32_q4_1(
4528
  for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
4529
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4530
 
4531
- float min = FLT_MAX;
4532
- float max = -FLT_MAX;
4533
-
4534
- for (int j = 0; j < QK4_1; j++) {
4535
- const float v = src[j];
4536
- if (min > v) min = v;
4537
- if (max < v) max = v;
4538
- }
4539
-
4540
- const float d = (max - min) / ((1 << 4) - 1);
4541
- const float id = d ? 1.0f/d : 0.0f;
4542
-
4543
- dst_data[i00/QK4_1].d = d;
4544
- dst_data[i00/QK4_1].m = min;
4545
-
4546
- for (int j = 0; j < QK4_1/2; ++j) {
4547
- const float x0 = (src[0 + j] - min)*id;
4548
- const float x1 = (src[QK4_1/2 + j] - min)*id;
4549
-
4550
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
4551
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
4552
-
4553
- dst_data[i00/QK4_1].qs[j] = xi0;
4554
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
4555
- }
4556
  }
4557
  }
4558
 
@@ -4579,38 +4713,7 @@ kernel void kernel_cpy_f32_q5_0(
4579
  for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
4580
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4581
 
4582
- float amax = 0.0f; // absolute max
4583
- float max = 0.0f;
4584
-
4585
- for (int j = 0; j < QK5_0; j++) {
4586
- const float v = src[j];
4587
- if (amax < fabs(v)) {
4588
- amax = fabs(v);
4589
- max = v;
4590
- }
4591
- }
4592
-
4593
- const float d = max / -16;
4594
- const float id = d ? 1.0f/d : 0.0f;
4595
-
4596
- dst_data[i00/QK5_0].d = d;
4597
-
4598
- uint32_t qh = 0;
4599
- for (int j = 0; j < QK5_0/2; ++j) {
4600
- const float x0 = src[0 + j]*id;
4601
- const float x1 = src[QK5_0/2 + j]*id;
4602
-
4603
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
4604
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
4605
-
4606
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4607
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4608
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
4609
- }
4610
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4611
- for (int j = 0; j < 4; ++j) {
4612
- dst_data[i00/QK5_0].qh[j] = qh8[j];
4613
- }
4614
  }
4615
  }
4616
 
@@ -4637,49 +4740,8 @@ kernel void kernel_cpy_f32_q5_1(
4637
  for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
4638
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4639
 
4640
- float max = src[0];
4641
- float min = src[0];
4642
-
4643
- for (int j = 1; j < QK5_1; j++) {
4644
- const float v = src[j];
4645
- min = v < min ? v : min;
4646
- max = v > max ? v : max;
4647
- }
4648
-
4649
- const float d = (max - min) / 31;
4650
- const float id = d ? 1.0f/d : 0.0f;
4651
-
4652
- dst_data[i00/QK5_1].d = d;
4653
- dst_data[i00/QK5_1].m = min;
4654
-
4655
- uint32_t qh = 0;
4656
- for (int j = 0; j < QK5_1/2; ++j) {
4657
- const float x0 = (src[0 + j] - min)*id;
4658
- const float x1 = (src[QK5_1/2 + j] - min)*id;
4659
-
4660
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
4661
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
4662
-
4663
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4664
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4665
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
4666
- }
4667
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4668
- for (int j = 0; j < 4; ++j) {
4669
- dst_data[i00/QK5_1].qh[j] = qh8[j];
4670
- }
4671
- }
4672
- }
4673
-
4674
- static inline int best_index_int8(int n, constant float * val, float x) {
4675
- if (x <= val[0]) return 0;
4676
- if (x >= val[n-1]) return n-1;
4677
- int ml = 0, mu = n-1;
4678
- while (mu-ml > 1) {
4679
- int mav = (ml+mu)/2;
4680
- if (x < val[mav]) mu = mav; else ml = mav;
4681
  }
4682
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4683
  }
4684
 
4685
  kernel void kernel_cpy_f32_iq4_nl(
@@ -4705,40 +4767,7 @@ kernel void kernel_cpy_f32_iq4_nl(
4705
  for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
4706
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4707
 
4708
- float amax = 0.0f; // absolute max
4709
- float max = 0.0f;
4710
-
4711
- for (int j = 0; j < QK4_NL; j++) {
4712
- const float v = src[j];
4713
- if (amax < fabs(v)) {
4714
- amax = fabs(v);
4715
- max = v;
4716
- }
4717
- }
4718
-
4719
- const float d = max / kvalues_iq4nl_f[0];
4720
- const float id = d ? 1.0f/d : 0.0f;
4721
-
4722
- float sumqx = 0, sumq2 = 0;
4723
- for (int j = 0; j < QK4_NL/2; ++j) {
4724
- const float x0 = src[0 + j]*id;
4725
- const float x1 = src[QK4_NL/2 + j]*id;
4726
-
4727
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
4728
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
4729
-
4730
- dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
4731
-
4732
- const float v0 = kvalues_iq4nl_f[xi0];
4733
- const float v1 = kvalues_iq4nl_f[xi1];
4734
- const float w0 = src[0 + j]*src[0 + j];
4735
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
4736
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
4737
- sumq2 += w0*v0*v0 + w1*v1*v1;
4738
-
4739
- }
4740
-
4741
- dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
4742
  }
4743
  }
4744
 
@@ -6419,10 +6448,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
6419
 
6420
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
6421
  kernel void kernel_get_rows_q(
 
6422
  device const void * src0,
6423
  device const void * src1,
6424
  device float * dst,
6425
- constant ggml_metal_kargs_get_rows & args,
6426
  uint3 tgpig[[threadgroup_position_in_grid]],
6427
  uint tiitg[[thread_index_in_threadgroup]],
6428
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6442,10 +6471,10 @@ kernel void kernel_get_rows_q(
6442
 
6443
  template<typename T>
6444
  kernel void kernel_get_rows_f(
 
6445
  device const void * src0,
6446
  device const void * src1,
6447
  device float * dst,
6448
- constant ggml_metal_kargs_get_rows & args,
6449
  uint3 tgpig[[threadgroup_position_in_grid]],
6450
  uint tiitg[[thread_index_in_threadgroup]],
6451
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6463,10 +6492,10 @@ kernel void kernel_get_rows_f(
6463
  }
6464
 
6465
  kernel void kernel_get_rows_i32(
 
6466
  device const void * src0,
6467
  device const void * src1,
6468
  device int32_t * dst,
6469
- constant ggml_metal_kargs_get_rows & args,
6470
  uint3 tgpig[[threadgroup_position_in_grid]],
6471
  uint tiitg[[thread_index_in_threadgroup]],
6472
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6483,6 +6512,67 @@ kernel void kernel_get_rows_i32(
6483
  }
6484
  }
6485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6486
 
6487
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
6488
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6906,6 +6996,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
6906
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
6907
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6909
  //
6910
  // matrix-matrix multiplication
6911
  //
 
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
  };
37
 
38
+ static inline int best_index_int8(int n, constant float * val, float x) {
39
+ if (x <= val[0]) return 0;
40
+ if (x >= val[n-1]) return n-1;
41
+ int ml = 0, mu = n-1;
42
+ while (mu-ml > 1) {
43
+ int mav = (ml+mu)/2;
44
+ if (x < val[mav]) mu = mav; else ml = mav;
45
+ }
46
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
47
+ }
48
+
49
  // NOTE: this is not dequantizing - we are simply fitting the template
50
  template <typename type4x4>
51
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
 
108
  }
109
  }
110
 
111
+ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
112
+ float amax = 0.0f; // absolute max
113
+ float max = 0.0f;
114
+
115
+ for (int j = 0; j < QK4_0; j++) {
116
+ const float v = src[j];
117
+ if (amax < fabs(v)) {
118
+ amax = fabs(v);
119
+ max = v;
120
+ }
121
+ }
122
+
123
+ const float d = max / -8;
124
+ const float id = d ? 1.0f/d : 0.0f;
125
+
126
+ dst.d = d;
127
+
128
+ for (int j = 0; j < QK4_0/2; ++j) {
129
+ const float x0 = src[0 + j]*id;
130
+ const float x1 = src[QK4_0/2 + j]*id;
131
+
132
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
133
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
134
+
135
+ dst.qs[j] = xi0;
136
+ dst.qs[j] |= xi1 << 4;
137
+ }
138
+ }
139
+
140
+ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
141
+ float min = FLT_MAX;
142
+ float max = -FLT_MAX;
143
+
144
+ for (int j = 0; j < QK4_1; j++) {
145
+ const float v = src[j];
146
+ if (min > v) min = v;
147
+ if (max < v) max = v;
148
+ }
149
+
150
+ const float d = (max - min) / ((1 << 4) - 1);
151
+ const float id = d ? 1.0f/d : 0.0f;
152
+
153
+ dst.d = d;
154
+ dst.m = min;
155
+
156
+ for (int j = 0; j < QK4_1/2; ++j) {
157
+ const float x0 = (src[0 + j] - min)*id;
158
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
159
+
160
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
161
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
162
+
163
+ dst.qs[j] = xi0;
164
+ dst.qs[j] |= xi1 << 4;
165
+ }
166
+ }
167
+
168
+ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
169
+ float amax = 0.0f; // absolute max
170
+ float max = 0.0f;
171
+
172
+ for (int j = 0; j < QK5_0; j++) {
173
+ const float v = src[j];
174
+ if (amax < fabs(v)) {
175
+ amax = fabs(v);
176
+ max = v;
177
+ }
178
+ }
179
+
180
+ const float d = max / -16;
181
+ const float id = d ? 1.0f/d : 0.0f;
182
+
183
+ dst.d = d;
184
+
185
+ uint32_t qh = 0;
186
+ for (int j = 0; j < QK5_0/2; ++j) {
187
+ const float x0 = src[0 + j]*id;
188
+ const float x1 = src[QK5_0/2 + j]*id;
189
+
190
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
191
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
192
+
193
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
194
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
195
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
196
+ }
197
+
198
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
199
+
200
+ for (int j = 0; j < 4; ++j) {
201
+ dst.qh[j] = qh8[j];
202
+ }
203
+ }
204
+
205
+ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
206
+ float max = src[0];
207
+ float min = src[0];
208
+
209
+ for (int j = 1; j < QK5_1; j++) {
210
+ const float v = src[j];
211
+ min = v < min ? v : min;
212
+ max = v > max ? v : max;
213
+ }
214
+
215
+ const float d = (max - min) / 31;
216
+ const float id = d ? 1.0f/d : 0.0f;
217
+
218
+ dst.d = d;
219
+ dst.m = min;
220
+
221
+ uint32_t qh = 0;
222
+ for (int j = 0; j < QK5_1/2; ++j) {
223
+ const float x0 = (src[0 + j] - min)*id;
224
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
225
+
226
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
227
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
228
+
229
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
230
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
231
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
232
+ }
233
+
234
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
235
+
236
+ for (int j = 0; j < 4; ++j) {
237
+ dst.qh[j] = qh8[j];
238
+ }
239
+ }
240
+
241
+ void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
242
+ float amax = 0.0f; // absolute max
243
+ float max = 0.0f;
244
+
245
+ for (int j = 0; j < QK4_NL; j++) {
246
+ const float v = src[j];
247
+ if (amax < fabs(v)) {
248
+ amax = fabs(v);
249
+ max = v;
250
+ }
251
+ }
252
+
253
+ const float d = max / kvalues_iq4nl_f[0];
254
+ const float id = d ? 1.0f/d : 0.0f;
255
+
256
+ float sumqx = 0, sumq2 = 0;
257
+ for (int j = 0; j < QK4_NL/2; ++j) {
258
+ const float x0 = src[0 + j]*id;
259
+ const float x1 = src[QK4_NL/2 + j]*id;
260
+
261
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
262
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
263
+
264
+ dst.qs[j] = xi0 | (xi1 << 4);
265
+
266
+ const float v0 = kvalues_iq4nl_f[xi0];
267
+ const float v1 = kvalues_iq4nl_f[xi1];
268
+ const float w0 = src[0 + j]*src[0 + j];
269
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
270
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
271
+ sumq2 += w0*v0*v0 + w1*v1*v1;
272
+
273
+ }
274
+
275
+ dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
276
+ }
277
+
278
  template <typename type4x4>
279
  void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
280
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
 
457
  }
458
  }
459
 
460
+ void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
461
+ float amax = 0.0f; // absolute max
462
+
463
+ for (int j = 0; j < QK8_0; j++) {
464
+ const float v = src[j];
465
+ amax = MAX(amax, fabs(v));
466
+ }
467
+
468
+ const float d = amax / ((1 << 7) - 1);
469
+ const float id = d ? 1.0f/d : 0.0f;
470
+
471
+ dst.d = d;
472
+
473
+ for (int j = 0; j < QK8_0; ++j) {
474
+ const float x0 = src[j]*id;
475
+
476
+ dst.qs[j] = round(x0);
477
+ }
478
+ }
479
+
480
  template <typename type4x4>
481
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
482
  const float d = xb->d;
 
4608
  template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
4609
  #endif
4610
 
4611
+ // TODO: templetify these kernels
4612
  kernel void kernel_cpy_f32_q8_0(
4613
  constant ggml_metal_kargs_cpy & args,
4614
  device const char * src0,
 
4632
  for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
4633
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4634
 
4635
+ quantize_q8_0(src, dst_data[i00/QK8_0]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4636
  }
4637
  }
4638
 
 
4659
  for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
4660
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4661
 
4662
+ quantize_q4_0(src, dst_data[i00/QK4_0]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4663
  }
4664
  }
4665
 
 
4686
  for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
4687
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4688
 
4689
+ quantize_q4_1(src, dst_data[i00/QK4_1]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4690
  }
4691
  }
4692
 
 
4713
  for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
4714
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4715
 
4716
+ quantize_q5_0(src, dst_data[i00/QK5_0]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4717
  }
4718
  }
4719
 
 
4740
  for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
4741
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4742
 
4743
+ quantize_q5_1(src, dst_data[i00/QK5_1]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4744
  }
 
4745
  }
4746
 
4747
  kernel void kernel_cpy_f32_iq4_nl(
 
4767
  for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
4768
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4769
 
4770
+ quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4771
  }
4772
  }
4773
 
 
6448
 
6449
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
6450
  kernel void kernel_get_rows_q(
6451
+ constant ggml_metal_kargs_get_rows & args,
6452
  device const void * src0,
6453
  device const void * src1,
6454
  device float * dst,
 
6455
  uint3 tgpig[[threadgroup_position_in_grid]],
6456
  uint tiitg[[thread_index_in_threadgroup]],
6457
  uint3 tptg [[threads_per_threadgroup]]) {
 
6471
 
6472
  template<typename T>
6473
  kernel void kernel_get_rows_f(
6474
+ constant ggml_metal_kargs_get_rows & args,
6475
  device const void * src0,
6476
  device const void * src1,
6477
  device float * dst,
 
6478
  uint3 tgpig[[threadgroup_position_in_grid]],
6479
  uint tiitg[[thread_index_in_threadgroup]],
6480
  uint3 tptg [[threads_per_threadgroup]]) {
 
6492
  }
6493
 
6494
  kernel void kernel_get_rows_i32(
6495
+ constant ggml_metal_kargs_get_rows & args,
6496
  device const void * src0,
6497
  device const void * src1,
6498
  device int32_t * dst,
 
6499
  uint3 tgpig[[threadgroup_position_in_grid]],
6500
  uint tiitg[[thread_index_in_threadgroup]],
6501
  uint3 tptg [[threads_per_threadgroup]]) {
 
6512
  }
6513
  }
6514
 
6515
+ template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
6516
+ kernel void kernel_set_rows_q32(
6517
+ constant ggml_metal_kargs_set_rows & args,
6518
+ device const void * src0,
6519
+ device const void * src1,
6520
+ device float * dst,
6521
+ uint3 tgpig[[threadgroup_position_in_grid]],
6522
+ uint tiitg[[thread_index_in_threadgroup]],
6523
+ uint3 tptg [[threads_per_threadgroup]]) {
6524
+ const int32_t i03 = tgpig.z;
6525
+ const int32_t i02 = tgpig.y;
6526
+
6527
+ const int32_t i12 = i03%args.ne12;
6528
+ const int32_t i11 = i02%args.ne11;
6529
+
6530
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6531
+ if (i01 >= args.ne01) {
6532
+ return;
6533
+ }
6534
+
6535
+ const int32_t i10 = i01;
6536
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6537
+
6538
+ device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6539
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6540
+
6541
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6542
+ quantize_func(src_row + 32*ind, dst_row[ind]);
6543
+ }
6544
+ }
6545
+
6546
+ template<typename T>
6547
+ kernel void kernel_set_rows_f(
6548
+ constant ggml_metal_kargs_set_rows & args,
6549
+ device const void * src0,
6550
+ device const void * src1,
6551
+ device float * dst,
6552
+ uint3 tgpig[[threadgroup_position_in_grid]],
6553
+ uint tiitg[[thread_index_in_threadgroup]],
6554
+ uint3 tptg [[threads_per_threadgroup]]) {
6555
+ const int32_t i03 = tgpig.z;
6556
+ const int32_t i02 = tgpig.y;
6557
+
6558
+ const int32_t i12 = i03%args.ne12;
6559
+ const int32_t i11 = i02%args.ne11;
6560
+
6561
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6562
+ if (i01 >= args.ne01) {
6563
+ return;
6564
+ }
6565
+
6566
+ const int32_t i10 = i01;
6567
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6568
+
6569
+ device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6570
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6571
+
6572
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6573
+ dst_row[ind] = (T) src_row[ind];
6574
+ }
6575
+ }
6576
 
6577
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
6578
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
 
6996
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
6997
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6998
 
6999
+ //
7000
+ // set rows
7001
+ //
7002
+
7003
+ typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
7004
+
7005
+ template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
7006
+ template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
7007
+ #if defined(GGML_METAL_USE_BF16)
7008
+ template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
7009
+ #endif
7010
+
7011
+ typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
7012
+
7013
+ template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
7014
+ template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
7015
+ template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
7016
+ template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
7017
+ template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
7018
+ template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
7019
+
7020
  //
7021
  // matrix-matrix multiplication
7022
  //
ggml/src/ggml.c CHANGED
@@ -933,6 +933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
933
  "TRANSPOSE",
934
  "GET_ROWS",
935
  "GET_ROWS_BACK",
 
936
  "DIAG",
937
  "DIAG_MASK_INF",
938
  "DIAG_MASK_ZERO",
@@ -983,7 +984,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
983
  "OPT_STEP_ADAMW",
984
  };
985
 
986
- static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
987
 
988
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
989
  "none",
@@ -1029,6 +1030,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1029
  "transpose(x)",
1030
  "get_rows(x)",
1031
  "get_rows_back(x)",
 
1032
  "diag(x)",
1033
  "diag_mask_inf(x)",
1034
  "diag_mask_zero(x)",
@@ -1079,7 +1081,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1079
  "adamw(x)",
1080
  };
1081
 
1082
- static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
1083
 
1084
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1085
 
@@ -1348,6 +1350,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1348
  tensor->nb[2] == ggml_type_size(tensor->type);
1349
  }
1350
 
 
 
 
 
 
 
1351
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1352
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1353
 
@@ -3384,6 +3392,35 @@ struct ggml_tensor * ggml_get_rows_back(
3384
  return result;
3385
  }
3386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3387
  // ggml_diag
3388
 
3389
  struct ggml_tensor * ggml_diag(
 
933
  "TRANSPOSE",
934
  "GET_ROWS",
935
  "GET_ROWS_BACK",
936
+ "SET_ROWS",
937
  "DIAG",
938
  "DIAG_MASK_INF",
939
  "DIAG_MASK_ZERO",
 
984
  "OPT_STEP_ADAMW",
985
  };
986
 
987
+ static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
988
 
989
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
990
  "none",
 
1030
  "transpose(x)",
1031
  "get_rows(x)",
1032
  "get_rows_back(x)",
1033
+ "set_rows(x)",
1034
  "diag(x)",
1035
  "diag_mask_inf(x)",
1036
  "diag_mask_zero(x)",
 
1081
  "adamw(x)",
1082
  };
1083
 
1084
+ static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
1085
 
1086
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1087
 
 
1350
  tensor->nb[2] == ggml_type_size(tensor->type);
1351
  }
1352
 
1353
+ bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
1354
+ return
1355
+ tensor->ne[0] == ggml_blck_size(tensor->type) ||
1356
+ tensor->nb[0] == ggml_type_size(tensor->type);
1357
+ }
1358
+
1359
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1360
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1361
 
 
3392
  return result;
3393
  }
3394
 
3395
+ // ggml_set_rows
3396
+
3397
+ struct ggml_tensor * ggml_set_rows(
3398
+ struct ggml_context * ctx,
3399
+ struct ggml_tensor * a,
3400
+ struct ggml_tensor * b,
3401
+ struct ggml_tensor * c) {
3402
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
3403
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
3404
+ GGML_ASSERT(a->ne[3] == b->ne[3]);
3405
+ GGML_ASSERT(b->ne[1] == c->ne[0]);
3406
+ GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3407
+ GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3408
+ GGML_ASSERT(c->ne[3] == 1);
3409
+ GGML_ASSERT(b->type == GGML_TYPE_F32);
3410
+ GGML_ASSERT(c->type == GGML_TYPE_I64);
3411
+
3412
+ GGML_ASSERT(ggml_is_contiguous_rows(a));
3413
+ GGML_ASSERT(ggml_is_contiguous_rows(b));
3414
+
3415
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3416
+
3417
+ result->op = GGML_OP_SET_ROWS;
3418
+ result->src[0] = b;
3419
+ result->src[1] = c;
3420
+
3421
+ return result;
3422
+ }
3423
+
3424
  // ggml_diag
3425
 
3426
  struct ggml_tensor * ggml_diag(