mollysama commited on
Commit
727de7e
·
1 Parent(s): 1e69b8c

llama: Add support for RWKV v7 architecture (llama/12412)

Browse files

* ggml: Add op l2_norm

Signed-off-by: Molly Sophia <[email protected]>

* ggml: Add op rwkv_wkv7

Signed-off-by: Molly Sophia <[email protected]>

* llama: Add support for RWKV7 and ARWKV7 models

Signed-off-by: Molly Sophia <[email protected]>

* llama: fix inference with RWKV6Qwen2

Signed-off-by: Molly Sophia <[email protected]>

* llama: add more (a)rwkv7 variants in size

Signed-off-by: Molly Sophia <[email protected]>

* Apply code-format changes

Signed-off-by: Molly Sophia <[email protected]>

* fix MUSA build

Signed-off-by: Molly Sophia <[email protected]>

* llama: fix shape error with rwkv using llama-parallel

Signed-off-by: Molly Sophia <[email protected]>

---------

Signed-off-by: Molly Sophia <[email protected]>

ggml/include/ggml.h CHANGED
@@ -454,6 +454,7 @@ extern "C" {
454
  GGML_OP_RMS_NORM,
455
  GGML_OP_RMS_NORM_BACK,
456
  GGML_OP_GROUP_NORM,
 
457
 
458
  GGML_OP_MUL_MAT,
459
  GGML_OP_MUL_MAT_ID,
@@ -502,6 +503,7 @@ extern "C" {
502
  GGML_OP_ADD_REL_POS,
503
  GGML_OP_RWKV_WKV6,
504
  GGML_OP_GATED_LINEAR_ATTN,
 
505
 
506
  GGML_OP_UNARY,
507
 
@@ -1095,6 +1097,18 @@ extern "C" {
1095
  int n_groups,
1096
  float eps);
1097
 
 
 
 
 
 
 
 
 
 
 
 
 
1098
  // a - x
1099
  // b - dy
1100
  GGML_API struct ggml_tensor * ggml_rms_norm_back(
@@ -1890,6 +1904,16 @@ extern "C" {
1890
  struct ggml_tensor * state,
1891
  float scale);
1892
 
 
 
 
 
 
 
 
 
 
 
1893
  // custom operators
1894
 
1895
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
 
454
  GGML_OP_RMS_NORM,
455
  GGML_OP_RMS_NORM_BACK,
456
  GGML_OP_GROUP_NORM,
457
+ GGML_OP_L2_NORM,
458
 
459
  GGML_OP_MUL_MAT,
460
  GGML_OP_MUL_MAT_ID,
 
503
  GGML_OP_ADD_REL_POS,
504
  GGML_OP_RWKV_WKV6,
505
  GGML_OP_GATED_LINEAR_ATTN,
506
+ GGML_OP_RWKV_WKV7,
507
 
508
  GGML_OP_UNARY,
509
 
 
1097
  int n_groups,
1098
  float eps);
1099
 
1100
+ // l2 normalize along rows
1101
+ // used in rwkv v7
1102
+ GGML_API struct ggml_tensor * ggml_l2_norm(
1103
+ struct ggml_context * ctx,
1104
+ struct ggml_tensor * a,
1105
+ float eps);
1106
+
1107
+ GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
1108
+ struct ggml_context * ctx,
1109
+ struct ggml_tensor * a,
1110
+ float eps);
1111
+
1112
  // a - x
1113
  // b - dy
1114
  GGML_API struct ggml_tensor * ggml_rms_norm_back(
 
1904
  struct ggml_tensor * state,
1905
  float scale);
1906
 
1907
+ GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
1908
+ struct ggml_context * ctx,
1909
+ struct ggml_tensor * r,
1910
+ struct ggml_tensor * w,
1911
+ struct ggml_tensor * k,
1912
+ struct ggml_tensor * v,
1913
+ struct ggml_tensor * a,
1914
+ struct ggml_tensor * b,
1915
+ struct ggml_tensor * state);
1916
+
1917
  // custom operators
1918
 
1919
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
8548
  }
8549
  }
8550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8551
  // ggml_compute_forward_mul_mat
8552
 
8553
  static void ggml_compute_forward_mul_mat_one_chunk(
@@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
13604
  }
13605
  }
13606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13607
  // ggml_compute_forward_map_unary
13608
 
13609
  static void ggml_compute_forward_map_unary_f32(
@@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14170
  {
14171
  ggml_compute_forward_group_norm(params, tensor);
14172
  } break;
 
 
 
 
14173
  case GGML_OP_MUL_MAT:
14174
  {
14175
  ggml_compute_forward_mul_mat(params, tensor);
@@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14357
  {
14358
  ggml_compute_forward_gla(params, tensor);
14359
  } break;
 
 
 
 
14360
  case GGML_OP_MAP_UNARY:
14361
  {
14362
  ggml_unary_op_f32_t fun;
@@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
14582
  case GGML_OP_NORM:
14583
  case GGML_OP_RMS_NORM:
14584
  case GGML_OP_RMS_NORM_BACK:
 
14585
  case GGML_OP_GROUP_NORM:
14586
  case GGML_OP_CONCAT:
14587
  case GGML_OP_MUL_MAT:
@@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
14648
  case GGML_OP_FLASH_ATTN_BACK:
14649
  case GGML_OP_SSM_CONV:
14650
  case GGML_OP_SSM_SCAN:
 
 
 
14651
  {
14652
  n_tasks = n_threads;
14653
  } break;
14654
  case GGML_OP_WIN_PART:
14655
  case GGML_OP_WIN_UNPART:
14656
  case GGML_OP_GET_REL_POS:
14657
- case GGML_OP_RWKV_WKV6:
14658
- case GGML_OP_GATED_LINEAR_ATTN:
14659
  case GGML_OP_MAP_UNARY:
14660
  case GGML_OP_MAP_BINARY:
14661
  case GGML_OP_MAP_CUSTOM1_F32:
 
8548
  }
8549
  }
8550
 
8551
+ // ggml_compute_forward_l2_norm
8552
+
8553
+ static void ggml_compute_forward_l2_norm_f32(
8554
+ const struct ggml_compute_params * params,
8555
+ struct ggml_tensor * dst) {
8556
+
8557
+ const struct ggml_tensor * src0 = dst->src[0];
8558
+
8559
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8560
+
8561
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
8562
+
8563
+ const int ith = params->ith;
8564
+ const int nth = params->nth;
8565
+
8566
+ GGML_TENSOR_UNARY_OP_LOCALS
8567
+
8568
+ float eps;
8569
+ memcpy(&eps, dst->op_params, sizeof(float));
8570
+
8571
+ GGML_ASSERT(eps >= 0.0f);
8572
+
8573
+ // TODO: optimize
8574
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8575
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8576
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
8577
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8578
+
8579
+ ggml_float sum = 0.0;
8580
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8581
+ sum += (ggml_float)(x[i00] * x[i00]);
8582
+ }
8583
+
8584
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
8585
+
8586
+ memcpy(y, x, ne00 * sizeof(float));
8587
+
8588
+ const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
8589
+
8590
+ ggml_vec_scale_f32(ne00, y, scale);
8591
+ }
8592
+ }
8593
+ }
8594
+ }
8595
+
8596
+ static void ggml_compute_forward_l2_norm(
8597
+ const struct ggml_compute_params * params,
8598
+ struct ggml_tensor * dst) {
8599
+
8600
+ const struct ggml_tensor * src0 = dst->src[0];
8601
+
8602
+ switch (src0->type) {
8603
+ case GGML_TYPE_F32:
8604
+ {
8605
+ ggml_compute_forward_l2_norm_f32(params, dst);
8606
+ } break;
8607
+ default:
8608
+ {
8609
+ GGML_ABORT("fatal error");
8610
+ }
8611
+ }
8612
+ }
8613
+
8614
  // ggml_compute_forward_mul_mat
8615
 
8616
  static void ggml_compute_forward_mul_mat_one_chunk(
 
13667
  }
13668
  }
13669
 
13670
+ // ggml_compute_forward_rwkv_wkv7
13671
+
13672
+ static void ggml_compute_forward_rwkv_wkv7_f32(
13673
+ const struct ggml_compute_params * params,
13674
+ struct ggml_tensor * dst) {
13675
+ const int64_t T = dst->src[1]->ne[2];
13676
+ const int64_t C = dst->ne[0];
13677
+ const int64_t HEADS = dst->src[1]->ne[1];
13678
+ const int64_t n_seqs = dst->src[6]->ne[1];
13679
+ const int64_t head_size = C / HEADS;
13680
+
13681
+ float * dst_data = (float *) dst->data;
13682
+ float * state = ((float *) dst->data) + C * T;
13683
+
13684
+ const int ith = params->ith;
13685
+ const int nth = params->nth;
13686
+
13687
+ if (ith >= HEADS) {
13688
+ return;
13689
+ }
13690
+
13691
+ const int h_start = (HEADS * ith) / nth;
13692
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
13693
+ (HEADS * (ith + 1)) / nth : HEADS;
13694
+
13695
+ float * r = (float *) dst->src[0]->data;
13696
+ float * w = (float *) dst->src[1]->data;
13697
+ float * k = (float *) dst->src[2]->data;
13698
+ float * v = (float *) dst->src[3]->data;
13699
+ float * a = (float *) dst->src[4]->data;
13700
+ float * b = (float *) dst->src[5]->data;
13701
+
13702
+ int64_t t_stride = HEADS * head_size; // Same to C
13703
+
13704
+ int64_t h_stride = C / HEADS;
13705
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
13706
+ int64_t h_stride_2d = head_size * head_size;
13707
+
13708
+ #if defined(GGML_SIMD)
13709
+ for (int64_t t = 0; t < T; t++) {
13710
+ int64_t t_offset = t * t_stride;
13711
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13712
+ float * state_cur = state + state_offset;
13713
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13714
+
13715
+ for (int64_t h = h_start; h < h_end; h++) {
13716
+ int64_t h_offset = h * h_stride;
13717
+ int64_t t_h_offset = t_offset + h_offset;
13718
+ int64_t h_2d_offset = h * h_stride_2d;
13719
+
13720
+ for (int64_t ii = 0; ii < head_size; ii++) {
13721
+ int64_t t_h_i_offset = t_h_offset + ii;
13722
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
13723
+
13724
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
13725
+
13726
+ float sa = 0;
13727
+ {
13728
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13729
+ GGML_F32_VEC ax[GGML_F32_ARR];
13730
+ GGML_F32_VEC ay[GGML_F32_ARR];
13731
+ for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
13732
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13733
+ ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
13734
+ ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
13735
+ sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
13736
+ }
13737
+ }
13738
+ GGML_F32_VEC_REDUCE(sa, sum);
13739
+ }
13740
+
13741
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
13742
+
13743
+ int64_t j = 0;
13744
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13745
+ for (; j < head_size; j += GGML_F32_STEP) {
13746
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13747
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
13748
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
13749
+
13750
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
13751
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
13752
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
13753
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
13754
+
13755
+ k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
13756
+
13757
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
13758
+ // kv + s * decay + sa * b
13759
+ state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
13760
+ state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
13761
+ GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
13762
+
13763
+ result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
13764
+ }
13765
+ }
13766
+ GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
13767
+
13768
+ // There shouldn't be left-overs though.
13769
+ for (; j < head_size; j++) {
13770
+ int64_t t_h_j_offset = t_h_offset + j;
13771
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13772
+
13773
+ float r_val = r[t_h_j_offset];
13774
+ float w_val = w[t_h_j_offset];
13775
+ float k_val = k[t_h_j_offset];
13776
+ float b_val = b[t_h_j_offset];
13777
+ float kv_val = v[t_h_i_offset] * k_val;
13778
+
13779
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13780
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13781
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
13782
+ }
13783
+ }
13784
+ }
13785
+ }
13786
+ #else
13787
+ for (int64_t t = 0; t < T; t++) {
13788
+ int64_t t_offset = t * t_stride;
13789
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13790
+ float * state_cur = state + state_offset;
13791
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13792
+
13793
+ for (int64_t h = h_start; h < h_end; h++) {
13794
+ int64_t h_offset = h * h_stride;
13795
+ int64_t t_h_offset = t_offset + h_offset;
13796
+ int64_t h_2d_offset = h * h_stride_2d;
13797
+
13798
+ for (int64_t i = 0; i < head_size; i++) {
13799
+ int64_t t_h_i_offset = t_h_offset + i;
13800
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
13801
+
13802
+ float v_val = v[t_h_i_offset];
13803
+
13804
+ float sa = 0, result = 0;
13805
+ for (int64_t j = 0; j < head_size; j++) {
13806
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
13807
+ }
13808
+
13809
+ for (int64_t j = 0; j < head_size; j++) {
13810
+ int64_t t_h_j_offset = t_h_offset + j;
13811
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13812
+
13813
+ float r_val = r[t_h_j_offset];
13814
+ float w_val = w[t_h_j_offset];
13815
+ float k_val = k[t_h_j_offset];
13816
+ float b_val = b[t_h_j_offset];
13817
+ float kv_val = v_val * k_val;
13818
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13819
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13820
+ result += state_cur[h_2d_i_j_offset] * r_val;
13821
+ }
13822
+ dst_data[t_h_i_offset] = result;
13823
+ }
13824
+ }
13825
+ }
13826
+ #endif
13827
+ }
13828
+
13829
+
13830
+ static void ggml_compute_forward_rwkv_wkv7(
13831
+ const struct ggml_compute_params * params,
13832
+ struct ggml_tensor * dst) {
13833
+
13834
+ const struct ggml_tensor * src0 = dst->src[0];
13835
+
13836
+ switch (src0->type) {
13837
+ case GGML_TYPE_F32:
13838
+ {
13839
+ ggml_compute_forward_rwkv_wkv7_f32(params, dst);
13840
+ } break;
13841
+ default:
13842
+ {
13843
+ GGML_ABORT("fatal error");
13844
+ }
13845
+ }
13846
+ }
13847
+
13848
  // ggml_compute_forward_map_unary
13849
 
13850
  static void ggml_compute_forward_map_unary_f32(
 
14411
  {
14412
  ggml_compute_forward_group_norm(params, tensor);
14413
  } break;
14414
+ case GGML_OP_L2_NORM:
14415
+ {
14416
+ ggml_compute_forward_l2_norm(params, tensor);
14417
+ } break;
14418
  case GGML_OP_MUL_MAT:
14419
  {
14420
  ggml_compute_forward_mul_mat(params, tensor);
 
14602
  {
14603
  ggml_compute_forward_gla(params, tensor);
14604
  } break;
14605
+ case GGML_OP_RWKV_WKV7:
14606
+ {
14607
+ ggml_compute_forward_rwkv_wkv7(params, tensor);
14608
+ } break;
14609
  case GGML_OP_MAP_UNARY:
14610
  {
14611
  ggml_unary_op_f32_t fun;
 
14831
  case GGML_OP_NORM:
14832
  case GGML_OP_RMS_NORM:
14833
  case GGML_OP_RMS_NORM_BACK:
14834
+ case GGML_OP_L2_NORM:
14835
  case GGML_OP_GROUP_NORM:
14836
  case GGML_OP_CONCAT:
14837
  case GGML_OP_MUL_MAT:
 
14898
  case GGML_OP_FLASH_ATTN_BACK:
14899
  case GGML_OP_SSM_CONV:
14900
  case GGML_OP_SSM_SCAN:
14901
+ case GGML_OP_RWKV_WKV6:
14902
+ case GGML_OP_GATED_LINEAR_ATTN:
14903
+ case GGML_OP_RWKV_WKV7:
14904
  {
14905
  n_tasks = n_threads;
14906
  } break;
14907
  case GGML_OP_WIN_PART:
14908
  case GGML_OP_WIN_UNPART:
14909
  case GGML_OP_GET_REL_POS:
 
 
14910
  case GGML_OP_MAP_UNARY:
14911
  case GGML_OP_MAP_BINARY:
14912
  case GGML_OP_MAP_CUSTOM1_F32:
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -36,7 +36,7 @@
36
  #include "ggml-cuda/tsembd.cuh"
37
  #include "ggml-cuda/unary.cuh"
38
  #include "ggml-cuda/upscale.cuh"
39
- #include "ggml-cuda/wkv6.cuh"
40
  #include "ggml-cuda/gla.cuh"
41
  #include "ggml.h"
42
 
@@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2196
  case GGML_OP_GROUP_NORM:
2197
  ggml_cuda_op_group_norm(ctx, dst);
2198
  break;
 
 
 
2199
  case GGML_OP_CONCAT:
2200
  ggml_cuda_op_concat(ctx, dst);
2201
  break;
@@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2304
  case GGML_OP_GATED_LINEAR_ATTN:
2305
  ggml_cuda_op_gated_linear_attn(ctx, dst);
2306
  break;
 
 
 
2307
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2308
  ggml_cuda_cross_entropy_loss_back(ctx, dst);
2309
  break;
@@ -3161,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3161
  break;
3162
  case GGML_OP_NORM:
3163
  case GGML_OP_RMS_NORM:
 
3164
  return true;
3165
  case GGML_OP_RMS_NORM_BACK:
3166
  return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
@@ -3215,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3215
  case GGML_OP_LEAKY_RELU:
3216
  case GGML_OP_RWKV_WKV6:
3217
  case GGML_OP_GATED_LINEAR_ATTN:
 
3218
  return true;
3219
  case GGML_OP_FLASH_ATTN_EXT: {
3220
  #ifndef FLASH_ATTN_AVAILABLE
 
36
  #include "ggml-cuda/tsembd.cuh"
37
  #include "ggml-cuda/unary.cuh"
38
  #include "ggml-cuda/upscale.cuh"
39
+ #include "ggml-cuda/wkv.cuh"
40
  #include "ggml-cuda/gla.cuh"
41
  #include "ggml.h"
42
 
 
2196
  case GGML_OP_GROUP_NORM:
2197
  ggml_cuda_op_group_norm(ctx, dst);
2198
  break;
2199
+ case GGML_OP_L2_NORM:
2200
+ ggml_cuda_op_l2_norm(ctx, dst);
2201
+ break;
2202
  case GGML_OP_CONCAT:
2203
  ggml_cuda_op_concat(ctx, dst);
2204
  break;
 
2307
  case GGML_OP_GATED_LINEAR_ATTN:
2308
  ggml_cuda_op_gated_linear_attn(ctx, dst);
2309
  break;
2310
+ case GGML_OP_RWKV_WKV7:
2311
+ ggml_cuda_op_rwkv_wkv7(ctx, dst);
2312
+ break;
2313
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2314
  ggml_cuda_cross_entropy_loss_back(ctx, dst);
2315
  break;
 
3167
  break;
3168
  case GGML_OP_NORM:
3169
  case GGML_OP_RMS_NORM:
3170
+ case GGML_OP_L2_NORM:
3171
  return true;
3172
  case GGML_OP_RMS_NORM_BACK:
3173
  return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
 
3222
  case GGML_OP_LEAKY_RELU:
3223
  case GGML_OP_RWKV_WKV6:
3224
  case GGML_OP_GATED_LINEAR_ATTN:
3225
+ case GGML_OP_RWKV_WKV7:
3226
  return true;
3227
  case GGML_OP_FLASH_ATTN_EXT: {
3228
  #ifndef FLASH_ATTN_AVAILABLE
ggml/src/ggml-cuda/norm.cu CHANGED
@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
201
  }
202
  }
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  static void norm_f32_cuda(
205
  const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
206
  const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
248
  }
249
  }
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
252
  const ggml_tensor * src0 = dst->src[0];
253
  const float * src0_d = (const float *) src0->data;
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
340
 
341
  rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
342
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  }
202
  }
203
 
204
+ // template <int block_size>
205
+ // static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
206
+ // const int row = blockIdx.x*blockDim.y + threadIdx.y;
207
+ // const int tid = threadIdx.x;
208
+
209
+ // float tmp = 0.0f; // partial sum for thread in warp
210
+
211
+ // for (int col = tid; col < ncols; col += block_size) {
212
+ // const float xi = x[row*ncols + col];
213
+ // tmp += xi * xi;
214
+ // }
215
+
216
+ // // sum up partial sums
217
+ // tmp = warp_reduce_sum(tmp);
218
+ // if (block_size > WARP_SIZE) {
219
+ // __shared__ float s_sum[32];
220
+ // int warp_id = threadIdx.x / WARP_SIZE;
221
+ // int lane_id = threadIdx.x % WARP_SIZE;
222
+ // if (lane_id == 0) {
223
+ // s_sum[warp_id] = tmp;
224
+ // }
225
+ // __syncthreads();
226
+ // tmp = s_sum[lane_id];
227
+ // tmp = warp_reduce_sum(tmp);
228
+ // }
229
+
230
+ // // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
231
+ // const float scale = rsqrtf(fmaxf(tmp, eps * eps));
232
+
233
+ // for (int col = tid; col < ncols; col += block_size) {
234
+ // dst[row*ncols + col] = scale * x[row*ncols + col];
235
+ // }
236
+ // }
237
+
238
+ template <int block_size>
239
+ static __global__ void l2_norm_f32(
240
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
241
+ const int64_t stride_sample, const float eps) {
242
+ const int nrows = gridDim.x;
243
+ const int nchannels = gridDim.y;
244
+
245
+ const int row = blockIdx.x;
246
+ const int channel = blockIdx.y;
247
+ const int sample = blockIdx.z;
248
+ const int tid = threadIdx.x;
249
+
250
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
251
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
252
+
253
+ float tmp = 0.0f; // partial sum for thread in warp
254
+
255
+ for (int col = tid; col < ncols; col += block_size) {
256
+ const float xi = x[col];
257
+ tmp += xi * xi;
258
+ }
259
+
260
+ // sum up partial sums
261
+ tmp = warp_reduce_sum(tmp);
262
+ if constexpr (block_size > WARP_SIZE) {
263
+ static_assert(block_size == 1024, "unexpected block_size");
264
+ __shared__ float s_sum[32];
265
+ const int warp_id = threadIdx.x / WARP_SIZE;
266
+ const int lane_id = threadIdx.x % WARP_SIZE;
267
+ if (lane_id == 0) {
268
+ s_sum[warp_id] = tmp;
269
+ }
270
+ __syncthreads();
271
+ tmp = s_sum[lane_id];
272
+ tmp = warp_reduce_sum(tmp);
273
+ }
274
+
275
+ // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
276
+ const float scale = rsqrtf(fmaxf(tmp, eps * eps));
277
+
278
+ for (int col = tid; col < ncols; col += block_size) {
279
+ dst[col] = scale * x[col];
280
+ }
281
+ }
282
+
283
  static void norm_f32_cuda(
284
  const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
285
  const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
 
327
  }
328
  }
329
 
330
+ static void l2_norm_f32_cuda(
331
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
332
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
333
+ const dim3 blocks_num(nrows, nchannels, nsamples);
334
+ if (ncols < 1024) {
335
+ const dim3 block_dims(WARP_SIZE, 1, 1);
336
+ l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
337
+ } else {
338
+ const dim3 block_dims(1024, 1, 1);
339
+ l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
340
+ }
341
+ }
342
+
343
  void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
344
  const ggml_tensor * src0 = dst->src[0];
345
  const float * src0_d = (const float *) src0->data;
 
432
 
433
  rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
434
  }
435
+
436
+ void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
437
+ const ggml_tensor * src0 = dst->src[0];
438
+ const float * src0_d = (const float *) src0->data;
439
+ float * dst_d = (float *) dst->data;
440
+ cudaStream_t stream = ctx.stream();
441
+
442
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
443
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
444
+
445
+ GGML_TENSOR_UNARY_OP_LOCALS;
446
+
447
+ float eps;
448
+ memcpy(&eps, dst->op_params, sizeof(float));
449
+ GGML_ASSERT(eps >= 0.0f);
450
+
451
+ const size_t ts0 = ggml_type_size(src0->type);
452
+ GGML_ASSERT(nb00 == ts0);
453
+ const int64_t s01 = nb01 / ts0;
454
+ const int64_t s02 = nb02 / ts0;
455
+ const int64_t s03 = nb03 / ts0;
456
+
457
+ l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
458
+ }
ggml/src/ggml-cuda/norm.cuh CHANGED
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
7
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
 
9
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
7
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
 
9
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10
+
11
+ void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/wkv.cu ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "wkv.cuh"
3
+
4
+ template <int block_size>
5
+ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
6
+ const int tid = threadIdx.x;
7
+ const int bid = blockIdx.x;
8
+
9
+ const int head_size = block_size;
10
+ const int batch_i = bid / H;
11
+ const int head_i = bid % H;
12
+ const int state_size = C * head_size;
13
+ const int n_seq_tokens = T / B;
14
+
15
+ float state[head_size];
16
+ __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
17
+
18
+ #pragma unroll
19
+ for (int i = 0; i < head_size; i++) {
20
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
21
+ }
22
+
23
+ __syncthreads();
24
+ _tf[tid] = tf[head_i * head_size + tid];
25
+ __syncthreads();
26
+
27
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
28
+ __syncthreads();
29
+ _k[tid] = k[t];
30
+ _r[tid] = r[t];
31
+ _td[tid] = td[t];
32
+ __syncthreads();
33
+
34
+ const float _v = v[t];
35
+ float y = 0;
36
+ for (int j = 0; j < head_size; j += 4) {
37
+ const float4& k = (float4&)(_k[j]);
38
+ const float4& r = (float4&)(_r[j]);
39
+ const float4& tf = (float4&)(_tf[j]);
40
+ const float4& td = (float4&)(_td[j]);
41
+ float4& s = (float4&)(state[j]);
42
+ float4 kv;
43
+
44
+ kv.x = k.x * _v;
45
+ kv.y = k.y * _v;
46
+ kv.z = k.z * _v;
47
+ kv.w = k.w * _v;
48
+
49
+ y += r.x * (tf.x * kv.x + s.x);
50
+ y += r.y * (tf.y * kv.y + s.y);
51
+ y += r.z * (tf.z * kv.z + s.z);
52
+ y += r.w * (tf.w * kv.w + s.w);
53
+
54
+ s.x = s.x * td.x + kv.x;
55
+ s.y = s.y * td.y + kv.y;
56
+ s.z = s.z * td.z + kv.z;
57
+ s.w = s.w * td.w + kv.w;
58
+ }
59
+ dst[t] = y;
60
+ }
61
+
62
+ #pragma unroll
63
+ for (int i = 0; i < head_size; i++) {
64
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
65
+ }
66
+ }
67
+
68
+ template <int block_size>
69
+ static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
70
+ const int tid = threadIdx.x;
71
+ const int bid = blockIdx.x;
72
+
73
+ const int head_size = block_size;
74
+ const int batch_i = bid / H;
75
+ const int head_i = bid % H;
76
+ const int state_size = C * head_size;
77
+ const int n_seq_tokens = T / B;
78
+
79
+ float state[head_size];
80
+ __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
81
+
82
+ #ifndef GGML_USE_MUSA
83
+ #pragma unroll
84
+ #endif
85
+ for (int i = 0; i < head_size; i++) {
86
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
87
+ }
88
+
89
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
90
+ __syncthreads();
91
+ _r[tid] = r[t];
92
+ _w[tid] = w[t];
93
+ _k[tid] = k[t];
94
+ _a[tid] = a[t];
95
+ _b[tid] = b[t];
96
+ __syncthreads();
97
+
98
+ float sa = 0;
99
+ #pragma unroll
100
+ for (int j = 0; j < head_size; j += 4)
101
+ {
102
+ const float4& a = (float4&)(_a[j]);
103
+ const float4& s = (float4&)(state[j]);
104
+ sa += a.x * s.x;
105
+ sa += a.y * s.y;
106
+ sa += a.z * s.z;
107
+ sa += a.w * s.w;
108
+ }
109
+
110
+ const float _v = v[t];
111
+ float y = 0;
112
+ for (int j = 0; j < head_size; j += 4) {
113
+ const float4& r = (float4&)(_r[j]);
114
+ const float4& w = (float4&)(_w[j]);
115
+ const float4& k = (float4&)(_k[j]);
116
+ const float4& b = (float4&)(_b[j]);
117
+ float4& s = (float4&)(state[j]);
118
+ float4 kv;
119
+
120
+ kv.x = k.x * _v;
121
+ kv.y = k.y * _v;
122
+ kv.z = k.z * _v;
123
+ kv.w = k.w * _v;
124
+
125
+ s.x = s.x * w.x + kv.x + sa * b.x;
126
+ s.y = s.y * w.y + kv.y + sa * b.y;
127
+ s.z = s.z * w.z + kv.z + sa * b.z;
128
+ s.w = s.w * w.w + kv.w + sa * b.w;
129
+
130
+ y += s.x * r.x;
131
+ y += s.y * r.y;
132
+ y += s.z * r.z;
133
+ y += s.w * r.w;
134
+ }
135
+ dst[t] = y;
136
+ }
137
+
138
+ #pragma unroll
139
+ for (int i = 0; i < head_size; i++) {
140
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
141
+ }
142
+ }
143
+
144
+ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
145
+ const float * k_d = (const float *)dst->src[0]->data;
146
+ const float * v_d = (const float *)dst->src[1]->data;
147
+ const float * r_d = (const float *)dst->src[2]->data;
148
+ const float * tf_d = (const float *)dst->src[3]->data;
149
+ const float * td_d = (const float *)dst->src[4]->data;
150
+ const float * s_d = (const float *)dst->src[5]->data;
151
+
152
+ const int64_t B = dst->src[5]->ne[1];
153
+ const int64_t T = dst->src[0]->ne[2];
154
+ const int64_t C = dst->ne[0];
155
+ const int64_t H = dst->src[0]->ne[1];
156
+
157
+ float * dst_d = (float *)dst->data;
158
+
159
+ cudaStream_t stream = ctx.stream();
160
+
161
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
162
+ GGML_ASSERT(C % H == 0);
163
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
164
+
165
+ if (C / H == CUDA_WKV_BLOCK_SIZE) {
166
+ rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
167
+ } else {
168
+ rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
169
+ }
170
+ }
171
+
172
+ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
173
+ const float * r_d = (const float *)dst->src[0]->data;
174
+ const float * w_d = (const float *)dst->src[1]->data;
175
+ const float * k_d = (const float *)dst->src[2]->data;
176
+ const float * v_d = (const float *)dst->src[3]->data;
177
+ const float * a_d = (const float *)dst->src[4]->data;
178
+ const float * b_d = (const float *)dst->src[5]->data;
179
+ const float * s_d = (const float *)dst->src[6]->data;
180
+
181
+ const int64_t B = dst->src[6]->ne[1];
182
+ const int64_t T = dst->src[0]->ne[2];
183
+ const int64_t C = dst->ne[0];
184
+ const int64_t H = dst->src[0]->ne[1];
185
+
186
+ float * dst_d = (float *)dst->data;
187
+
188
+ cudaStream_t stream = ctx.stream();
189
+
190
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
191
+ GGML_ASSERT(C % H == 0);
192
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
193
+
194
+ if (C / H == CUDA_WKV_BLOCK_SIZE) {
195
+ rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
196
+ } else {
197
+ rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
198
+ }
199
+ }
ggml/src/ggml-cuda/wkv.cuh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_WKV_BLOCK_SIZE 64
4
+
5
+ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
+
7
+ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -285,6 +285,13 @@ typedef struct {
285
  float eps;
286
  } ggml_metal_kargs_rms_norm;
287
 
 
 
 
 
 
 
 
288
  typedef struct {
289
  int64_t ne00;
290
  int64_t ne01;
 
285
  float eps;
286
  } ggml_metal_kargs_rms_norm;
287
 
288
+ typedef struct {
289
+ int32_t ne00;
290
+ int32_t ne00_4;
291
+ uint64_t nb01;
292
+ float eps;
293
+ } ggml_metal_kargs_l2_norm;
294
+
295
  typedef struct {
296
  int64_t ne00;
297
  int64_t ne01;
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
184
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
185
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
186
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
 
187
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
188
  GGML_METAL_KERNEL_TYPE_NORM,
189
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
190
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
 
 
191
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
192
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
193
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
810
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
811
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
812
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
 
813
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
814
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
815
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
816
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
 
 
817
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
818
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
819
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1251
  case GGML_OP_GROUP_NORM:
1252
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1253
  case GGML_OP_RMS_NORM:
 
1254
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1255
  case GGML_OP_ARGMAX:
1256
  return true;
@@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1288
  return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
1289
  case GGML_OP_SSM_CONV:
1290
  case GGML_OP_SSM_SCAN:
 
 
1291
  return true;
1292
  case GGML_OP_MUL_MAT:
1293
  case GGML_OP_MUL_MAT_ID:
@@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
2216
 
2217
  [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2218
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2219
  case GGML_OP_MUL_MAT:
2220
  {
2221
  GGML_ASSERT(ne00 == ne10);
@@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
3122
 
3123
  const int64_t nrows = ggml_nrows(src0);
3124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3125
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3126
  } break;
3127
  case GGML_OP_GROUP_NORM:
 
184
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
185
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
186
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
187
+ GGML_METAL_KERNEL_TYPE_L2_NORM,
188
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
189
  GGML_METAL_KERNEL_TYPE_NORM,
190
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
191
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
192
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
193
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
194
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
195
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
196
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
 
813
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
814
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
815
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
816
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
817
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
818
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
819
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
820
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
821
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
822
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
823
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
824
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
825
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
 
1257
  case GGML_OP_GROUP_NORM:
1258
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1259
  case GGML_OP_RMS_NORM:
1260
+ case GGML_OP_L2_NORM:
1261
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1262
  case GGML_OP_ARGMAX:
1263
  return true;
 
1295
  return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
1296
  case GGML_OP_SSM_CONV:
1297
  case GGML_OP_SSM_SCAN:
1298
+ case GGML_OP_RWKV_WKV6:
1299
+ case GGML_OP_RWKV_WKV7:
1300
  return true;
1301
  case GGML_OP_MUL_MAT:
1302
  case GGML_OP_MUL_MAT_ID:
 
2225
 
2226
  [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2227
  } break;
2228
+ case GGML_OP_RWKV_WKV6:
2229
+ {
2230
+ const int64_t B = dst->src[5]->ne[1];
2231
+ const int64_t T = dst->src[0]->ne[2];
2232
+ const int64_t C = dst->ne[0];
2233
+ const int64_t H = dst->src[0]->ne[1];
2234
+
2235
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
2236
+ GGML_ASSERT(C % H == 0);
2237
+ GGML_ASSERT(C / H == 64);
2238
+
2239
+ size_t offs_src3 = 0;
2240
+ size_t offs_src4 = 0;
2241
+ size_t offs_src5 = 0;
2242
+
2243
+ id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2244
+ id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2245
+ id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2246
+
2247
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
2248
+
2249
+ [encoder setComputePipelineState:pipeline];
2250
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2251
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2252
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2253
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2254
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2255
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2256
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2257
+
2258
+ [encoder setBytes:&B length:sizeof(B) atIndex:7];
2259
+ [encoder setBytes:&T length:sizeof(T) atIndex:8];
2260
+ [encoder setBytes:&C length:sizeof(C) atIndex:9];
2261
+ [encoder setBytes:&H length:sizeof(H) atIndex:10];
2262
+
2263
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2264
+ } break;
2265
+ case GGML_OP_RWKV_WKV7:
2266
+ {
2267
+ const int64_t B = dst->src[6]->ne[1];
2268
+ const int64_t T = dst->src[0]->ne[2];
2269
+ const int64_t C = dst->ne[0];
2270
+ const int64_t H = dst->src[0]->ne[1];
2271
+
2272
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
2273
+ GGML_ASSERT(C % H == 0);
2274
+ GGML_ASSERT(C / H == 64);
2275
+
2276
+ size_t offs_src3 = 0;
2277
+ size_t offs_src4 = 0;
2278
+ size_t offs_src5 = 0;
2279
+ size_t offs_src6 = 0;
2280
+
2281
+ id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2282
+ id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2283
+ id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2284
+ id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
2285
+
2286
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
2287
+
2288
+ [encoder setComputePipelineState:pipeline];
2289
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2290
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2291
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2292
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2293
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2294
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2295
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2296
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2297
+
2298
+ [encoder setBytes:&B length:sizeof(B) atIndex:8];
2299
+ [encoder setBytes:&T length:sizeof(T) atIndex:9];
2300
+ [encoder setBytes:&C length:sizeof(C) atIndex:10];
2301
+ [encoder setBytes:&H length:sizeof(H) atIndex:11];
2302
+
2303
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2304
+ } break;
2305
  case GGML_OP_MUL_MAT:
2306
  {
2307
  GGML_ASSERT(ne00 == ne10);
 
3208
 
3209
  const int64_t nrows = ggml_nrows(src0);
3210
 
3211
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3212
+ } break;
3213
+ case GGML_OP_L2_NORM:
3214
+ {
3215
+ GGML_ASSERT(ne00 % 4 == 0);
3216
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3217
+
3218
+ float eps;
3219
+ memcpy(&eps, dst->op_params, sizeof(float));
3220
+
3221
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
3222
+
3223
+ int nth = 32; // SIMD width
3224
+
3225
+ while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3226
+ nth *= 2;
3227
+ }
3228
+
3229
+ nth = MIN(nth, ne00/4);
3230
+
3231
+ ggml_metal_kargs_l2_norm args = {
3232
+ /*.ne00 =*/ ne00,
3233
+ /*.ne00_4 =*/ ne00/4,
3234
+ /*.nb01 =*/ nb01,
3235
+ /*.eps =*/ eps,
3236
+ };
3237
+
3238
+ [encoder setComputePipelineState:pipeline];
3239
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3240
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3241
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3242
+
3243
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3244
+
3245
+ const int64_t nrows = ggml_nrows(src0);
3246
+
3247
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3248
  } break;
3249
  case GGML_OP_GROUP_NORM:
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
1295
  }
1296
  }
1297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1298
  kernel void kernel_argmax(
1299
  device const void * x,
1300
  device int32_t * dst,
@@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
1463
  }
1464
  }
1465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1466
  kernel void kernel_group_norm(
1467
  device const float * src0,
1468
  device float * dst,
 
1295
  }
1296
  }
1297
 
1298
+ kernel void kernel_rwkv_wkv6_f32(
1299
+ device const float * k,
1300
+ device const float * v,
1301
+ device const float * r,
1302
+ device const float * tf,
1303
+ device const float * td,
1304
+ device const float * state_in,
1305
+ device float * dst,
1306
+ constant uint & B,
1307
+ constant uint & T,
1308
+ constant uint & C,
1309
+ constant uint & H,
1310
+ uint3 tgpig[[threadgroup_position_in_grid]],
1311
+ uint3 tpitg[[thread_position_in_threadgroup]],
1312
+ uint3 ntg[[threads_per_threadgroup]]) {
1313
+
1314
+ const uint head_size = 64; // TODO: support head_size = 128
1315
+ const uint batch_id = tgpig.x / H;
1316
+ const uint head_id = tgpig.x % H;
1317
+ const uint tid = tpitg.x;
1318
+
1319
+ if (batch_id >= B || head_id >= H) {
1320
+ return;
1321
+ }
1322
+
1323
+ const uint state_size = C * head_size;
1324
+ const uint n_seq_tokens = T / B;
1325
+
1326
+ threadgroup float _k[head_size];
1327
+ threadgroup float _r[head_size];
1328
+ threadgroup float _tf[head_size];
1329
+ threadgroup float _td[head_size];
1330
+
1331
+ float state[head_size];
1332
+
1333
+ for (uint i = 0; i < head_size; i++) {
1334
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
1335
+ + i * head_size + tid];
1336
+ }
1337
+
1338
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1339
+ _tf[tid] = tf[head_id * head_size + tid];
1340
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1341
+
1342
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
1343
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
1344
+
1345
+ for (uint t = start_t; t < end_t; t += C) {
1346
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1347
+ _k[tid] = k[t];
1348
+ _r[tid] = r[t];
1349
+ _td[tid] = td[t];
1350
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1351
+
1352
+ const float v_val = v[t];
1353
+ float y = 0.0;
1354
+
1355
+ for (uint j = 0; j < head_size; j += 4) {
1356
+ float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
1357
+ float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
1358
+ float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
1359
+ float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
1360
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1361
+
1362
+ float4 kv = k_vec * v_val;
1363
+
1364
+ float4 temp = tf_vec * kv + s_vec;
1365
+ y += dot(r_vec, temp);
1366
+
1367
+ s_vec = s_vec * td_vec + kv;
1368
+ state[j] = s_vec[0];
1369
+ state[j+1] = s_vec[1];
1370
+ state[j+2] = s_vec[2];
1371
+ state[j+3] = s_vec[3];
1372
+ }
1373
+
1374
+ dst[t] = y;
1375
+ }
1376
+
1377
+ for (uint i = 0; i < head_size; i++) {
1378
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
1379
+ + i * head_size + tid] = state[i];
1380
+ }
1381
+ }
1382
+
1383
+ kernel void kernel_rwkv_wkv7_f32(
1384
+ device const float * r,
1385
+ device const float * w,
1386
+ device const float * k,
1387
+ device const float * v,
1388
+ device const float * a,
1389
+ device const float * b,
1390
+ device const float * state_in,
1391
+ device float * dst,
1392
+ constant uint & B,
1393
+ constant uint & T,
1394
+ constant uint & C,
1395
+ constant uint & H,
1396
+ uint3 tgpig[[threadgroup_position_in_grid]],
1397
+ uint3 tpitg[[thread_position_in_threadgroup]],
1398
+ uint3 ntg[[threads_per_threadgroup]]) {
1399
+
1400
+ const uint head_size = 64; // TODO: support head_size = 128
1401
+ const uint batch_id = tgpig.x / H;
1402
+ const uint head_id = tgpig.x % H;
1403
+ const uint tid = tpitg.x;
1404
+
1405
+ if (batch_id >= B || head_id >= H) {
1406
+ return;
1407
+ }
1408
+
1409
+ const uint state_size = C * head_size;
1410
+ const uint n_seq_tokens = T / B;
1411
+
1412
+ threadgroup float _r[head_size];
1413
+ threadgroup float _w[head_size];
1414
+ threadgroup float _k[head_size];
1415
+ threadgroup float _a[head_size];
1416
+ threadgroup float _b[head_size];
1417
+
1418
+ float state[head_size];
1419
+
1420
+ for (uint i = 0; i < head_size; i++) {
1421
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
1422
+ + tid * head_size + i];
1423
+ }
1424
+
1425
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
1426
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
1427
+
1428
+ for (uint t = start_t; t < end_t; t += C) {
1429
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1430
+ _r[tid] = r[t];
1431
+ _w[tid] = w[t];
1432
+ _k[tid] = k[t];
1433
+ _a[tid] = a[t];
1434
+ _b[tid] = b[t];
1435
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1436
+
1437
+ const float v_val = v[t];
1438
+ float y = 0.0, sa = 0.0;
1439
+
1440
+ float4 sa_vec(0.0);
1441
+
1442
+ for (int j = 0; j < head_size; j += 4) {
1443
+ float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
1444
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1445
+ sa_vec += a_vec * s_vec;
1446
+ }
1447
+ sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
1448
+
1449
+ for (uint j = 0; j < head_size; j += 4) {
1450
+ float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
1451
+ float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
1452
+ float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
1453
+ float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
1454
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1455
+
1456
+ float4 kv = k_vec * v_val;
1457
+
1458
+ s_vec = s_vec * w_vec + kv + sa * b_vec;
1459
+ y += dot(s_vec, r_vec);
1460
+
1461
+ state[j] = s_vec[0];
1462
+ state[j+1] = s_vec[1];
1463
+ state[j+2] = s_vec[2];
1464
+ state[j+3] = s_vec[3];
1465
+ }
1466
+
1467
+ dst[t] = y;
1468
+ }
1469
+
1470
+ for (uint i = 0; i < head_size; i++) {
1471
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
1472
+ + tid * head_size + i] = state[i];
1473
+ }
1474
+ }
1475
+
1476
  kernel void kernel_argmax(
1477
  device const void * x,
1478
  device int32_t * dst,
 
1641
  }
1642
  }
1643
 
1644
+ kernel void kernel_l2_norm(
1645
+ constant ggml_metal_kargs_l2_norm & args,
1646
+ device const char * src0,
1647
+ device char * dst,
1648
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
1649
+ uint tgpig[[threadgroup_position_in_grid]],
1650
+ ushort tpitg[[thread_position_in_threadgroup]],
1651
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1652
+ ushort tiisg[[thread_index_in_simdgroup]],
1653
+ ushort ntg[[threads_per_threadgroup]]) {
1654
+ if (sgitg == 0) {
1655
+ shmem_f32[tiisg] = 0.0f;
1656
+ }
1657
+
1658
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
1659
+
1660
+ float sumf = 0.0f;
1661
+
1662
+ // parallel sum
1663
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1664
+ sumf += dot(x[i00], x[i00]);
1665
+ }
1666
+ sumf = simd_sum(sumf);
1667
+
1668
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1669
+
1670
+ if (tiisg == 0) {
1671
+ shmem_f32[sgitg] = sumf;
1672
+ }
1673
+
1674
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1675
+
1676
+ sumf = shmem_f32[tiisg];
1677
+ sumf = simd_sum(sumf);
1678
+
1679
+ const float scale = 1.0f/sqrt(max(sumf, args.eps));
1680
+
1681
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
1682
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1683
+ y[i00] = x[i00] * scale;
1684
+ }
1685
+ }
1686
+
1687
  kernel void kernel_group_norm(
1688
  device const float * src0,
1689
  device float * dst,
ggml/src/ggml-sycl/backend.hpp CHANGED
@@ -26,7 +26,7 @@
26
  #include "softmax.hpp"
27
  #include "tsembd.hpp"
28
  #include "im2col.hpp"
29
- #include "wkv6.hpp"
30
  #include "outprod.hpp"
31
  #include "element_wise.hpp"
32
  #include "cpy.hpp"
 
26
  #include "softmax.hpp"
27
  #include "tsembd.hpp"
28
  #include "im2col.hpp"
29
+ #include "wkv.hpp"
30
  #include "outprod.hpp"
31
  #include "element_wise.hpp"
32
  #include "cpy.hpp"
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -2696,6 +2696,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
2696
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2697
  }
2698
 
 
 
 
 
 
 
2699
  static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2700
  GGML_SYCL_DEBUG("call %s\n", __func__);
2701
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
@@ -3410,6 +3416,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
3410
  case GGML_OP_RMS_NORM:
3411
  ggml_sycl_rms_norm(ctx, dst);
3412
  break;
 
 
 
3413
  case GGML_OP_MUL_MAT:
3414
  if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3415
  return false;
@@ -3487,6 +3496,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
3487
  case GGML_OP_RWKV_WKV6:
3488
  ggml_sycl_op_rwkv_wkv6(ctx, dst);
3489
  break;
 
 
 
3490
  case GGML_OP_GATED_LINEAR_ATTN:
3491
  ggml_sycl_op_gated_linear_attn(ctx, dst);
3492
  break;
@@ -4012,6 +4024,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4012
  return (op->src[0]->type == GGML_TYPE_F32);
4013
  case GGML_OP_NORM:
4014
  case GGML_OP_RMS_NORM:
 
4015
  case GGML_OP_GROUP_NORM:
4016
  return ggml_is_contiguous(op->src[0]);
4017
  case GGML_OP_SCALE:
@@ -4045,6 +4058,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4045
  case GGML_OP_LEAKY_RELU:
4046
  case GGML_OP_TIMESTEP_EMBEDDING:
4047
  case GGML_OP_RWKV_WKV6:
 
4048
  case GGML_OP_GATED_LINEAR_ATTN:
4049
  return true;
4050
  default:
 
2696
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2697
  }
2698
 
2699
+ static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2700
+ GGML_SYCL_DEBUG("call %s\n", __func__);
2701
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
2702
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
2703
+ }
2704
+
2705
  static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2706
  GGML_SYCL_DEBUG("call %s\n", __func__);
2707
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
 
3416
  case GGML_OP_RMS_NORM:
3417
  ggml_sycl_rms_norm(ctx, dst);
3418
  break;
3419
+ case GGML_OP_L2_NORM:
3420
+ ggml_sycl_l2_norm(ctx, dst);
3421
+ break;
3422
  case GGML_OP_MUL_MAT:
3423
  if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3424
  return false;
 
3496
  case GGML_OP_RWKV_WKV6:
3497
  ggml_sycl_op_rwkv_wkv6(ctx, dst);
3498
  break;
3499
+ case GGML_OP_RWKV_WKV7:
3500
+ ggml_sycl_op_rwkv_wkv7(ctx, dst);
3501
+ break;
3502
  case GGML_OP_GATED_LINEAR_ATTN:
3503
  ggml_sycl_op_gated_linear_attn(ctx, dst);
3504
  break;
 
4024
  return (op->src[0]->type == GGML_TYPE_F32);
4025
  case GGML_OP_NORM:
4026
  case GGML_OP_RMS_NORM:
4027
+ case GGML_OP_L2_NORM:
4028
  case GGML_OP_GROUP_NORM:
4029
  return ggml_is_contiguous(op->src[0]);
4030
  case GGML_OP_SCALE:
 
4058
  case GGML_OP_LEAKY_RELU:
4059
  case GGML_OP_TIMESTEP_EMBEDDING:
4060
  case GGML_OP_RWKV_WKV6:
4061
+ case GGML_OP_RWKV_WKV7:
4062
  case GGML_OP_GATED_LINEAR_ATTN:
4063
  return true;
4064
  default:
ggml/src/ggml-sycl/norm.cpp CHANGED
@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
180
  }
181
  }
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  static void norm_f32_sycl(const float* x, float* dst, const int ncols,
184
  const int nrows, const float eps,
185
  queue_ptr stream, int device) {
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
311
  }
312
  }
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
315
  ggml_tensor* dst, const float* src0_dd,
316
  const float* src1_dd, float* dst_dd,
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
376
  (void)dst;
377
  (void)src1_dd;
378
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  }
181
  }
182
 
183
+ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
184
+ const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
185
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
186
+ item_ct1.get_local_id(1);
187
+ const int tid = item_ct1.get_local_id(2);
188
+ const int nthreads = item_ct1.get_local_range(2);
189
+ const int nwarps = nthreads / WARP_SIZE;
190
+ float tmp = 0.0f; // partial sum for thread in warp
191
+
192
+ for (int col = tid; col < ncols; col += block_size) {
193
+ const float xi = x[row * ncols + col];
194
+ tmp += xi * xi;
195
+ }
196
+
197
+ // sum up partial sums
198
+ tmp = warp_reduce_sum(tmp, item_ct1);
199
+ if (block_size > WARP_SIZE) {
200
+
201
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
202
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
203
+ if (lane_id == 0) {
204
+ s_sum[warp_id] = tmp;
205
+ }
206
+ /*
207
+ DPCT1118:3: SYCL group functions and algorithms must be encountered in
208
+ converged control flow. You may need to adjust the code.
209
+ */
210
+ item_ct1.barrier(sycl::access::fence_space::local_space);
211
+ size_t nreduce = nwarps / WARP_SIZE;
212
+ tmp = 0.f;
213
+ for (size_t i = 0; i < nreduce; i += 1)
214
+ {
215
+ tmp += s_sum[lane_id + i * WARP_SIZE];
216
+ }
217
+ tmp = warp_reduce_sum(tmp, item_ct1);
218
+ }
219
+
220
+ const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
221
+
222
+ for (int col = tid; col < ncols; col += block_size) {
223
+ dst[row * ncols + col] = scale * x[row * ncols + col];
224
+ }
225
+ }
226
+
227
  static void norm_f32_sycl(const float* x, float* dst, const int ncols,
228
  const int nrows, const float eps,
229
  queue_ptr stream, int device) {
 
355
  }
356
  }
357
 
358
+ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
359
+ const int nrows, const float eps,
360
+ queue_ptr stream, int device) {
361
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
362
+ // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
363
+ if (ncols < 1024) {
364
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
365
+ stream->submit([&](sycl::handler& cgh) {
366
+ cgh.parallel_for(
367
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
368
+ block_dims),
369
+ [=](sycl::nd_item<3> item_ct1)
370
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
371
+ l2_norm_f32(x, dst, ncols, eps, item_ct1,
372
+ nullptr, WARP_SIZE);
373
+ });
374
+ });
375
+ }
376
+ else {
377
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
378
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
379
+ const sycl::range<3> block_dims(1, 1, work_group_size);
380
+ /*
381
+ DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
382
+ the limit. To get the device limit, query
383
+ info::device::max_work_group_size. Adjust the work-group size if needed.
384
+ */
385
+ stream->submit([&](sycl::handler& cgh) {
386
+ sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
387
+ cgh);
388
+ cgh.parallel_for(
389
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
390
+ block_dims),
391
+ [=](sycl::nd_item<3> item_ct1)
392
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
393
+ l2_norm_f32(x, dst, ncols, eps, item_ct1,
394
+ get_pointer(s_sum_acc_ct1), work_group_size);
395
+ });
396
+ });
397
+ }
398
+ }
399
+
400
  void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
401
  ggml_tensor* dst, const float* src0_dd,
402
  const float* src1_dd, float* dst_dd,
 
462
  (void)dst;
463
  (void)src1_dd;
464
  }
465
+
466
+ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
467
+ const ggml_tensor* src1, ggml_tensor* dst,
468
+ const float* src0_dd, const float* src1_dd,
469
+ float* dst_dd,
470
+ const queue_ptr& main_stream) {
471
+
472
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
473
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
474
+
475
+ const int64_t ne00 = src0->ne[0];
476
+ const int64_t nrows = ggml_nrows(src0);
477
+
478
+ float eps;
479
+ memcpy(&eps, dst->op_params, sizeof(float));
480
+
481
+ l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
482
+
483
+ (void)src1;
484
+ (void)dst;
485
+ (void)src1_dd;
486
+ }
ggml/src/ggml-sycl/norm.hpp CHANGED
@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
32
  float* dst_dd,
33
  const queue_ptr& main_stream);
34
 
 
 
 
 
 
 
35
  #endif // GGML_SYCL_NORM_HPP
 
32
  float* dst_dd,
33
  const queue_ptr& main_stream);
34
 
35
+ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
36
+ const ggml_tensor* src1, ggml_tensor* dst,
37
+ const float* src0_dd, const float* src1_dd,
38
+ float* dst_dd,
39
+ const queue_ptr& main_stream);
40
+
41
  #endif // GGML_SYCL_NORM_HPP
ggml/src/ggml-sycl/wkv.cpp ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <sycl/sycl.hpp>
2
+ #include "wkv.hpp"
3
+
4
+ constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
5
+
6
+ // Helper function for the main kernel
7
+ template <int block_size>
8
+ static void rwkv_wkv6_f32_kernel(
9
+ const int B, const int T, const int C, const int H,
10
+ const float* k, const float* v, const float* r,
11
+ const float* tf, const float* td, const float* s,
12
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
13
+
14
+ const int tid = item_ct1.get_local_id(2);
15
+ const int bid = item_ct1.get_group(2);
16
+
17
+ const int head_size = block_size;
18
+ const int batch_i = bid / H;
19
+ const int head_i = bid % H;
20
+ const int state_size = C * head_size;
21
+ const int n_seq_tokens = T / B;
22
+
23
+ // Set up shared memory pointers
24
+ float* _k = shared_mem;
25
+ float* _r = _k + head_size;
26
+ float* _tf = _r + head_size;
27
+ float* _td = _tf + head_size;
28
+
29
+ // Local state array
30
+ float state[block_size];
31
+
32
+ // Load initial state
33
+ #pragma unroll
34
+ for (int i = 0; i < head_size; i++) {
35
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
36
+ }
37
+
38
+ // Sync threads before shared memory operations
39
+ item_ct1.barrier(sycl::access::fence_space::local_space);
40
+
41
+ // Load time-mixing parameters
42
+ _tf[tid] = tf[head_i * head_size + tid];
43
+ item_ct1.barrier(sycl::access::fence_space::local_space);
44
+
45
+ // Main sequence processing loop
46
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
47
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
48
+ t += C) {
49
+
50
+ item_ct1.barrier(sycl::access::fence_space::local_space);
51
+
52
+ // Load current timestep data to shared memory
53
+ _k[tid] = k[t];
54
+ _r[tid] = r[t];
55
+ _td[tid] = td[t];
56
+
57
+ item_ct1.barrier(sycl::access::fence_space::local_space);
58
+
59
+ const float _v = v[t];
60
+ float y = 0;
61
+
62
+ // Process in chunks of 4 for better vectorization
63
+ sycl::float4 k4, r4, tf4, td4, s4;
64
+ #pragma unroll
65
+ for (int j = 0; j < head_size; j += 4) {
66
+ // Load data in vec4 chunks
67
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
68
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
69
+ tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
70
+ td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
71
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
72
+
73
+ // Compute key-value product
74
+ sycl::float4 kv4 = k4 * _v;
75
+
76
+ // Accumulate weighted sum
77
+ y += sycl::dot(r4, tf4 * kv4 + s4);
78
+
79
+ // Update state
80
+ s4 = s4 * td4 + kv4;
81
+
82
+ // Store updated state
83
+ state[j] = s4.x();
84
+ state[j+1] = s4.y();
85
+ state[j+2] = s4.z();
86
+ state[j+3] = s4.w();
87
+ }
88
+
89
+ dst[t] = y;
90
+ }
91
+
92
+ // Save final state
93
+ #pragma unroll
94
+ for (int i = 0; i < head_size; i++) {
95
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
96
+ }
97
+ }
98
+
99
+ template <int block_size>
100
+ static void rwkv_wkv7_f32_kernel(
101
+ const int B, const int T, const int C, const int H,
102
+ const float* r, const float* w, const float* k, const float* v,
103
+ const float* a, const float* b, const float* s,
104
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
105
+
106
+ const int tid = item_ct1.get_local_id(2);
107
+ const int bid = item_ct1.get_group(2);
108
+
109
+ const int head_size = block_size;
110
+ const int batch_i = bid / H;
111
+ const int head_i = bid % H;
112
+ const int state_size = C * head_size;
113
+ const int n_seq_tokens = T / B;
114
+
115
+ float* _r = shared_mem;
116
+ float* _w = _r + head_size;
117
+ float* _k = _w + head_size;
118
+ float* _a = _k + head_size;
119
+ float* _b = _a + head_size;
120
+
121
+ float state[block_size];
122
+
123
+ #pragma unroll
124
+ for (int i = 0; i < head_size; i++) {
125
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
126
+ }
127
+
128
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
129
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
130
+ t += C) {
131
+
132
+ item_ct1.barrier(sycl::access::fence_space::local_space);
133
+
134
+ _r[tid] = r[t];
135
+ _w[tid] = w[t];
136
+ _k[tid] = k[t];
137
+ _a[tid] = a[t];
138
+ _b[tid] = b[t];
139
+
140
+ item_ct1.barrier(sycl::access::fence_space::local_space);
141
+
142
+ const float _v = v[t];
143
+ float y = 0, sa = 0;
144
+ sycl::float4 a4, s4;
145
+
146
+ #pragma unroll
147
+ for (int j = 0; j < head_size; j += 4) {
148
+ a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
149
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
150
+ sa += sycl::dot(a4, s4);
151
+ }
152
+
153
+ sycl::float4 r4, w4, k4, b4;
154
+ #pragma unroll
155
+ for (int j = 0; j < head_size; j += 4) {
156
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
157
+ w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
158
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
159
+ b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
160
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
161
+
162
+ sycl::float4 kv4 = k4 * _v;
163
+
164
+ s4 = s4 * w4 + kv4 + sa * b4;
165
+ y += sycl::dot(r4, s4);
166
+
167
+ state[j] = s4.x();
168
+ state[j+1] = s4.y();
169
+ state[j+2] = s4.z();
170
+ state[j+3] = s4.w();
171
+ }
172
+
173
+ dst[t] = y;
174
+ }
175
+
176
+ #pragma unroll
177
+ for (int i = 0; i < head_size; i++) {
178
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
179
+ }
180
+ }
181
+
182
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
183
+
184
+ const ggml_tensor *src0 = dst->src[0];
185
+ const ggml_tensor *src1 = dst->src[1];
186
+
187
+ const float* k_d = (const float*)dst->src[0]->data;
188
+ const float* v_d = (const float*)dst->src[1]->data;
189
+ const float* r_d = (const float*)dst->src[2]->data;
190
+ const float* tf_d = (const float*)dst->src[3]->data;
191
+ const float* td_d = (const float*)dst->src[4]->data;
192
+ const float* s_d = (const float*)dst->src[5]->data;
193
+ float* dst_d = (float*)dst->data;
194
+
195
+ const int64_t B = dst->src[5]->ne[1];
196
+ const int64_t T = dst->src[0]->ne[2];
197
+ const int64_t C = dst->ne[0];
198
+ const int64_t H = dst->src[0]->ne[1];
199
+
200
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
201
+ GGML_ASSERT(C % H == 0);
202
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
203
+
204
+ dpct::queue_ptr stream = ctx.stream();
205
+
206
+ // Calculate execution configuration
207
+ const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
208
+ sycl::range<3> block_dims(1, 1, C / H);
209
+ sycl::range<3> grid_dims(1, 1, B * H);
210
+
211
+ // Submit kernel
212
+ if (C / H == WKV_BLOCK_SIZE) {
213
+ stream->submit([&](sycl::handler& cgh) {
214
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
215
+
216
+ cgh.parallel_for(
217
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
218
+ [=](sycl::nd_item<3> item_ct1) {
219
+ rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
220
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
221
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
222
+ );
223
+ });
224
+ });
225
+ } else {
226
+ stream->submit([&](sycl::handler& cgh) {
227
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
228
+
229
+ cgh.parallel_for(
230
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
231
+ [=](sycl::nd_item<3> item_ct1) {
232
+ rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
233
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
234
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
235
+ );
236
+ });
237
+ });
238
+ }
239
+
240
+ GGML_UNUSED(src0);
241
+ GGML_UNUSED(src1);
242
+ }
243
+
244
+ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
245
+
246
+ const ggml_tensor *src0 = dst->src[0];
247
+ const ggml_tensor *src1 = dst->src[1];
248
+
249
+ const float* r_d = (const float*)dst->src[0]->data;
250
+ const float* w_d = (const float*)dst->src[1]->data;
251
+ const float* k_d = (const float*)dst->src[2]->data;
252
+ const float* v_d = (const float*)dst->src[3]->data;
253
+ const float* a_d = (const float*)dst->src[4]->data;
254
+ const float* b_d = (const float*)dst->src[5]->data;
255
+ const float* s_d = (const float*)dst->src[6]->data;
256
+ float* dst_d = (float*)dst->data;
257
+
258
+ const int64_t B = dst->src[6]->ne[1];
259
+ const int64_t T = dst->src[0]->ne[2];
260
+ const int64_t C = dst->ne[0];
261
+ const int64_t H = dst->src[0]->ne[1];
262
+
263
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
264
+ GGML_ASSERT(C % H == 0);
265
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
266
+
267
+ dpct::queue_ptr stream = ctx.stream();
268
+
269
+ // Calculate execution configuration
270
+ const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
271
+ sycl::range<3> block_dims(1, 1, C / H);
272
+ sycl::range<3> grid_dims(1, 1, B * H);
273
+
274
+ // Submit kernel
275
+ if (C / H == WKV_BLOCK_SIZE) {
276
+ stream->submit([&](sycl::handler& cgh) {
277
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
278
+
279
+ cgh.parallel_for(
280
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
281
+ [=](sycl::nd_item<3> item_ct1) {
282
+ rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
283
+ B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
284
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
285
+ );
286
+ });
287
+ });
288
+ } else {
289
+ stream->submit([&](sycl::handler& cgh) {
290
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
291
+
292
+ cgh.parallel_for(
293
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
294
+ [=](sycl::nd_item<3> item_ct1) {
295
+ rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
296
+ B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
297
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
298
+ );
299
+ });
300
+ });
301
+ }
302
+
303
+ GGML_UNUSED(src0);
304
+ GGML_UNUSED(src1);
305
+ }
ggml/src/ggml-sycl/wkv.hpp ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef GGML_SYCL_WKV_HPP
2
+ #define GGML_SYCL_WKV_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7
+
8
+ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
9
+
10
+ #endif // GGML_SYCL_WKV_HPP
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -304,6 +304,7 @@ struct vk_device_struct {
304
  vk_pipeline pipeline_group_norm_f32;
305
  vk_pipeline pipeline_rms_norm_f32;
306
  vk_pipeline pipeline_rms_norm_back_f32;
 
307
  vk_pipeline pipeline_gelu_f32;
308
  vk_pipeline pipeline_gelu_quick_f32;
309
  vk_pipeline pipeline_silu_f32;
@@ -328,6 +329,7 @@ struct vk_device_struct {
328
  vk_pipeline pipeline_timestep_embedding_f32;
329
  vk_pipeline pipeline_pool2d_f32;
330
  vk_pipeline pipeline_rwkv_wkv6_f32;
 
331
  vk_pipeline pipeline_opt_step_adamw_f32;
332
 
333
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@@ -629,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
629
  uint32_t H;
630
  };
631
 
 
 
 
 
 
 
 
632
  // Allow pre-recording command buffers
633
  struct vk_staging_memcpy {
634
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -2263,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2263
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2264
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2265
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
2266
 
2267
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2268
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2374,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2374
 
2375
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2376
 
 
 
2377
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2378
 
2379
  for (auto &c : compiles) {
@@ -5473,6 +5485,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5473
  return ctx->device->pipeline_rms_norm_back_f32;
5474
  }
5475
  return nullptr;
 
 
 
 
 
5476
  case GGML_OP_UNARY:
5477
  switch (ggml_get_unary_op(dst)) {
5478
  case GGML_UNARY_OP_SILU:
@@ -5612,6 +5629,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5612
  return ctx->device->pipeline_rwkv_wkv6_f32;
5613
  }
5614
  return nullptr;
 
 
 
 
 
5615
  case GGML_OP_OPT_STEP_ADAMW:
5616
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5617
  return ctx->device->pipeline_opt_step_adamw_f32;
@@ -5859,6 +5881,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5859
  case GGML_OP_NORM:
5860
  case GGML_OP_RMS_NORM:
5861
  case GGML_OP_RMS_NORM_BACK:
 
5862
  case GGML_OP_SOFT_MAX:
5863
  case GGML_OP_SOFT_MAX_BACK:
5864
  case GGML_OP_SUM_ROWS:
@@ -6108,23 +6131,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
6108
  }, dryrun);
6109
  }
6110
 
6111
- static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
6112
- const ggml_tensor * k = dst->src[0];
6113
- const ggml_tensor * v = dst->src[1];
6114
- const ggml_tensor * r = dst->src[2];
6115
- const ggml_tensor * tf = dst->src[3];
6116
- const ggml_tensor * td = dst->src[4];
6117
- const ggml_tensor * state = dst->src[5];
6118
-
6119
- GGML_ASSERT(!ggml_is_quantized(k->type));
6120
- GGML_ASSERT(!ggml_is_quantized(v->type));
6121
- GGML_ASSERT(!ggml_is_quantized(r->type));
6122
- GGML_ASSERT(!ggml_is_quantized(tf->type));
6123
- GGML_ASSERT(!ggml_is_quantized(td->type));
6124
- GGML_ASSERT(!ggml_is_quantized(state->type));
6125
  GGML_ASSERT(dst->buffer != nullptr);
6126
 
6127
- vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
6128
  GGML_ASSERT(pipeline != nullptr);
6129
 
6130
  if (dryrun) {
@@ -6133,89 +6150,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
6133
  }
6134
 
6135
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
6136
- ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
6137
- ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
6138
- ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
6139
- ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
6140
- ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
6141
- ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
6142
 
6143
  ggml_vk_sync_buffers(subctx);
6144
 
6145
- vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
6146
- size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
6147
- bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
6148
 
6149
  if (ctx->device->uma) {
6150
- ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
6151
- ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
6152
- ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
6153
- ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
6154
- ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
6155
- ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
6156
- ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6157
 
6158
- K_uma = d_K != nullptr;
6159
- V_uma = d_V != nullptr;
6160
- R_uma = d_R != nullptr;
6161
- TF_uma = d_TF != nullptr;
6162
- TD_uma = d_TD != nullptr;
6163
- STATE_uma = d_State != nullptr;
6164
- DST_uma = d_D != nullptr;
6165
  }
6166
 
6167
- if (!K_uma) {
6168
- d_K = k_buf_ctx->dev_buffer;
6169
- k_offset = vk_tensor_offset(k) + k->view_offs;
6170
- }
6171
- if (!V_uma) {
6172
- d_V = v_buf_ctx->dev_buffer;
6173
- v_offset = vk_tensor_offset(v) + v->view_offs;
6174
- }
6175
- if (!R_uma) {
6176
- d_R = r_buf_ctx->dev_buffer;
6177
- r_offset = vk_tensor_offset(r) + r->view_offs;
6178
- }
6179
- if (!TF_uma) {
6180
- d_TF = tf_buf_ctx->dev_buffer;
6181
- tf_offset = vk_tensor_offset(tf) + tf->view_offs;
6182
- }
6183
- if (!TD_uma) {
6184
- d_TD = td_buf_ctx->dev_buffer;
6185
- td_offset = vk_tensor_offset(td) + td->view_offs;
6186
- }
6187
- if (!STATE_uma) {
6188
- d_State = state_buf_ctx->dev_buffer;
6189
- state_offset = vk_tensor_offset(state) + state->view_offs;
6190
  }
6191
- if (!DST_uma) {
 
 
6192
  d_D = dst_buf_ctx->dev_buffer;
6193
  dst_offset = vk_tensor_offset(dst) + dst->view_offs;
6194
  }
6195
 
6196
- const uint64_t k_size = ggml_nbytes(k);
6197
- const uint64_t v_size = ggml_nbytes(v);
6198
- const uint64_t r_size = ggml_nbytes(r);
6199
- const uint64_t tf_size = ggml_nbytes(tf);
6200
- const uint64_t td_size = ggml_nbytes(td);
6201
- const uint64_t state_size = ggml_nbytes(state);
6202
- const uint64_t dst_size = ggml_nbytes(dst);
6203
-
6204
  std::array<uint32_t, 3> elements = {
6205
  (uint32_t)(pc.B * pc.H),
6206
  1,
6207
  1
6208
  };
6209
 
6210
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6211
- vk_subbuffer{ d_K, k_offset, k_size },
6212
- vk_subbuffer{ d_V, v_offset, v_size },
6213
- vk_subbuffer{ d_R, r_offset, r_size },
6214
- vk_subbuffer{ d_TF, tf_offset, tf_size },
6215
- vk_subbuffer{ d_TD, td_offset, td_size },
6216
- vk_subbuffer{ d_State, state_offset, state_size },
6217
- vk_subbuffer{ d_D, dst_offset, dst_size }
6218
- }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6219
  }
6220
 
6221
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@@ -6224,7 +6225,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6224
  const size_t n_heads = dst->src[0]->ne[1];
6225
  const size_t n_seqs = dst->src[5]->ne[1];
6226
 
6227
- ggml_vk_op_f32_rwkv6(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6228
  ctx, subctx, dst,
6229
  {
6230
  (uint32_t)n_seqs,
@@ -6232,6 +6252,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6232
  (uint32_t)n_embed,
6233
  (uint32_t)n_heads,
6234
  },
 
6235
  dryrun
6236
  );
6237
  }
@@ -6533,6 +6554,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
6533
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6534
  }
6535
 
 
 
 
 
 
6536
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6537
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6538
  }
@@ -7528,6 +7554,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7528
  case GGML_OP_GROUP_NORM:
7529
  case GGML_OP_RMS_NORM:
7530
  case GGML_OP_RMS_NORM_BACK:
 
7531
  case GGML_OP_DIAG_MASK_INF:
7532
  case GGML_OP_SOFT_MAX:
7533
  case GGML_OP_SOFT_MAX_BACK:
@@ -7544,6 +7571,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7544
  case GGML_OP_TIMESTEP_EMBEDDING:
7545
  case GGML_OP_POOL_2D:
7546
  case GGML_OP_RWKV_WKV6:
 
7547
  case GGML_OP_LEAKY_RELU:
7548
  case GGML_OP_FLASH_ATTN_EXT:
7549
  case GGML_OP_OPT_STEP_ADAMW:
@@ -7590,6 +7618,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7590
  case GGML_OP_GROUP_NORM:
7591
  case GGML_OP_RMS_NORM:
7592
  case GGML_OP_RMS_NORM_BACK:
 
7593
  case GGML_OP_UNARY:
7594
  case GGML_OP_DIAG_MASK_INF:
7595
  case GGML_OP_SOFT_MAX:
@@ -7707,6 +7736,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7707
  case GGML_OP_RMS_NORM_BACK:
7708
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7709
 
 
 
 
 
7710
  break;
7711
  case GGML_OP_UNARY:
7712
  switch (ggml_get_unary_op(node)) {
@@ -7797,6 +7830,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7797
 
7798
  break;
7799
 
 
 
 
 
 
7800
  case GGML_OP_OPT_STEP_ADAMW:
7801
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7802
 
@@ -7870,6 +7908,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7870
  case GGML_OP_GROUP_NORM:
7871
  case GGML_OP_RMS_NORM:
7872
  case GGML_OP_RMS_NORM_BACK:
 
7873
  case GGML_OP_DIAG_MASK_INF:
7874
  case GGML_OP_SOFT_MAX:
7875
  case GGML_OP_SOFT_MAX_BACK:
@@ -7889,6 +7928,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7889
  case GGML_OP_TIMESTEP_EMBEDDING:
7890
  case GGML_OP_POOL_2D:
7891
  case GGML_OP_RWKV_WKV6:
 
7892
  case GGML_OP_LEAKY_RELU:
7893
  case GGML_OP_REPEAT:
7894
  case GGML_OP_REPEAT_BACK:
@@ -8806,6 +8846,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8806
  case GGML_OP_NORM:
8807
  case GGML_OP_GROUP_NORM:
8808
  case GGML_OP_RMS_NORM:
 
8809
  return ggml_is_contiguous(op->src[0]);
8810
  case GGML_OP_ADD:
8811
  case GGML_OP_SUB:
@@ -8835,6 +8876,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8835
  case GGML_OP_TIMESTEP_EMBEDDING:
8836
  case GGML_OP_POOL_2D:
8837
  case GGML_OP_RWKV_WKV6:
 
8838
  case GGML_OP_LEAKY_RELU:
8839
  case GGML_OP_OPT_STEP_ADAMW:
8840
  return true;
@@ -9219,6 +9261,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9219
  tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9220
  } else if (tensor->op == GGML_OP_SILU_BACK) {
9221
  tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
 
 
 
9222
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9223
  if (src1 != nullptr) {
9224
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
@@ -9338,6 +9383,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9338
  } else if (tensor->op == GGML_OP_RWKV_WKV6) {
9339
  tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
9340
  src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
 
 
 
9341
  } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9342
  src_clone[0]->flags = src0->flags;
9343
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
 
304
  vk_pipeline pipeline_group_norm_f32;
305
  vk_pipeline pipeline_rms_norm_f32;
306
  vk_pipeline pipeline_rms_norm_back_f32;
307
+ vk_pipeline pipeline_l2_norm_f32;
308
  vk_pipeline pipeline_gelu_f32;
309
  vk_pipeline pipeline_gelu_quick_f32;
310
  vk_pipeline pipeline_silu_f32;
 
329
  vk_pipeline pipeline_timestep_embedding_f32;
330
  vk_pipeline pipeline_pool2d_f32;
331
  vk_pipeline pipeline_rwkv_wkv6_f32;
332
+ vk_pipeline pipeline_rwkv_wkv7_f32;
333
  vk_pipeline pipeline_opt_step_adamw_f32;
334
 
335
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
 
631
  uint32_t H;
632
  };
633
 
634
+ struct vk_op_rwkv_wkv7_push_constants {
635
+ uint32_t B;
636
+ uint32_t T;
637
+ uint32_t C;
638
+ uint32_t H;
639
+ };
640
+
641
  // Allow pre-recording command buffers
642
  struct vk_staging_memcpy {
643
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
 
2272
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2273
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2274
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2275
+ ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2276
 
2277
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2278
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
2384
 
2385
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2386
 
2387
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2388
+
2389
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2390
 
2391
  for (auto &c : compiles) {
 
5485
  return ctx->device->pipeline_rms_norm_back_f32;
5486
  }
5487
  return nullptr;
5488
+ case GGML_OP_L2_NORM:
5489
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5490
+ return ctx->device->pipeline_l2_norm_f32;
5491
+ }
5492
+ return nullptr;
5493
  case GGML_OP_UNARY:
5494
  switch (ggml_get_unary_op(dst)) {
5495
  case GGML_UNARY_OP_SILU:
 
5629
  return ctx->device->pipeline_rwkv_wkv6_f32;
5630
  }
5631
  return nullptr;
5632
+ case GGML_OP_RWKV_WKV7:
5633
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5634
+ return ctx->device->pipeline_rwkv_wkv7_f32;
5635
+ }
5636
+ return nullptr;
5637
  case GGML_OP_OPT_STEP_ADAMW:
5638
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5639
  return ctx->device->pipeline_opt_step_adamw_f32;
 
5881
  case GGML_OP_NORM:
5882
  case GGML_OP_RMS_NORM:
5883
  case GGML_OP_RMS_NORM_BACK:
5884
+ case GGML_OP_L2_NORM:
5885
  case GGML_OP_SOFT_MAX:
5886
  case GGML_OP_SOFT_MAX_BACK:
5887
  case GGML_OP_SUM_ROWS:
 
6131
  }, dryrun);
6132
  }
6133
 
6134
+ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
6135
+ GGML_ASSERT(version == 6 || version == 7);
6136
+ int num_srcs = version == 6 ? 6 : 7;
6137
+
6138
+ for (int i = 0; i < num_srcs; i++) {
6139
+ GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
6140
+ }
6141
+
 
 
 
 
 
 
6142
  GGML_ASSERT(dst->buffer != nullptr);
6143
 
6144
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
6145
  GGML_ASSERT(pipeline != nullptr);
6146
 
6147
  if (dryrun) {
 
6150
  }
6151
 
6152
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
6153
+ ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6154
+ for (int i = 0; i < num_srcs; i++) {
6155
+ src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
6156
+ }
 
 
6157
 
6158
  ggml_vk_sync_buffers(subctx);
6159
 
6160
+ vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6161
+ size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
6162
+ bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
6163
 
6164
  if (ctx->device->uma) {
6165
+ for (int i = 0; i < num_srcs; i++) {
6166
+ ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
6167
+ srcs_uma[i] = d_srcs[i] != nullptr;
6168
+ }
 
 
 
6169
 
6170
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6171
+ dst_uma = d_D != nullptr;
 
 
 
 
 
6172
  }
6173
 
6174
+ uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
6175
+ for (int i = 0; i < num_srcs; i++) {
6176
+ src_sizes[i] = ggml_nbytes(dst->src[i]);
6177
+ if (!srcs_uma[i]) {
6178
+ d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
6179
+ src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
6180
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6181
  }
6182
+
6183
+ const uint64_t dst_size = ggml_nbytes(dst);
6184
+ if (!dst_uma) {
6185
  d_D = dst_buf_ctx->dev_buffer;
6186
  dst_offset = vk_tensor_offset(dst) + dst->view_offs;
6187
  }
6188
 
 
 
 
 
 
 
 
 
6189
  std::array<uint32_t, 3> elements = {
6190
  (uint32_t)(pc.B * pc.H),
6191
  1,
6192
  1
6193
  };
6194
 
6195
+ if (version == 6) {
6196
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6197
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6198
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6199
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6200
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6201
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6202
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6203
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6204
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
6205
+ } else if (version == 7) {
6206
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6207
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6208
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6209
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6210
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6211
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6212
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6213
+ vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
6214
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6215
+ }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
6216
+ } else {
6217
+ // shouldn't happen
6218
+ GGML_ASSERT(false);
6219
+ }
6220
  }
6221
 
6222
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
 
6225
  const size_t n_heads = dst->src[0]->ne[1];
6226
  const size_t n_seqs = dst->src[5]->ne[1];
6227
 
6228
+ ggml_vk_op_f32_wkv(
6229
+ ctx, subctx, dst,
6230
+ {
6231
+ (uint32_t)n_seqs,
6232
+ (uint32_t)seq_length,
6233
+ (uint32_t)n_embed,
6234
+ (uint32_t)n_heads,
6235
+ },
6236
+ 6,
6237
+ dryrun
6238
+ );
6239
+ }
6240
+
6241
+ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6242
+ const size_t seq_length = dst->src[0]->ne[2];
6243
+ const size_t n_embed = dst->ne[0];
6244
+ const size_t n_heads = dst->src[0]->ne[1];
6245
+ const size_t n_seqs = dst->src[6]->ne[1];
6246
+
6247
+ ggml_vk_op_f32_wkv(
6248
  ctx, subctx, dst,
6249
  {
6250
  (uint32_t)n_seqs,
 
6252
  (uint32_t)n_embed,
6253
  (uint32_t)n_heads,
6254
  },
6255
+ 7,
6256
  dryrun
6257
  );
6258
  }
 
6554
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6555
  }
6556
 
6557
+ static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6558
+ float * op_params = (float *)dst->op_params;
6559
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6560
+ }
6561
+
6562
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6563
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6564
  }
 
7554
  case GGML_OP_GROUP_NORM:
7555
  case GGML_OP_RMS_NORM:
7556
  case GGML_OP_RMS_NORM_BACK:
7557
+ case GGML_OP_L2_NORM:
7558
  case GGML_OP_DIAG_MASK_INF:
7559
  case GGML_OP_SOFT_MAX:
7560
  case GGML_OP_SOFT_MAX_BACK:
 
7571
  case GGML_OP_TIMESTEP_EMBEDDING:
7572
  case GGML_OP_POOL_2D:
7573
  case GGML_OP_RWKV_WKV6:
7574
+ case GGML_OP_RWKV_WKV7:
7575
  case GGML_OP_LEAKY_RELU:
7576
  case GGML_OP_FLASH_ATTN_EXT:
7577
  case GGML_OP_OPT_STEP_ADAMW:
 
7618
  case GGML_OP_GROUP_NORM:
7619
  case GGML_OP_RMS_NORM:
7620
  case GGML_OP_RMS_NORM_BACK:
7621
+ case GGML_OP_L2_NORM:
7622
  case GGML_OP_UNARY:
7623
  case GGML_OP_DIAG_MASK_INF:
7624
  case GGML_OP_SOFT_MAX:
 
7736
  case GGML_OP_RMS_NORM_BACK:
7737
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7738
 
7739
+ break;
7740
+ case GGML_OP_L2_NORM:
7741
+ ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
7742
+
7743
  break;
7744
  case GGML_OP_UNARY:
7745
  switch (ggml_get_unary_op(node)) {
 
7830
 
7831
  break;
7832
 
7833
+ case GGML_OP_RWKV_WKV7:
7834
+ ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
7835
+
7836
+ break;
7837
+
7838
  case GGML_OP_OPT_STEP_ADAMW:
7839
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7840
 
 
7908
  case GGML_OP_GROUP_NORM:
7909
  case GGML_OP_RMS_NORM:
7910
  case GGML_OP_RMS_NORM_BACK:
7911
+ case GGML_OP_L2_NORM:
7912
  case GGML_OP_DIAG_MASK_INF:
7913
  case GGML_OP_SOFT_MAX:
7914
  case GGML_OP_SOFT_MAX_BACK:
 
7928
  case GGML_OP_TIMESTEP_EMBEDDING:
7929
  case GGML_OP_POOL_2D:
7930
  case GGML_OP_RWKV_WKV6:
7931
+ case GGML_OP_RWKV_WKV7:
7932
  case GGML_OP_LEAKY_RELU:
7933
  case GGML_OP_REPEAT:
7934
  case GGML_OP_REPEAT_BACK:
 
8846
  case GGML_OP_NORM:
8847
  case GGML_OP_GROUP_NORM:
8848
  case GGML_OP_RMS_NORM:
8849
+ case GGML_OP_L2_NORM:
8850
  return ggml_is_contiguous(op->src[0]);
8851
  case GGML_OP_ADD:
8852
  case GGML_OP_SUB:
 
8876
  case GGML_OP_TIMESTEP_EMBEDDING:
8877
  case GGML_OP_POOL_2D:
8878
  case GGML_OP_RWKV_WKV6:
8879
+ case GGML_OP_RWKV_WKV7:
8880
  case GGML_OP_LEAKY_RELU:
8881
  case GGML_OP_OPT_STEP_ADAMW:
8882
  return true;
 
9261
  tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9262
  } else if (tensor->op == GGML_OP_SILU_BACK) {
9263
  tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
9264
+ } else if (tensor->op == GGML_OP_L2_NORM) {
9265
+ const float eps = ((float *) tensor->op_params)[0];
9266
+ tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9267
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9268
  if (src1 != nullptr) {
9269
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
 
9383
  } else if (tensor->op == GGML_OP_RWKV_WKV6) {
9384
  tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
9385
  src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
9386
+ } else if (tensor->op == GGML_OP_RWKV_WKV7) {
9387
+ tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
9388
+ src_clone[4], src_clone[5], src_clone[6]);
9389
  } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9390
  src_clone[0]->flags = src0->flags;
9391
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "generic_head.comp"
4
+ #include "types.comp"
5
+
6
+ #extension GL_EXT_control_flow_attributes : enable
7
+ #define BLOCK_SIZE 512
8
+
9
+ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
+
11
+ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
12
+ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
13
+
14
+ shared FLOAT_TYPE sum[BLOCK_SIZE];
15
+
16
+ void main() {
17
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
18
+ const uint tid = gl_LocalInvocationID.x;
19
+
20
+ sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
21
+
22
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
23
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
24
+ sum[tid] += xi * xi;
25
+ }
26
+
27
+ // sum up partial sums and write back result
28
+ barrier();
29
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
30
+ if (tid < s) {
31
+ sum[tid] += sum[tid + s];
32
+ }
33
+ barrier();
34
+ }
35
+
36
+ const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
37
+
38
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
39
+ data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
40
+ }
41
+ }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -434,6 +434,7 @@ void process_shaders() {
434
  string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
435
  string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
436
  string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 
437
 
438
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
439
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@@ -528,6 +529,8 @@ void process_shaders() {
528
 
529
  string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
530
 
 
 
531
  string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
532
 
533
  for (auto &c : compiles) {
 
434
  string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
435
  string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
436
  string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
437
+ string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
438
 
439
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
440
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
 
529
 
530
  string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
531
 
532
+ string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
533
+
534
  string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
535
 
536
  for (auto &c : compiles) {
ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : require
4
+
5
+ #define BLOCK_SIZE 64
6
+ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
7
+
8
+ layout(push_constant) uniform Parameters {
9
+ uint B;
10
+ uint T;
11
+ uint C;
12
+ uint H;
13
+ };
14
+
15
+ layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
16
+ layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
17
+ layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
18
+ layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
19
+ layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
20
+ layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
21
+ layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
22
+ layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
23
+
24
+ shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
25
+
26
+ void main() {
27
+ const uint head_size = BLOCK_SIZE;
28
+ const uint batch_id = gl_WorkGroupID.x / H;
29
+ const uint head_id = gl_WorkGroupID.x % H;
30
+ const uint tid = gl_LocalInvocationID.x;
31
+
32
+ const uint state_size = C * head_size;
33
+ const uint n_seq_tokens = T / B;
34
+
35
+ if (batch_id >= B || head_id >= H) {
36
+ return;
37
+ }
38
+
39
+ A_TYPE state[BLOCK_SIZE];
40
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
41
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
42
+ + tid * head_size + i];
43
+ }
44
+
45
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
46
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
47
+
48
+ for (uint t = start_t; t < end_t; t += C) {
49
+ barrier();
50
+ _r[tid] = r[t];
51
+ _w[tid] = w[t];
52
+ _k[tid] = k[t];
53
+ _a[tid] = a[t];
54
+ _b[tid] = b[t];
55
+ barrier();
56
+
57
+ A_TYPE sa = 0.0;
58
+ [[unroll]] for (uint j = 0; j < head_size; j += 4) {
59
+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
60
+ vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
61
+ sa += dot(s_vec, a_vec);
62
+ }
63
+
64
+ const A_TYPE v_val = v[t];
65
+ A_TYPE y = 0.0;
66
+
67
+ [[unroll]] for (uint j = 0; j < head_size; j += 4) {
68
+ vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
69
+ vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
70
+ vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
71
+ vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
72
+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
73
+
74
+ vec4 kv = k_vec * v_val;
75
+ s_vec = s_vec * w_vec + kv + sa * b_vec;
76
+ y += dot(r_vec, s_vec);
77
+
78
+ state[j] = s_vec.x;
79
+ state[j+1] = s_vec.y;
80
+ state[j+2] = s_vec.z;
81
+ state[j+3] = s_vec.w;
82
+ }
83
+
84
+ dst[t] = y;
85
+ }
86
+
87
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
88
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
89
+ + tid * head_size + i] = state[i];
90
+ }
91
+ }
ggml/src/ggml.c CHANGED
@@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
929
  "RMS_NORM",
930
  "RMS_NORM_BACK",
931
  "GROUP_NORM",
 
932
 
933
  "MUL_MAT",
934
  "MUL_MAT_ID",
@@ -977,6 +978,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
977
  "ADD_REL_POS",
978
  "RWKV_WKV6",
979
  "GATED_LINEAR_ATTN",
 
980
 
981
  "UNARY",
982
 
@@ -996,7 +998,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
996
  "OPT_STEP_ADAMW",
997
  };
998
 
999
- static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
1000
 
1001
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1002
  "none",
@@ -1026,6 +1028,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1026
  "rms_norm(x)",
1027
  "rms_norm_back(x)",
1028
  "group_norm(x)",
 
1029
 
1030
  "X*Y",
1031
  "X[i]*Y",
@@ -1074,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1074
  "add_rel_pos(x)",
1075
  "rwkv_wkv6(k, v, r, tf, td, s)",
1076
  "gated_linear_attn(k, v, q, gate, s)",
 
1077
 
1078
  "unary(x)",
1079
 
@@ -1093,7 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1093
  "adamw(x)",
1094
  };
1095
 
1096
- static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
1097
 
1098
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1099
 
@@ -2686,6 +2690,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
2686
  return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
2687
  }
2688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2689
  // ggml_mul_mat
2690
 
2691
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4720,6 +4755,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
4720
  return result;
4721
  }
4722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4723
  // ggml_unary
4724
 
4725
  static struct ggml_tensor * ggml_unary_impl(
 
929
  "RMS_NORM",
930
  "RMS_NORM_BACK",
931
  "GROUP_NORM",
932
+ "L2_NORM",
933
 
934
  "MUL_MAT",
935
  "MUL_MAT_ID",
 
978
  "ADD_REL_POS",
979
  "RWKV_WKV6",
980
  "GATED_LINEAR_ATTN",
981
+ "RWKV_WKV7",
982
 
983
  "UNARY",
984
 
 
998
  "OPT_STEP_ADAMW",
999
  };
1000
 
1001
+ static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1002
 
1003
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1004
  "none",
 
1028
  "rms_norm(x)",
1029
  "rms_norm_back(x)",
1030
  "group_norm(x)",
1031
+ "l2_norm(x)",
1032
 
1033
  "X*Y",
1034
  "X[i]*Y",
 
1077
  "add_rel_pos(x)",
1078
  "rwkv_wkv6(k, v, r, tf, td, s)",
1079
  "gated_linear_attn(k, v, q, gate, s)",
1080
+ "rwkv_wkv7(r, w, k, v, a, b, s)",
1081
 
1082
  "unary(x)",
1083
 
 
1097
  "adamw(x)",
1098
  };
1099
 
1100
+ static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1101
 
1102
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1103
 
 
2690
  return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
2691
  }
2692
 
2693
+ // ggml_l2_norm
2694
+
2695
+ static struct ggml_tensor * ggml_l2_norm_impl(
2696
+ struct ggml_context * ctx,
2697
+ struct ggml_tensor * a,
2698
+ float eps,
2699
+ bool inplace) {
2700
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2701
+
2702
+ ggml_set_op_params_f32(result, 0, eps);
2703
+
2704
+ result->op = GGML_OP_L2_NORM;
2705
+ result->src[0] = a;
2706
+
2707
+ return result;
2708
+ }
2709
+
2710
+ struct ggml_tensor * ggml_l2_norm(
2711
+ struct ggml_context * ctx,
2712
+ struct ggml_tensor * a,
2713
+ float eps) {
2714
+ return ggml_l2_norm_impl(ctx, a, eps, false);
2715
+ }
2716
+
2717
+ struct ggml_tensor * ggml_l2_norm_inplace(
2718
+ struct ggml_context * ctx,
2719
+ struct ggml_tensor * a,
2720
+ float eps) {
2721
+ return ggml_l2_norm_impl(ctx, a, eps, true);
2722
+ }
2723
+
2724
  // ggml_mul_mat
2725
 
2726
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
 
4755
  return result;
4756
  }
4757
 
4758
+ // ggml_rwkv_wkv7
4759
+
4760
+ struct ggml_tensor * ggml_rwkv_wkv7(
4761
+ struct ggml_context * ctx,
4762
+ struct ggml_tensor * r,
4763
+ struct ggml_tensor * w,
4764
+ struct ggml_tensor * k,
4765
+ struct ggml_tensor * v,
4766
+ struct ggml_tensor * a,
4767
+ struct ggml_tensor * b,
4768
+ struct ggml_tensor * state) {
4769
+ GGML_ASSERT(ggml_is_contiguous(r));
4770
+ GGML_ASSERT(ggml_is_contiguous(w));
4771
+ GGML_ASSERT(ggml_is_contiguous(k));
4772
+ GGML_ASSERT(ggml_is_contiguous(v));
4773
+ GGML_ASSERT(ggml_is_contiguous(a));
4774
+ GGML_ASSERT(ggml_is_contiguous(b));
4775
+ GGML_ASSERT(ggml_is_contiguous(state));
4776
+
4777
+ const int64_t S = k->ne[0];
4778
+ const int64_t H = k->ne[1];
4779
+ const int64_t n_tokens = k->ne[2];
4780
+ const int64_t n_seqs = state->ne[1];
4781
+ {
4782
+ GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
4783
+ GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
4784
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4785
+ GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
4786
+ GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
4787
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4788
+ }
4789
+
4790
+ // concat output and new_state
4791
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4792
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4793
+
4794
+ result->op = GGML_OP_RWKV_WKV7;
4795
+ result->src[0] = r;
4796
+ result->src[1] = w;
4797
+ result->src[2] = k;
4798
+ result->src[3] = v;
4799
+ result->src[4] = a;
4800
+ result->src[5] = b;
4801
+ result->src[6] = state;
4802
+
4803
+ return result;
4804
+ }
4805
+
4806
  // ggml_unary
4807
 
4808
  static struct ggml_tensor * ggml_unary_impl(