Spaces:
Running
Running
ggml : add `GGML_PAD_REFLECT_1D` operation (ggml/1034)
Browse files* ggml_pad_reflect_1d defined in header
* implemented on CPU
* called the forward pass
* impl Metal kernel
* added Metal kernel
* added OP_PAD_REFLECT_1D in test-backend-ops.cpp
* add test-pad-reflect-1d test case
* test case support multiple backend
- ggml/include/ggml.h +8 -0
- ggml/src/ggml-cpu/ggml-cpu.c +39 -0
- ggml/src/ggml-metal/ggml-metal.m +35 -0
- ggml/src/ggml-metal/ggml-metal.metal +47 -0
- ggml/src/ggml.c +35 -2
ggml/include/ggml.h
CHANGED
|
@@ -499,6 +499,7 @@ extern "C" {
|
|
| 499 |
GGML_OP_POOL_2D_BACK,
|
| 500 |
GGML_OP_UPSCALE, // nearest interpolate
|
| 501 |
GGML_OP_PAD,
|
|
|
|
| 502 |
GGML_OP_ARANGE,
|
| 503 |
GGML_OP_TIMESTEP_EMBEDDING,
|
| 504 |
GGML_OP_ARGSORT,
|
|
@@ -1695,6 +1696,13 @@ extern "C" {
|
|
| 1695 |
int p2,
|
| 1696 |
int p3);
|
| 1697 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1698 |
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
| 1699 |
// timesteps: [N,]
|
| 1700 |
// return: [N, dim]
|
|
|
|
| 499 |
GGML_OP_POOL_2D_BACK,
|
| 500 |
GGML_OP_UPSCALE, // nearest interpolate
|
| 501 |
GGML_OP_PAD,
|
| 502 |
+
GGML_OP_PAD_REFLECT_1D,
|
| 503 |
GGML_OP_ARANGE,
|
| 504 |
GGML_OP_TIMESTEP_EMBEDDING,
|
| 505 |
GGML_OP_ARGSORT,
|
|
|
|
| 1696 |
int p2,
|
| 1697 |
int p3);
|
| 1698 |
|
| 1699 |
+
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
| 1700 |
+
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
| 1701 |
+
struct ggml_context * ctx,
|
| 1702 |
+
struct ggml_tensor * a,
|
| 1703 |
+
int p0,
|
| 1704 |
+
int p1);
|
| 1705 |
+
|
| 1706 |
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
| 1707 |
// timesteps: [N,]
|
| 1708 |
// return: [N, dim]
|
ggml/src/ggml-cpu/ggml-cpu.c
CHANGED
|
@@ -10439,6 +10439,40 @@ static void ggml_compute_forward_pad(
|
|
| 10439 |
}
|
| 10440 |
}
|
| 10441 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10442 |
|
| 10443 |
// ggml_compute_forward_arange
|
| 10444 |
|
|
@@ -12535,6 +12569,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 12535 |
{
|
| 12536 |
ggml_compute_forward_pad(params, tensor);
|
| 12537 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12538 |
case GGML_OP_ARANGE:
|
| 12539 |
{
|
| 12540 |
ggml_compute_forward_arange(params, tensor);
|
|
@@ -12877,6 +12915,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 12877 |
} break;
|
| 12878 |
case GGML_OP_UPSCALE:
|
| 12879 |
case GGML_OP_PAD:
|
|
|
|
| 12880 |
case GGML_OP_ARANGE:
|
| 12881 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 12882 |
case GGML_OP_ARGSORT:
|
|
|
|
| 10439 |
}
|
| 10440 |
}
|
| 10441 |
|
| 10442 |
+
// ggml_compute_forward_pad_reflect_1d
|
| 10443 |
+
|
| 10444 |
+
static void ggml_compute_forward_pad_reflect_1d(
|
| 10445 |
+
const struct ggml_compute_params * params,
|
| 10446 |
+
struct ggml_tensor * dst) {
|
| 10447 |
+
|
| 10448 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 10449 |
+
|
| 10450 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 10451 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 10452 |
+
|
| 10453 |
+
const int ith = params->ith;
|
| 10454 |
+
const int nth = params->nth;
|
| 10455 |
+
|
| 10456 |
+
const int32_t * opts = (const int32_t *) dst->op_params;
|
| 10457 |
+
const int p0 = opts[0];
|
| 10458 |
+
const int p1 = opts[1];
|
| 10459 |
+
|
| 10460 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
| 10461 |
+
|
| 10462 |
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
| 10463 |
+
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
| 10464 |
+
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
| 10465 |
+
float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
|
| 10466 |
+
float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
|
| 10467 |
+
|
| 10468 |
+
ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
| 10469 |
+
|
| 10470 |
+
for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
|
| 10471 |
+
for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
|
| 10472 |
+
}
|
| 10473 |
+
}
|
| 10474 |
+
}
|
| 10475 |
+
}
|
| 10476 |
|
| 10477 |
// ggml_compute_forward_arange
|
| 10478 |
|
|
|
|
| 12569 |
{
|
| 12570 |
ggml_compute_forward_pad(params, tensor);
|
| 12571 |
} break;
|
| 12572 |
+
case GGML_OP_PAD_REFLECT_1D:
|
| 12573 |
+
{
|
| 12574 |
+
ggml_compute_forward_pad_reflect_1d(params, tensor);
|
| 12575 |
+
} break;
|
| 12576 |
case GGML_OP_ARANGE:
|
| 12577 |
{
|
| 12578 |
ggml_compute_forward_arange(params, tensor);
|
|
|
|
| 12915 |
} break;
|
| 12916 |
case GGML_OP_UPSCALE:
|
| 12917 |
case GGML_OP_PAD:
|
| 12918 |
+
case GGML_OP_PAD_REFLECT_1D:
|
| 12919 |
case GGML_OP_ARANGE:
|
| 12920 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 12921 |
case GGML_OP_ARGSORT:
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -310,6 +310,7 @@ enum ggml_metal_kernel_type {
|
|
| 310 |
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
|
| 311 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
| 312 |
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
|
|
| 313 |
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
| 314 |
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
| 315 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
@@ -877,6 +878,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 877 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
| 878 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
| 879 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
|
|
| 880 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
| 881 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
| 882 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
@@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1099 |
case GGML_OP_POOL_2D:
|
| 1100 |
case GGML_OP_UPSCALE:
|
| 1101 |
case GGML_OP_PAD:
|
|
|
|
| 1102 |
case GGML_OP_ARANGE:
|
| 1103 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 1104 |
case GGML_OP_ARGSORT:
|
|
@@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(
|
|
| 3258 |
|
| 3259 |
const int nth = MIN(1024, ne0);
|
| 3260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3261 |
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3262 |
} break;
|
| 3263 |
case GGML_OP_ARANGE:
|
|
|
|
| 310 |
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
|
| 311 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
| 312 |
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
| 313 |
+
GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
|
| 314 |
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
| 315 |
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
| 316 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
|
|
| 878 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
| 879 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
| 880 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
| 881 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
|
| 882 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
| 883 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
| 884 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
|
|
| 1101 |
case GGML_OP_POOL_2D:
|
| 1102 |
case GGML_OP_UPSCALE:
|
| 1103 |
case GGML_OP_PAD:
|
| 1104 |
+
case GGML_OP_PAD_REFLECT_1D:
|
| 1105 |
case GGML_OP_ARANGE:
|
| 1106 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 1107 |
case GGML_OP_ARGSORT:
|
|
|
|
| 3261 |
|
| 3262 |
const int nth = MIN(1024, ne0);
|
| 3263 |
|
| 3264 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3265 |
+
} break;
|
| 3266 |
+
case GGML_OP_PAD_REFLECT_1D:
|
| 3267 |
+
{
|
| 3268 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3269 |
+
|
| 3270 |
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
|
| 3271 |
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
|
| 3272 |
+
|
| 3273 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
| 3274 |
+
|
| 3275 |
+
[encoder setComputePipelineState:pipeline];
|
| 3276 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3277 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3278 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 3279 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 3280 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 3281 |
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 3282 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
|
| 3283 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
| 3284 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
| 3285 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
| 3286 |
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
| 3287 |
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
|
| 3288 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
|
| 3289 |
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
|
| 3290 |
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
|
| 3291 |
+
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
|
| 3292 |
+
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
|
| 3293 |
+
|
| 3294 |
+
const int nth = MIN(1024, ne0);
|
| 3295 |
+
|
| 3296 |
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3297 |
} break;
|
| 3298 |
case GGML_OP_ARANGE:
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -2897,6 +2897,53 @@ kernel void kernel_pad_f32(
|
|
| 2897 |
}
|
| 2898 |
}
|
| 2899 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2900 |
kernel void kernel_arange_f32(
|
| 2901 |
device char * dst,
|
| 2902 |
constant int64_t & ne0,
|
|
|
|
| 2897 |
}
|
| 2898 |
}
|
| 2899 |
|
| 2900 |
+
kernel void kernel_pad_reflect_1d_f32(
|
| 2901 |
+
device const char * src0,
|
| 2902 |
+
device char * dst,
|
| 2903 |
+
constant int64_t & ne00,
|
| 2904 |
+
constant int64_t & ne01,
|
| 2905 |
+
constant int64_t & ne02,
|
| 2906 |
+
constant int64_t & ne03,
|
| 2907 |
+
constant int64_t & ne0,
|
| 2908 |
+
constant uint64_t & nb00,
|
| 2909 |
+
constant uint64_t & nb01,
|
| 2910 |
+
constant uint64_t & nb02,
|
| 2911 |
+
constant uint64_t & nb03,
|
| 2912 |
+
constant uint64_t & nb0,
|
| 2913 |
+
constant uint64_t & nb1,
|
| 2914 |
+
constant uint64_t & nb2,
|
| 2915 |
+
constant uint64_t & nb3,
|
| 2916 |
+
constant int32_t & p0,
|
| 2917 |
+
constant int32_t & p1,
|
| 2918 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2919 |
+
uint3 tgpg[[threadgroups_per_grid]],
|
| 2920 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2921 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2922 |
+
|
| 2923 |
+
const int64_t i3 = tgpig.z;
|
| 2924 |
+
const int64_t i2 = tgpig.y;
|
| 2925 |
+
const int64_t i1 = tgpig.x;
|
| 2926 |
+
|
| 2927 |
+
const int64_t i03 = i3;
|
| 2928 |
+
const int64_t i02 = i2;
|
| 2929 |
+
const int64_t i01 = i1;
|
| 2930 |
+
|
| 2931 |
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
| 2932 |
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
| 2933 |
+
|
| 2934 |
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
| 2935 |
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 2936 |
+
if (i0 < p0) {
|
| 2937 |
+
dst_ptr[i0] = src0_ptr[p0 - i0];
|
| 2938 |
+
} else if (i0 < ne0 - p1) {
|
| 2939 |
+
dst_ptr[i0] = src0_ptr[i0 - p0];
|
| 2940 |
+
} else {
|
| 2941 |
+
dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
|
| 2942 |
+
}
|
| 2943 |
+
}
|
| 2944 |
+
}
|
| 2945 |
+
}
|
| 2946 |
+
|
| 2947 |
kernel void kernel_arange_f32(
|
| 2948 |
device char * dst,
|
| 2949 |
constant int64_t & ne0,
|
ggml/src/ggml.c
CHANGED
|
@@ -950,6 +950,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 950 |
"POOL_2D_BACK",
|
| 951 |
"UPSCALE",
|
| 952 |
"PAD",
|
|
|
|
| 953 |
"ARANGE",
|
| 954 |
"TIMESTEP_EMBEDDING",
|
| 955 |
"ARGSORT",
|
|
@@ -983,7 +984,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 983 |
"OPT_STEP_ADAMW",
|
| 984 |
};
|
| 985 |
|
| 986 |
-
static_assert(GGML_OP_COUNT ==
|
| 987 |
|
| 988 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 989 |
"none",
|
|
@@ -1045,6 +1046,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1045 |
"pool_2d_back(x)",
|
| 1046 |
"upscale(x)",
|
| 1047 |
"pad(x)",
|
|
|
|
| 1048 |
"arange(start, stop, step)",
|
| 1049 |
"timestep_embedding(timesteps, dim, max_period)",
|
| 1050 |
"argsort(x)",
|
|
@@ -1078,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1078 |
"adamw(x)",
|
| 1079 |
};
|
| 1080 |
|
| 1081 |
-
static_assert(GGML_OP_COUNT ==
|
| 1082 |
|
| 1083 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1084 |
|
|
@@ -4097,6 +4099,37 @@ struct ggml_tensor * ggml_pad(
|
|
| 4097 |
return result;
|
| 4098 |
}
|
| 4099 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4100 |
// ggml_arange
|
| 4101 |
|
| 4102 |
struct ggml_tensor * ggml_arange(
|
|
|
|
| 950 |
"POOL_2D_BACK",
|
| 951 |
"UPSCALE",
|
| 952 |
"PAD",
|
| 953 |
+
"PAD_REFLECT_1D",
|
| 954 |
"ARANGE",
|
| 955 |
"TIMESTEP_EMBEDDING",
|
| 956 |
"ARGSORT",
|
|
|
|
| 984 |
"OPT_STEP_ADAMW",
|
| 985 |
};
|
| 986 |
|
| 987 |
+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
| 988 |
|
| 989 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 990 |
"none",
|
|
|
|
| 1046 |
"pool_2d_back(x)",
|
| 1047 |
"upscale(x)",
|
| 1048 |
"pad(x)",
|
| 1049 |
+
"pad_reflect_1d(x)",
|
| 1050 |
"arange(start, stop, step)",
|
| 1051 |
"timestep_embedding(timesteps, dim, max_period)",
|
| 1052 |
"argsort(x)",
|
|
|
|
| 1080 |
"adamw(x)",
|
| 1081 |
};
|
| 1082 |
|
| 1083 |
+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
| 1084 |
|
| 1085 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1086 |
|
|
|
|
| 4099 |
return result;
|
| 4100 |
}
|
| 4101 |
|
| 4102 |
+
// ggml_pad_reflect_1d
|
| 4103 |
+
|
| 4104 |
+
struct ggml_tensor * ggml_pad_reflect_1d(
|
| 4105 |
+
struct ggml_context * ctx,
|
| 4106 |
+
struct ggml_tensor * a,
|
| 4107 |
+
int p0,
|
| 4108 |
+
int p1) {
|
| 4109 |
+
GGML_ASSERT(p0 >= 0);
|
| 4110 |
+
GGML_ASSERT(p1 >= 0);
|
| 4111 |
+
|
| 4112 |
+
GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
|
| 4113 |
+
GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
|
| 4114 |
+
|
| 4115 |
+
GGML_ASSERT(ggml_is_contiguous(a));
|
| 4116 |
+
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
| 4117 |
+
|
| 4118 |
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
|
| 4119 |
+
a->ne[0] + p0 + p1,
|
| 4120 |
+
a->ne[1],
|
| 4121 |
+
a->ne[2],
|
| 4122 |
+
a->ne[3]);
|
| 4123 |
+
|
| 4124 |
+
int32_t params[] = { p0, p1 };
|
| 4125 |
+
ggml_set_op_params(result, params, sizeof(params));
|
| 4126 |
+
|
| 4127 |
+
result->op = GGML_OP_PAD_REFLECT_1D;
|
| 4128 |
+
result->src[0] = a;
|
| 4129 |
+
|
| 4130 |
+
return result;
|
| 4131 |
+
}
|
| 4132 |
+
|
| 4133 |
// ggml_arange
|
| 4134 |
|
| 4135 |
struct ggml_tensor * ggml_arange(
|