Spaces:
Sleeping
Sleeping
Sigbjørn Skjæret
commited on
Commit
·
1f97ff4
1
Parent(s):
cbe8006
cuda : add set rows for bf16 (llama/14664)
Browse files
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3226 |
} break;
|
| 3227 |
case GGML_OP_SET_ROWS:
|
| 3228 |
{
|
| 3229 |
-
#pragma message("TODO: implement
|
| 3230 |
-
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
| 3231 |
op->src[0]->type == GGML_TYPE_F32 &&
|
| 3232 |
op->src[1]->type == GGML_TYPE_I64;
|
| 3233 |
} break;
|
|
|
|
| 3226 |
} break;
|
| 3227 |
case GGML_OP_SET_ROWS:
|
| 3228 |
{
|
| 3229 |
+
#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
| 3230 |
+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
|
| 3231 |
op->src[0]->type == GGML_TYPE_F32 &&
|
| 3232 |
op->src[1]->type == GGML_TYPE_I64;
|
| 3233 |
} break;
|
ggml/src/ggml-cuda/set-rows.cu
CHANGED
|
@@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, hal
|
|
| 10 |
*dst_h = __float2half(*src_f);
|
| 11 |
}
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
template<>
|
| 14 |
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
|
| 15 |
*dst_f = *src_f;
|
|
@@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 124 |
nb1, nb2, nb3,
|
| 125 |
stream
|
| 126 |
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
} else {
|
| 128 |
GGML_ABORT("unsupported type");
|
| 129 |
}
|
|
|
|
| 10 |
*dst_h = __float2half(*src_f);
|
| 11 |
}
|
| 12 |
|
| 13 |
+
template<>
|
| 14 |
+
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
|
| 15 |
+
*dst_b = *src_f;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
template<>
|
| 19 |
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
|
| 20 |
*dst_f = *src_f;
|
|
|
|
| 129 |
nb1, nb2, nb3,
|
| 130 |
stream
|
| 131 |
);
|
| 132 |
+
} else if (dst->type == GGML_TYPE_BF16) {
|
| 133 |
+
set_rows_cuda(
|
| 134 |
+
src0_d, src1_d, (nv_bfloat16*)dst->data,
|
| 135 |
+
ne00, ne01, ne02, ne03,
|
| 136 |
+
ne10, ne11, ne12, ne13,
|
| 137 |
+
nb01, nb02, nb03,
|
| 138 |
+
nb10, nb11, nb12,
|
| 139 |
+
nb1, nb2, nb3,
|
| 140 |
+
stream
|
| 141 |
+
);
|
| 142 |
} else {
|
| 143 |
GGML_ABORT("unsupported type");
|
| 144 |
}
|