ggerganov commited on
Commit
8d359ad
·
1 Parent(s): 7097123

ggml : generalize GGML_OP_CONCAT (llama/7563)

Browse files

* ggml : generalize GGML_OP_CONCAT (WIP)

ggml-ci

* tests : add dim != 2 tests

* metal : generalize concat kernel

* tests : naming

* cuda : generalize concat kernel

ggml-ci

* sycl : add warning and assert

* ggml : fix op params handling

* metal : bugfix kernel

ggml-ci

* ggml : reimplement CPU and Metal

* cuda : add asserts

ggml-ci

* ggml : fix ptrs

ggml-ci

Files changed (6) hide show
  1. ggml-cuda/concat.cu +87 -6
  2. ggml-metal.m +3 -0
  3. ggml-metal.metal +14 -15
  4. ggml-sycl.cpp +4 -0
  5. ggml.c +40 -23
  6. ggml.h +3 -2
ggml-cuda/concat.cu CHANGED
@@ -1,15 +1,68 @@
1
  #include "concat.cuh"
2
 
3
- static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
4
  int nidx = threadIdx.x + blockIdx.x * blockDim.x;
5
  if (nidx >= ne0) {
6
  return;
7
  }
8
- // operation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  int offset_dst =
10
  nidx +
11
  blockIdx.y * ne0 +
12
  blockIdx.z * ne0 * gridDim.y;
 
13
  if (blockIdx.z < ne02) { // src0
14
  int offset_src =
15
  nidx +
@@ -25,25 +78,53 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
25
  }
26
  }
27
 
28
- static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
29
  int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
30
  dim3 gridDim(num_blocks, ne1, ne2);
31
- concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
 
 
 
 
 
 
 
 
32
  }
33
 
34
  void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
35
  const ggml_tensor * src0 = dst->src[0];
36
  const ggml_tensor * src1 = dst->src[1];
 
37
  const float * src0_d = (const float *)src0->data;
38
  const float * src1_d = (const float *)src1->data;
 
39
  float * dst_d = (float *)dst->data;
40
  cudaStream_t stream = ctx.stream();
41
 
 
 
 
 
 
42
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
43
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
44
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
45
 
46
- for (int i3 = 0; i3 < dst->ne[3]; i3++) {
47
- concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  }
49
  }
 
1
  #include "concat.cuh"
2
 
3
+ static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
4
  int nidx = threadIdx.x + blockIdx.x * blockDim.x;
5
  if (nidx >= ne0) {
6
  return;
7
  }
8
+
9
+ int offset_dst =
10
+ nidx +
11
+ blockIdx.y * ne0 +
12
+ blockIdx.z * ne0 * gridDim.y;
13
+
14
+ if (nidx < ne00) { // src0
15
+ int offset_src =
16
+ nidx +
17
+ blockIdx.y * ne00 +
18
+ blockIdx.z * ne00 * gridDim.y;
19
+ dst[offset_dst] = x[offset_src];
20
+ } else {
21
+ int offset_src =
22
+ (nidx - ne00) +
23
+ blockIdx.y * (ne0 - ne00) +
24
+ blockIdx.z * (ne0 - ne00) * gridDim.y;
25
+ dst[offset_dst] = y[offset_src];
26
+ }
27
+ }
28
+
29
+ static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
30
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
31
+ if (nidx >= ne0) {
32
+ return;
33
+ }
34
+
35
+ int offset_dst =
36
+ nidx +
37
+ blockIdx.y * ne0 +
38
+ blockIdx.z * ne0 * gridDim.y;
39
+
40
+ if (blockIdx.y < ne01) { // src0
41
+ int offset_src =
42
+ nidx +
43
+ blockIdx.y * ne0 +
44
+ blockIdx.z * ne0 * ne01;
45
+ dst[offset_dst] = x[offset_src];
46
+ } else {
47
+ int offset_src =
48
+ nidx +
49
+ (blockIdx.y - ne01) * ne0 +
50
+ blockIdx.z * ne0 * (gridDim.y - ne01);
51
+ dst[offset_dst] = y[offset_src];
52
+ }
53
+ }
54
+
55
+ static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
56
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
57
+ if (nidx >= ne0) {
58
+ return;
59
+ }
60
+
61
  int offset_dst =
62
  nidx +
63
  blockIdx.y * ne0 +
64
  blockIdx.z * ne0 * gridDim.y;
65
+
66
  if (blockIdx.z < ne02) { // src0
67
  int offset_src =
68
  nidx +
 
78
  }
79
  }
80
 
81
+ static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
82
  int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
83
  dim3 gridDim(num_blocks, ne1, ne2);
84
+ if (dim == 0) {
85
+ concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
86
+ return;
87
+ }
88
+ if (dim == 1) {
89
+ concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
90
+ return;
91
+ }
92
+ concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
93
  }
94
 
95
  void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
96
  const ggml_tensor * src0 = dst->src[0];
97
  const ggml_tensor * src1 = dst->src[1];
98
+
99
  const float * src0_d = (const float *)src0->data;
100
  const float * src1_d = (const float *)src1->data;
101
+
102
  float * dst_d = (float *)dst->data;
103
  cudaStream_t stream = ctx.stream();
104
 
105
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
106
+
107
+ GGML_ASSERT(ggml_is_contiguous(src0));
108
+ GGML_ASSERT(ggml_is_contiguous(src1));
109
+
110
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
111
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
112
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
113
 
114
+ if (dim != 3) {
115
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
116
+ concat_f32_cuda(
117
+ src0_d + i3 * (src0->nb[3] / 4),
118
+ src1_d + i3 * (src1->nb[3] / 4),
119
+ dst_d + i3 * ( dst->nb[3] / 4),
120
+ src0->ne[0], src0->ne[1], src0->ne[2],
121
+ dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
122
+ }
123
+ } else {
124
+ const size_t size0 = ggml_nbytes(src0);
125
+ const size_t size1 = ggml_nbytes(src1);
126
+
127
+ CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
128
+ CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
129
  }
130
  }
ggml-metal.m CHANGED
@@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute(
990
  {
991
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
992
 
 
 
993
  [encoder setComputePipelineState:pipeline];
994
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
995
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
1018
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1019
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1020
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
 
1021
 
1022
  const int nth = MIN(1024, ne0);
1023
 
 
990
  {
991
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
992
 
993
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
994
+
995
  [encoder setComputePipelineState:pipeline];
996
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
997
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
 
1020
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1021
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1022
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1023
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1024
 
1025
  const int nth = MIN(1024, ne0);
1026
 
ggml-metal.metal CHANGED
@@ -3366,31 +3366,30 @@ kernel void kernel_concat(
3366
  constant uint64_t & nb1,
3367
  constant uint64_t & nb2,
3368
  constant uint64_t & nb3,
 
3369
  uint3 tgpig[[threadgroup_position_in_grid]],
3370
  uint3 tpitg[[thread_position_in_threadgroup]],
3371
  uint3 ntg[[threads_per_threadgroup]]) {
3372
 
3373
- const int64_t i03 = tgpig.z;
3374
- const int64_t i02 = tgpig.y;
3375
- const int64_t i01 = tgpig.x;
3376
 
3377
- const int64_t i13 = i03 % ne13;
3378
- const int64_t i12 = i02 % ne12;
3379
- const int64_t i11 = i01 % ne11;
3380
 
3381
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3382
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3383
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3384
 
3385
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3386
- if (i02 < ne02) {
3387
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3388
- src0_ptr += ntg.x*nb00;
3389
  } else {
3390
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3391
- src1_ptr += ntg.x*nb10;
3392
  }
3393
- dst_ptr += ntg.x*nb0;
 
 
 
3394
  }
3395
  }
3396
 
 
3366
  constant uint64_t & nb1,
3367
  constant uint64_t & nb2,
3368
  constant uint64_t & nb3,
3369
+ constant int32_t & dim,
3370
  uint3 tgpig[[threadgroup_position_in_grid]],
3371
  uint3 tpitg[[thread_position_in_threadgroup]],
3372
  uint3 ntg[[threads_per_threadgroup]]) {
3373
 
3374
+ const int64_t i3 = tgpig.z;
3375
+ const int64_t i2 = tgpig.y;
3376
+ const int64_t i1 = tgpig.x;
3377
 
3378
+ int64_t o[4] = {0, 0, 0, 0};
3379
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
 
3380
 
3381
+ device const float * x;
 
 
3382
 
3383
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3384
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3385
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
 
3386
  } else {
3387
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
 
3388
  }
3389
+
3390
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3391
+
3392
+ *y = *x;
3393
  }
3394
  }
3395
 
ggml-sycl.cpp CHANGED
@@ -13512,6 +13512,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
13512
  const float *src0_dd, const float *src1_dd,
13513
  float *dst_dd,
13514
  const dpct::queue_ptr &main_stream) {
 
 
 
 
13515
 
13516
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13517
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
13512
  const float *src0_dd, const float *src1_dd,
13513
  float *dst_dd,
13514
  const dpct::queue_ptr &main_stream) {
13515
+ #pragma message("TODO: generalize concat kernel for dim != 2")
13516
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
13517
+ int dim = dst->op_params[0];
13518
+ GGML_ASSERT(dim != 2);
13519
 
13520
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13521
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
ggml.c CHANGED
@@ -4882,10 +4882,21 @@ struct ggml_tensor * ggml_repeat_back(
4882
  // ggml_concat
4883
 
4884
  struct ggml_tensor * ggml_concat(
4885
- struct ggml_context* ctx,
4886
- struct ggml_tensor* a,
4887
- struct ggml_tensor* b) {
4888
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
 
 
 
 
 
 
 
 
 
 
 
4889
 
4890
  bool is_node = false;
4891
 
@@ -4893,7 +4904,9 @@ struct ggml_tensor * ggml_concat(
4893
  is_node = true;
4894
  }
4895
 
4896
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
 
 
4897
 
4898
  result->op = GGML_OP_CONCAT;
4899
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5013,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu(
5013
  }
5014
 
5015
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
5016
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
5017
 
5018
  result->op = GGML_OP_LEAKY_RELU;
@@ -10977,26 +10991,29 @@ static void ggml_compute_forward_concat_f32(
10977
  GGML_ASSERT(nb00 == sizeof(float));
10978
  GGML_ASSERT(nb10 == sizeof(float));
10979
 
 
 
 
 
 
 
 
 
 
 
10980
  for (int i3 = 0; i3 < ne3; i3++) {
10981
  for (int i2 = ith; i2 < ne2; i2 += nth) {
10982
- if (i2 < ne02) { // src0
10983
- for (int i1 = 0; i1 < ne1; i1++) {
10984
- for (int i0 = 0; i0 < ne0; i0++) {
10985
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10986
-
10987
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10988
- *y = *x;
10989
- }
10990
- }
10991
- } // src1
10992
- else {
10993
- for (int i1 = 0; i1 < ne1; i1++) {
10994
- for (int i0 = 0; i0 < ne0; i0++) {
10995
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10996
-
10997
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10998
- *y = *x;
10999
  }
 
 
 
 
11000
  }
11001
  }
11002
  }
@@ -11004,7 +11021,7 @@ static void ggml_compute_forward_concat_f32(
11004
  }
11005
 
11006
  static void ggml_compute_forward_concat(
11007
- const struct ggml_compute_params* params,
11008
  struct ggml_tensor* dst) {
11009
 
11010
  const struct ggml_tensor * src0 = dst->src[0];
 
4882
  // ggml_concat
4883
 
4884
  struct ggml_tensor * ggml_concat(
4885
+ struct ggml_context * ctx,
4886
+ struct ggml_tensor * a,
4887
+ struct ggml_tensor * b,
4888
+ int dim) {
4889
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4890
+
4891
+ int64_t ne[GGML_MAX_DIMS];
4892
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4893
+ if (d == dim) {
4894
+ ne[d] = a->ne[d] + b->ne[d];
4895
+ continue;
4896
+ }
4897
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
4898
+ ne[d] = a->ne[d];
4899
+ }
4900
 
4901
  bool is_node = false;
4902
 
 
4904
  is_node = true;
4905
  }
4906
 
4907
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4908
+
4909
+ ggml_set_op_params_i32(result, 0, dim);
4910
 
4911
  result->op = GGML_OP_CONCAT;
4912
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
 
5026
  }
5027
 
5028
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5029
+
5030
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
5031
 
5032
  result->op = GGML_OP_LEAKY_RELU;
 
10991
  GGML_ASSERT(nb00 == sizeof(float));
10992
  GGML_ASSERT(nb10 == sizeof(float));
10993
 
10994
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
10995
+
10996
+ GGML_ASSERT(dim >= 0 && dim < 4);
10997
+
10998
+ int64_t o[4] = {0, 0, 0, 0};
10999
+ o[dim] = src0->ne[dim];
11000
+
11001
+ const float * x;
11002
+
11003
+ // TODO: smarter multi-theading
11004
  for (int i3 = 0; i3 < ne3; i3++) {
11005
  for (int i2 = ith; i2 < ne2; i2 += nth) {
11006
+ for (int i1 = 0; i1 < ne1; i1++) {
11007
+ for (int i0 = 0; i0 < ne0; i0++) {
11008
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
11009
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
11010
+ } else {
11011
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
 
 
 
 
 
 
 
 
 
 
 
11012
  }
11013
+
11014
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
11015
+
11016
+ *y = *x;
11017
  }
11018
  }
11019
  }
 
11021
  }
11022
 
11023
  static void ggml_compute_forward_concat(
11024
+ const struct ggml_compute_params * params,
11025
  struct ggml_tensor* dst) {
11026
 
11027
  const struct ggml_tensor * src0 = dst->src[0];
ggml.h CHANGED
@@ -1007,12 +1007,13 @@ extern "C" {
1007
  struct ggml_tensor * a,
1008
  struct ggml_tensor * b);
1009
 
1010
- // concat a and b on dim 2
1011
  // used in stable-diffusion
1012
  GGML_API struct ggml_tensor * ggml_concat(
1013
  struct ggml_context * ctx,
1014
  struct ggml_tensor * a,
1015
- struct ggml_tensor * b);
 
1016
 
1017
  GGML_API struct ggml_tensor * ggml_abs(
1018
  struct ggml_context * ctx,
 
1007
  struct ggml_tensor * a,
1008
  struct ggml_tensor * b);
1009
 
1010
+ // concat a and b along dim
1011
  // used in stable-diffusion
1012
  GGML_API struct ggml_tensor * ggml_concat(
1013
  struct ggml_context * ctx,
1014
  struct ggml_tensor * a,
1015
+ struct ggml_tensor * b,
1016
+ int dim);
1017
 
1018
  GGML_API struct ggml_tensor * ggml_abs(
1019
  struct ggml_context * ctx,