Spaces:
Running
Running
opencl: allow mixed f16/f32 `add` (llama/15140)
Browse files
ggml/src/ggml-opencl/ggml-opencl.cpp
CHANGED
|
@@ -2481,6 +2481,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
| 2481 |
case GGML_OP_SCALE:
|
| 2482 |
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
| 2483 |
case GGML_OP_ADD:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2484 |
case GGML_OP_MUL:
|
| 2485 |
case GGML_OP_DIV:
|
| 2486 |
case GGML_OP_SUB:
|
|
@@ -3717,34 +3724,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
|
|
| 3717 |
GGML_ASSERT(dst);
|
| 3718 |
GGML_ASSERT(dst->extra);
|
| 3719 |
|
| 3720 |
-
|
| 3721 |
-
|
| 3722 |
-
|
| 3723 |
-
|
| 3724 |
-
const int ne00 = src0->ne[0];
|
| 3725 |
-
const int ne01 = src0->ne[1];
|
| 3726 |
-
const int ne02 = src0->ne[2];
|
| 3727 |
-
const int ne03 = src0->ne[3];
|
| 3728 |
|
| 3729 |
const cl_ulong nb00 = src0->nb[0];
|
| 3730 |
const cl_ulong nb01 = src0->nb[1];
|
| 3731 |
const cl_ulong nb02 = src0->nb[2];
|
| 3732 |
const cl_ulong nb03 = src0->nb[3];
|
| 3733 |
|
| 3734 |
-
const int
|
| 3735 |
-
const int
|
| 3736 |
-
const int
|
| 3737 |
-
const int
|
| 3738 |
|
| 3739 |
const cl_ulong nb10 = src1->nb[0];
|
| 3740 |
const cl_ulong nb11 = src1->nb[1];
|
| 3741 |
const cl_ulong nb12 = src1->nb[2];
|
| 3742 |
-
const cl_ulong nb13 = src1->nb[3];
|
| 3743 |
|
| 3744 |
-
const int
|
| 3745 |
-
const int
|
| 3746 |
-
const int
|
| 3747 |
-
const int
|
| 3748 |
|
| 3749 |
const cl_ulong nb0 = dst->nb[0];
|
| 3750 |
const cl_ulong nb1 = dst->nb[1];
|
|
@@ -3761,68 +3764,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
|
|
| 3761 |
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
| 3762 |
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
| 3763 |
|
| 3764 |
-
bool bcast_row = false;
|
| 3765 |
cl_kernel kernel;
|
| 3766 |
|
| 3767 |
-
|
| 3768 |
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 3769 |
|
| 3770 |
-
|
|
|
|
| 3771 |
GGML_ASSERT(ne11 == 1);
|
|
|
|
| 3772 |
|
| 3773 |
-
|
| 3774 |
-
|
| 3775 |
-
|
| 3776 |
-
if (src0->type == GGML_TYPE_F32) {
|
| 3777 |
kernel = backend_ctx->kernel_add_row;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3778 |
} else {
|
| 3779 |
-
kernel = backend_ctx->kernel_add_row_f16;
|
| 3780 |
-
}
|
| 3781 |
-
|
| 3782 |
-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
| 3783 |
-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 3784 |
-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
| 3785 |
-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 3786 |
-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
| 3787 |
-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
| 3788 |
-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
|
| 3789 |
-
} else {
|
| 3790 |
-
if (src0->type == GGML_TYPE_F32) {
|
| 3791 |
kernel = backend_ctx->kernel_add;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3792 |
} else {
|
| 3793 |
kernel = backend_ctx->kernel_add_f16;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3794 |
}
|
| 3795 |
-
|
| 3796 |
-
|
| 3797 |
-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 3798 |
-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
| 3799 |
-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 3800 |
-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
| 3801 |
-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
| 3802 |
-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
| 3803 |
-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
| 3804 |
-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
| 3805 |
-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
|
| 3806 |
-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
|
| 3807 |
-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
|
| 3808 |
-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
|
| 3809 |
-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
|
| 3810 |
-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
|
| 3811 |
-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
|
| 3812 |
-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
|
| 3813 |
-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
|
| 3814 |
-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
|
| 3815 |
-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
|
| 3816 |
-
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
|
| 3817 |
-
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
|
| 3818 |
-
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
|
| 3819 |
-
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
|
| 3820 |
-
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
|
| 3821 |
-
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
|
| 3822 |
-
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
|
| 3823 |
-
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
|
| 3824 |
-
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
|
| 3825 |
-
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
|
| 3826 |
}
|
| 3827 |
|
| 3828 |
if (bcast_row) {
|
|
@@ -3832,13 +3881,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
|
|
| 3832 |
|
| 3833 |
size_t * local_work_size_ptr = local_work_size;
|
| 3834 |
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
|
| 3835 |
-
local_work_size_ptr = nullptr;
|
| 3836 |
}
|
| 3837 |
|
| 3838 |
-
backend_ctx->enqueue_ndrange_kernel(kernel,
|
| 3839 |
} else {
|
| 3840 |
unsigned int nth = MIN(64, ne0);
|
| 3841 |
-
size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
|
| 3842 |
size_t local_work_size[] = {nth, 1, 1};
|
| 3843 |
|
| 3844 |
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
|
|
|
| 2481 |
case GGML_OP_SCALE:
|
| 2482 |
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
| 2483 |
case GGML_OP_ADD:
|
| 2484 |
+
if (op->type == GGML_TYPE_F16) {
|
| 2485 |
+
const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
|
| 2486 |
+
const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
|
| 2487 |
+
if (src0_ok && src1_ok) {
|
| 2488 |
+
return true;
|
| 2489 |
+
}
|
| 2490 |
+
}
|
| 2491 |
case GGML_OP_MUL:
|
| 2492 |
case GGML_OP_DIV:
|
| 2493 |
case GGML_OP_SUB:
|
|
|
|
| 3724 |
GGML_ASSERT(dst);
|
| 3725 |
GGML_ASSERT(dst->extra);
|
| 3726 |
|
| 3727 |
+
const int ne00 = src0->ne[0];
|
| 3728 |
+
const int ne01 = src0->ne[1];
|
| 3729 |
+
const int ne02 = src0->ne[2];
|
| 3730 |
+
const int ne03 = src0->ne[3];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3731 |
|
| 3732 |
const cl_ulong nb00 = src0->nb[0];
|
| 3733 |
const cl_ulong nb01 = src0->nb[1];
|
| 3734 |
const cl_ulong nb02 = src0->nb[2];
|
| 3735 |
const cl_ulong nb03 = src0->nb[3];
|
| 3736 |
|
| 3737 |
+
const int ne10 = src1->ne[0];
|
| 3738 |
+
const int ne11 = src1->ne[1];
|
| 3739 |
+
const int ne12 = src1->ne[2];
|
| 3740 |
+
const int ne13 = src1->ne[3];
|
| 3741 |
|
| 3742 |
const cl_ulong nb10 = src1->nb[0];
|
| 3743 |
const cl_ulong nb11 = src1->nb[1];
|
| 3744 |
const cl_ulong nb12 = src1->nb[2];
|
| 3745 |
+
const cl_ulong nb13 = src1->nb[3];
|
| 3746 |
|
| 3747 |
+
const int ne0 = dst->ne[0];
|
| 3748 |
+
const int ne1 = dst->ne[1];
|
| 3749 |
+
const int ne2 = dst->ne[2];
|
| 3750 |
+
const int ne3 = dst->ne[3];
|
| 3751 |
|
| 3752 |
const cl_ulong nb0 = dst->nb[0];
|
| 3753 |
const cl_ulong nb1 = dst->nb[1];
|
|
|
|
| 3764 |
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
| 3765 |
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
| 3766 |
|
|
|
|
| 3767 |
cl_kernel kernel;
|
| 3768 |
|
| 3769 |
+
const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
|
|
|
|
| 3770 |
|
| 3771 |
+
if (bcast_row) {
|
| 3772 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 3773 |
GGML_ASSERT(ne11 == 1);
|
| 3774 |
+
}
|
| 3775 |
|
| 3776 |
+
if (dst->type == GGML_TYPE_F32) {
|
| 3777 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
|
| 3778 |
+
if (bcast_row) {
|
|
|
|
| 3779 |
kernel = backend_ctx->kernel_add_row;
|
| 3780 |
+
const int ne = ne00 / 4;
|
| 3781 |
+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
| 3782 |
+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 3783 |
+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
| 3784 |
+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 3785 |
+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
| 3786 |
+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
| 3787 |
+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
|
| 3788 |
} else {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3789 |
kernel = backend_ctx->kernel_add;
|
| 3790 |
+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
| 3791 |
+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 3792 |
+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
| 3793 |
+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 3794 |
+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
| 3795 |
+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
| 3796 |
+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
| 3797 |
+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
| 3798 |
+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
| 3799 |
+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
|
| 3800 |
+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
|
| 3801 |
+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
|
| 3802 |
+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
|
| 3803 |
+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
|
| 3804 |
+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
|
| 3805 |
+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
|
| 3806 |
+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
|
| 3807 |
+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
|
| 3808 |
+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
|
| 3809 |
+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
|
| 3810 |
+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
|
| 3811 |
+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
|
| 3812 |
+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
|
| 3813 |
+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
|
| 3814 |
+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
|
| 3815 |
+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
|
| 3816 |
+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
|
| 3817 |
+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
|
| 3818 |
+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
|
| 3819 |
+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
|
| 3820 |
+
}
|
| 3821 |
+
} else if (dst->type == GGML_TYPE_F16) {
|
| 3822 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
|
| 3823 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
| 3824 |
+
const int type_src0 = (src0->type == GGML_TYPE_F32);
|
| 3825 |
+
const int type_src1 = (src1->type == GGML_TYPE_F32);
|
| 3826 |
+
if (bcast_row) {
|
| 3827 |
+
kernel = backend_ctx->kernel_add_row_f16;
|
| 3828 |
+
const int ne = ne00 / 4;
|
| 3829 |
+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
| 3830 |
+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 3831 |
+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
| 3832 |
+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 3833 |
+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
| 3834 |
+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
| 3835 |
+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
|
| 3836 |
+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0));
|
| 3837 |
+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1));
|
| 3838 |
} else {
|
| 3839 |
kernel = backend_ctx->kernel_add_f16;
|
| 3840 |
+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
| 3841 |
+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 3842 |
+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
| 3843 |
+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 3844 |
+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
| 3845 |
+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
| 3846 |
+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
| 3847 |
+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
| 3848 |
+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
| 3849 |
+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
|
| 3850 |
+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
|
| 3851 |
+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
|
| 3852 |
+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
|
| 3853 |
+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
|
| 3854 |
+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
|
| 3855 |
+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
|
| 3856 |
+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
|
| 3857 |
+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
|
| 3858 |
+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
|
| 3859 |
+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
|
| 3860 |
+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
|
| 3861 |
+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
|
| 3862 |
+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
|
| 3863 |
+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
|
| 3864 |
+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
|
| 3865 |
+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
|
| 3866 |
+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
|
| 3867 |
+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
|
| 3868 |
+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
|
| 3869 |
+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
|
| 3870 |
+
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0));
|
| 3871 |
+
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1));
|
| 3872 |
}
|
| 3873 |
+
} else {
|
| 3874 |
+
GGML_ASSERT(false && "unsupported data types for add");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3875 |
}
|
| 3876 |
|
| 3877 |
if (bcast_row) {
|
|
|
|
| 3881 |
|
| 3882 |
size_t * local_work_size_ptr = local_work_size;
|
| 3883 |
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
|
| 3884 |
+
local_work_size_ptr = nullptr;
|
| 3885 |
}
|
| 3886 |
|
| 3887 |
+
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
|
| 3888 |
} else {
|
| 3889 |
unsigned int nth = MIN(64, ne0);
|
| 3890 |
+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
| 3891 |
size_t local_work_size[] = {nth, 1, 1};
|
| 3892 |
|
| 3893 |
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
ggml/src/ggml-opencl/kernels/add.cl
CHANGED
|
@@ -112,7 +112,9 @@ kernel void kernel_add_f16(
|
|
| 112 |
ulong nb0,
|
| 113 |
ulong nb1,
|
| 114 |
ulong nb2,
|
| 115 |
-
ulong nb3
|
|
|
|
|
|
|
| 116 |
) {
|
| 117 |
src0 = src0 + offset0;
|
| 118 |
src1 = src1 + offset1;
|
|
@@ -132,25 +134,57 @@ kernel void kernel_add_f16(
|
|
| 132 |
|
| 133 |
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 134 |
const int i10 = i0 % ne10;
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
}
|
| 137 |
}
|
| 138 |
|
| 139 |
kernel void kernel_add_row_f16(
|
| 140 |
-
global
|
| 141 |
ulong offset0,
|
| 142 |
-
global
|
| 143 |
ulong offset1,
|
| 144 |
global half4 * dst,
|
| 145 |
ulong offsetd,
|
| 146 |
-
int ne
|
|
|
|
|
|
|
| 147 |
) {
|
| 148 |
-
src0 = (global half4*)((global char*)src0 + offset0);
|
| 149 |
-
src1 = (global half4*)((global char*)src1 + offset1);
|
| 150 |
dst = (global half4*)((global char*)dst + offsetd);
|
| 151 |
|
| 152 |
// This performs better than using %.
|
| 153 |
uint gid = get_global_id(0);
|
| 154 |
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
}
|
|
|
|
| 112 |
ulong nb0,
|
| 113 |
ulong nb1,
|
| 114 |
ulong nb2,
|
| 115 |
+
ulong nb3,
|
| 116 |
+
int type_src0,
|
| 117 |
+
int type_src1
|
| 118 |
) {
|
| 119 |
src0 = src0 + offset0;
|
| 120 |
src1 = src1 + offset1;
|
|
|
|
| 134 |
|
| 135 |
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 136 |
const int i10 = i0 % ne10;
|
| 137 |
+
|
| 138 |
+
half v0, v1;
|
| 139 |
+
if (type_src0 == 1) {
|
| 140 |
+
v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
|
| 141 |
+
} else {
|
| 142 |
+
v0 = *((global half *)(src0_ptr + i0*nb00));
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
if (type_src1 == 1) {
|
| 146 |
+
v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
|
| 147 |
+
} else {
|
| 148 |
+
v1 = *((global half *)(src1_ptr + i10*nb10));
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
*((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
|
| 152 |
}
|
| 153 |
}
|
| 154 |
|
| 155 |
kernel void kernel_add_row_f16(
|
| 156 |
+
global char * src0,
|
| 157 |
ulong offset0,
|
| 158 |
+
global char * src1,
|
| 159 |
ulong offset1,
|
| 160 |
global half4 * dst,
|
| 161 |
ulong offsetd,
|
| 162 |
+
int ne,
|
| 163 |
+
int type_src0,
|
| 164 |
+
int type_src1
|
| 165 |
) {
|
|
|
|
|
|
|
| 166 |
dst = (global half4*)((global char*)dst + offsetd);
|
| 167 |
|
| 168 |
// This performs better than using %.
|
| 169 |
uint gid = get_global_id(0);
|
| 170 |
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
|
| 171 |
+
|
| 172 |
+
half4 v0, v1;
|
| 173 |
+
if (type_src0 == 1) {
|
| 174 |
+
global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
|
| 175 |
+
v0 = convert_half4(src0_f32[gid]);
|
| 176 |
+
} else {
|
| 177 |
+
global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
|
| 178 |
+
v0 = src0_f16[gid];
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
if (type_src1 == 1) {
|
| 182 |
+
global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
|
| 183 |
+
v1 = convert_half4(src1_f32[idx1]);
|
| 184 |
+
} else {
|
| 185 |
+
global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
|
| 186 |
+
v1 = src1_f16[idx1];
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
dst[gid] = v0 + v1;
|
| 190 |
}
|