Sigbjørn Skjæret commited on
Commit
1f97ff4
·
1 Parent(s): cbe8006

cuda : add set rows for bf16 (llama/14664)

Browse files
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3226
  } break;
3227
  case GGML_OP_SET_ROWS:
3228
  {
3229
- #pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
3230
- return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3231
  op->src[0]->type == GGML_TYPE_F32 &&
3232
  op->src[1]->type == GGML_TYPE_I64;
3233
  } break;
 
3226
  } break;
3227
  case GGML_OP_SET_ROWS:
3228
  {
3229
+ #pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
3230
+ return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
3231
  op->src[0]->type == GGML_TYPE_F32 &&
3232
  op->src[1]->type == GGML_TYPE_I64;
3233
  } break;
ggml/src/ggml-cuda/set-rows.cu CHANGED
@@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, hal
10
  *dst_h = __float2half(*src_f);
11
  }
12
 
 
 
 
 
 
13
  template<>
14
  __device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
15
  *dst_f = *src_f;
@@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
124
  nb1, nb2, nb3,
125
  stream
126
  );
 
 
 
 
 
 
 
 
 
 
127
  } else {
128
  GGML_ABORT("unsupported type");
129
  }
 
10
  *dst_h = __float2half(*src_f);
11
  }
12
 
13
+ template<>
14
+ __device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
15
+ *dst_b = *src_f;
16
+ }
17
+
18
  template<>
19
  __device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
20
  *dst_f = *src_f;
 
129
  nb1, nb2, nb3,
130
  stream
131
  );
132
+ } else if (dst->type == GGML_TYPE_BF16) {
133
+ set_rows_cuda(
134
+ src0_d, src1_d, (nv_bfloat16*)dst->data,
135
+ ne00, ne01, ne02, ne03,
136
+ ne10, ne11, ne12, ne13,
137
+ nb01, nb02, nb03,
138
+ nb10, nb11, nb12,
139
+ nb1, nb2, nb3,
140
+ stream
141
+ );
142
  } else {
143
  GGML_ABORT("unsupported type");
144
  }