Spaces:
Sleeping
llama: add support for QRWKV6 model architecture (llama/11001)
Browse filesllama: add support for QRWKV6 model architecture (llama/11001)
* WIP: Add support for RWKV6Qwen2
Signed-off-by: Molly Sophia <[email protected]>
* RWKV: Some graph simplification
Signed-off-by: Molly Sophia <[email protected]>
* Add support for RWKV6Qwen2 with cpu and cuda GLA
Signed-off-by: Molly Sophia <[email protected]>
* RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead
Signed-off-by: Molly Sophia <[email protected]>
* Fix some typos
Signed-off-by: Molly Sophia <[email protected]>
* code format changes
Signed-off-by: Molly Sophia <[email protected]>
* Fix wkv test & add gla test
Signed-off-by: Molly Sophia <[email protected]>
* Fix cuda warning
Signed-off-by: Molly Sophia <[email protected]>
* Update README.md
Signed-off-by: Molly Sophia <[email protected]>
* Update ggml/src/ggml-cuda/gla.cu
Co-authored-by: Georgi Gerganov <[email protected]>
* Fix fused lerp weights loading with RWKV6
Signed-off-by: Molly Sophia <[email protected]>
* better sanity check skipping for QRWKV6 in llama-quant
thanks
@compilade
Signed-off-by: Molly Sophia <[email protected]>
Co-authored-by: compilade <[email protected]>
---------
Signed-off-by: Molly Sophia <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: compilade <[email protected]>
- ggml/include/ggml.h +10 -0
- ggml/src/ggml-cpu/ggml-cpu.c +198 -2
- ggml/src/ggml-cuda/ggml-cuda.cu +5 -0
- ggml/src/ggml-cuda/gla.cu +93 -0
- ggml/src/ggml-cuda/gla.cuh +3 -0
- ggml/src/ggml-cuda/wkv6.cu +2 -2
- ggml/src/ggml-sycl/wkv6.cpp +2 -2
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +2 -2
- ggml/src/ggml.c +52 -9
|
@@ -501,6 +501,7 @@ extern "C" {
|
|
| 501 |
GGML_OP_GET_REL_POS,
|
| 502 |
GGML_OP_ADD_REL_POS,
|
| 503 |
GGML_OP_RWKV_WKV6,
|
|
|
|
| 504 |
|
| 505 |
GGML_OP_UNARY,
|
| 506 |
|
|
@@ -1859,6 +1860,15 @@ extern "C" {
|
|
| 1859 |
struct ggml_tensor * td,
|
| 1860 |
struct ggml_tensor * state);
|
| 1861 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1862 |
// custom operators
|
| 1863 |
|
| 1864 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
|
|
| 501 |
GGML_OP_GET_REL_POS,
|
| 502 |
GGML_OP_ADD_REL_POS,
|
| 503 |
GGML_OP_RWKV_WKV6,
|
| 504 |
+
GGML_OP_GATED_LINEAR_ATTN,
|
| 505 |
|
| 506 |
GGML_OP_UNARY,
|
| 507 |
|
|
|
|
| 1860 |
struct ggml_tensor * td,
|
| 1861 |
struct ggml_tensor * state);
|
| 1862 |
|
| 1863 |
+
GGML_API struct ggml_tensor * ggml_gated_linear_attn(
|
| 1864 |
+
struct ggml_context * ctx,
|
| 1865 |
+
struct ggml_tensor * k,
|
| 1866 |
+
struct ggml_tensor * v,
|
| 1867 |
+
struct ggml_tensor * q,
|
| 1868 |
+
struct ggml_tensor * g,
|
| 1869 |
+
struct ggml_tensor * state,
|
| 1870 |
+
float scale);
|
| 1871 |
+
|
| 1872 |
// custom operators
|
| 1873 |
|
| 1874 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
|
|
| 11803 |
static void ggml_compute_forward_rwkv_wkv6_f32(
|
| 11804 |
const struct ggml_compute_params * params,
|
| 11805 |
struct ggml_tensor * dst) {
|
| 11806 |
-
const int64_t T = dst->src[1]->ne[
|
| 11807 |
const int64_t C = dst->ne[0];
|
| 11808 |
-
const int64_t HEADS = dst->src[1]->ne[
|
| 11809 |
const int64_t n_seqs = dst->src[5]->ne[1];
|
| 11810 |
const int64_t head_size = C / HEADS;
|
| 11811 |
|
|
@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
|
|
| 12000 |
}
|
| 12001 |
}
|
| 12002 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12003 |
// ggml_compute_forward_map_unary
|
| 12004 |
|
| 12005 |
static void ggml_compute_forward_map_unary_f32(
|
|
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 12749 |
{
|
| 12750 |
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
| 12751 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12752 |
case GGML_OP_MAP_UNARY:
|
| 12753 |
{
|
| 12754 |
ggml_unary_op_f32_t fun;
|
|
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 13047 |
case GGML_OP_WIN_UNPART:
|
| 13048 |
case GGML_OP_GET_REL_POS:
|
| 13049 |
case GGML_OP_RWKV_WKV6:
|
|
|
|
| 13050 |
case GGML_OP_MAP_UNARY:
|
| 13051 |
case GGML_OP_MAP_BINARY:
|
| 13052 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
|
|
| 11803 |
static void ggml_compute_forward_rwkv_wkv6_f32(
|
| 11804 |
const struct ggml_compute_params * params,
|
| 11805 |
struct ggml_tensor * dst) {
|
| 11806 |
+
const int64_t T = dst->src[1]->ne[2];
|
| 11807 |
const int64_t C = dst->ne[0];
|
| 11808 |
+
const int64_t HEADS = dst->src[1]->ne[1];
|
| 11809 |
const int64_t n_seqs = dst->src[5]->ne[1];
|
| 11810 |
const int64_t head_size = C / HEADS;
|
| 11811 |
|
|
|
|
| 12000 |
}
|
| 12001 |
}
|
| 12002 |
|
| 12003 |
+
// ggml_compute_forward_gla
|
| 12004 |
+
|
| 12005 |
+
static void ggml_compute_forward_gla_f32(
|
| 12006 |
+
const struct ggml_compute_params * params,
|
| 12007 |
+
struct ggml_tensor * dst) {
|
| 12008 |
+
const int64_t T = dst->src[1]->ne[2];
|
| 12009 |
+
const int64_t C = dst->ne[0];
|
| 12010 |
+
const int64_t HEADS = dst->src[1]->ne[1];
|
| 12011 |
+
const int64_t n_seqs = dst->src[4]->ne[1];
|
| 12012 |
+
const int64_t head_size = C / HEADS;
|
| 12013 |
+
const float scale = ggml_get_op_params_f32(dst, 0);
|
| 12014 |
+
|
| 12015 |
+
float * dst_data = (float *) dst->data;
|
| 12016 |
+
float * state = ((float *) dst->data) + C * T;
|
| 12017 |
+
|
| 12018 |
+
const int ith = params->ith;
|
| 12019 |
+
const int nth = params->nth;
|
| 12020 |
+
|
| 12021 |
+
if (ith >= HEADS) {
|
| 12022 |
+
return;
|
| 12023 |
+
}
|
| 12024 |
+
|
| 12025 |
+
const int h_start = (HEADS * ith) / nth;
|
| 12026 |
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
| 12027 |
+
(HEADS * (ith + 1)) / nth : HEADS;
|
| 12028 |
+
|
| 12029 |
+
float * k = (float *) dst->src[0]->data;
|
| 12030 |
+
float * v = (float *) dst->src[1]->data;
|
| 12031 |
+
float * q = (float *) dst->src[2]->data;
|
| 12032 |
+
float * g = (float *) dst->src[3]->data;
|
| 12033 |
+
|
| 12034 |
+
size_t t_stride = HEADS * head_size; // Same to C
|
| 12035 |
+
|
| 12036 |
+
size_t h_stride = C / HEADS;
|
| 12037 |
+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
| 12038 |
+
size_t h_stride_2d = head_size * head_size;
|
| 12039 |
+
|
| 12040 |
+
if (ith == 0) {
|
| 12041 |
+
memset(dst_data, 0, T * C * sizeof(float));
|
| 12042 |
+
}
|
| 12043 |
+
ggml_barrier(params->threadpool);
|
| 12044 |
+
|
| 12045 |
+
|
| 12046 |
+
#if defined(__AVX__) && !defined(__AVX512F__)
|
| 12047 |
+
#define GGML_F32X GGML_F32x8
|
| 12048 |
+
#define GGML_F32X_SET1 GGML_F32x8_SET1
|
| 12049 |
+
#define GGML_F32X_LOAD GGML_F32x8_LOAD
|
| 12050 |
+
#define GGML_F32X_STORE GGML_F32x8_STORE
|
| 12051 |
+
#define GGML_F32X_MUL GGML_F32x8_MUL
|
| 12052 |
+
#define GGML_F32X_FMA GGML_F32x8_FMA
|
| 12053 |
+
#define GLA_VECTOR_SIZE 8
|
| 12054 |
+
#elif defined(__AVX512F__)
|
| 12055 |
+
#define GGML_F32X GGML_F32x16
|
| 12056 |
+
#define GGML_F32X_SET1 GGML_F32x16_SET1
|
| 12057 |
+
#define GGML_F32X_LOAD GGML_F32x16_LOAD
|
| 12058 |
+
#define GGML_F32X_STORE GGML_F32x16_STORE
|
| 12059 |
+
#define GGML_F32X_MUL GGML_F32x16_MUL
|
| 12060 |
+
#define GGML_F32X_FMA GGML_F32x16_FMA
|
| 12061 |
+
#define GLA_VECTOR_SIZE 16
|
| 12062 |
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
| 12063 |
+
#define GGML_F32X GGML_F32x4
|
| 12064 |
+
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
| 12065 |
+
#define GGML_F32X_LOAD GGML_F32x4_LOAD
|
| 12066 |
+
#define GGML_F32X_STORE GGML_F32x4_STORE
|
| 12067 |
+
#define GGML_F32X_MUL GGML_F32x4_MUL
|
| 12068 |
+
#define GGML_F32X_FMA GGML_F32x4_FMA
|
| 12069 |
+
#define GLA_VECTOR_SIZE 4
|
| 12070 |
+
#endif
|
| 12071 |
+
|
| 12072 |
+
#ifdef GLA_VECTOR_SIZE
|
| 12073 |
+
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
|
| 12074 |
+
|
| 12075 |
+
for (int64_t t = 0; t < T; t++) {
|
| 12076 |
+
size_t t_offset = t * t_stride;
|
| 12077 |
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
| 12078 |
+
float * state_cur = state + state_offset;
|
| 12079 |
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
| 12080 |
+
|
| 12081 |
+
for (int64_t h = h_start; h < h_end; h++) {
|
| 12082 |
+
size_t h_offset = h * h_stride;
|
| 12083 |
+
size_t t_h_offset = t_offset + h_offset;
|
| 12084 |
+
size_t h_2d_offset = h * h_stride_2d;
|
| 12085 |
+
|
| 12086 |
+
for (int64_t i = 0; i < head_size; i++) {
|
| 12087 |
+
size_t t_h_i_offset = t_h_offset + i;
|
| 12088 |
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
| 12089 |
+
|
| 12090 |
+
float k_val = k[t_h_i_offset];
|
| 12091 |
+
float q_val = q[t_h_i_offset] * scale;
|
| 12092 |
+
float g_val = g[t_h_i_offset];
|
| 12093 |
+
|
| 12094 |
+
// Broadcast scalar values to vectors
|
| 12095 |
+
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
|
| 12096 |
+
GGML_F32X q_vec = GGML_F32X_SET1(q_val);
|
| 12097 |
+
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
|
| 12098 |
+
|
| 12099 |
+
for (int64_t j = 0; j < vec_count; j++) {
|
| 12100 |
+
size_t base_j = j * GLA_VECTOR_SIZE;
|
| 12101 |
+
size_t t_h_j_offset = t_h_offset + base_j;
|
| 12102 |
+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
| 12103 |
+
|
| 12104 |
+
// Load x elements at once
|
| 12105 |
+
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
|
| 12106 |
+
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
|
| 12107 |
+
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
|
| 12108 |
+
|
| 12109 |
+
// Compute kv = v * k
|
| 12110 |
+
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
|
| 12111 |
+
|
| 12112 |
+
// Compute temp = prev_state * g + kv
|
| 12113 |
+
GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
|
| 12114 |
+
|
| 12115 |
+
// Update dst: dst += temp * q
|
| 12116 |
+
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
|
| 12117 |
+
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
|
| 12118 |
+
|
| 12119 |
+
// Update state
|
| 12120 |
+
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
|
| 12121 |
+
}
|
| 12122 |
+
|
| 12123 |
+
// Handle remaining elements, this will not be used.
|
| 12124 |
+
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
|
| 12125 |
+
size_t t_h_j_offset = t_h_offset + j;
|
| 12126 |
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
| 12127 |
+
float v_val = v[t_h_j_offset];
|
| 12128 |
+
float kv_val = v_val * k_val;
|
| 12129 |
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
| 12130 |
+
float temp_val = kv_val + prev_state_val * g_val;
|
| 12131 |
+
dst_data[t_h_j_offset] += temp_val * q_val;
|
| 12132 |
+
state_cur[h_2d_i_j_offset] = temp_val;
|
| 12133 |
+
}
|
| 12134 |
+
}
|
| 12135 |
+
}
|
| 12136 |
+
}
|
| 12137 |
+
|
| 12138 |
+
#else
|
| 12139 |
+
for (int64_t t = 0; t < T; t++) {
|
| 12140 |
+
size_t t_offset = t * t_stride;
|
| 12141 |
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
| 12142 |
+
float * state_cur = state + state_offset;
|
| 12143 |
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
| 12144 |
+
|
| 12145 |
+
for (int64_t h = h_start; h < h_end; h++) {
|
| 12146 |
+
size_t h_offset = h * h_stride;
|
| 12147 |
+
size_t t_h_offset = t_offset + h_offset;
|
| 12148 |
+
size_t h_2d_offset = h * h_stride_2d;
|
| 12149 |
+
|
| 12150 |
+
for (int64_t i = 0; i < head_size; i++) {
|
| 12151 |
+
size_t t_h_i_offset = t_h_offset + i;
|
| 12152 |
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
| 12153 |
+
|
| 12154 |
+
float k_val = k[t_h_i_offset];
|
| 12155 |
+
float q_val = q[t_h_i_offset] * scale;
|
| 12156 |
+
float g_val = g[t_h_i_offset];
|
| 12157 |
+
|
| 12158 |
+
for (int64_t j = 0; j < head_size; j++) {
|
| 12159 |
+
size_t t_h_j_offset = t_h_offset + j;
|
| 12160 |
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
| 12161 |
+
|
| 12162 |
+
float v_val = v[t_h_j_offset];
|
| 12163 |
+
float kv_val = v_val * k_val;
|
| 12164 |
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
| 12165 |
+
float temp_val = prev_state_val * g_val + kv_val;
|
| 12166 |
+
dst_data[t_h_j_offset] += temp_val * q_val;
|
| 12167 |
+
state_cur[h_2d_i_j_offset] = temp_val;
|
| 12168 |
+
}
|
| 12169 |
+
}
|
| 12170 |
+
}
|
| 12171 |
+
}
|
| 12172 |
+
#endif
|
| 12173 |
+
}
|
| 12174 |
+
|
| 12175 |
+
|
| 12176 |
+
static void ggml_compute_forward_gla(
|
| 12177 |
+
const struct ggml_compute_params * params,
|
| 12178 |
+
struct ggml_tensor * dst) {
|
| 12179 |
+
|
| 12180 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 12181 |
+
|
| 12182 |
+
switch (src0->type) {
|
| 12183 |
+
case GGML_TYPE_F32:
|
| 12184 |
+
{
|
| 12185 |
+
ggml_compute_forward_gla_f32(params, dst);
|
| 12186 |
+
} break;
|
| 12187 |
+
default:
|
| 12188 |
+
{
|
| 12189 |
+
GGML_ABORT("fatal error");
|
| 12190 |
+
}
|
| 12191 |
+
}
|
| 12192 |
+
}
|
| 12193 |
+
|
| 12194 |
// ggml_compute_forward_map_unary
|
| 12195 |
|
| 12196 |
static void ggml_compute_forward_map_unary_f32(
|
|
|
|
| 12940 |
{
|
| 12941 |
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
| 12942 |
} break;
|
| 12943 |
+
case GGML_OP_GATED_LINEAR_ATTN:
|
| 12944 |
+
{
|
| 12945 |
+
ggml_compute_forward_gla(params, tensor);
|
| 12946 |
+
} break;
|
| 12947 |
case GGML_OP_MAP_UNARY:
|
| 12948 |
{
|
| 12949 |
ggml_unary_op_f32_t fun;
|
|
|
|
| 13242 |
case GGML_OP_WIN_UNPART:
|
| 13243 |
case GGML_OP_GET_REL_POS:
|
| 13244 |
case GGML_OP_RWKV_WKV6:
|
| 13245 |
+
case GGML_OP_GATED_LINEAR_ATTN:
|
| 13246 |
case GGML_OP_MAP_UNARY:
|
| 13247 |
case GGML_OP_MAP_BINARY:
|
| 13248 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
@@ -37,6 +37,7 @@
|
|
| 37 |
#include "ggml-cuda/unary.cuh"
|
| 38 |
#include "ggml-cuda/upscale.cuh"
|
| 39 |
#include "ggml-cuda/wkv6.cuh"
|
|
|
|
| 40 |
|
| 41 |
#include <algorithm>
|
| 42 |
#include <array>
|
|
@@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2167 |
case GGML_OP_RWKV_WKV6:
|
| 2168 |
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
| 2169 |
break;
|
|
|
|
|
|
|
|
|
|
| 2170 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2171 |
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
| 2172 |
break;
|
|
@@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3011 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 3012 |
case GGML_OP_LEAKY_RELU:
|
| 3013 |
case GGML_OP_RWKV_WKV6:
|
|
|
|
| 3014 |
return true;
|
| 3015 |
case GGML_OP_FLASH_ATTN_EXT: {
|
| 3016 |
#ifndef FLASH_ATTN_AVAILABLE
|
|
|
|
| 37 |
#include "ggml-cuda/unary.cuh"
|
| 38 |
#include "ggml-cuda/upscale.cuh"
|
| 39 |
#include "ggml-cuda/wkv6.cuh"
|
| 40 |
+
#include "ggml-cuda/gla.cuh"
|
| 41 |
|
| 42 |
#include <algorithm>
|
| 43 |
#include <array>
|
|
|
|
| 2168 |
case GGML_OP_RWKV_WKV6:
|
| 2169 |
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
| 2170 |
break;
|
| 2171 |
+
case GGML_OP_GATED_LINEAR_ATTN:
|
| 2172 |
+
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
| 2173 |
+
break;
|
| 2174 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2175 |
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
| 2176 |
break;
|
|
|
|
| 3015 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 3016 |
case GGML_OP_LEAKY_RELU:
|
| 3017 |
case GGML_OP_RWKV_WKV6:
|
| 3018 |
+
case GGML_OP_GATED_LINEAR_ATTN:
|
| 3019 |
return true;
|
| 3020 |
case GGML_OP_FLASH_ATTN_EXT: {
|
| 3021 |
#ifndef FLASH_ATTN_AVAILABLE
|
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
#include "gla.cuh"
|
| 3 |
+
|
| 4 |
+
template<int HEAD_SIZE>
|
| 5 |
+
static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
|
| 6 |
+
const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
|
| 7 |
+
const int tid = threadIdx.x;
|
| 8 |
+
const int bid = blockIdx.x;
|
| 9 |
+
|
| 10 |
+
const int head_size = HEAD_SIZE;
|
| 11 |
+
const int batch_i = bid / H;
|
| 12 |
+
const int head_i = bid % H;
|
| 13 |
+
const int state_size = C * head_size;
|
| 14 |
+
const int n_seq_tokens = T / B;
|
| 15 |
+
|
| 16 |
+
float state[head_size];
|
| 17 |
+
__shared__ float _k[head_size], _r[head_size], _td[head_size];
|
| 18 |
+
|
| 19 |
+
#pragma unroll
|
| 20 |
+
for (int i = 0; i < head_size; i++) {
|
| 21 |
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
| 25 |
+
__syncthreads();
|
| 26 |
+
_k[tid] = k[t];
|
| 27 |
+
_r[tid] = r[t];
|
| 28 |
+
_td[tid] = td[t];
|
| 29 |
+
__syncthreads();
|
| 30 |
+
|
| 31 |
+
const float _v = v[t];
|
| 32 |
+
float y = 0;
|
| 33 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 34 |
+
const float4 & k = (float4 &)(_k[j]);
|
| 35 |
+
const float4 & r = (float4 &)(_r[j]);
|
| 36 |
+
const float4 & td = (float4 &)(_td[j]);
|
| 37 |
+
float4 & s = (float4 &)(state[j]);
|
| 38 |
+
float4 kv;
|
| 39 |
+
|
| 40 |
+
kv.x = k.x * _v;
|
| 41 |
+
kv.y = k.y * _v;
|
| 42 |
+
kv.z = k.z * _v;
|
| 43 |
+
kv.w = k.w * _v;
|
| 44 |
+
|
| 45 |
+
s.x = s.x * td.x + kv.x;
|
| 46 |
+
s.y = s.y * td.y + kv.y;
|
| 47 |
+
s.z = s.z * td.z + kv.z;
|
| 48 |
+
s.w = s.w * td.w + kv.w;
|
| 49 |
+
|
| 50 |
+
y += r.x * s.x;
|
| 51 |
+
y += r.y * s.y;
|
| 52 |
+
y += r.z * s.z;
|
| 53 |
+
y += r.w * s.w;
|
| 54 |
+
}
|
| 55 |
+
dst[t] = y * scale;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
#pragma unroll
|
| 59 |
+
for (int i = 0; i < head_size; i++) {
|
| 60 |
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 65 |
+
const float * k_d = (const float *)dst->src[0]->data;
|
| 66 |
+
const float * v_d = (const float *)dst->src[1]->data;
|
| 67 |
+
const float * r_d = (const float *)dst->src[2]->data;
|
| 68 |
+
const float * td_d = (const float *)dst->src[3]->data;
|
| 69 |
+
const float * s_d = (const float *)dst->src[4]->data;
|
| 70 |
+
|
| 71 |
+
const int64_t B = dst->src[4]->ne[1];
|
| 72 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 73 |
+
const int64_t C = dst->ne[0];
|
| 74 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 75 |
+
|
| 76 |
+
float scale;
|
| 77 |
+
memcpy(&scale, (float*)dst->op_params, sizeof(float));
|
| 78 |
+
|
| 79 |
+
float * dst_d = (float *)dst->data;
|
| 80 |
+
|
| 81 |
+
cudaStream_t stream = ctx.stream();
|
| 82 |
+
|
| 83 |
+
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
|
| 84 |
+
GGML_ASSERT(C % H == 0);
|
| 85 |
+
GGML_ASSERT(C / H == 64 || C / H == 128);
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if (C / H == 64) {
|
| 89 |
+
gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
|
| 90 |
+
} else {
|
| 91 |
+
gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
|
| 92 |
+
}
|
| 93 |
+
}
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
| 73 |
const float * s_d = (const float *)dst->src[5]->data;
|
| 74 |
|
| 75 |
const int64_t B = dst->src[5]->ne[1];
|
| 76 |
-
const int64_t T = dst->src[0]->ne[
|
| 77 |
const int64_t C = dst->ne[0];
|
| 78 |
-
const int64_t H = dst->src[0]->ne[
|
| 79 |
|
| 80 |
float * dst_d = (float *)dst->data;
|
| 81 |
|
|
|
|
| 73 |
const float * s_d = (const float *)dst->src[5]->data;
|
| 74 |
|
| 75 |
const int64_t B = dst->src[5]->ne[1];
|
| 76 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 77 |
const int64_t C = dst->ne[0];
|
| 78 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 79 |
|
| 80 |
float * dst_d = (float *)dst->data;
|
| 81 |
|
|
@@ -109,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
| 109 |
float* dst_d = (float*)dst->data;
|
| 110 |
|
| 111 |
const int64_t B = dst->src[5]->ne[1];
|
| 112 |
-
const int64_t T = dst->src[0]->ne[
|
| 113 |
const int64_t C = dst->ne[0];
|
| 114 |
-
const int64_t H = dst->src[0]->ne[
|
| 115 |
|
| 116 |
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
| 117 |
GGML_ASSERT(C % H == 0);
|
|
|
|
| 109 |
float* dst_d = (float*)dst->data;
|
| 110 |
|
| 111 |
const int64_t B = dst->src[5]->ne[1];
|
| 112 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 113 |
const int64_t C = dst->ne[0];
|
| 114 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 115 |
|
| 116 |
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
| 117 |
GGML_ASSERT(C % H == 0);
|
|
@@ -5633,9 +5633,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
|
| 5633 |
}
|
| 5634 |
|
| 5635 |
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
| 5636 |
-
const size_t seq_length = dst->src[0]->ne[
|
| 5637 |
const size_t n_embed = dst->ne[0];
|
| 5638 |
-
const size_t n_heads = dst->src[0]->ne[
|
| 5639 |
const size_t n_seqs = dst->src[5]->ne[1];
|
| 5640 |
|
| 5641 |
ggml_vk_op_f32_rwkv6(
|
|
|
|
| 5633 |
}
|
| 5634 |
|
| 5635 |
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
| 5636 |
+
const size_t seq_length = dst->src[0]->ne[2];
|
| 5637 |
const size_t n_embed = dst->ne[0];
|
| 5638 |
+
const size_t n_heads = dst->src[0]->ne[1];
|
| 5639 |
const size_t n_seqs = dst->src[5]->ne[1];
|
| 5640 |
|
| 5641 |
ggml_vk_op_f32_rwkv6(
|
|
@@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 968 |
"GET_REL_POS",
|
| 969 |
"ADD_REL_POS",
|
| 970 |
"RWKV_WKV6",
|
|
|
|
| 971 |
|
| 972 |
"UNARY",
|
| 973 |
|
|
@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 987 |
"OPT_STEP_ADAMW",
|
| 988 |
};
|
| 989 |
|
| 990 |
-
static_assert(GGML_OP_COUNT ==
|
| 991 |
|
| 992 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 993 |
"none",
|
|
@@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1064 |
"get_rel_pos(x)",
|
| 1065 |
"add_rel_pos(x)",
|
| 1066 |
"rwkv_wkv6(k, v, r, tf, td, s)",
|
|
|
|
| 1067 |
|
| 1068 |
"unary(x)",
|
| 1069 |
|
|
@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1083 |
"adamw(x)",
|
| 1084 |
};
|
| 1085 |
|
| 1086 |
-
static_assert(GGML_OP_COUNT ==
|
| 1087 |
|
| 1088 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1089 |
|
|
@@ -4629,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
|
|
| 4629 |
GGML_ASSERT(ggml_is_contiguous(state));
|
| 4630 |
|
| 4631 |
const int64_t S = k->ne[0];
|
| 4632 |
-
const int64_t H = k->ne[
|
| 4633 |
-
const int64_t n_tokens = k->ne[
|
| 4634 |
const int64_t n_seqs = state->ne[1];
|
| 4635 |
{
|
| 4636 |
-
GGML_ASSERT(
|
| 4637 |
-
GGML_ASSERT(
|
| 4638 |
-
GGML_ASSERT(
|
| 4639 |
-
// TODO: RWKV v4 and v5
|
| 4640 |
-
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
|
| 4641 |
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
| 4642 |
}
|
| 4643 |
|
|
@@ -4656,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
|
|
| 4656 |
return result;
|
| 4657 |
}
|
| 4658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4659 |
// ggml_unary
|
| 4660 |
|
| 4661 |
static struct ggml_tensor * ggml_unary_impl(
|
|
|
|
| 968 |
"GET_REL_POS",
|
| 969 |
"ADD_REL_POS",
|
| 970 |
"RWKV_WKV6",
|
| 971 |
+
"GATED_LINEAR_ATTN",
|
| 972 |
|
| 973 |
"UNARY",
|
| 974 |
|
|
|
|
| 988 |
"OPT_STEP_ADAMW",
|
| 989 |
};
|
| 990 |
|
| 991 |
+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
| 992 |
|
| 993 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 994 |
"none",
|
|
|
|
| 1065 |
"get_rel_pos(x)",
|
| 1066 |
"add_rel_pos(x)",
|
| 1067 |
"rwkv_wkv6(k, v, r, tf, td, s)",
|
| 1068 |
+
"gated_linear_attn(k, v, q, gate, s)",
|
| 1069 |
|
| 1070 |
"unary(x)",
|
| 1071 |
|
|
|
|
| 1085 |
"adamw(x)",
|
| 1086 |
};
|
| 1087 |
|
| 1088 |
+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
| 1089 |
|
| 1090 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1091 |
|
|
|
|
| 4631 |
GGML_ASSERT(ggml_is_contiguous(state));
|
| 4632 |
|
| 4633 |
const int64_t S = k->ne[0];
|
| 4634 |
+
const int64_t H = k->ne[1];
|
| 4635 |
+
const int64_t n_tokens = k->ne[2];
|
| 4636 |
const int64_t n_seqs = state->ne[1];
|
| 4637 |
{
|
| 4638 |
+
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
| 4639 |
+
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
|
| 4640 |
+
GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
|
|
|
|
|
|
|
| 4641 |
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
| 4642 |
}
|
| 4643 |
|
|
|
|
| 4656 |
return result;
|
| 4657 |
}
|
| 4658 |
|
| 4659 |
+
// ggml_gated_linear_attn
|
| 4660 |
+
|
| 4661 |
+
struct ggml_tensor * ggml_gated_linear_attn(
|
| 4662 |
+
struct ggml_context * ctx,
|
| 4663 |
+
struct ggml_tensor * k,
|
| 4664 |
+
struct ggml_tensor * v,
|
| 4665 |
+
struct ggml_tensor * q,
|
| 4666 |
+
struct ggml_tensor * g,
|
| 4667 |
+
struct ggml_tensor * state,
|
| 4668 |
+
float scale) {
|
| 4669 |
+
GGML_ASSERT(ggml_is_contiguous(k));
|
| 4670 |
+
GGML_ASSERT(ggml_is_contiguous(v));
|
| 4671 |
+
GGML_ASSERT(ggml_is_contiguous(q));
|
| 4672 |
+
GGML_ASSERT(ggml_is_contiguous(g));
|
| 4673 |
+
GGML_ASSERT(ggml_is_contiguous(state));
|
| 4674 |
+
|
| 4675 |
+
const int64_t S = k->ne[0];
|
| 4676 |
+
const int64_t H = k->ne[1];
|
| 4677 |
+
const int64_t n_tokens = k->ne[2];
|
| 4678 |
+
const int64_t n_seqs = state->ne[1];
|
| 4679 |
+
{
|
| 4680 |
+
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
| 4681 |
+
GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
|
| 4682 |
+
GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
|
| 4683 |
+
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
| 4684 |
+
}
|
| 4685 |
+
|
| 4686 |
+
// concat output and new_state
|
| 4687 |
+
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
| 4688 |
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4689 |
+
|
| 4690 |
+
ggml_set_op_params_f32(result, 0, scale);
|
| 4691 |
+
|
| 4692 |
+
result->op = GGML_OP_GATED_LINEAR_ATTN;
|
| 4693 |
+
result->src[0] = k;
|
| 4694 |
+
result->src[1] = v;
|
| 4695 |
+
result->src[2] = q;
|
| 4696 |
+
result->src[3] = g;
|
| 4697 |
+
result->src[4] = state;
|
| 4698 |
+
|
| 4699 |
+
return result;
|
| 4700 |
+
}
|
| 4701 |
+
|
| 4702 |
// ggml_unary
|
| 4703 |
|
| 4704 |
static struct ggml_tensor * ggml_unary_impl(
|