lhez commited on
Commit
4dc1834
·
1 Parent(s): e2965b0

opencl: add f16 for `add`, `sub`, `mul`, `div` (llama/14984)

Browse files
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -400,10 +400,10 @@ struct ggml_backend_opencl_context {
400
  cl_program program_mul_mm_f32_f32_l4_lm;
401
  cl_program program_mul_mm_f16_f32_l4_lm;
402
 
403
- cl_kernel kernel_add, kernel_add_row;
404
- cl_kernel kernel_mul, kernel_mul_row;
405
- cl_kernel kernel_div, kernel_div_row;
406
- cl_kernel kernel_sub, kernel_sub_row;
407
  cl_kernel kernel_scale;
408
  cl_kernel kernel_silu, kernel_silu_4;
409
  cl_kernel kernel_gelu, kernel_gelu_4;
@@ -674,8 +674,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
674
  backend_ctx->program_add =
675
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
676
 
677
- CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
678
- CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
 
 
679
  GGML_LOG_CONT(".");
680
  }
681
 
@@ -1089,8 +1091,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1089
  backend_ctx->program_mul =
1090
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1091
 
1092
- CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
1093
- CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
 
 
1094
  GGML_LOG_CONT(".");
1095
  }
1096
 
@@ -1288,11 +1292,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1288
  #else
1289
  const std::string kernel_src = read_file("div.cl");
1290
  #endif
 
 
 
1291
  backend_ctx->program_div =
1292
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1293
 
1294
- CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
1295
- CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
 
 
1296
  GGML_LOG_CONT(".");
1297
  }
1298
 
@@ -1308,8 +1317,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1308
  backend_ctx->program_sub =
1309
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1310
 
1311
- CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
1312
- CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
 
 
1313
  GGML_LOG_CONT(".");
1314
  }
1315
 
@@ -2447,12 +2458,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2447
  default:
2448
  return false;
2449
  }
2450
- case GGML_OP_ADD:
2451
  case GGML_OP_SCALE:
 
 
2452
  case GGML_OP_MUL:
2453
  case GGML_OP_DIV:
2454
  case GGML_OP_SUB:
2455
- return op->src[0]->type == GGML_TYPE_F32;
 
 
2456
  case GGML_OP_UNARY:
2457
  switch (ggml_get_unary_op(op)) {
2458
  case GGML_UNARY_OP_GELU:
@@ -3680,35 +3694,39 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3680
  GGML_ASSERT(dst);
3681
  GGML_ASSERT(dst->extra);
3682
 
3683
- const int ne00 = src0 ? src0->ne[0] : 0;
3684
- const int ne01 = src0 ? src0->ne[1] : 0;
3685
- const int ne02 = src0 ? src0->ne[2] : 0;
3686
- const int ne03 = src0 ? src0->ne[3] : 0;
3687
 
3688
- const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
3689
- const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
3690
- const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
3691
- const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
3692
 
3693
- const int ne10 = src1 ? src1->ne[0] : 0;
3694
- const int ne11 = src1 ? src1->ne[1] : 0;
3695
- const int ne12 = src1 ? src1->ne[2] : 0;
3696
- const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
3697
 
3698
- const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
3699
- const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
3700
- const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
3701
- const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
3702
 
3703
- const int ne0 = dst ? dst->ne[0] : 0;
3704
- const int ne1 = dst ? dst->ne[1] : 0;
3705
- const int ne2 = dst ? dst->ne[2] : 0;
3706
- const int ne3 = dst ? dst->ne[3] : 0;
3707
 
3708
- const cl_ulong nb0 = dst ? dst->nb[0] : 0;
3709
- const cl_ulong nb1 = dst ? dst->nb[1] : 0;
3710
- const cl_ulong nb2 = dst ? dst->nb[2] : 0;
3711
- const cl_ulong nb3 = dst ? dst->nb[3] : 0;
 
 
 
 
 
3712
 
3713
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3714
 
@@ -3731,7 +3749,12 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3731
 
3732
  bcast_row = true;
3733
  int ne = ne00 / 4;
3734
- kernel = backend_ctx->kernel_add_row;
 
 
 
 
 
3735
 
3736
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3737
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3741,7 +3764,11 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3741
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3742
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3743
  } else {
3744
- kernel = backend_ctx->kernel_add;
 
 
 
 
3745
 
3746
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3747
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3803,35 +3830,39 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
3803
  GGML_ASSERT(dst);
3804
  GGML_ASSERT(dst->extra);
3805
 
3806
- const int ne00 = src0 ? src0->ne[0] : 0;
3807
- const int ne01 = src0 ? src0->ne[1] : 0;
3808
- const int ne02 = src0 ? src0->ne[2] : 0;
3809
- const int ne03 = src0 ? src0->ne[3] : 0;
3810
 
3811
- const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
3812
- const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
3813
- const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
3814
- const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
3815
 
3816
- const int ne10 = src1 ? src1->ne[0] : 0;
3817
- const int ne11 = src1 ? src1->ne[1] : 0;
3818
- const int ne12 = src1 ? src1->ne[2] : 0;
3819
- const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
3820
 
3821
- const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
3822
- const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
3823
- const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
3824
- const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
 
 
 
 
 
3825
 
3826
- const int ne0 = dst ? dst->ne[0] : 0;
3827
- const int ne1 = dst ? dst->ne[1] : 0;
3828
- const int ne2 = dst ? dst->ne[2] : 0;
3829
- const int ne3 = dst ? dst->ne[3] : 0;
3830
 
3831
- const cl_ulong nb0 = dst ? dst->nb[0] : 0;
3832
- const cl_ulong nb1 = dst ? dst->nb[1] : 0;
3833
- const cl_ulong nb2 = dst ? dst->nb[2] : 0;
3834
- const cl_ulong nb3 = dst ? dst->nb[3] : 0;
3835
 
3836
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3837
 
@@ -3854,7 +3885,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
3854
 
3855
  bcast_row = true;
3856
  int ne = ne00 / 4;
3857
- kernel = backend_ctx->kernel_mul_row;
 
 
 
 
 
3858
 
3859
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3860
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3864,7 +3900,11 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
3864
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3865
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3866
  } else {
3867
- kernel = backend_ctx->kernel_mul;
 
 
 
 
3868
 
3869
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3870
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3926,6 +3966,10 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
3926
  GGML_ASSERT(dst);
3927
  GGML_ASSERT(dst->extra);
3928
 
 
 
 
 
3929
  const int ne00 = src0->ne[0];
3930
  const int ne01 = src0->ne[1];
3931
  const int ne02 = src0->ne[2];
@@ -3974,7 +4018,12 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
3974
 
3975
  bcast_row = true;
3976
  int ne = ne00 / 4;
3977
- kernel = backend_ctx->kernel_div_row;
 
 
 
 
 
3978
 
3979
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3980
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3984,7 +4033,11 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
3984
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3985
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3986
  } else {
3987
- kernel = backend_ctx->kernel_div;
 
 
 
 
3988
 
3989
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3990
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -4034,6 +4087,10 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
4034
  GGML_ASSERT(dst);
4035
  GGML_ASSERT(dst->extra);
4036
 
 
 
 
 
4037
  const int ne00 = src0->ne[0];
4038
  const int ne01 = src0->ne[1];
4039
  const int ne02 = src0->ne[2];
@@ -4082,7 +4139,12 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
4082
 
4083
  bcast_row = true;
4084
  int ne = ne00 / 4;
4085
- kernel = backend_ctx->kernel_sub_row;
 
 
 
 
 
4086
 
4087
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4088
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -4092,7 +4154,11 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
4092
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4093
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
4094
  } else {
4095
- kernel = backend_ctx->kernel_sub;
 
 
 
 
4096
 
4097
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4098
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
400
  cl_program program_mul_mm_f32_f32_l4_lm;
401
  cl_program program_mul_mm_f16_f32_l4_lm;
402
 
403
+ cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
404
+ cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
405
+ cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
406
+ cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
407
  cl_kernel kernel_scale;
408
  cl_kernel kernel_silu, kernel_silu_4;
409
  cl_kernel kernel_gelu, kernel_gelu_4;
 
674
  backend_ctx->program_add =
675
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
676
 
677
+ CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
678
+ CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
679
+ CL_CHECK((backend_ctx->kernel_add_f16 = clCreateKernel(backend_ctx->program_add, "kernel_add_f16", &err), err));
680
+ CL_CHECK((backend_ctx->kernel_add_row_f16 = clCreateKernel(backend_ctx->program_add, "kernel_add_row_f16", &err), err));
681
  GGML_LOG_CONT(".");
682
  }
683
 
 
1091
  backend_ctx->program_mul =
1092
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1093
 
1094
+ CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
1095
+ CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
1096
+ CL_CHECK((backend_ctx->kernel_mul_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_f16", &err), err));
1097
+ CL_CHECK((backend_ctx->kernel_mul_row_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row_f16", &err), err));
1098
  GGML_LOG_CONT(".");
1099
  }
1100
 
 
1292
  #else
1293
  const std::string kernel_src = read_file("div.cl");
1294
  #endif
1295
+ std::string compile_opts = std::string("-cl-std=") + opencl_c_std +
1296
+ " -cl-mad-enable -cl-finite-math-only ";
1297
+
1298
  backend_ctx->program_div =
1299
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1300
 
1301
+ CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
1302
+ CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
1303
+ CL_CHECK((backend_ctx->kernel_div_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_f16", &err), err));
1304
+ CL_CHECK((backend_ctx->kernel_div_row_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_row_f16", &err), err));
1305
  GGML_LOG_CONT(".");
1306
  }
1307
 
 
1317
  backend_ctx->program_sub =
1318
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1319
 
1320
+ CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
1321
+ CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
1322
+ CL_CHECK((backend_ctx->kernel_sub_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_f16", &err), err));
1323
+ CL_CHECK((backend_ctx->kernel_sub_row_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row_f16", &err), err));
1324
  GGML_LOG_CONT(".");
1325
  }
1326
 
 
2458
  default:
2459
  return false;
2460
  }
 
2461
  case GGML_OP_SCALE:
2462
+ return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2463
+ case GGML_OP_ADD:
2464
  case GGML_OP_MUL:
2465
  case GGML_OP_DIV:
2466
  case GGML_OP_SUB:
2467
+ return (op->src[0]->type == op->src[1]->type) &&
2468
+ (op->src[0]->type == op->type) &&
2469
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
2470
  case GGML_OP_UNARY:
2471
  switch (ggml_get_unary_op(op)) {
2472
  case GGML_UNARY_OP_GELU:
 
3694
  GGML_ASSERT(dst);
3695
  GGML_ASSERT(dst->extra);
3696
 
3697
+ GGML_ASSERT(src0->type == src1->type);
3698
+ GGML_ASSERT(src0->type == dst->type);
3699
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
 
3700
 
3701
+ const int ne00 = src0->ne[0];
3702
+ const int ne01 = src0->ne[1];
3703
+ const int ne02 = src0->ne[2];
3704
+ const int ne03 = src0->ne[3];
3705
 
3706
+ const cl_ulong nb00 = src0->nb[0];
3707
+ const cl_ulong nb01 = src0->nb[1];
3708
+ const cl_ulong nb02 = src0->nb[2];
3709
+ const cl_ulong nb03 = src0->nb[3];
3710
 
3711
+ const int ne10 = src1->ne[0];
3712
+ const int ne11 = src1->ne[1];
3713
+ const int ne12 = src1->ne[2];
3714
+ const int ne13 = src1->ne[3]; UNUSED(ne13);
3715
 
3716
+ const cl_ulong nb10 = src1->nb[0];
3717
+ const cl_ulong nb11 = src1->nb[1];
3718
+ const cl_ulong nb12 = src1->nb[2];
3719
+ const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
3720
 
3721
+ const int ne0 = dst->ne[0];
3722
+ const int ne1 = dst->ne[1];
3723
+ const int ne2 = dst->ne[2];
3724
+ const int ne3 = dst->ne[3];
3725
+
3726
+ const cl_ulong nb0 = dst->nb[0];
3727
+ const cl_ulong nb1 = dst->nb[1];
3728
+ const cl_ulong nb2 = dst->nb[2];
3729
+ const cl_ulong nb3 = dst->nb[3];
3730
 
3731
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3732
 
 
3749
 
3750
  bcast_row = true;
3751
  int ne = ne00 / 4;
3752
+
3753
+ if (src0->type == GGML_TYPE_F32) {
3754
+ kernel = backend_ctx->kernel_add_row;
3755
+ } else {
3756
+ kernel = backend_ctx->kernel_add_row_f16;
3757
+ }
3758
 
3759
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3760
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
3764
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3765
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3766
  } else {
3767
+ if (src0->type == GGML_TYPE_F32) {
3768
+ kernel = backend_ctx->kernel_add;
3769
+ } else {
3770
+ kernel = backend_ctx->kernel_add_f16;
3771
+ }
3772
 
3773
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3774
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
3830
  GGML_ASSERT(dst);
3831
  GGML_ASSERT(dst->extra);
3832
 
3833
+ GGML_ASSERT(src0->type == src1->type);
3834
+ GGML_ASSERT(src0->type == dst->type);
3835
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
 
3836
 
3837
+ const int ne00 = src0->ne[0];
3838
+ const int ne01 = src0->ne[1];
3839
+ const int ne02 = src0->ne[2];
3840
+ const int ne03 = src0->ne[3];
3841
 
3842
+ const cl_ulong nb00 = src0->nb[0];
3843
+ const cl_ulong nb01 = src0->nb[1];
3844
+ const cl_ulong nb02 = src0->nb[2];
3845
+ const cl_ulong nb03 = src0->nb[3];
3846
 
3847
+ const int ne10 = src1->ne[0];
3848
+ const int ne11 = src1->ne[1];
3849
+ const int ne12 = src1->ne[2];
3850
+ const int ne13 = src1->ne[3]; UNUSED(ne13);
3851
+
3852
+ const cl_ulong nb10 = src1->nb[0];
3853
+ const cl_ulong nb11 = src1->nb[1];
3854
+ const cl_ulong nb12 = src1->nb[2];
3855
+ const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
3856
 
3857
+ const int ne0 = dst->ne[0];
3858
+ const int ne1 = dst->ne[1];
3859
+ const int ne2 = dst->ne[2];
3860
+ const int ne3 = dst->ne[3];
3861
 
3862
+ const cl_ulong nb0 = dst->nb[0];
3863
+ const cl_ulong nb1 = dst->nb[1];
3864
+ const cl_ulong nb2 = dst->nb[2];
3865
+ const cl_ulong nb3 = dst->nb[3];
3866
 
3867
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3868
 
 
3885
 
3886
  bcast_row = true;
3887
  int ne = ne00 / 4;
3888
+
3889
+ if (src0->type == GGML_TYPE_F32) {
3890
+ kernel = backend_ctx->kernel_mul_row;
3891
+ } else {
3892
+ kernel = backend_ctx->kernel_mul_row_f16;
3893
+ }
3894
 
3895
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3896
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
3900
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3901
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3902
  } else {
3903
+ if (src0->type == GGML_TYPE_F32) {
3904
+ kernel = backend_ctx->kernel_mul;
3905
+ } else {
3906
+ kernel = backend_ctx->kernel_mul_f16;
3907
+ }
3908
 
3909
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3910
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
3966
  GGML_ASSERT(dst);
3967
  GGML_ASSERT(dst->extra);
3968
 
3969
+ GGML_ASSERT(src0->type == src1->type);
3970
+ GGML_ASSERT(src0->type == dst->type);
3971
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3972
+
3973
  const int ne00 = src0->ne[0];
3974
  const int ne01 = src0->ne[1];
3975
  const int ne02 = src0->ne[2];
 
4018
 
4019
  bcast_row = true;
4020
  int ne = ne00 / 4;
4021
+
4022
+ if (src0->type == GGML_TYPE_F32) {
4023
+ kernel = backend_ctx->kernel_div_row;
4024
+ } else {
4025
+ kernel = backend_ctx->kernel_div_row_f16;
4026
+ }
4027
 
4028
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4029
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
4033
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4034
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
4035
  } else {
4036
+ if (src0->type == GGML_TYPE_F32) {
4037
+ kernel = backend_ctx->kernel_div;
4038
+ } else {
4039
+ kernel = backend_ctx->kernel_div_f16;
4040
+ }
4041
 
4042
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4043
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
4087
  GGML_ASSERT(dst);
4088
  GGML_ASSERT(dst->extra);
4089
 
4090
+ GGML_ASSERT(src0->type == src1->type);
4091
+ GGML_ASSERT(src0->type == dst->type);
4092
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
4093
+
4094
  const int ne00 = src0->ne[0];
4095
  const int ne01 = src0->ne[1];
4096
  const int ne02 = src0->ne[2];
 
4139
 
4140
  bcast_row = true;
4141
  int ne = ne00 / 4;
4142
+
4143
+ if (src0->type == GGML_TYPE_F32) {
4144
+ kernel = backend_ctx->kernel_sub_row;
4145
+ } else {
4146
+ kernel = backend_ctx->kernel_sub_row_f16;
4147
+ }
4148
 
4149
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4150
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
 
4154
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4155
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
4156
  } else {
4157
+ if (src0->type == GGML_TYPE_F32) {
4158
+ kernel = backend_ctx->kernel_sub;
4159
+ } else {
4160
+ kernel = backend_ctx->kernel_sub_f16;
4161
+ }
4162
 
4163
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4164
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
ggml/src/ggml-opencl/kernels/add.cl CHANGED
@@ -81,3 +81,76 @@ kernel void kernel_add_row(
81
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
82
  dst[gid] = src0[gid] + src1[idx1];
83
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
82
  dst[gid] = src0[gid] + src1[idx1];
83
  }
84
+
85
+ kernel void kernel_add_f16(
86
+ global char * src0,
87
+ ulong offset0,
88
+ global char * src1,
89
+ ulong offset1,
90
+ global char * dst,
91
+ ulong offsetd,
92
+ int ne00,
93
+ int ne01,
94
+ int ne02,
95
+ int ne03,
96
+ ulong nb00,
97
+ ulong nb01,
98
+ ulong nb02,
99
+ ulong nb03,
100
+ int ne10,
101
+ int ne11,
102
+ int ne12,
103
+ int ne13,
104
+ ulong nb10,
105
+ ulong nb11,
106
+ ulong nb12,
107
+ ulong nb13,
108
+ int ne0,
109
+ int ne1,
110
+ int ne2,
111
+ int ne3,
112
+ ulong nb0,
113
+ ulong nb1,
114
+ ulong nb2,
115
+ ulong nb3
116
+ ) {
117
+ src0 = src0 + offset0;
118
+ src1 = src1 + offset1;
119
+ dst = dst + offsetd;
120
+
121
+ int i03 = get_group_id(2);
122
+ int i02 = get_group_id(1);
123
+ int i01 = get_group_id(0);
124
+
125
+ int i13 = i03 % ne13;
126
+ int i12 = i02 % ne12;
127
+ int i11 = i01 % ne11;
128
+
129
+ global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
130
+ global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
131
+ global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
132
+
133
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
134
+ const int i10 = i0 % ne10;
135
+ *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) + *((global half *)(src1_ptr + i10*nb10));
136
+ }
137
+ }
138
+
139
+ kernel void kernel_add_row_f16(
140
+ global half4 * src0,
141
+ ulong offset0,
142
+ global half4 * src1,
143
+ ulong offset1,
144
+ global half4 * dst,
145
+ ulong offsetd,
146
+ int ne
147
+ ) {
148
+ src0 = (global half4*)((global char*)src0 + offset0);
149
+ src1 = (global half4*)((global char*)src1 + offset1);
150
+ dst = (global half4*)((global char*)dst + offsetd);
151
+
152
+ // This performs better than using %.
153
+ uint gid = get_global_id(0);
154
+ uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
155
+ dst[gid] = src0[gid] + src1[idx1];
156
+ }
ggml/src/ggml-opencl/kernels/div.cl CHANGED
@@ -70,3 +70,69 @@ kernel void kernel_div_row(
70
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
71
  dst[gid] = src0[gid] / src1[idx1];
72
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
71
  dst[gid] = src0[gid] / src1[idx1];
72
  }
73
+
74
+ kernel void kernel_div_f16(
75
+ global char * src0,
76
+ ulong offset0,
77
+ global char * src1,
78
+ ulong offset1,
79
+ global char * dst,
80
+ ulong offsetd,
81
+ ulong nb00,
82
+ ulong nb01,
83
+ ulong nb02,
84
+ ulong nb03,
85
+ int ne10,
86
+ int ne11,
87
+ int ne12,
88
+ int ne13,
89
+ ulong nb10,
90
+ ulong nb11,
91
+ ulong nb12,
92
+ ulong nb13,
93
+ int ne0,
94
+ ulong nb0,
95
+ ulong nb1,
96
+ ulong nb2,
97
+ ulong nb3
98
+ ) {
99
+ src0 = src0 + offset0;
100
+ src1 = src1 + offset1;
101
+ dst = dst + offsetd;
102
+
103
+ int i03 = get_group_id(2);
104
+ int i02 = get_group_id(1);
105
+ int i01 = get_group_id(0);
106
+
107
+ int i13 = i03 % ne13;
108
+ int i12 = i02 % ne12;
109
+ int i11 = i01 % ne11;
110
+
111
+ global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
112
+ global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
113
+ global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
114
+
115
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
116
+ const int i10 = i0 % ne10;
117
+ *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) / *((global half *)(src1_ptr + i10*nb10));
118
+ }
119
+ }
120
+
121
+ kernel void kernel_div_row_f16(
122
+ global half4 * src0,
123
+ ulong offset0,
124
+ global half4 * src1,
125
+ ulong offset1,
126
+ global half4 * dst,
127
+ ulong offsetd,
128
+ int ne
129
+ ) {
130
+ src0 = (global half4*)((global char*)src0 + offset0);
131
+ src1 = (global half4*)((global char*)src1 + offset1);
132
+ dst = (global half4*)((global char*)dst + offsetd);
133
+
134
+ // This performs better than using %.
135
+ uint gid = get_global_id(0);
136
+ uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
137
+ dst[gid] = src0[gid] / src1[idx1];
138
+ }
ggml/src/ggml-opencl/kernels/mul.cl CHANGED
@@ -77,3 +77,76 @@ kernel void kernel_mul_row(
77
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
78
  dst[gid] = src0[gid] * src1[idx1];
79
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
78
  dst[gid] = src0[gid] * src1[idx1];
79
  }
80
+
81
+ kernel void kernel_mul_f16(
82
+ global char * src0,
83
+ ulong offset0,
84
+ global char * src1,
85
+ ulong offset1,
86
+ global char * dst,
87
+ ulong offsetd,
88
+ int ne00,
89
+ int ne01,
90
+ int ne02,
91
+ int ne03,
92
+ ulong nb00,
93
+ ulong nb01,
94
+ ulong nb02,
95
+ ulong nb03,
96
+ int ne10,
97
+ int ne11,
98
+ int ne12,
99
+ int ne13,
100
+ ulong nb10,
101
+ ulong nb11,
102
+ ulong nb12,
103
+ ulong nb13,
104
+ int ne0,
105
+ int ne1,
106
+ int ne2,
107
+ int ne3,
108
+ ulong nb0,
109
+ ulong nb1,
110
+ ulong nb2,
111
+ ulong nb3
112
+ ) {
113
+ src0 = src0 + offset0;
114
+ src1 = src1 + offset1;
115
+ dst = dst + offsetd;
116
+
117
+ int i03 = get_group_id(2);
118
+ int i02 = get_group_id(1);
119
+ int i01 = get_group_id(0);
120
+
121
+ int i13 = i03 % ne13;
122
+ int i12 = i02 % ne12;
123
+ int i11 = i01 % ne11;
124
+
125
+ global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
126
+ global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
127
+ global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
128
+
129
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
130
+ const int i10 = i0 % ne10;
131
+ *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) * *((global half *)(src1_ptr + i10*nb10));
132
+ }
133
+ }
134
+
135
+ kernel void kernel_mul_row_f16(
136
+ global half4 * src0,
137
+ ulong offset0,
138
+ global half4 * src1,
139
+ ulong offset1,
140
+ global half4 * dst,
141
+ ulong offsetd,
142
+ int ne
143
+ ) {
144
+ src0 = (global half4*)((global char*)src0 + offset0);
145
+ src1 = (global half4*)((global char*)src1 + offset1);
146
+ dst = (global half4*)((global char*)dst + offsetd);
147
+
148
+ // This performs better than using %.
149
+ uint gid = get_global_id(0);
150
+ uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
151
+ dst[gid] = src0[gid] * src1[idx1];
152
+ }
ggml/src/ggml-opencl/kernels/sub.cl CHANGED
@@ -70,3 +70,69 @@ kernel void kernel_sub_row(
70
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
71
  dst[gid] = src0[gid] - src1[idx1];
72
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
71
  dst[gid] = src0[gid] - src1[idx1];
72
  }
73
+
74
+ kernel void kernel_sub_f16(
75
+ global char * src0,
76
+ ulong offset0,
77
+ global char * src1,
78
+ ulong offset1,
79
+ global char * dst,
80
+ ulong offsetd,
81
+ ulong nb00,
82
+ ulong nb01,
83
+ ulong nb02,
84
+ ulong nb03,
85
+ int ne10,
86
+ int ne11,
87
+ int ne12,
88
+ int ne13,
89
+ ulong nb10,
90
+ ulong nb11,
91
+ ulong nb12,
92
+ ulong nb13,
93
+ int ne0,
94
+ ulong nb0,
95
+ ulong nb1,
96
+ ulong nb2,
97
+ ulong nb3
98
+ ) {
99
+ src0 = src0 + offset0;
100
+ src1 = src1 + offset1;
101
+ dst = dst + offsetd;
102
+
103
+ int i03 = get_group_id(2);
104
+ int i02 = get_group_id(1);
105
+ int i01 = get_group_id(0);
106
+
107
+ int i13 = i03 % ne13;
108
+ int i12 = i02 % ne12;
109
+ int i11 = i01 % ne11;
110
+
111
+ global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
112
+ global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
113
+ global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
114
+
115
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
116
+ const int i10 = i0 % ne10;
117
+ *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) - *((global half *)(src1_ptr + i10*nb10));
118
+ }
119
+ }
120
+
121
+ kernel void kernel_sub_row_f16(
122
+ global half4 * src0,
123
+ ulong offset0,
124
+ global half4 * src1,
125
+ ulong offset1,
126
+ global half4 * dst,
127
+ ulong offsetd,
128
+ int ne
129
+ ) {
130
+ src0 = (global half4*)((global char*)src0 + offset0);
131
+ src1 = (global half4*)((global char*)src1 + offset1);
132
+ dst = (global half4*)((global char*)dst + offsetd);
133
+
134
+ // This performs better than using %.
135
+ uint gid = get_global_id(0);
136
+ uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
137
+ dst[gid] = src0[gid] - src1[idx1];
138
+ }