cmdr2 commited on
Commit
2b94a24
·
1 Parent(s): 6c8e7ec

Support pure float16 add/sub/mul/div operations in the CUDA (and CPU) backend (ggml/1121)

Browse files

* Support float16-to-float16 add/sub/mul/div operations in the CUDA backend

* Add fp16 support for add/sub/mul/div on the CPU backend

* Add test cases for fp16 add/sub/mul/div

ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -1415,15 +1415,35 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
1415
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1416
  inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1417
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
 
 
 
 
 
1418
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1419
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
1420
  inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
1421
  inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
 
 
 
 
 
1422
  inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
1423
  inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1424
  inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
1425
  inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
 
 
 
 
 
1426
  inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
 
 
 
 
 
1427
 
1428
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
1429
  assert(nrc == 1);
@@ -4379,7 +4399,7 @@ static void ggml_compute_forward_add_f16_f16(
4379
  const struct ggml_tensor * src0 = dst->src[0];
4380
  const struct ggml_tensor * src1 = dst->src[1];
4381
 
4382
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
4383
 
4384
  const int ith = params->ith;
4385
  const int nth = params->nth;
@@ -4404,17 +4424,22 @@ static void ggml_compute_forward_add_f16_f16(
4404
 
4405
  if (nb10 == sizeof(ggml_fp16_t)) {
4406
  for (int ir = ir0; ir < ir1; ++ir) {
4407
- // src0, src1 and dst are same shape => same indices
4408
- const int i3 = ir/(ne2*ne1);
4409
- const int i2 = (ir - i3*ne2*ne1)/ne1;
4410
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
4411
 
4412
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
4413
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
4414
- ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
 
4415
 
4416
- for (int i = 0; i < ne0; i++) {
4417
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
 
 
 
 
4418
  }
4419
  }
4420
  }
@@ -5202,6 +5227,62 @@ static void ggml_compute_forward_sub_f32(
5202
  }
5203
  }
5204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5205
  static void ggml_compute_forward_sub(
5206
  const struct ggml_compute_params * params,
5207
  struct ggml_tensor * dst) {
@@ -5213,6 +5294,10 @@ static void ggml_compute_forward_sub(
5213
  {
5214
  ggml_compute_forward_sub_f32(params, dst);
5215
  } break;
 
 
 
 
5216
  default:
5217
  {
5218
  GGML_ABORT("fatal error");
@@ -5293,6 +5378,55 @@ static void ggml_compute_forward_mul_f32(
5293
  }
5294
  }
5295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5296
  static void ggml_compute_forward_mul(
5297
  const struct ggml_compute_params * params,
5298
  struct ggml_tensor * dst) {
@@ -5300,13 +5434,17 @@ static void ggml_compute_forward_mul(
5300
  const struct ggml_tensor * src0 = dst->src[0];
5301
  const struct ggml_tensor * src1 = dst->src[1];
5302
 
5303
- GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
5304
 
5305
  switch (src0->type) {
5306
  case GGML_TYPE_F32:
5307
  {
5308
  ggml_compute_forward_mul_f32(params, dst);
5309
  } break;
 
 
 
 
5310
  default:
5311
  {
5312
  GGML_ABORT("fatal error");
@@ -5387,6 +5525,55 @@ static void ggml_compute_forward_div_f32(
5387
  }
5388
  }
5389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5390
  static void ggml_compute_forward_div(
5391
  const struct ggml_compute_params * params,
5392
  struct ggml_tensor * dst) {
@@ -5398,6 +5585,10 @@ static void ggml_compute_forward_div(
5398
  {
5399
  ggml_compute_forward_div_f32(params, dst);
5400
  } break;
 
 
 
 
5401
  default:
5402
  {
5403
  GGML_ABORT("fatal error");
 
1415
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1416
  inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1417
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1418
+ inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1419
+ for (int i = 0; i < n; ++i) {
1420
+ z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) + GGML_FP16_TO_FP32(y[i]));
1421
+ }
1422
+ }
1423
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1424
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
1425
  inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
1426
  inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
1427
+ inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1428
+ for (int i = 0; i < n; ++i) {
1429
+ z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) - GGML_FP16_TO_FP32(y[i]));
1430
+ }
1431
+ }
1432
  inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
1433
  inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1434
  inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
1435
  inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
1436
+ inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1437
+ for (int i = 0; i < n; ++i) {
1438
+ z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) * GGML_FP16_TO_FP32(y[i]));
1439
+ }
1440
+ }
1441
  inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
1442
+ inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
1443
+ for (int i = 0; i < n; ++i) {
1444
+ z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) / GGML_FP16_TO_FP32(y[i]));
1445
+ }
1446
+ }
1447
 
1448
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
1449
  assert(nrc == 1);
 
4399
  const struct ggml_tensor * src0 = dst->src[0];
4400
  const struct ggml_tensor * src1 = dst->src[1];
4401
 
4402
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
4403
 
4404
  const int ith = params->ith;
4405
  const int nth = params->nth;
 
4424
 
4425
  if (nb10 == sizeof(ggml_fp16_t)) {
4426
  for (int ir = ir0; ir < ir1; ++ir) {
4427
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
4428
+ const int64_t i03 = ir/(ne02*ne01);
4429
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
4430
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
4431
 
4432
+ const int64_t i13 = i03 % ne13;
4433
+ const int64_t i12 = i02 % ne12;
4434
+ const int64_t i11 = i01 % ne11;
4435
+ const int64_t nr0 = ne00 / ne10;
4436
 
4437
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
4438
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
4439
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
4440
+
4441
+ for (int64_t r = 0; r < nr0; ++r) {
4442
+ ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
4443
  }
4444
  }
4445
  }
 
5227
  }
5228
  }
5229
 
5230
+ static void ggml_compute_forward_sub_f16(
5231
+ const struct ggml_compute_params * params,
5232
+ struct ggml_tensor * dst) {
5233
+
5234
+ const struct ggml_tensor * src0 = dst->src[0];
5235
+ const struct ggml_tensor * src1 = dst->src[1];
5236
+
5237
+ assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
5238
+
5239
+ const int ith = params->ith;
5240
+ const int nth = params->nth;
5241
+
5242
+ const int nr = ggml_nrows(src0);
5243
+
5244
+ GGML_TENSOR_BINARY_OP_LOCALS
5245
+
5246
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
5247
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
5248
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
5249
+
5250
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5251
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5252
+
5253
+ // rows per thread
5254
+ const int dr = (nr + nth - 1)/nth;
5255
+
5256
+ // row range for this thread
5257
+ const int ir0 = dr*ith;
5258
+ const int ir1 = MIN(ir0 + dr, nr);
5259
+
5260
+ if (nb10 == sizeof(ggml_fp16_t)) {
5261
+ for (int ir = ir0; ir < ir1; ++ir) {
5262
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
5263
+ const int64_t i03 = ir/(ne02*ne01);
5264
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5265
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5266
+
5267
+ const int64_t i13 = i03 % ne13;
5268
+ const int64_t i12 = i02 % ne12;
5269
+ const int64_t i11 = i01 % ne11;
5270
+ const int64_t nr0 = ne00 / ne10;
5271
+
5272
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5273
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5274
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5275
+
5276
+ for (int64_t r = 0; r < nr0; ++r) {
5277
+ ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5278
+ }
5279
+ }
5280
+ } else {
5281
+ // src1 is not contiguous
5282
+ GGML_ABORT("unimplemented error");
5283
+ }
5284
+ }
5285
+
5286
  static void ggml_compute_forward_sub(
5287
  const struct ggml_compute_params * params,
5288
  struct ggml_tensor * dst) {
 
5294
  {
5295
  ggml_compute_forward_sub_f32(params, dst);
5296
  } break;
5297
+ case GGML_TYPE_F16:
5298
+ {
5299
+ ggml_compute_forward_sub_f16(params, dst);
5300
+ } break;
5301
  default:
5302
  {
5303
  GGML_ABORT("fatal error");
 
5378
  }
5379
  }
5380
 
5381
+ static void ggml_compute_forward_mul_f16(
5382
+ const struct ggml_compute_params * params,
5383
+ struct ggml_tensor * dst) {
5384
+
5385
+ const struct ggml_tensor * src0 = dst->src[0];
5386
+ const struct ggml_tensor * src1 = dst->src[1];
5387
+
5388
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
5389
+
5390
+ const int ith = params->ith;
5391
+ const int nth = params->nth;
5392
+
5393
+ const int64_t nr = ggml_nrows(src0);
5394
+
5395
+ GGML_TENSOR_BINARY_OP_LOCALS
5396
+
5397
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
5398
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
5399
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
5400
+
5401
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5402
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5403
+
5404
+ if (nb10 == sizeof(ggml_fp16_t)) {
5405
+ for (int64_t ir = ith; ir < nr; ir += nth) {
5406
+ // src0 and dst are same shape => same indices
5407
+ const int64_t i03 = ir/(ne02*ne01);
5408
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5409
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5410
+
5411
+ const int64_t i13 = i03 % ne13;
5412
+ const int64_t i12 = i02 % ne12;
5413
+ const int64_t i11 = i01 % ne11;
5414
+ const int64_t nr0 = ne00 / ne10;
5415
+
5416
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5417
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5418
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5419
+
5420
+ for (int64_t r = 0 ; r < nr0; ++r) {
5421
+ ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5422
+ }
5423
+ }
5424
+ } else {
5425
+ // src1 is not contiguous
5426
+ GGML_ABORT("unimplemented error");
5427
+ }
5428
+ }
5429
+
5430
  static void ggml_compute_forward_mul(
5431
  const struct ggml_compute_params * params,
5432
  struct ggml_tensor * dst) {
 
5434
  const struct ggml_tensor * src0 = dst->src[0];
5435
  const struct ggml_tensor * src1 = dst->src[1];
5436
 
5437
+ GGML_ASSERT((src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) && "only f32/f16 src1 supported for now");
5438
 
5439
  switch (src0->type) {
5440
  case GGML_TYPE_F32:
5441
  {
5442
  ggml_compute_forward_mul_f32(params, dst);
5443
  } break;
5444
+ case GGML_TYPE_F16:
5445
+ {
5446
+ ggml_compute_forward_mul_f16(params, dst);
5447
+ } break;
5448
  default:
5449
  {
5450
  GGML_ABORT("fatal error");
 
5525
  }
5526
  }
5527
 
5528
+ static void ggml_compute_forward_div_f16(
5529
+ const struct ggml_compute_params * params,
5530
+ struct ggml_tensor * dst) {
5531
+
5532
+ const struct ggml_tensor * src0 = dst->src[0];
5533
+ const struct ggml_tensor * src1 = dst->src[1];
5534
+
5535
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
5536
+
5537
+ const int ith = params->ith;
5538
+ const int nth = params->nth;
5539
+
5540
+ const int64_t nr = ggml_nrows(src0);
5541
+
5542
+ GGML_TENSOR_BINARY_OP_LOCALS
5543
+
5544
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
5545
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
5546
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
5547
+
5548
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5549
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5550
+
5551
+ if (nb10 == sizeof(ggml_fp16_t)) {
5552
+ for (int64_t ir = ith; ir < nr; ir += nth) {
5553
+ // src0 and dst are same shape => same indices
5554
+ const int64_t i03 = ir/(ne02*ne01);
5555
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5556
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5557
+
5558
+ const int64_t i13 = i03 % ne13;
5559
+ const int64_t i12 = i02 % ne12;
5560
+ const int64_t i11 = i01 % ne11;
5561
+ const int64_t nr0 = ne00 / ne10;
5562
+
5563
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5564
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5565
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5566
+
5567
+ for (int64_t r = 0; r < nr0; ++r) {
5568
+ ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5569
+ }
5570
+ }
5571
+ } else {
5572
+ // src1 is not contiguous
5573
+ GGML_ABORT("unimplemented error");
5574
+ }
5575
+ }
5576
+
5577
  static void ggml_compute_forward_div(
5578
  const struct ggml_compute_params * params,
5579
  struct ggml_tensor * dst) {
 
5585
  {
5586
  ggml_compute_forward_div_f32(params, dst);
5587
  } break;
5588
+ case GGML_TYPE_F16:
5589
+ {
5590
+ ggml_compute_forward_div_f16(params, dst);
5591
+ } break;
5592
  default:
5593
  {
5594
  GGML_ABORT("fatal error");
ggml/src/ggml-cuda/binbcast.cu CHANGED
@@ -294,11 +294,13 @@ static void ggml_cuda_op_bin_bcast(
294
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
295
  const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
296
 
297
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
298
 
299
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
300
  op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
301
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
 
 
302
  op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
303
  } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
304
  op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
 
294
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
295
  const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
296
 
297
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
298
 
299
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
300
  op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
301
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
302
+ op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
303
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
304
  op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
305
  } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
306
  op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);