PABannier commited on
Commit
154bbc0
·
1 Parent(s): d3e3ea1

ggml : add `GGML_PAD_REFLECT_1D` operation (ggml/1034)

Browse files

* ggml_pad_reflect_1d defined in header

* implemented on CPU

* called the forward pass

* impl Metal kernel

* added Metal kernel

* added OP_PAD_REFLECT_1D in test-backend-ops.cpp

* add test-pad-reflect-1d test case

* test case support multiple backend

ggml/include/ggml.h CHANGED
@@ -499,6 +499,7 @@ extern "C" {
499
  GGML_OP_POOL_2D_BACK,
500
  GGML_OP_UPSCALE, // nearest interpolate
501
  GGML_OP_PAD,
 
502
  GGML_OP_ARANGE,
503
  GGML_OP_TIMESTEP_EMBEDDING,
504
  GGML_OP_ARGSORT,
@@ -1695,6 +1696,13 @@ extern "C" {
1695
  int p2,
1696
  int p3);
1697
 
 
 
 
 
 
 
 
1698
  // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1699
  // timesteps: [N,]
1700
  // return: [N, dim]
 
499
  GGML_OP_POOL_2D_BACK,
500
  GGML_OP_UPSCALE, // nearest interpolate
501
  GGML_OP_PAD,
502
+ GGML_OP_PAD_REFLECT_1D,
503
  GGML_OP_ARANGE,
504
  GGML_OP_TIMESTEP_EMBEDDING,
505
  GGML_OP_ARGSORT,
 
1696
  int p2,
1697
  int p3);
1698
 
1699
+ // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
1700
+ GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
1701
+ struct ggml_context * ctx,
1702
+ struct ggml_tensor * a,
1703
+ int p0,
1704
+ int p1);
1705
+
1706
  // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1707
  // timesteps: [N,]
1708
  // return: [N, dim]
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -10439,6 +10439,40 @@ static void ggml_compute_forward_pad(
10439
  }
10440
  }
10441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10442
 
10443
  // ggml_compute_forward_arange
10444
 
@@ -12535,6 +12569,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12535
  {
12536
  ggml_compute_forward_pad(params, tensor);
12537
  } break;
 
 
 
 
12538
  case GGML_OP_ARANGE:
12539
  {
12540
  ggml_compute_forward_arange(params, tensor);
@@ -12877,6 +12915,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
12877
  } break;
12878
  case GGML_OP_UPSCALE:
12879
  case GGML_OP_PAD:
 
12880
  case GGML_OP_ARANGE:
12881
  case GGML_OP_TIMESTEP_EMBEDDING:
12882
  case GGML_OP_ARGSORT:
 
10439
  }
10440
  }
10441
 
10442
+ // ggml_compute_forward_pad_reflect_1d
10443
+
10444
+ static void ggml_compute_forward_pad_reflect_1d(
10445
+ const struct ggml_compute_params * params,
10446
+ struct ggml_tensor * dst) {
10447
+
10448
+ const struct ggml_tensor * src0 = dst->src[0];
10449
+
10450
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
10451
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
10452
+
10453
+ const int ith = params->ith;
10454
+ const int nth = params->nth;
10455
+
10456
+ const int32_t * opts = (const int32_t *) dst->op_params;
10457
+ const int p0 = opts[0];
10458
+ const int p1 = opts[1];
10459
+
10460
+ GGML_TENSOR_UNARY_OP_LOCALS
10461
+
10462
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
10463
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
10464
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
10465
+ float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
10466
+ float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
10467
+
10468
+ ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
10469
+
10470
+ for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
10471
+ for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
10472
+ }
10473
+ }
10474
+ }
10475
+ }
10476
 
10477
  // ggml_compute_forward_arange
10478
 
 
12569
  {
12570
  ggml_compute_forward_pad(params, tensor);
12571
  } break;
12572
+ case GGML_OP_PAD_REFLECT_1D:
12573
+ {
12574
+ ggml_compute_forward_pad_reflect_1d(params, tensor);
12575
+ } break;
12576
  case GGML_OP_ARANGE:
12577
  {
12578
  ggml_compute_forward_arange(params, tensor);
 
12915
  } break;
12916
  case GGML_OP_UPSCALE:
12917
  case GGML_OP_PAD:
12918
+ case GGML_OP_PAD_REFLECT_1D:
12919
  case GGML_OP_ARANGE:
12920
  case GGML_OP_TIMESTEP_EMBEDDING:
12921
  case GGML_OP_ARGSORT:
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -310,6 +310,7 @@ enum ggml_metal_kernel_type {
310
  GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
311
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
312
  GGML_METAL_KERNEL_TYPE_PAD_F32,
 
313
  GGML_METAL_KERNEL_TYPE_ARANGE_F32,
314
  GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
315
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -877,6 +878,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
877
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
878
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
879
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
 
880
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
881
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
882
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1099
  case GGML_OP_POOL_2D:
1100
  case GGML_OP_UPSCALE:
1101
  case GGML_OP_PAD:
 
1102
  case GGML_OP_ARANGE:
1103
  case GGML_OP_TIMESTEP_EMBEDDING:
1104
  case GGML_OP_ARGSORT:
@@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(
3258
 
3259
  const int nth = MIN(1024, ne0);
3260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3261
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3262
  } break;
3263
  case GGML_OP_ARANGE:
 
310
  GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
311
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
312
  GGML_METAL_KERNEL_TYPE_PAD_F32,
313
+ GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
314
  GGML_METAL_KERNEL_TYPE_ARANGE_F32,
315
  GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
316
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
 
878
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
879
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
880
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
881
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
882
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
883
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
884
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
 
1101
  case GGML_OP_POOL_2D:
1102
  case GGML_OP_UPSCALE:
1103
  case GGML_OP_PAD:
1104
+ case GGML_OP_PAD_REFLECT_1D:
1105
  case GGML_OP_ARANGE:
1106
  case GGML_OP_TIMESTEP_EMBEDDING:
1107
  case GGML_OP_ARGSORT:
 
3261
 
3262
  const int nth = MIN(1024, ne0);
3263
 
3264
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3265
+ } break;
3266
+ case GGML_OP_PAD_REFLECT_1D:
3267
+ {
3268
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
3269
+
3270
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
3271
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
3272
+
3273
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
3274
+
3275
+ [encoder setComputePipelineState:pipeline];
3276
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3277
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3278
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3279
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3280
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3281
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3282
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
3283
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
3284
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
3285
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
3286
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
3287
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
3288
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
3289
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
3290
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
3291
+ [encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
3292
+ [encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
3293
+
3294
+ const int nth = MIN(1024, ne0);
3295
+
3296
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3297
  } break;
3298
  case GGML_OP_ARANGE:
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -2897,6 +2897,53 @@ kernel void kernel_pad_f32(
2897
  }
2898
  }
2899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2900
  kernel void kernel_arange_f32(
2901
  device char * dst,
2902
  constant int64_t & ne0,
 
2897
  }
2898
  }
2899
 
2900
+ kernel void kernel_pad_reflect_1d_f32(
2901
+ device const char * src0,
2902
+ device char * dst,
2903
+ constant int64_t & ne00,
2904
+ constant int64_t & ne01,
2905
+ constant int64_t & ne02,
2906
+ constant int64_t & ne03,
2907
+ constant int64_t & ne0,
2908
+ constant uint64_t & nb00,
2909
+ constant uint64_t & nb01,
2910
+ constant uint64_t & nb02,
2911
+ constant uint64_t & nb03,
2912
+ constant uint64_t & nb0,
2913
+ constant uint64_t & nb1,
2914
+ constant uint64_t & nb2,
2915
+ constant uint64_t & nb3,
2916
+ constant int32_t & p0,
2917
+ constant int32_t & p1,
2918
+ uint3 tgpig[[threadgroup_position_in_grid]],
2919
+ uint3 tgpg[[threadgroups_per_grid]],
2920
+ uint3 tpitg[[thread_position_in_threadgroup]],
2921
+ uint3 ntg[[threads_per_threadgroup]]) {
2922
+
2923
+ const int64_t i3 = tgpig.z;
2924
+ const int64_t i2 = tgpig.y;
2925
+ const int64_t i1 = tgpig.x;
2926
+
2927
+ const int64_t i03 = i3;
2928
+ const int64_t i02 = i2;
2929
+ const int64_t i01 = i1;
2930
+
2931
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2932
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
2933
+
2934
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2935
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2936
+ if (i0 < p0) {
2937
+ dst_ptr[i0] = src0_ptr[p0 - i0];
2938
+ } else if (i0 < ne0 - p1) {
2939
+ dst_ptr[i0] = src0_ptr[i0 - p0];
2940
+ } else {
2941
+ dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
2942
+ }
2943
+ }
2944
+ }
2945
+ }
2946
+
2947
  kernel void kernel_arange_f32(
2948
  device char * dst,
2949
  constant int64_t & ne0,
ggml/src/ggml.c CHANGED
@@ -950,6 +950,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
950
  "POOL_2D_BACK",
951
  "UPSCALE",
952
  "PAD",
 
953
  "ARANGE",
954
  "TIMESTEP_EMBEDDING",
955
  "ARGSORT",
@@ -983,7 +984,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
983
  "OPT_STEP_ADAMW",
984
  };
985
 
986
- static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
987
 
988
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
989
  "none",
@@ -1045,6 +1046,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1045
  "pool_2d_back(x)",
1046
  "upscale(x)",
1047
  "pad(x)",
 
1048
  "arange(start, stop, step)",
1049
  "timestep_embedding(timesteps, dim, max_period)",
1050
  "argsort(x)",
@@ -1078,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1078
  "adamw(x)",
1079
  };
1080
 
1081
- static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
1082
 
1083
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1084
 
@@ -4097,6 +4099,37 @@ struct ggml_tensor * ggml_pad(
4097
  return result;
4098
  }
4099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4100
  // ggml_arange
4101
 
4102
  struct ggml_tensor * ggml_arange(
 
950
  "POOL_2D_BACK",
951
  "UPSCALE",
952
  "PAD",
953
+ "PAD_REFLECT_1D",
954
  "ARANGE",
955
  "TIMESTEP_EMBEDDING",
956
  "ARGSORT",
 
984
  "OPT_STEP_ADAMW",
985
  };
986
 
987
+ static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
988
 
989
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
990
  "none",
 
1046
  "pool_2d_back(x)",
1047
  "upscale(x)",
1048
  "pad(x)",
1049
+ "pad_reflect_1d(x)",
1050
  "arange(start, stop, step)",
1051
  "timestep_embedding(timesteps, dim, max_period)",
1052
  "argsort(x)",
 
1080
  "adamw(x)",
1081
  };
1082
 
1083
+ static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1084
 
1085
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1086
 
 
4099
  return result;
4100
  }
4101
 
4102
+ // ggml_pad_reflect_1d
4103
+
4104
+ struct ggml_tensor * ggml_pad_reflect_1d(
4105
+ struct ggml_context * ctx,
4106
+ struct ggml_tensor * a,
4107
+ int p0,
4108
+ int p1) {
4109
+ GGML_ASSERT(p0 >= 0);
4110
+ GGML_ASSERT(p1 >= 0);
4111
+
4112
+ GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
4113
+ GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
4114
+
4115
+ GGML_ASSERT(ggml_is_contiguous(a));
4116
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
4117
+
4118
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
4119
+ a->ne[0] + p0 + p1,
4120
+ a->ne[1],
4121
+ a->ne[2],
4122
+ a->ne[3]);
4123
+
4124
+ int32_t params[] = { p0, p1 };
4125
+ ggml_set_op_params(result, params, sizeof(params));
4126
+
4127
+ result->op = GGML_OP_PAD_REFLECT_1D;
4128
+ result->src[0] = a;
4129
+
4130
+ return result;
4131
+ }
4132
+
4133
  // ggml_arange
4134
 
4135
  struct ggml_tensor * ggml_arange(