mollysama ggerganov compilade commited on
Commit
4a6b7e0
·
1 Parent(s): fa23a38

llama: add support for QRWKV6 model architecture (llama/11001)

Browse files

llama: add support for QRWKV6 model architecture (llama/11001)

* WIP: Add support for RWKV6Qwen2

Signed-off-by: Molly Sophia <[email protected]>

* RWKV: Some graph simplification

Signed-off-by: Molly Sophia <[email protected]>

* Add support for RWKV6Qwen2 with cpu and cuda GLA

Signed-off-by: Molly Sophia <[email protected]>

* RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead

Signed-off-by: Molly Sophia <[email protected]>

* Fix some typos

Signed-off-by: Molly Sophia <[email protected]>

* code format changes

Signed-off-by: Molly Sophia <[email protected]>

* Fix wkv test & add gla test

Signed-off-by: Molly Sophia <[email protected]>

* Fix cuda warning

Signed-off-by: Molly Sophia <[email protected]>

* Update README.md

Signed-off-by: Molly Sophia <[email protected]>

* Update ggml/src/ggml-cuda/gla.cu

Co-authored-by: Georgi Gerganov <[email protected]>

* Fix fused lerp weights loading with RWKV6

Signed-off-by: Molly Sophia <[email protected]>

* better sanity check skipping for QRWKV6 in llama-quant

thanks

@compilade


Signed-off-by: Molly Sophia <[email protected]>
Co-authored-by: compilade <[email protected]>

---------

Signed-off-by: Molly Sophia <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: compilade <[email protected]>

ggml/include/ggml.h CHANGED
@@ -501,6 +501,7 @@ extern "C" {
501
  GGML_OP_GET_REL_POS,
502
  GGML_OP_ADD_REL_POS,
503
  GGML_OP_RWKV_WKV6,
 
504
 
505
  GGML_OP_UNARY,
506
 
@@ -1859,6 +1860,15 @@ extern "C" {
1859
  struct ggml_tensor * td,
1860
  struct ggml_tensor * state);
1861
 
 
 
 
 
 
 
 
 
 
1862
  // custom operators
1863
 
1864
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
 
501
  GGML_OP_GET_REL_POS,
502
  GGML_OP_ADD_REL_POS,
503
  GGML_OP_RWKV_WKV6,
504
+ GGML_OP_GATED_LINEAR_ATTN,
505
 
506
  GGML_OP_UNARY,
507
 
 
1860
  struct ggml_tensor * td,
1861
  struct ggml_tensor * state);
1862
 
1863
+ GGML_API struct ggml_tensor * ggml_gated_linear_attn(
1864
+ struct ggml_context * ctx,
1865
+ struct ggml_tensor * k,
1866
+ struct ggml_tensor * v,
1867
+ struct ggml_tensor * q,
1868
+ struct ggml_tensor * g,
1869
+ struct ggml_tensor * state,
1870
+ float scale);
1871
+
1872
  // custom operators
1873
 
1874
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
11803
  static void ggml_compute_forward_rwkv_wkv6_f32(
11804
  const struct ggml_compute_params * params,
11805
  struct ggml_tensor * dst) {
11806
- const int64_t T = dst->src[1]->ne[3];
11807
  const int64_t C = dst->ne[0];
11808
- const int64_t HEADS = dst->src[1]->ne[2];
11809
  const int64_t n_seqs = dst->src[5]->ne[1];
11810
  const int64_t head_size = C / HEADS;
11811
 
@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
12000
  }
12001
  }
12002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12003
  // ggml_compute_forward_map_unary
12004
 
12005
  static void ggml_compute_forward_map_unary_f32(
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12749
  {
12750
  ggml_compute_forward_rwkv_wkv6(params, tensor);
12751
  } break;
 
 
 
 
12752
  case GGML_OP_MAP_UNARY:
12753
  {
12754
  ggml_unary_op_f32_t fun;
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
13047
  case GGML_OP_WIN_UNPART:
13048
  case GGML_OP_GET_REL_POS:
13049
  case GGML_OP_RWKV_WKV6:
 
13050
  case GGML_OP_MAP_UNARY:
13051
  case GGML_OP_MAP_BINARY:
13052
  case GGML_OP_MAP_CUSTOM1_F32:
 
11803
  static void ggml_compute_forward_rwkv_wkv6_f32(
11804
  const struct ggml_compute_params * params,
11805
  struct ggml_tensor * dst) {
11806
+ const int64_t T = dst->src[1]->ne[2];
11807
  const int64_t C = dst->ne[0];
11808
+ const int64_t HEADS = dst->src[1]->ne[1];
11809
  const int64_t n_seqs = dst->src[5]->ne[1];
11810
  const int64_t head_size = C / HEADS;
11811
 
 
12000
  }
12001
  }
12002
 
12003
+ // ggml_compute_forward_gla
12004
+
12005
+ static void ggml_compute_forward_gla_f32(
12006
+ const struct ggml_compute_params * params,
12007
+ struct ggml_tensor * dst) {
12008
+ const int64_t T = dst->src[1]->ne[2];
12009
+ const int64_t C = dst->ne[0];
12010
+ const int64_t HEADS = dst->src[1]->ne[1];
12011
+ const int64_t n_seqs = dst->src[4]->ne[1];
12012
+ const int64_t head_size = C / HEADS;
12013
+ const float scale = ggml_get_op_params_f32(dst, 0);
12014
+
12015
+ float * dst_data = (float *) dst->data;
12016
+ float * state = ((float *) dst->data) + C * T;
12017
+
12018
+ const int ith = params->ith;
12019
+ const int nth = params->nth;
12020
+
12021
+ if (ith >= HEADS) {
12022
+ return;
12023
+ }
12024
+
12025
+ const int h_start = (HEADS * ith) / nth;
12026
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
12027
+ (HEADS * (ith + 1)) / nth : HEADS;
12028
+
12029
+ float * k = (float *) dst->src[0]->data;
12030
+ float * v = (float *) dst->src[1]->data;
12031
+ float * q = (float *) dst->src[2]->data;
12032
+ float * g = (float *) dst->src[3]->data;
12033
+
12034
+ size_t t_stride = HEADS * head_size; // Same to C
12035
+
12036
+ size_t h_stride = C / HEADS;
12037
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
12038
+ size_t h_stride_2d = head_size * head_size;
12039
+
12040
+ if (ith == 0) {
12041
+ memset(dst_data, 0, T * C * sizeof(float));
12042
+ }
12043
+ ggml_barrier(params->threadpool);
12044
+
12045
+
12046
+ #if defined(__AVX__) && !defined(__AVX512F__)
12047
+ #define GGML_F32X GGML_F32x8
12048
+ #define GGML_F32X_SET1 GGML_F32x8_SET1
12049
+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
12050
+ #define GGML_F32X_STORE GGML_F32x8_STORE
12051
+ #define GGML_F32X_MUL GGML_F32x8_MUL
12052
+ #define GGML_F32X_FMA GGML_F32x8_FMA
12053
+ #define GLA_VECTOR_SIZE 8
12054
+ #elif defined(__AVX512F__)
12055
+ #define GGML_F32X GGML_F32x16
12056
+ #define GGML_F32X_SET1 GGML_F32x16_SET1
12057
+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
12058
+ #define GGML_F32X_STORE GGML_F32x16_STORE
12059
+ #define GGML_F32X_MUL GGML_F32x16_MUL
12060
+ #define GGML_F32X_FMA GGML_F32x16_FMA
12061
+ #define GLA_VECTOR_SIZE 16
12062
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
12063
+ #define GGML_F32X GGML_F32x4
12064
+ #define GGML_F32X_SET1 GGML_F32x4_SET1
12065
+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
12066
+ #define GGML_F32X_STORE GGML_F32x4_STORE
12067
+ #define GGML_F32X_MUL GGML_F32x4_MUL
12068
+ #define GGML_F32X_FMA GGML_F32x4_FMA
12069
+ #define GLA_VECTOR_SIZE 4
12070
+ #endif
12071
+
12072
+ #ifdef GLA_VECTOR_SIZE
12073
+ const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
12074
+
12075
+ for (int64_t t = 0; t < T; t++) {
12076
+ size_t t_offset = t * t_stride;
12077
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12078
+ float * state_cur = state + state_offset;
12079
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12080
+
12081
+ for (int64_t h = h_start; h < h_end; h++) {
12082
+ size_t h_offset = h * h_stride;
12083
+ size_t t_h_offset = t_offset + h_offset;
12084
+ size_t h_2d_offset = h * h_stride_2d;
12085
+
12086
+ for (int64_t i = 0; i < head_size; i++) {
12087
+ size_t t_h_i_offset = t_h_offset + i;
12088
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12089
+
12090
+ float k_val = k[t_h_i_offset];
12091
+ float q_val = q[t_h_i_offset] * scale;
12092
+ float g_val = g[t_h_i_offset];
12093
+
12094
+ // Broadcast scalar values to vectors
12095
+ GGML_F32X k_vec = GGML_F32X_SET1(k_val);
12096
+ GGML_F32X q_vec = GGML_F32X_SET1(q_val);
12097
+ GGML_F32X g_vec = GGML_F32X_SET1(g_val);
12098
+
12099
+ for (int64_t j = 0; j < vec_count; j++) {
12100
+ size_t base_j = j * GLA_VECTOR_SIZE;
12101
+ size_t t_h_j_offset = t_h_offset + base_j;
12102
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
12103
+
12104
+ // Load x elements at once
12105
+ GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
12106
+ GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
12107
+ GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
12108
+
12109
+ // Compute kv = v * k
12110
+ GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
12111
+
12112
+ // Compute temp = prev_state * g + kv
12113
+ GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
12114
+
12115
+ // Update dst: dst += temp * q
12116
+ dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
12117
+ GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
12118
+
12119
+ // Update state
12120
+ GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
12121
+ }
12122
+
12123
+ // Handle remaining elements, this will not be used.
12124
+ for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
12125
+ size_t t_h_j_offset = t_h_offset + j;
12126
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12127
+ float v_val = v[t_h_j_offset];
12128
+ float kv_val = v_val * k_val;
12129
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12130
+ float temp_val = kv_val + prev_state_val * g_val;
12131
+ dst_data[t_h_j_offset] += temp_val * q_val;
12132
+ state_cur[h_2d_i_j_offset] = temp_val;
12133
+ }
12134
+ }
12135
+ }
12136
+ }
12137
+
12138
+ #else
12139
+ for (int64_t t = 0; t < T; t++) {
12140
+ size_t t_offset = t * t_stride;
12141
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12142
+ float * state_cur = state + state_offset;
12143
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12144
+
12145
+ for (int64_t h = h_start; h < h_end; h++) {
12146
+ size_t h_offset = h * h_stride;
12147
+ size_t t_h_offset = t_offset + h_offset;
12148
+ size_t h_2d_offset = h * h_stride_2d;
12149
+
12150
+ for (int64_t i = 0; i < head_size; i++) {
12151
+ size_t t_h_i_offset = t_h_offset + i;
12152
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12153
+
12154
+ float k_val = k[t_h_i_offset];
12155
+ float q_val = q[t_h_i_offset] * scale;
12156
+ float g_val = g[t_h_i_offset];
12157
+
12158
+ for (int64_t j = 0; j < head_size; j++) {
12159
+ size_t t_h_j_offset = t_h_offset + j;
12160
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12161
+
12162
+ float v_val = v[t_h_j_offset];
12163
+ float kv_val = v_val * k_val;
12164
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12165
+ float temp_val = prev_state_val * g_val + kv_val;
12166
+ dst_data[t_h_j_offset] += temp_val * q_val;
12167
+ state_cur[h_2d_i_j_offset] = temp_val;
12168
+ }
12169
+ }
12170
+ }
12171
+ }
12172
+ #endif
12173
+ }
12174
+
12175
+
12176
+ static void ggml_compute_forward_gla(
12177
+ const struct ggml_compute_params * params,
12178
+ struct ggml_tensor * dst) {
12179
+
12180
+ const struct ggml_tensor * src0 = dst->src[0];
12181
+
12182
+ switch (src0->type) {
12183
+ case GGML_TYPE_F32:
12184
+ {
12185
+ ggml_compute_forward_gla_f32(params, dst);
12186
+ } break;
12187
+ default:
12188
+ {
12189
+ GGML_ABORT("fatal error");
12190
+ }
12191
+ }
12192
+ }
12193
+
12194
  // ggml_compute_forward_map_unary
12195
 
12196
  static void ggml_compute_forward_map_unary_f32(
 
12940
  {
12941
  ggml_compute_forward_rwkv_wkv6(params, tensor);
12942
  } break;
12943
+ case GGML_OP_GATED_LINEAR_ATTN:
12944
+ {
12945
+ ggml_compute_forward_gla(params, tensor);
12946
+ } break;
12947
  case GGML_OP_MAP_UNARY:
12948
  {
12949
  ggml_unary_op_f32_t fun;
 
13242
  case GGML_OP_WIN_UNPART:
13243
  case GGML_OP_GET_REL_POS:
13244
  case GGML_OP_RWKV_WKV6:
13245
+ case GGML_OP_GATED_LINEAR_ATTN:
13246
  case GGML_OP_MAP_UNARY:
13247
  case GGML_OP_MAP_BINARY:
13248
  case GGML_OP_MAP_CUSTOM1_F32:
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -37,6 +37,7 @@
37
  #include "ggml-cuda/unary.cuh"
38
  #include "ggml-cuda/upscale.cuh"
39
  #include "ggml-cuda/wkv6.cuh"
 
40
 
41
  #include <algorithm>
42
  #include <array>
@@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2167
  case GGML_OP_RWKV_WKV6:
2168
  ggml_cuda_op_rwkv_wkv6(ctx, dst);
2169
  break;
 
 
 
2170
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2171
  ggml_cuda_cross_entropy_loss_back(ctx, dst);
2172
  break;
@@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3011
  case GGML_OP_TIMESTEP_EMBEDDING:
3012
  case GGML_OP_LEAKY_RELU:
3013
  case GGML_OP_RWKV_WKV6:
 
3014
  return true;
3015
  case GGML_OP_FLASH_ATTN_EXT: {
3016
  #ifndef FLASH_ATTN_AVAILABLE
 
37
  #include "ggml-cuda/unary.cuh"
38
  #include "ggml-cuda/upscale.cuh"
39
  #include "ggml-cuda/wkv6.cuh"
40
+ #include "ggml-cuda/gla.cuh"
41
 
42
  #include <algorithm>
43
  #include <array>
 
2168
  case GGML_OP_RWKV_WKV6:
2169
  ggml_cuda_op_rwkv_wkv6(ctx, dst);
2170
  break;
2171
+ case GGML_OP_GATED_LINEAR_ATTN:
2172
+ ggml_cuda_op_gated_linear_attn(ctx, dst);
2173
+ break;
2174
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2175
  ggml_cuda_cross_entropy_loss_back(ctx, dst);
2176
  break;
 
3015
  case GGML_OP_TIMESTEP_EMBEDDING:
3016
  case GGML_OP_LEAKY_RELU:
3017
  case GGML_OP_RWKV_WKV6:
3018
+ case GGML_OP_GATED_LINEAR_ATTN:
3019
  return true;
3020
  case GGML_OP_FLASH_ATTN_EXT: {
3021
  #ifndef FLASH_ATTN_AVAILABLE
ggml/src/ggml-cuda/gla.cu ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "gla.cuh"
3
+
4
+ template<int HEAD_SIZE>
5
+ static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
6
+ const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
7
+ const int tid = threadIdx.x;
8
+ const int bid = blockIdx.x;
9
+
10
+ const int head_size = HEAD_SIZE;
11
+ const int batch_i = bid / H;
12
+ const int head_i = bid % H;
13
+ const int state_size = C * head_size;
14
+ const int n_seq_tokens = T / B;
15
+
16
+ float state[head_size];
17
+ __shared__ float _k[head_size], _r[head_size], _td[head_size];
18
+
19
+ #pragma unroll
20
+ for (int i = 0; i < head_size; i++) {
21
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
22
+ }
23
+
24
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
25
+ __syncthreads();
26
+ _k[tid] = k[t];
27
+ _r[tid] = r[t];
28
+ _td[tid] = td[t];
29
+ __syncthreads();
30
+
31
+ const float _v = v[t];
32
+ float y = 0;
33
+ for (int j = 0; j < head_size; j += 4) {
34
+ const float4 & k = (float4 &)(_k[j]);
35
+ const float4 & r = (float4 &)(_r[j]);
36
+ const float4 & td = (float4 &)(_td[j]);
37
+ float4 & s = (float4 &)(state[j]);
38
+ float4 kv;
39
+
40
+ kv.x = k.x * _v;
41
+ kv.y = k.y * _v;
42
+ kv.z = k.z * _v;
43
+ kv.w = k.w * _v;
44
+
45
+ s.x = s.x * td.x + kv.x;
46
+ s.y = s.y * td.y + kv.y;
47
+ s.z = s.z * td.z + kv.z;
48
+ s.w = s.w * td.w + kv.w;
49
+
50
+ y += r.x * s.x;
51
+ y += r.y * s.y;
52
+ y += r.z * s.z;
53
+ y += r.w * s.w;
54
+ }
55
+ dst[t] = y * scale;
56
+ }
57
+
58
+ #pragma unroll
59
+ for (int i = 0; i < head_size; i++) {
60
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
61
+ }
62
+ }
63
+
64
+ void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
65
+ const float * k_d = (const float *)dst->src[0]->data;
66
+ const float * v_d = (const float *)dst->src[1]->data;
67
+ const float * r_d = (const float *)dst->src[2]->data;
68
+ const float * td_d = (const float *)dst->src[3]->data;
69
+ const float * s_d = (const float *)dst->src[4]->data;
70
+
71
+ const int64_t B = dst->src[4]->ne[1];
72
+ const int64_t T = dst->src[0]->ne[2];
73
+ const int64_t C = dst->ne[0];
74
+ const int64_t H = dst->src[0]->ne[1];
75
+
76
+ float scale;
77
+ memcpy(&scale, (float*)dst->op_params, sizeof(float));
78
+
79
+ float * dst_d = (float *)dst->data;
80
+
81
+ cudaStream_t stream = ctx.stream();
82
+
83
+ GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
84
+ GGML_ASSERT(C % H == 0);
85
+ GGML_ASSERT(C / H == 64 || C / H == 128);
86
+
87
+
88
+ if (C / H == 64) {
89
+ gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
90
+ } else {
91
+ gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
92
+ }
93
+ }
ggml/src/ggml-cuda/gla.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/wkv6.cu CHANGED
@@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
73
  const float * s_d = (const float *)dst->src[5]->data;
74
 
75
  const int64_t B = dst->src[5]->ne[1];
76
- const int64_t T = dst->src[0]->ne[3];
77
  const int64_t C = dst->ne[0];
78
- const int64_t H = dst->src[0]->ne[2];
79
 
80
  float * dst_d = (float *)dst->data;
81
 
 
73
  const float * s_d = (const float *)dst->src[5]->data;
74
 
75
  const int64_t B = dst->src[5]->ne[1];
76
+ const int64_t T = dst->src[0]->ne[2];
77
  const int64_t C = dst->ne[0];
78
+ const int64_t H = dst->src[0]->ne[1];
79
 
80
  float * dst_d = (float *)dst->data;
81
 
ggml/src/ggml-sycl/wkv6.cpp CHANGED
@@ -109,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
109
  float* dst_d = (float*)dst->data;
110
 
111
  const int64_t B = dst->src[5]->ne[1];
112
- const int64_t T = dst->src[0]->ne[3];
113
  const int64_t C = dst->ne[0];
114
- const int64_t H = dst->src[0]->ne[2];
115
 
116
  GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
117
  GGML_ASSERT(C % H == 0);
 
109
  float* dst_d = (float*)dst->data;
110
 
111
  const int64_t B = dst->src[5]->ne[1];
112
+ const int64_t T = dst->src[0]->ne[2];
113
  const int64_t C = dst->ne[0];
114
+ const int64_t H = dst->src[0]->ne[1];
115
 
116
  GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
117
  GGML_ASSERT(C % H == 0);
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -5633,9 +5633,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
5633
  }
5634
 
5635
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
5636
- const size_t seq_length = dst->src[0]->ne[3];
5637
  const size_t n_embed = dst->ne[0];
5638
- const size_t n_heads = dst->src[0]->ne[2];
5639
  const size_t n_seqs = dst->src[5]->ne[1];
5640
 
5641
  ggml_vk_op_f32_rwkv6(
 
5633
  }
5634
 
5635
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
5636
+ const size_t seq_length = dst->src[0]->ne[2];
5637
  const size_t n_embed = dst->ne[0];
5638
+ const size_t n_heads = dst->src[0]->ne[1];
5639
  const size_t n_seqs = dst->src[5]->ne[1];
5640
 
5641
  ggml_vk_op_f32_rwkv6(
ggml/src/ggml.c CHANGED
@@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
968
  "GET_REL_POS",
969
  "ADD_REL_POS",
970
  "RWKV_WKV6",
 
971
 
972
  "UNARY",
973
 
@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
987
  "OPT_STEP_ADAMW",
988
  };
989
 
990
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
991
 
992
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
993
  "none",
@@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1064
  "get_rel_pos(x)",
1065
  "add_rel_pos(x)",
1066
  "rwkv_wkv6(k, v, r, tf, td, s)",
 
1067
 
1068
  "unary(x)",
1069
 
@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1083
  "adamw(x)",
1084
  };
1085
 
1086
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1087
 
1088
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1089
 
@@ -4629,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
4629
  GGML_ASSERT(ggml_is_contiguous(state));
4630
 
4631
  const int64_t S = k->ne[0];
4632
- const int64_t H = k->ne[2];
4633
- const int64_t n_tokens = k->ne[3];
4634
  const int64_t n_seqs = state->ne[1];
4635
  {
4636
- GGML_ASSERT(k->ne[1] == 1);
4637
- GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
4638
- GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
4639
- // TODO: RWKV v4 and v5
4640
- GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
4641
  GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4642
  }
4643
 
@@ -4656,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
4656
  return result;
4657
  }
4658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4659
  // ggml_unary
4660
 
4661
  static struct ggml_tensor * ggml_unary_impl(
 
968
  "GET_REL_POS",
969
  "ADD_REL_POS",
970
  "RWKV_WKV6",
971
+ "GATED_LINEAR_ATTN",
972
 
973
  "UNARY",
974
 
 
988
  "OPT_STEP_ADAMW",
989
  };
990
 
991
+ static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
992
 
993
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
994
  "none",
 
1065
  "get_rel_pos(x)",
1066
  "add_rel_pos(x)",
1067
  "rwkv_wkv6(k, v, r, tf, td, s)",
1068
+ "gated_linear_attn(k, v, q, gate, s)",
1069
 
1070
  "unary(x)",
1071
 
 
1085
  "adamw(x)",
1086
  };
1087
 
1088
+ static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
1089
 
1090
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1091
 
 
4631
  GGML_ASSERT(ggml_is_contiguous(state));
4632
 
4633
  const int64_t S = k->ne[0];
4634
+ const int64_t H = k->ne[1];
4635
+ const int64_t n_tokens = k->ne[2];
4636
  const int64_t n_seqs = state->ne[1];
4637
  {
4638
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4639
+ GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
4640
+ GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
 
 
4641
  GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4642
  }
4643
 
 
4656
  return result;
4657
  }
4658
 
4659
+ // ggml_gated_linear_attn
4660
+
4661
+ struct ggml_tensor * ggml_gated_linear_attn(
4662
+ struct ggml_context * ctx,
4663
+ struct ggml_tensor * k,
4664
+ struct ggml_tensor * v,
4665
+ struct ggml_tensor * q,
4666
+ struct ggml_tensor * g,
4667
+ struct ggml_tensor * state,
4668
+ float scale) {
4669
+ GGML_ASSERT(ggml_is_contiguous(k));
4670
+ GGML_ASSERT(ggml_is_contiguous(v));
4671
+ GGML_ASSERT(ggml_is_contiguous(q));
4672
+ GGML_ASSERT(ggml_is_contiguous(g));
4673
+ GGML_ASSERT(ggml_is_contiguous(state));
4674
+
4675
+ const int64_t S = k->ne[0];
4676
+ const int64_t H = k->ne[1];
4677
+ const int64_t n_tokens = k->ne[2];
4678
+ const int64_t n_seqs = state->ne[1];
4679
+ {
4680
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4681
+ GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
4682
+ GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
4683
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4684
+ }
4685
+
4686
+ // concat output and new_state
4687
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4688
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4689
+
4690
+ ggml_set_op_params_f32(result, 0, scale);
4691
+
4692
+ result->op = GGML_OP_GATED_LINEAR_ATTN;
4693
+ result->src[0] = k;
4694
+ result->src[1] = v;
4695
+ result->src[2] = q;
4696
+ result->src[3] = g;
4697
+ result->src[4] = state;
4698
+
4699
+ return result;
4700
+ }
4701
+
4702
  // ggml_unary
4703
 
4704
  static struct ggml_tensor * ggml_unary_impl(