Ronsor ggerganov commited on
Commit
f541d31
·
1 Parent(s): 93e1056

feat: add new `sin` and `cos` operators (ggml/919)

Browse files

* ggml : add sin/cos operators

* ggml-cuda : add sin/cos operators

* ggml : add corresponding tests for sin/cos

* ggml : add backward computation for sin/cos operators

* ggml-vulkan : add sin/cos operators

* ggml-vulkan : add sin/cos shader source

* metal : add sin, cos

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ggml/include/ggml.h CHANGED
@@ -451,6 +451,8 @@ extern "C" {
451
  GGML_OP_SQR,
452
  GGML_OP_SQRT,
453
  GGML_OP_LOG,
 
 
454
  GGML_OP_SUM,
455
  GGML_OP_SUM_ROWS,
456
  GGML_OP_MEAN,
@@ -967,6 +969,22 @@ extern "C" {
967
  struct ggml_context * ctx,
968
  struct ggml_tensor * a);
969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
  // return scalar
971
  GGML_API struct ggml_tensor * ggml_sum(
972
  struct ggml_context * ctx,
 
451
  GGML_OP_SQR,
452
  GGML_OP_SQRT,
453
  GGML_OP_LOG,
454
+ GGML_OP_SIN,
455
+ GGML_OP_COS,
456
  GGML_OP_SUM,
457
  GGML_OP_SUM_ROWS,
458
  GGML_OP_MEAN,
 
969
  struct ggml_context * ctx,
970
  struct ggml_tensor * a);
971
 
972
+ GGML_API struct ggml_tensor * ggml_sin(
973
+ struct ggml_context * ctx,
974
+ struct ggml_tensor * a);
975
+
976
+ GGML_API struct ggml_tensor * ggml_sin_inplace(
977
+ struct ggml_context * ctx,
978
+ struct ggml_tensor * a);
979
+
980
+ GGML_API struct ggml_tensor * ggml_cos(
981
+ struct ggml_context * ctx,
982
+ struct ggml_tensor * a);
983
+
984
+ GGML_API struct ggml_tensor * ggml_cos_inplace(
985
+ struct ggml_context * ctx,
986
+ struct ggml_tensor * a);
987
+
988
  // return scalar
989
  GGML_API struct ggml_tensor * ggml_sum(
990
  struct ggml_context * ctx,
ggml/src/ggml-cuda.cu CHANGED
@@ -2267,6 +2267,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2267
  case GGML_OP_SQRT:
2268
  ggml_cuda_op_sqrt(ctx, dst);
2269
  break;
 
 
 
 
 
 
2270
  case GGML_OP_CLAMP:
2271
  ggml_cuda_op_clamp(ctx, dst);
2272
  break;
@@ -2859,6 +2865,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2859
  case GGML_OP_SCALE:
2860
  case GGML_OP_SQR:
2861
  case GGML_OP_SQRT:
 
 
2862
  case GGML_OP_CLAMP:
2863
  case GGML_OP_CONT:
2864
  case GGML_OP_DIAG_MASK_INF:
 
2267
  case GGML_OP_SQRT:
2268
  ggml_cuda_op_sqrt(ctx, dst);
2269
  break;
2270
+ case GGML_OP_SIN:
2271
+ ggml_cuda_op_sin(ctx, dst);
2272
+ break;
2273
+ case GGML_OP_COS:
2274
+ ggml_cuda_op_cos(ctx, dst);
2275
+ break;
2276
  case GGML_OP_CLAMP:
2277
  ggml_cuda_op_clamp(ctx, dst);
2278
  break;
 
2865
  case GGML_OP_SCALE:
2866
  case GGML_OP_SQR:
2867
  case GGML_OP_SQRT:
2868
+ case GGML_OP_SIN:
2869
+ case GGML_OP_COS:
2870
  case GGML_OP_CLAMP:
2871
  case GGML_OP_CONT:
2872
  case GGML_OP_DIAG_MASK_INF:
ggml/src/ggml-cuda/unary.cu CHANGED
@@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
101
  dst[i] = sqrtf(x[i]);
102
  }
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
105
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
106
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
156
  sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
157
  }
158
 
 
 
 
 
 
 
 
 
 
 
159
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160
  const ggml_tensor * src0 = dst->src[0];
161
  const float * src0_d = (const float *)src0->data;
@@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
312
 
313
  sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
314
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  dst[i] = sqrtf(x[i]);
102
  }
103
 
104
+ static __global__ void sin_f32(const float * x, float * dst, const int k) {
105
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
106
+
107
+ if (i >= k) {
108
+ return;
109
+ }
110
+ dst[i] = sinf(x[i]);
111
+ }
112
+
113
+ static __global__ void cos_f32(const float * x, float * dst, const int k) {
114
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
115
+
116
+ if (i >= k) {
117
+ return;
118
+ }
119
+ dst[i] = cosf(x[i]);
120
+ }
121
+
122
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
123
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
124
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 
174
  sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
175
  }
176
 
177
+ static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
178
+ const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
179
+ sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
180
+ }
181
+
182
+ static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
183
+ const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
184
+ cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
185
+ }
186
+
187
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188
  const ggml_tensor * src0 = dst->src[0];
189
  const float * src0_d = (const float *)src0->data;
 
340
 
341
  sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
342
  }
343
+
344
+ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
345
+ const ggml_tensor * src0 = dst->src[0];
346
+ const float * src0_d = (const float *)src0->data;
347
+ float * dst_d = (float *)dst->data;
348
+ cudaStream_t stream = ctx.stream();
349
+
350
+ GGML_ASSERT(ggml_is_contiguous(src0));
351
+
352
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
353
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
354
+
355
+ sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
356
+ }
357
+
358
+ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
359
+ const ggml_tensor * src0 = dst->src[0];
360
+ const float * src0_d = (const float *)src0->data;
361
+ float * dst_d = (float *)dst->data;
362
+ cudaStream_t stream = ctx.stream();
363
+
364
+ GGML_ASSERT(ggml_is_contiguous(src0));
365
+
366
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
367
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
368
+
369
+ cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
370
+ }
ggml/src/ggml-cuda/unary.cuh CHANGED
@@ -9,6 +9,8 @@
9
  #define CUDA_HARDSWISH_BLOCK_SIZE 256
10
  #define CUDA_SQR_BLOCK_SIZE 256
11
  #define CUDA_SQRT_BLOCK_SIZE 256
 
 
12
 
13
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
14
 
@@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
31
  void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
32
 
33
  void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
 
 
9
  #define CUDA_HARDSWISH_BLOCK_SIZE 256
10
  #define CUDA_SQR_BLOCK_SIZE 256
11
  #define CUDA_SQRT_BLOCK_SIZE 256
12
+ #define CUDA_SIN_BLOCK_SIZE 256
13
+ #define CUDA_COS_BLOCK_SIZE 256
14
 
15
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
16
 
 
33
  void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
34
 
35
  void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
36
+
37
+ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
38
+
39
+ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-metal.m CHANGED
@@ -205,6 +205,8 @@ enum ggml_metal_kernel_type {
205
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
206
  GGML_METAL_KERNEL_TYPE_CONCAT,
207
  GGML_METAL_KERNEL_TYPE_SQR,
 
 
208
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
209
 
210
  GGML_METAL_KERNEL_TYPE_COUNT
@@ -665,6 +667,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
665
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
666
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
667
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
 
 
668
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
669
  }
670
 
@@ -771,9 +775,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
771
  case GGML_OP_REPEAT:
772
  case GGML_OP_SCALE:
773
  case GGML_OP_CLAMP:
 
774
  case GGML_OP_SQR:
 
 
 
775
  case GGML_OP_SUM_ROWS:
776
- return true;
777
  case GGML_OP_SOFT_MAX:
778
  case GGML_OP_RMS_NORM:
779
  case GGML_OP_GROUP_NORM:
@@ -1409,6 +1416,34 @@ static enum ggml_status ggml_metal_graph_compute(
1409
 
1410
  const int64_t n = ggml_nelements(dst);
1411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1412
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1413
  } break;
1414
  case GGML_OP_SUM_ROWS:
 
205
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
206
  GGML_METAL_KERNEL_TYPE_CONCAT,
207
  GGML_METAL_KERNEL_TYPE_SQR,
208
+ GGML_METAL_KERNEL_TYPE_SIN,
209
+ GGML_METAL_KERNEL_TYPE_COS,
210
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
211
 
212
  GGML_METAL_KERNEL_TYPE_COUNT
 
667
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
668
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
669
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
670
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
671
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
672
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
673
  }
674
 
 
775
  case GGML_OP_REPEAT:
776
  case GGML_OP_SCALE:
777
  case GGML_OP_CLAMP:
778
+ return true;
779
  case GGML_OP_SQR:
780
+ case GGML_OP_SIN:
781
+ case GGML_OP_COS:
782
+ return ggml_is_contiguous(op->src[0]);
783
  case GGML_OP_SUM_ROWS:
 
784
  case GGML_OP_SOFT_MAX:
785
  case GGML_OP_RMS_NORM:
786
  case GGML_OP_GROUP_NORM:
 
1416
 
1417
  const int64_t n = ggml_nelements(dst);
1418
 
1419
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1420
+ } break;
1421
+ case GGML_OP_SIN:
1422
+ {
1423
+ GGML_ASSERT(ggml_is_contiguous(src0));
1424
+
1425
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
1426
+
1427
+ [encoder setComputePipelineState:pipeline];
1428
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1429
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1430
+
1431
+ const int64_t n = ggml_nelements(dst);
1432
+
1433
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1434
+ } break;
1435
+ case GGML_OP_COS:
1436
+ {
1437
+ GGML_ASSERT(ggml_is_contiguous(src0));
1438
+
1439
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
1440
+
1441
+ [encoder setComputePipelineState:pipeline];
1442
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1443
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1444
+
1445
+ const int64_t n = ggml_nelements(dst);
1446
+
1447
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1448
  } break;
1449
  case GGML_OP_SUM_ROWS:
ggml/src/ggml-metal.metal CHANGED
@@ -358,6 +358,20 @@ kernel void kernel_sqr(
358
  dst[tpig] = src0[tpig] * src0[tpig];
359
  }
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  kernel void kernel_sum_rows(
362
  device const float * src0,
363
  device float * dst,
 
358
  dst[tpig] = src0[tpig] * src0[tpig];
359
  }
360
 
361
+ kernel void kernel_sin(
362
+ device const float * src0,
363
+ device float * dst,
364
+ uint tpig[[thread_position_in_grid]]) {
365
+ dst[tpig] = sin(src0[tpig]);
366
+ }
367
+
368
+ kernel void kernel_cos(
369
+ device const float * src0,
370
+ device float * dst,
371
+ uint tpig[[thread_position_in_grid]]) {
372
+ dst[tpig] = cos(src0[tpig]);
373
+ }
374
+
375
  kernel void kernel_sum_rows(
376
  device const float * src0,
377
  device float * dst,
ggml/src/ggml-vulkan.cpp CHANGED
@@ -184,6 +184,8 @@ struct vk_device_struct {
184
  vk_pipeline pipeline_upscale_f32;
185
  vk_pipeline pipeline_scale_f32;
186
  vk_pipeline pipeline_sqr_f32;
 
 
187
  vk_pipeline pipeline_clamp_f32;
188
  vk_pipeline pipeline_pad_f32;
189
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
@@ -1654,6 +1656,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1654
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1655
 
1656
  ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
 
1657
 
1658
  ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1659
 
@@ -3972,6 +3976,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3972
  return ctx->device->pipeline_sqr_f32;
3973
  }
3974
  return nullptr;
 
 
 
 
 
 
 
 
 
 
3975
  case GGML_OP_CLAMP:
3976
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3977
  return ctx->device->pipeline_clamp_f32;
@@ -4124,6 +4138,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
4124
  case GGML_OP_UPSCALE:
4125
  case GGML_OP_SCALE:
4126
  case GGML_OP_SQR:
 
 
4127
  case GGML_OP_CLAMP:
4128
  case GGML_OP_PAD:
4129
  return true;
@@ -4335,6 +4351,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4335
  case GGML_OP_MUL:
4336
  case GGML_OP_SCALE:
4337
  case GGML_OP_SQR:
 
 
4338
  case GGML_OP_CLAMP:
4339
  case GGML_OP_PAD:
4340
  case GGML_OP_CPY:
@@ -4576,6 +4594,32 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
4576
  });
4577
  }
4578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4579
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4580
  float * op_params = (float *)dst->op_params;
4581
  const uint32_t src0_type_size = ggml_type_size(src0->type);
@@ -5481,6 +5525,8 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
5481
  case GGML_OP_ADD:
5482
  case GGML_OP_SCALE:
5483
  case GGML_OP_SQR:
 
 
5484
  case GGML_OP_CLAMP:
5485
  case GGML_OP_PAD:
5486
  case GGML_OP_CPY:
@@ -5761,6 +5807,8 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5761
  case GGML_OP_UPSCALE:
5762
  case GGML_OP_SCALE:
5763
  case GGML_OP_SQR:
 
 
5764
  case GGML_OP_CLAMP:
5765
  case GGML_OP_PAD:
5766
  case GGML_OP_CPY:
@@ -5832,6 +5880,14 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5832
  case GGML_OP_SQR:
5833
  ggml_vk_sqr(ctx, compute_ctx, src0, node);
5834
 
 
 
 
 
 
 
 
 
5835
  break;
5836
  case GGML_OP_CLAMP:
5837
  ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -5943,6 +5999,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
5943
  case GGML_OP_UPSCALE:
5944
  case GGML_OP_SCALE:
5945
  case GGML_OP_SQR:
 
 
5946
  case GGML_OP_CLAMP:
5947
  case GGML_OP_PAD:
5948
  case GGML_OP_CPY:
@@ -6658,6 +6716,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
6658
  case GGML_OP_UPSCALE:
6659
  case GGML_OP_SCALE:
6660
  case GGML_OP_SQR:
 
 
6661
  case GGML_OP_CLAMP:
6662
  case GGML_OP_PAD:
6663
  case GGML_OP_CONT:
 
184
  vk_pipeline pipeline_upscale_f32;
185
  vk_pipeline pipeline_scale_f32;
186
  vk_pipeline pipeline_sqr_f32;
187
+ vk_pipeline pipeline_sin_f32;
188
+ vk_pipeline pipeline_cos_f32;
189
  vk_pipeline pipeline_clamp_f32;
190
  vk_pipeline pipeline_pad_f32;
191
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
 
1656
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1657
 
1658
  ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1659
+ ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1660
+ ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1661
 
1662
  ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1663
 
 
3976
  return ctx->device->pipeline_sqr_f32;
3977
  }
3978
  return nullptr;
3979
+ case GGML_OP_SIN:
3980
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3981
+ return ctx->device->pipeline_sin_f32;
3982
+ }
3983
+ return nullptr;
3984
+ case GGML_OP_COS:
3985
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3986
+ return ctx->device->pipeline_cos_f32;
3987
+ }
3988
+ return nullptr;
3989
  case GGML_OP_CLAMP:
3990
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3991
  return ctx->device->pipeline_clamp_f32;
 
4138
  case GGML_OP_UPSCALE:
4139
  case GGML_OP_SCALE:
4140
  case GGML_OP_SQR:
4141
+ case GGML_OP_SIN:
4142
+ case GGML_OP_COS:
4143
  case GGML_OP_CLAMP:
4144
  case GGML_OP_PAD:
4145
  return true;
 
4351
  case GGML_OP_MUL:
4352
  case GGML_OP_SCALE:
4353
  case GGML_OP_SQR:
4354
+ case GGML_OP_SIN:
4355
+ case GGML_OP_COS:
4356
  case GGML_OP_CLAMP:
4357
  case GGML_OP_PAD:
4358
  case GGML_OP_CPY:
 
4594
  });
4595
  }
4596
 
4597
+ static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4598
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
4599
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
4600
+
4601
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
4602
+ (uint32_t)ggml_nelements(src0),
4603
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
4604
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4605
+ 0,
4606
+ 0.0f, 0.0f,
4607
+ });
4608
+ }
4609
+
4610
+ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4611
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
4612
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
4613
+
4614
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
4615
+ (uint32_t)ggml_nelements(src0),
4616
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
4617
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4618
+ 0,
4619
+ 0.0f, 0.0f,
4620
+ });
4621
+ }
4622
+
4623
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4624
  float * op_params = (float *)dst->op_params;
4625
  const uint32_t src0_type_size = ggml_type_size(src0->type);
 
5525
  case GGML_OP_ADD:
5526
  case GGML_OP_SCALE:
5527
  case GGML_OP_SQR:
5528
+ case GGML_OP_SIN:
5529
+ case GGML_OP_COS:
5530
  case GGML_OP_CLAMP:
5531
  case GGML_OP_PAD:
5532
  case GGML_OP_CPY:
 
5807
  case GGML_OP_UPSCALE:
5808
  case GGML_OP_SCALE:
5809
  case GGML_OP_SQR:
5810
+ case GGML_OP_SIN:
5811
+ case GGML_OP_COS:
5812
  case GGML_OP_CLAMP:
5813
  case GGML_OP_PAD:
5814
  case GGML_OP_CPY:
 
5880
  case GGML_OP_SQR:
5881
  ggml_vk_sqr(ctx, compute_ctx, src0, node);
5882
 
5883
+ break;
5884
+ case GGML_OP_SIN:
5885
+ ggml_vk_sin(ctx, compute_ctx, src0, node);
5886
+
5887
+ break;
5888
+ case GGML_OP_COS:
5889
+ ggml_vk_cos(ctx, compute_ctx, src0, node);
5890
+
5891
  break;
5892
  case GGML_OP_CLAMP:
5893
  ggml_vk_clamp(ctx, compute_ctx, src0, node);
 
5999
  case GGML_OP_UPSCALE:
6000
  case GGML_OP_SCALE:
6001
  case GGML_OP_SQR:
6002
+ case GGML_OP_SIN:
6003
+ case GGML_OP_COS:
6004
  case GGML_OP_CLAMP:
6005
  case GGML_OP_PAD:
6006
  case GGML_OP_CPY:
 
6716
  case GGML_OP_UPSCALE:
6717
  case GGML_OP_SCALE:
6718
  case GGML_OP_SQR:
6719
+ case GGML_OP_SIN:
6720
+ case GGML_OP_COS:
6721
  case GGML_OP_CLAMP:
6722
  case GGML_OP_PAD:
6723
  case GGML_OP_CONT:
ggml/src/ggml.c CHANGED
@@ -2310,7 +2310,9 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
2310
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
2311
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
2312
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
2313
- inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
 
 
2314
  inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
2315
  inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
2316
  inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@@ -2760,6 +2762,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2760
  "SQR",
2761
  "SQRT",
2762
  "LOG",
 
 
2763
  "SUM",
2764
  "SUM_ROWS",
2765
  "MEAN",
@@ -2833,7 +2837,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2833
  "CROSS_ENTROPY_LOSS_BACK",
2834
  };
2835
 
2836
- static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2837
 
2838
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2839
  "none",
@@ -2848,6 +2852,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2848
  "x^2",
2849
  "√x",
2850
  "log(x)",
 
 
2851
  "Σx",
2852
  "Σx_k",
2853
  "Σx/n",
@@ -2921,7 +2927,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2921
  "cross_entropy_loss_back(x,y)",
2922
  };
2923
 
2924
- static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2925
 
2926
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2927
 
@@ -4882,6 +4888,72 @@ struct ggml_tensor * ggml_log_inplace(
4882
  return ggml_log_impl(ctx, a, true);
4883
  }
4884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4885
  // ggml_sum
4886
 
4887
  struct ggml_tensor * ggml_sum(
@@ -10512,6 +10584,96 @@ static void ggml_compute_forward_log(
10512
  }
10513
  }
10514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10515
  // ggml_compute_forward_sum
10516
 
10517
  static void ggml_compute_forward_sum_f32(
@@ -16787,6 +16949,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16787
  {
16788
  ggml_compute_forward_log(params, tensor);
16789
  } break;
 
 
 
 
 
 
 
 
16790
  case GGML_OP_SUM:
16791
  {
16792
  ggml_compute_forward_sum(params, tensor);
@@ -17433,6 +17603,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17433
  zero_table);
17434
  }
17435
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17436
  case GGML_OP_SUM:
17437
  {
17438
  if (src0->grad) {
@@ -18520,6 +18714,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
18520
  case GGML_OP_SQR:
18521
  case GGML_OP_SQRT:
18522
  case GGML_OP_LOG:
 
 
18523
  case GGML_OP_SUM:
18524
  case GGML_OP_SUM_ROWS:
18525
  case GGML_OP_MEAN:
 
2310
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
2311
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
2312
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
2313
+ inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
2314
+ inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
2315
+ inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
2316
  inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
2317
  inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
2318
  inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
 
2762
  "SQR",
2763
  "SQRT",
2764
  "LOG",
2765
+ "SIN",
2766
+ "COS",
2767
  "SUM",
2768
  "SUM_ROWS",
2769
  "MEAN",
 
2837
  "CROSS_ENTROPY_LOSS_BACK",
2838
  };
2839
 
2840
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2841
 
2842
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2843
  "none",
 
2852
  "x^2",
2853
  "√x",
2854
  "log(x)",
2855
+ "sin(x)",
2856
+ "cos(x)",
2857
  "Σx",
2858
  "Σx_k",
2859
  "Σx/n",
 
2927
  "cross_entropy_loss_back(x,y)",
2928
  };
2929
 
2930
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2931
 
2932
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2933
 
 
4888
  return ggml_log_impl(ctx, a, true);
4889
  }
4890
 
4891
+ // ggml_sin
4892
+
4893
+ static struct ggml_tensor * ggml_sin_impl(
4894
+ struct ggml_context * ctx,
4895
+ struct ggml_tensor * a,
4896
+ bool inplace) {
4897
+ bool is_node = false;
4898
+
4899
+ if (!inplace && (a->grad)) {
4900
+ is_node = true;
4901
+ }
4902
+
4903
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4904
+
4905
+ result->op = GGML_OP_SIN;
4906
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4907
+ result->src[0] = a;
4908
+
4909
+ return result;
4910
+ }
4911
+
4912
+ struct ggml_tensor * ggml_sin(
4913
+ struct ggml_context * ctx,
4914
+ struct ggml_tensor * a) {
4915
+ return ggml_sin_impl(ctx, a, false);
4916
+ }
4917
+
4918
+ struct ggml_tensor * ggml_sin_inplace(
4919
+ struct ggml_context * ctx,
4920
+ struct ggml_tensor * a) {
4921
+ return ggml_sin_impl(ctx, a, true);
4922
+ }
4923
+
4924
+ // ggml_cos
4925
+
4926
+ static struct ggml_tensor * ggml_cos_impl(
4927
+ struct ggml_context * ctx,
4928
+ struct ggml_tensor * a,
4929
+ bool inplace) {
4930
+ bool is_node = false;
4931
+
4932
+ if (!inplace && (a->grad)) {
4933
+ is_node = true;
4934
+ }
4935
+
4936
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4937
+
4938
+ result->op = GGML_OP_COS;
4939
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4940
+ result->src[0] = a;
4941
+
4942
+ return result;
4943
+ }
4944
+
4945
+ struct ggml_tensor * ggml_cos(
4946
+ struct ggml_context * ctx,
4947
+ struct ggml_tensor * a) {
4948
+ return ggml_cos_impl(ctx, a, false);
4949
+ }
4950
+
4951
+ struct ggml_tensor * ggml_cos_inplace(
4952
+ struct ggml_context * ctx,
4953
+ struct ggml_tensor * a) {
4954
+ return ggml_cos_impl(ctx, a, true);
4955
+ }
4956
+
4957
  // ggml_sum
4958
 
4959
  struct ggml_tensor * ggml_sum(
 
10584
  }
10585
  }
10586
 
10587
+ // ggml_compute_forward_sin
10588
+
10589
+ static void ggml_compute_forward_sin_f32(
10590
+ const struct ggml_compute_params * params,
10591
+ struct ggml_tensor * dst) {
10592
+
10593
+ const struct ggml_tensor * src0 = dst->src[0];
10594
+
10595
+ if (params->ith != 0) {
10596
+ return;
10597
+ }
10598
+
10599
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
10600
+
10601
+ const int n = ggml_nrows(src0);
10602
+ const int nc = src0->ne[0];
10603
+
10604
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
10605
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10606
+
10607
+ for (int i = 0; i < n; i++) {
10608
+ ggml_vec_sin_f32(nc,
10609
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
10610
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
10611
+ }
10612
+ }
10613
+
10614
+ static void ggml_compute_forward_sin(
10615
+ const struct ggml_compute_params * params,
10616
+ struct ggml_tensor * dst) {
10617
+
10618
+ const struct ggml_tensor * src0 = dst->src[0];
10619
+
10620
+ switch (src0->type) {
10621
+ case GGML_TYPE_F32:
10622
+ {
10623
+ ggml_compute_forward_sin_f32(params, dst);
10624
+ } break;
10625
+ default:
10626
+ {
10627
+ GGML_ABORT("fatal error");
10628
+ }
10629
+ }
10630
+ }
10631
+
10632
+ // ggml_compute_forward_cos
10633
+
10634
+ static void ggml_compute_forward_cos_f32(
10635
+ const struct ggml_compute_params * params,
10636
+ struct ggml_tensor * dst) {
10637
+
10638
+ const struct ggml_tensor * src0 = dst->src[0];
10639
+
10640
+ if (params->ith != 0) {
10641
+ return;
10642
+ }
10643
+
10644
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
10645
+
10646
+ const int n = ggml_nrows(src0);
10647
+ const int nc = src0->ne[0];
10648
+
10649
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
10650
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10651
+
10652
+ for (int i = 0; i < n; i++) {
10653
+ ggml_vec_cos_f32(nc,
10654
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
10655
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
10656
+ }
10657
+ }
10658
+
10659
+ static void ggml_compute_forward_cos(
10660
+ const struct ggml_compute_params * params,
10661
+ struct ggml_tensor * dst) {
10662
+
10663
+ const struct ggml_tensor * src0 = dst->src[0];
10664
+
10665
+ switch (src0->type) {
10666
+ case GGML_TYPE_F32:
10667
+ {
10668
+ ggml_compute_forward_cos_f32(params, dst);
10669
+ } break;
10670
+ default:
10671
+ {
10672
+ GGML_ABORT("fatal error");
10673
+ }
10674
+ }
10675
+ }
10676
+
10677
  // ggml_compute_forward_sum
10678
 
10679
  static void ggml_compute_forward_sum_f32(
 
16949
  {
16950
  ggml_compute_forward_log(params, tensor);
16951
  } break;
16952
+ case GGML_OP_SIN:
16953
+ {
16954
+ ggml_compute_forward_sin(params, tensor);
16955
+ } break;
16956
+ case GGML_OP_COS:
16957
+ {
16958
+ ggml_compute_forward_cos(params, tensor);
16959
+ } break;
16960
  case GGML_OP_SUM:
16961
  {
16962
  ggml_compute_forward_sum(params, tensor);
 
17603
  zero_table);
17604
  }
17605
  } break;
17606
+ case GGML_OP_SIN:
17607
+ {
17608
+ if (src0->grad) {
17609
+ src0->grad =
17610
+ ggml_add_or_set(ctx,
17611
+ src0->grad,
17612
+ ggml_mul(ctx,
17613
+ tensor->grad,
17614
+ ggml_cos(ctx, src0)),
17615
+ zero_table);
17616
+ }
17617
+ } break;
17618
+ case GGML_OP_COS:
17619
+ {
17620
+ if (src0->grad) {
17621
+ src0->grad =
17622
+ ggml_sub_or_set(ctx,
17623
+ src0->grad,
17624
+ ggml_mul(ctx,
17625
+ tensor->grad,
17626
+ ggml_sin(ctx, src0)),
17627
+ zero_table);
17628
+ }
17629
+ } break;
17630
  case GGML_OP_SUM:
17631
  {
17632
  if (src0->grad) {
 
18714
  case GGML_OP_SQR:
18715
  case GGML_OP_SQRT:
18716
  case GGML_OP_LOG:
18717
+ case GGML_OP_SIN:
18718
+ case GGML_OP_COS:
18719
  case GGML_OP_SUM:
18720
  case GGML_OP_SUM_ROWS:
18721
  case GGML_OP_MEAN: