mrfatso commited on
Commit
345810b
·
1 Parent(s): 008e169

opencl: allow mixed f16/f32 `add` (llama/15140)

Browse files
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -2481,6 +2481,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2481
  case GGML_OP_SCALE:
2482
  return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2483
  case GGML_OP_ADD:
 
 
 
 
 
 
 
2484
  case GGML_OP_MUL:
2485
  case GGML_OP_DIV:
2486
  case GGML_OP_SUB:
@@ -3717,34 +3724,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3717
  GGML_ASSERT(dst);
3718
  GGML_ASSERT(dst->extra);
3719
 
3720
- GGML_ASSERT(src0->type == src1->type);
3721
- GGML_ASSERT(src0->type == dst->type);
3722
- GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3723
-
3724
- const int ne00 = src0->ne[0];
3725
- const int ne01 = src0->ne[1];
3726
- const int ne02 = src0->ne[2];
3727
- const int ne03 = src0->ne[3];
3728
 
3729
  const cl_ulong nb00 = src0->nb[0];
3730
  const cl_ulong nb01 = src0->nb[1];
3731
  const cl_ulong nb02 = src0->nb[2];
3732
  const cl_ulong nb03 = src0->nb[3];
3733
 
3734
- const int ne10 = src1->ne[0];
3735
- const int ne11 = src1->ne[1];
3736
- const int ne12 = src1->ne[2];
3737
- const int ne13 = src1->ne[3]; UNUSED(ne13);
3738
 
3739
  const cl_ulong nb10 = src1->nb[0];
3740
  const cl_ulong nb11 = src1->nb[1];
3741
  const cl_ulong nb12 = src1->nb[2];
3742
- const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
3743
 
3744
- const int ne0 = dst->ne[0];
3745
- const int ne1 = dst->ne[1];
3746
- const int ne2 = dst->ne[2];
3747
- const int ne3 = dst->ne[3];
3748
 
3749
  const cl_ulong nb0 = dst->nb[0];
3750
  const cl_ulong nb1 = dst->nb[1];
@@ -3761,68 +3764,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3761
  cl_ulong offset1 = extra1->offset + src1->view_offs;
3762
  cl_ulong offsetd = extrad->offset + dst->view_offs;
3763
 
3764
- bool bcast_row = false;
3765
  cl_kernel kernel;
3766
 
3767
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
3768
- GGML_ASSERT(ggml_is_contiguous(src0));
3769
 
3770
- // src1 is a row
 
3771
  GGML_ASSERT(ne11 == 1);
 
3772
 
3773
- bcast_row = true;
3774
- int ne = ne00 / 4;
3775
-
3776
- if (src0->type == GGML_TYPE_F32) {
3777
  kernel = backend_ctx->kernel_add_row;
 
 
 
 
 
 
 
 
3778
  } else {
3779
- kernel = backend_ctx->kernel_add_row_f16;
3780
- }
3781
-
3782
- CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3783
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3784
- CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3785
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3786
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3787
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3788
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3789
- } else {
3790
- if (src0->type == GGML_TYPE_F32) {
3791
  kernel = backend_ctx->kernel_add;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3792
  } else {
3793
  kernel = backend_ctx->kernel_add_f16;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3794
  }
3795
-
3796
- CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3797
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3798
- CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3799
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3800
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3801
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3802
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
3803
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
3804
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
3805
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
3806
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
3807
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
3808
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
3809
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
3810
- CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
3811
- CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
3812
- CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
3813
- CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
3814
- CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
3815
- CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
3816
- CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
3817
- CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
3818
- CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
3819
- CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
3820
- CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
3821
- CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
3822
- CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
3823
- CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
3824
- CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
3825
- CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
3826
  }
3827
 
3828
  if (bcast_row) {
@@ -3832,13 +3881,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3832
 
3833
  size_t * local_work_size_ptr = local_work_size;
3834
  if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
3835
- local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
3836
  }
3837
 
3838
- backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
3839
  } else {
3840
  unsigned int nth = MIN(64, ne0);
3841
- size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
3842
  size_t local_work_size[] = {nth, 1, 1};
3843
 
3844
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 
2481
  case GGML_OP_SCALE:
2482
  return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2483
  case GGML_OP_ADD:
2484
+ if (op->type == GGML_TYPE_F16) {
2485
+ const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
2486
+ const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
2487
+ if (src0_ok && src1_ok) {
2488
+ return true;
2489
+ }
2490
+ }
2491
  case GGML_OP_MUL:
2492
  case GGML_OP_DIV:
2493
  case GGML_OP_SUB:
 
3724
  GGML_ASSERT(dst);
3725
  GGML_ASSERT(dst->extra);
3726
 
3727
+ const int ne00 = src0->ne[0];
3728
+ const int ne01 = src0->ne[1];
3729
+ const int ne02 = src0->ne[2];
3730
+ const int ne03 = src0->ne[3];
 
 
 
 
3731
 
3732
  const cl_ulong nb00 = src0->nb[0];
3733
  const cl_ulong nb01 = src0->nb[1];
3734
  const cl_ulong nb02 = src0->nb[2];
3735
  const cl_ulong nb03 = src0->nb[3];
3736
 
3737
+ const int ne10 = src1->ne[0];
3738
+ const int ne11 = src1->ne[1];
3739
+ const int ne12 = src1->ne[2];
3740
+ const int ne13 = src1->ne[3];
3741
 
3742
  const cl_ulong nb10 = src1->nb[0];
3743
  const cl_ulong nb11 = src1->nb[1];
3744
  const cl_ulong nb12 = src1->nb[2];
3745
+ const cl_ulong nb13 = src1->nb[3];
3746
 
3747
+ const int ne0 = dst->ne[0];
3748
+ const int ne1 = dst->ne[1];
3749
+ const int ne2 = dst->ne[2];
3750
+ const int ne3 = dst->ne[3];
3751
 
3752
  const cl_ulong nb0 = dst->nb[0];
3753
  const cl_ulong nb1 = dst->nb[1];
 
3764
  cl_ulong offset1 = extra1->offset + src1->view_offs;
3765
  cl_ulong offsetd = extrad->offset + dst->view_offs;
3766
 
 
3767
  cl_kernel kernel;
3768
 
3769
+ const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
 
3770
 
3771
+ if (bcast_row) {
3772
+ GGML_ASSERT(ggml_is_contiguous(src0));
3773
  GGML_ASSERT(ne11 == 1);
3774
+ }
3775
 
3776
+ if (dst->type == GGML_TYPE_F32) {
3777
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
3778
+ if (bcast_row) {
 
3779
  kernel = backend_ctx->kernel_add_row;
3780
+ const int ne = ne00 / 4;
3781
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3782
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3783
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3784
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3785
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3786
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3787
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3788
  } else {
 
 
 
 
 
 
 
 
 
 
 
 
3789
  kernel = backend_ctx->kernel_add;
3790
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3791
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3792
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3793
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3794
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3795
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3796
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
3797
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
3798
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
3799
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
3800
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
3801
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
3802
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
3803
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
3804
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
3805
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
3806
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
3807
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
3808
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
3809
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
3810
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
3811
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
3812
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
3813
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
3814
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
3815
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
3816
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
3817
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
3818
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
3819
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
3820
+ }
3821
+ } else if (dst->type == GGML_TYPE_F16) {
3822
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3823
+ GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
3824
+ const int type_src0 = (src0->type == GGML_TYPE_F32);
3825
+ const int type_src1 = (src1->type == GGML_TYPE_F32);
3826
+ if (bcast_row) {
3827
+ kernel = backend_ctx->kernel_add_row_f16;
3828
+ const int ne = ne00 / 4;
3829
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3830
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3831
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3832
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3833
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3834
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3835
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3836
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0));
3837
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1));
3838
  } else {
3839
  kernel = backend_ctx->kernel_add_f16;
3840
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3841
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3842
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3843
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3844
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3845
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3846
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
3847
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
3848
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
3849
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
3850
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
3851
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
3852
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
3853
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
3854
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
3855
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
3856
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
3857
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
3858
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
3859
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
3860
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
3861
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
3862
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
3863
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
3864
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
3865
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
3866
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
3867
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
3868
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
3869
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
3870
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0));
3871
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1));
3872
  }
3873
+ } else {
3874
+ GGML_ASSERT(false && "unsupported data types for add");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3875
  }
3876
 
3877
  if (bcast_row) {
 
3881
 
3882
  size_t * local_work_size_ptr = local_work_size;
3883
  if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
3884
+ local_work_size_ptr = nullptr;
3885
  }
3886
 
3887
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
3888
  } else {
3889
  unsigned int nth = MIN(64, ne0);
3890
+ size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
3891
  size_t local_work_size[] = {nth, 1, 1};
3892
 
3893
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
ggml/src/ggml-opencl/kernels/add.cl CHANGED
@@ -112,7 +112,9 @@ kernel void kernel_add_f16(
112
  ulong nb0,
113
  ulong nb1,
114
  ulong nb2,
115
- ulong nb3
 
 
116
  ) {
117
  src0 = src0 + offset0;
118
  src1 = src1 + offset1;
@@ -132,25 +134,57 @@ kernel void kernel_add_f16(
132
 
133
  for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
134
  const int i10 = i0 % ne10;
135
- *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) + *((global half *)(src1_ptr + i10*nb10));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  }
137
  }
138
 
139
  kernel void kernel_add_row_f16(
140
- global half4 * src0,
141
  ulong offset0,
142
- global half4 * src1,
143
  ulong offset1,
144
  global half4 * dst,
145
  ulong offsetd,
146
- int ne
 
 
147
  ) {
148
- src0 = (global half4*)((global char*)src0 + offset0);
149
- src1 = (global half4*)((global char*)src1 + offset1);
150
  dst = (global half4*)((global char*)dst + offsetd);
151
 
152
  // This performs better than using %.
153
  uint gid = get_global_id(0);
154
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
155
- dst[gid] = src0[gid] + src1[idx1];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  }
 
112
  ulong nb0,
113
  ulong nb1,
114
  ulong nb2,
115
+ ulong nb3,
116
+ int type_src0,
117
+ int type_src1
118
  ) {
119
  src0 = src0 + offset0;
120
  src1 = src1 + offset1;
 
134
 
135
  for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
136
  const int i10 = i0 % ne10;
137
+
138
+ half v0, v1;
139
+ if (type_src0 == 1) {
140
+ v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
141
+ } else {
142
+ v0 = *((global half *)(src0_ptr + i0*nb00));
143
+ }
144
+
145
+ if (type_src1 == 1) {
146
+ v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
147
+ } else {
148
+ v1 = *((global half *)(src1_ptr + i10*nb10));
149
+ }
150
+
151
+ *((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
152
  }
153
  }
154
 
155
  kernel void kernel_add_row_f16(
156
+ global char * src0,
157
  ulong offset0,
158
+ global char * src1,
159
  ulong offset1,
160
  global half4 * dst,
161
  ulong offsetd,
162
+ int ne,
163
+ int type_src0,
164
+ int type_src1
165
  ) {
 
 
166
  dst = (global half4*)((global char*)dst + offsetd);
167
 
168
  // This performs better than using %.
169
  uint gid = get_global_id(0);
170
  uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
171
+
172
+ half4 v0, v1;
173
+ if (type_src0 == 1) {
174
+ global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
175
+ v0 = convert_half4(src0_f32[gid]);
176
+ } else {
177
+ global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
178
+ v0 = src0_f16[gid];
179
+ }
180
+
181
+ if (type_src1 == 1) {
182
+ global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
183
+ v1 = convert_half4(src1_f32[idx1]);
184
+ } else {
185
+ global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
186
+ v1 = src1_f16[idx1];
187
+ }
188
+
189
+ dst[gid] = v0 + v1;
190
  }