Spaces:
Sleeping
Sleeping
llama: Add support for RWKV v7 architecture (llama/12412)
Browse files* ggml: Add op l2_norm
Signed-off-by: Molly Sophia <[email protected]>
* ggml: Add op rwkv_wkv7
Signed-off-by: Molly Sophia <[email protected]>
* llama: Add support for RWKV7 and ARWKV7 models
Signed-off-by: Molly Sophia <[email protected]>
* llama: fix inference with RWKV6Qwen2
Signed-off-by: Molly Sophia <[email protected]>
* llama: add more (a)rwkv7 variants in size
Signed-off-by: Molly Sophia <[email protected]>
* Apply code-format changes
Signed-off-by: Molly Sophia <[email protected]>
* fix MUSA build
Signed-off-by: Molly Sophia <[email protected]>
* llama: fix shape error with rwkv using llama-parallel
Signed-off-by: Molly Sophia <[email protected]>
---------
Signed-off-by: Molly Sophia <[email protected]>
- ggml/include/ggml.h +24 -0
- ggml/src/ggml-cpu/ggml-cpu.c +253 -2
- ggml/src/ggml-cuda/ggml-cuda.cu +9 -1
- ggml/src/ggml-cuda/norm.cu +116 -0
- ggml/src/ggml-cuda/norm.cuh +2 -0
- ggml/src/ggml-cuda/wkv.cu +199 -0
- ggml/src/ggml-cuda/wkv.cuh +7 -0
- ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
- ggml/src/ggml-metal/ggml-metal.m +122 -0
- ggml/src/ggml-metal/ggml-metal.metal +221 -0
- ggml/src/ggml-sycl/backend.hpp +1 -1
- ggml/src/ggml-sycl/ggml-sycl.cpp +14 -0
- ggml/src/ggml-sycl/norm.cpp +108 -0
- ggml/src/ggml-sycl/norm.hpp +6 -0
- ggml/src/ggml-sycl/wkv.cpp +305 -0
- ggml/src/ggml-sycl/wkv.hpp +10 -0
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +128 -80
- ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +3 -0
- ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- ggml/src/ggml.c +85 -2
ggml/include/ggml.h
CHANGED
|
@@ -454,6 +454,7 @@ extern "C" {
|
|
| 454 |
GGML_OP_RMS_NORM,
|
| 455 |
GGML_OP_RMS_NORM_BACK,
|
| 456 |
GGML_OP_GROUP_NORM,
|
|
|
|
| 457 |
|
| 458 |
GGML_OP_MUL_MAT,
|
| 459 |
GGML_OP_MUL_MAT_ID,
|
|
@@ -502,6 +503,7 @@ extern "C" {
|
|
| 502 |
GGML_OP_ADD_REL_POS,
|
| 503 |
GGML_OP_RWKV_WKV6,
|
| 504 |
GGML_OP_GATED_LINEAR_ATTN,
|
|
|
|
| 505 |
|
| 506 |
GGML_OP_UNARY,
|
| 507 |
|
|
@@ -1095,6 +1097,18 @@ extern "C" {
|
|
| 1095 |
int n_groups,
|
| 1096 |
float eps);
|
| 1097 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1098 |
// a - x
|
| 1099 |
// b - dy
|
| 1100 |
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
|
@@ -1890,6 +1904,16 @@ extern "C" {
|
|
| 1890 |
struct ggml_tensor * state,
|
| 1891 |
float scale);
|
| 1892 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1893 |
// custom operators
|
| 1894 |
|
| 1895 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
|
|
| 454 |
GGML_OP_RMS_NORM,
|
| 455 |
GGML_OP_RMS_NORM_BACK,
|
| 456 |
GGML_OP_GROUP_NORM,
|
| 457 |
+
GGML_OP_L2_NORM,
|
| 458 |
|
| 459 |
GGML_OP_MUL_MAT,
|
| 460 |
GGML_OP_MUL_MAT_ID,
|
|
|
|
| 503 |
GGML_OP_ADD_REL_POS,
|
| 504 |
GGML_OP_RWKV_WKV6,
|
| 505 |
GGML_OP_GATED_LINEAR_ATTN,
|
| 506 |
+
GGML_OP_RWKV_WKV7,
|
| 507 |
|
| 508 |
GGML_OP_UNARY,
|
| 509 |
|
|
|
|
| 1097 |
int n_groups,
|
| 1098 |
float eps);
|
| 1099 |
|
| 1100 |
+
// l2 normalize along rows
|
| 1101 |
+
// used in rwkv v7
|
| 1102 |
+
GGML_API struct ggml_tensor * ggml_l2_norm(
|
| 1103 |
+
struct ggml_context * ctx,
|
| 1104 |
+
struct ggml_tensor * a,
|
| 1105 |
+
float eps);
|
| 1106 |
+
|
| 1107 |
+
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
|
| 1108 |
+
struct ggml_context * ctx,
|
| 1109 |
+
struct ggml_tensor * a,
|
| 1110 |
+
float eps);
|
| 1111 |
+
|
| 1112 |
// a - x
|
| 1113 |
// b - dy
|
| 1114 |
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
|
|
|
| 1904 |
struct ggml_tensor * state,
|
| 1905 |
float scale);
|
| 1906 |
|
| 1907 |
+
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
|
| 1908 |
+
struct ggml_context * ctx,
|
| 1909 |
+
struct ggml_tensor * r,
|
| 1910 |
+
struct ggml_tensor * w,
|
| 1911 |
+
struct ggml_tensor * k,
|
| 1912 |
+
struct ggml_tensor * v,
|
| 1913 |
+
struct ggml_tensor * a,
|
| 1914 |
+
struct ggml_tensor * b,
|
| 1915 |
+
struct ggml_tensor * state);
|
| 1916 |
+
|
| 1917 |
// custom operators
|
| 1918 |
|
| 1919 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
ggml/src/ggml-cpu/ggml-cpu.c
CHANGED
|
@@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
|
|
| 8548 |
}
|
| 8549 |
}
|
| 8550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8551 |
// ggml_compute_forward_mul_mat
|
| 8552 |
|
| 8553 |
static void ggml_compute_forward_mul_mat_one_chunk(
|
|
@@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
|
|
| 13604 |
}
|
| 13605 |
}
|
| 13606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13607 |
// ggml_compute_forward_map_unary
|
| 13608 |
|
| 13609 |
static void ggml_compute_forward_map_unary_f32(
|
|
@@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 14170 |
{
|
| 14171 |
ggml_compute_forward_group_norm(params, tensor);
|
| 14172 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14173 |
case GGML_OP_MUL_MAT:
|
| 14174 |
{
|
| 14175 |
ggml_compute_forward_mul_mat(params, tensor);
|
|
@@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 14357 |
{
|
| 14358 |
ggml_compute_forward_gla(params, tensor);
|
| 14359 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14360 |
case GGML_OP_MAP_UNARY:
|
| 14361 |
{
|
| 14362 |
ggml_unary_op_f32_t fun;
|
|
@@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 14582 |
case GGML_OP_NORM:
|
| 14583 |
case GGML_OP_RMS_NORM:
|
| 14584 |
case GGML_OP_RMS_NORM_BACK:
|
|
|
|
| 14585 |
case GGML_OP_GROUP_NORM:
|
| 14586 |
case GGML_OP_CONCAT:
|
| 14587 |
case GGML_OP_MUL_MAT:
|
|
@@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 14648 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 14649 |
case GGML_OP_SSM_CONV:
|
| 14650 |
case GGML_OP_SSM_SCAN:
|
|
|
|
|
|
|
|
|
|
| 14651 |
{
|
| 14652 |
n_tasks = n_threads;
|
| 14653 |
} break;
|
| 14654 |
case GGML_OP_WIN_PART:
|
| 14655 |
case GGML_OP_WIN_UNPART:
|
| 14656 |
case GGML_OP_GET_REL_POS:
|
| 14657 |
-
case GGML_OP_RWKV_WKV6:
|
| 14658 |
-
case GGML_OP_GATED_LINEAR_ATTN:
|
| 14659 |
case GGML_OP_MAP_UNARY:
|
| 14660 |
case GGML_OP_MAP_BINARY:
|
| 14661 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
|
|
| 8548 |
}
|
| 8549 |
}
|
| 8550 |
|
| 8551 |
+
// ggml_compute_forward_l2_norm
|
| 8552 |
+
|
| 8553 |
+
static void ggml_compute_forward_l2_norm_f32(
|
| 8554 |
+
const struct ggml_compute_params * params,
|
| 8555 |
+
struct ggml_tensor * dst) {
|
| 8556 |
+
|
| 8557 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 8558 |
+
|
| 8559 |
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
| 8560 |
+
|
| 8561 |
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 8562 |
+
|
| 8563 |
+
const int ith = params->ith;
|
| 8564 |
+
const int nth = params->nth;
|
| 8565 |
+
|
| 8566 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
| 8567 |
+
|
| 8568 |
+
float eps;
|
| 8569 |
+
memcpy(&eps, dst->op_params, sizeof(float));
|
| 8570 |
+
|
| 8571 |
+
GGML_ASSERT(eps >= 0.0f);
|
| 8572 |
+
|
| 8573 |
+
// TODO: optimize
|
| 8574 |
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
| 8575 |
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 8576 |
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
| 8577 |
+
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
| 8578 |
+
|
| 8579 |
+
ggml_float sum = 0.0;
|
| 8580 |
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
| 8581 |
+
sum += (ggml_float)(x[i00] * x[i00]);
|
| 8582 |
+
}
|
| 8583 |
+
|
| 8584 |
+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
| 8585 |
+
|
| 8586 |
+
memcpy(y, x, ne00 * sizeof(float));
|
| 8587 |
+
|
| 8588 |
+
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
| 8589 |
+
|
| 8590 |
+
ggml_vec_scale_f32(ne00, y, scale);
|
| 8591 |
+
}
|
| 8592 |
+
}
|
| 8593 |
+
}
|
| 8594 |
+
}
|
| 8595 |
+
|
| 8596 |
+
static void ggml_compute_forward_l2_norm(
|
| 8597 |
+
const struct ggml_compute_params * params,
|
| 8598 |
+
struct ggml_tensor * dst) {
|
| 8599 |
+
|
| 8600 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 8601 |
+
|
| 8602 |
+
switch (src0->type) {
|
| 8603 |
+
case GGML_TYPE_F32:
|
| 8604 |
+
{
|
| 8605 |
+
ggml_compute_forward_l2_norm_f32(params, dst);
|
| 8606 |
+
} break;
|
| 8607 |
+
default:
|
| 8608 |
+
{
|
| 8609 |
+
GGML_ABORT("fatal error");
|
| 8610 |
+
}
|
| 8611 |
+
}
|
| 8612 |
+
}
|
| 8613 |
+
|
| 8614 |
// ggml_compute_forward_mul_mat
|
| 8615 |
|
| 8616 |
static void ggml_compute_forward_mul_mat_one_chunk(
|
|
|
|
| 13667 |
}
|
| 13668 |
}
|
| 13669 |
|
| 13670 |
+
// ggml_compute_forward_rwkv_wkv7
|
| 13671 |
+
|
| 13672 |
+
static void ggml_compute_forward_rwkv_wkv7_f32(
|
| 13673 |
+
const struct ggml_compute_params * params,
|
| 13674 |
+
struct ggml_tensor * dst) {
|
| 13675 |
+
const int64_t T = dst->src[1]->ne[2];
|
| 13676 |
+
const int64_t C = dst->ne[0];
|
| 13677 |
+
const int64_t HEADS = dst->src[1]->ne[1];
|
| 13678 |
+
const int64_t n_seqs = dst->src[6]->ne[1];
|
| 13679 |
+
const int64_t head_size = C / HEADS;
|
| 13680 |
+
|
| 13681 |
+
float * dst_data = (float *) dst->data;
|
| 13682 |
+
float * state = ((float *) dst->data) + C * T;
|
| 13683 |
+
|
| 13684 |
+
const int ith = params->ith;
|
| 13685 |
+
const int nth = params->nth;
|
| 13686 |
+
|
| 13687 |
+
if (ith >= HEADS) {
|
| 13688 |
+
return;
|
| 13689 |
+
}
|
| 13690 |
+
|
| 13691 |
+
const int h_start = (HEADS * ith) / nth;
|
| 13692 |
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
| 13693 |
+
(HEADS * (ith + 1)) / nth : HEADS;
|
| 13694 |
+
|
| 13695 |
+
float * r = (float *) dst->src[0]->data;
|
| 13696 |
+
float * w = (float *) dst->src[1]->data;
|
| 13697 |
+
float * k = (float *) dst->src[2]->data;
|
| 13698 |
+
float * v = (float *) dst->src[3]->data;
|
| 13699 |
+
float * a = (float *) dst->src[4]->data;
|
| 13700 |
+
float * b = (float *) dst->src[5]->data;
|
| 13701 |
+
|
| 13702 |
+
int64_t t_stride = HEADS * head_size; // Same to C
|
| 13703 |
+
|
| 13704 |
+
int64_t h_stride = C / HEADS;
|
| 13705 |
+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
| 13706 |
+
int64_t h_stride_2d = head_size * head_size;
|
| 13707 |
+
|
| 13708 |
+
#if defined(GGML_SIMD)
|
| 13709 |
+
for (int64_t t = 0; t < T; t++) {
|
| 13710 |
+
int64_t t_offset = t * t_stride;
|
| 13711 |
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
| 13712 |
+
float * state_cur = state + state_offset;
|
| 13713 |
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
| 13714 |
+
|
| 13715 |
+
for (int64_t h = h_start; h < h_end; h++) {
|
| 13716 |
+
int64_t h_offset = h * h_stride;
|
| 13717 |
+
int64_t t_h_offset = t_offset + h_offset;
|
| 13718 |
+
int64_t h_2d_offset = h * h_stride_2d;
|
| 13719 |
+
|
| 13720 |
+
for (int64_t ii = 0; ii < head_size; ii++) {
|
| 13721 |
+
int64_t t_h_i_offset = t_h_offset + ii;
|
| 13722 |
+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
| 13723 |
+
|
| 13724 |
+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
| 13725 |
+
|
| 13726 |
+
float sa = 0;
|
| 13727 |
+
{
|
| 13728 |
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
| 13729 |
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
| 13730 |
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
| 13731 |
+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
| 13732 |
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
| 13733 |
+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
| 13734 |
+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
| 13735 |
+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
| 13736 |
+
}
|
| 13737 |
+
}
|
| 13738 |
+
GGML_F32_VEC_REDUCE(sa, sum);
|
| 13739 |
+
}
|
| 13740 |
+
|
| 13741 |
+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
| 13742 |
+
|
| 13743 |
+
int64_t j = 0;
|
| 13744 |
+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
| 13745 |
+
for (; j < head_size; j += GGML_F32_STEP) {
|
| 13746 |
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
| 13747 |
+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
| 13748 |
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
| 13749 |
+
|
| 13750 |
+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
| 13751 |
+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
| 13752 |
+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
| 13753 |
+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
| 13754 |
+
|
| 13755 |
+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
| 13756 |
+
|
| 13757 |
+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
| 13758 |
+
// kv + s * decay + sa * b
|
| 13759 |
+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
| 13760 |
+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
| 13761 |
+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
| 13762 |
+
|
| 13763 |
+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
| 13764 |
+
}
|
| 13765 |
+
}
|
| 13766 |
+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
| 13767 |
+
|
| 13768 |
+
// There shouldn't be left-overs though.
|
| 13769 |
+
for (; j < head_size; j++) {
|
| 13770 |
+
int64_t t_h_j_offset = t_h_offset + j;
|
| 13771 |
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
| 13772 |
+
|
| 13773 |
+
float r_val = r[t_h_j_offset];
|
| 13774 |
+
float w_val = w[t_h_j_offset];
|
| 13775 |
+
float k_val = k[t_h_j_offset];
|
| 13776 |
+
float b_val = b[t_h_j_offset];
|
| 13777 |
+
float kv_val = v[t_h_i_offset] * k_val;
|
| 13778 |
+
|
| 13779 |
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
| 13780 |
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
| 13781 |
+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
| 13782 |
+
}
|
| 13783 |
+
}
|
| 13784 |
+
}
|
| 13785 |
+
}
|
| 13786 |
+
#else
|
| 13787 |
+
for (int64_t t = 0; t < T; t++) {
|
| 13788 |
+
int64_t t_offset = t * t_stride;
|
| 13789 |
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
| 13790 |
+
float * state_cur = state + state_offset;
|
| 13791 |
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
| 13792 |
+
|
| 13793 |
+
for (int64_t h = h_start; h < h_end; h++) {
|
| 13794 |
+
int64_t h_offset = h * h_stride;
|
| 13795 |
+
int64_t t_h_offset = t_offset + h_offset;
|
| 13796 |
+
int64_t h_2d_offset = h * h_stride_2d;
|
| 13797 |
+
|
| 13798 |
+
for (int64_t i = 0; i < head_size; i++) {
|
| 13799 |
+
int64_t t_h_i_offset = t_h_offset + i;
|
| 13800 |
+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
| 13801 |
+
|
| 13802 |
+
float v_val = v[t_h_i_offset];
|
| 13803 |
+
|
| 13804 |
+
float sa = 0, result = 0;
|
| 13805 |
+
for (int64_t j = 0; j < head_size; j++) {
|
| 13806 |
+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
| 13807 |
+
}
|
| 13808 |
+
|
| 13809 |
+
for (int64_t j = 0; j < head_size; j++) {
|
| 13810 |
+
int64_t t_h_j_offset = t_h_offset + j;
|
| 13811 |
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
| 13812 |
+
|
| 13813 |
+
float r_val = r[t_h_j_offset];
|
| 13814 |
+
float w_val = w[t_h_j_offset];
|
| 13815 |
+
float k_val = k[t_h_j_offset];
|
| 13816 |
+
float b_val = b[t_h_j_offset];
|
| 13817 |
+
float kv_val = v_val * k_val;
|
| 13818 |
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
| 13819 |
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
| 13820 |
+
result += state_cur[h_2d_i_j_offset] * r_val;
|
| 13821 |
+
}
|
| 13822 |
+
dst_data[t_h_i_offset] = result;
|
| 13823 |
+
}
|
| 13824 |
+
}
|
| 13825 |
+
}
|
| 13826 |
+
#endif
|
| 13827 |
+
}
|
| 13828 |
+
|
| 13829 |
+
|
| 13830 |
+
static void ggml_compute_forward_rwkv_wkv7(
|
| 13831 |
+
const struct ggml_compute_params * params,
|
| 13832 |
+
struct ggml_tensor * dst) {
|
| 13833 |
+
|
| 13834 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 13835 |
+
|
| 13836 |
+
switch (src0->type) {
|
| 13837 |
+
case GGML_TYPE_F32:
|
| 13838 |
+
{
|
| 13839 |
+
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
|
| 13840 |
+
} break;
|
| 13841 |
+
default:
|
| 13842 |
+
{
|
| 13843 |
+
GGML_ABORT("fatal error");
|
| 13844 |
+
}
|
| 13845 |
+
}
|
| 13846 |
+
}
|
| 13847 |
+
|
| 13848 |
// ggml_compute_forward_map_unary
|
| 13849 |
|
| 13850 |
static void ggml_compute_forward_map_unary_f32(
|
|
|
|
| 14411 |
{
|
| 14412 |
ggml_compute_forward_group_norm(params, tensor);
|
| 14413 |
} break;
|
| 14414 |
+
case GGML_OP_L2_NORM:
|
| 14415 |
+
{
|
| 14416 |
+
ggml_compute_forward_l2_norm(params, tensor);
|
| 14417 |
+
} break;
|
| 14418 |
case GGML_OP_MUL_MAT:
|
| 14419 |
{
|
| 14420 |
ggml_compute_forward_mul_mat(params, tensor);
|
|
|
|
| 14602 |
{
|
| 14603 |
ggml_compute_forward_gla(params, tensor);
|
| 14604 |
} break;
|
| 14605 |
+
case GGML_OP_RWKV_WKV7:
|
| 14606 |
+
{
|
| 14607 |
+
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
| 14608 |
+
} break;
|
| 14609 |
case GGML_OP_MAP_UNARY:
|
| 14610 |
{
|
| 14611 |
ggml_unary_op_f32_t fun;
|
|
|
|
| 14831 |
case GGML_OP_NORM:
|
| 14832 |
case GGML_OP_RMS_NORM:
|
| 14833 |
case GGML_OP_RMS_NORM_BACK:
|
| 14834 |
+
case GGML_OP_L2_NORM:
|
| 14835 |
case GGML_OP_GROUP_NORM:
|
| 14836 |
case GGML_OP_CONCAT:
|
| 14837 |
case GGML_OP_MUL_MAT:
|
|
|
|
| 14898 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 14899 |
case GGML_OP_SSM_CONV:
|
| 14900 |
case GGML_OP_SSM_SCAN:
|
| 14901 |
+
case GGML_OP_RWKV_WKV6:
|
| 14902 |
+
case GGML_OP_GATED_LINEAR_ATTN:
|
| 14903 |
+
case GGML_OP_RWKV_WKV7:
|
| 14904 |
{
|
| 14905 |
n_tasks = n_threads;
|
| 14906 |
} break;
|
| 14907 |
case GGML_OP_WIN_PART:
|
| 14908 |
case GGML_OP_WIN_UNPART:
|
| 14909 |
case GGML_OP_GET_REL_POS:
|
|
|
|
|
|
|
| 14910 |
case GGML_OP_MAP_UNARY:
|
| 14911 |
case GGML_OP_MAP_BINARY:
|
| 14912 |
case GGML_OP_MAP_CUSTOM1_F32:
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -36,7 +36,7 @@
|
|
| 36 |
#include "ggml-cuda/tsembd.cuh"
|
| 37 |
#include "ggml-cuda/unary.cuh"
|
| 38 |
#include "ggml-cuda/upscale.cuh"
|
| 39 |
-
#include "ggml-cuda/
|
| 40 |
#include "ggml-cuda/gla.cuh"
|
| 41 |
#include "ggml.h"
|
| 42 |
|
|
@@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2196 |
case GGML_OP_GROUP_NORM:
|
| 2197 |
ggml_cuda_op_group_norm(ctx, dst);
|
| 2198 |
break;
|
|
|
|
|
|
|
|
|
|
| 2199 |
case GGML_OP_CONCAT:
|
| 2200 |
ggml_cuda_op_concat(ctx, dst);
|
| 2201 |
break;
|
|
@@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2304 |
case GGML_OP_GATED_LINEAR_ATTN:
|
| 2305 |
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
| 2306 |
break;
|
|
|
|
|
|
|
|
|
|
| 2307 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2308 |
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
| 2309 |
break;
|
|
@@ -3161,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3161 |
break;
|
| 3162 |
case GGML_OP_NORM:
|
| 3163 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 3164 |
return true;
|
| 3165 |
case GGML_OP_RMS_NORM_BACK:
|
| 3166 |
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
|
@@ -3215,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3215 |
case GGML_OP_LEAKY_RELU:
|
| 3216 |
case GGML_OP_RWKV_WKV6:
|
| 3217 |
case GGML_OP_GATED_LINEAR_ATTN:
|
|
|
|
| 3218 |
return true;
|
| 3219 |
case GGML_OP_FLASH_ATTN_EXT: {
|
| 3220 |
#ifndef FLASH_ATTN_AVAILABLE
|
|
|
|
| 36 |
#include "ggml-cuda/tsembd.cuh"
|
| 37 |
#include "ggml-cuda/unary.cuh"
|
| 38 |
#include "ggml-cuda/upscale.cuh"
|
| 39 |
+
#include "ggml-cuda/wkv.cuh"
|
| 40 |
#include "ggml-cuda/gla.cuh"
|
| 41 |
#include "ggml.h"
|
| 42 |
|
|
|
|
| 2196 |
case GGML_OP_GROUP_NORM:
|
| 2197 |
ggml_cuda_op_group_norm(ctx, dst);
|
| 2198 |
break;
|
| 2199 |
+
case GGML_OP_L2_NORM:
|
| 2200 |
+
ggml_cuda_op_l2_norm(ctx, dst);
|
| 2201 |
+
break;
|
| 2202 |
case GGML_OP_CONCAT:
|
| 2203 |
ggml_cuda_op_concat(ctx, dst);
|
| 2204 |
break;
|
|
|
|
| 2307 |
case GGML_OP_GATED_LINEAR_ATTN:
|
| 2308 |
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
| 2309 |
break;
|
| 2310 |
+
case GGML_OP_RWKV_WKV7:
|
| 2311 |
+
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
| 2312 |
+
break;
|
| 2313 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2314 |
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
| 2315 |
break;
|
|
|
|
| 3167 |
break;
|
| 3168 |
case GGML_OP_NORM:
|
| 3169 |
case GGML_OP_RMS_NORM:
|
| 3170 |
+
case GGML_OP_L2_NORM:
|
| 3171 |
return true;
|
| 3172 |
case GGML_OP_RMS_NORM_BACK:
|
| 3173 |
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
|
|
|
| 3222 |
case GGML_OP_LEAKY_RELU:
|
| 3223 |
case GGML_OP_RWKV_WKV6:
|
| 3224 |
case GGML_OP_GATED_LINEAR_ATTN:
|
| 3225 |
+
case GGML_OP_RWKV_WKV7:
|
| 3226 |
return true;
|
| 3227 |
case GGML_OP_FLASH_ATTN_EXT: {
|
| 3228 |
#ifndef FLASH_ATTN_AVAILABLE
|
ggml/src/ggml-cuda/norm.cu
CHANGED
|
@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
|
|
| 201 |
}
|
| 202 |
}
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
static void norm_f32_cuda(
|
| 205 |
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
| 206 |
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
|
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
|
| 248 |
}
|
| 249 |
}
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 252 |
const ggml_tensor * src0 = dst->src[0];
|
| 253 |
const float * src0_d = (const float *) src0->data;
|
|
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
|
| 340 |
|
| 341 |
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
| 342 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
}
|
| 202 |
}
|
| 203 |
|
| 204 |
+
// template <int block_size>
|
| 205 |
+
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
| 206 |
+
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 207 |
+
// const int tid = threadIdx.x;
|
| 208 |
+
|
| 209 |
+
// float tmp = 0.0f; // partial sum for thread in warp
|
| 210 |
+
|
| 211 |
+
// for (int col = tid; col < ncols; col += block_size) {
|
| 212 |
+
// const float xi = x[row*ncols + col];
|
| 213 |
+
// tmp += xi * xi;
|
| 214 |
+
// }
|
| 215 |
+
|
| 216 |
+
// // sum up partial sums
|
| 217 |
+
// tmp = warp_reduce_sum(tmp);
|
| 218 |
+
// if (block_size > WARP_SIZE) {
|
| 219 |
+
// __shared__ float s_sum[32];
|
| 220 |
+
// int warp_id = threadIdx.x / WARP_SIZE;
|
| 221 |
+
// int lane_id = threadIdx.x % WARP_SIZE;
|
| 222 |
+
// if (lane_id == 0) {
|
| 223 |
+
// s_sum[warp_id] = tmp;
|
| 224 |
+
// }
|
| 225 |
+
// __syncthreads();
|
| 226 |
+
// tmp = s_sum[lane_id];
|
| 227 |
+
// tmp = warp_reduce_sum(tmp);
|
| 228 |
+
// }
|
| 229 |
+
|
| 230 |
+
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
| 231 |
+
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
| 232 |
+
|
| 233 |
+
// for (int col = tid; col < ncols; col += block_size) {
|
| 234 |
+
// dst[row*ncols + col] = scale * x[row*ncols + col];
|
| 235 |
+
// }
|
| 236 |
+
// }
|
| 237 |
+
|
| 238 |
+
template <int block_size>
|
| 239 |
+
static __global__ void l2_norm_f32(
|
| 240 |
+
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
| 241 |
+
const int64_t stride_sample, const float eps) {
|
| 242 |
+
const int nrows = gridDim.x;
|
| 243 |
+
const int nchannels = gridDim.y;
|
| 244 |
+
|
| 245 |
+
const int row = blockIdx.x;
|
| 246 |
+
const int channel = blockIdx.y;
|
| 247 |
+
const int sample = blockIdx.z;
|
| 248 |
+
const int tid = threadIdx.x;
|
| 249 |
+
|
| 250 |
+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
| 251 |
+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
| 252 |
+
|
| 253 |
+
float tmp = 0.0f; // partial sum for thread in warp
|
| 254 |
+
|
| 255 |
+
for (int col = tid; col < ncols; col += block_size) {
|
| 256 |
+
const float xi = x[col];
|
| 257 |
+
tmp += xi * xi;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// sum up partial sums
|
| 261 |
+
tmp = warp_reduce_sum(tmp);
|
| 262 |
+
if constexpr (block_size > WARP_SIZE) {
|
| 263 |
+
static_assert(block_size == 1024, "unexpected block_size");
|
| 264 |
+
__shared__ float s_sum[32];
|
| 265 |
+
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 266 |
+
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 267 |
+
if (lane_id == 0) {
|
| 268 |
+
s_sum[warp_id] = tmp;
|
| 269 |
+
}
|
| 270 |
+
__syncthreads();
|
| 271 |
+
tmp = s_sum[lane_id];
|
| 272 |
+
tmp = warp_reduce_sum(tmp);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
| 276 |
+
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
| 277 |
+
|
| 278 |
+
for (int col = tid; col < ncols; col += block_size) {
|
| 279 |
+
dst[col] = scale * x[col];
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
static void norm_f32_cuda(
|
| 284 |
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
| 285 |
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
|
|
|
| 327 |
}
|
| 328 |
}
|
| 329 |
|
| 330 |
+
static void l2_norm_f32_cuda(
|
| 331 |
+
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
| 332 |
+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
| 333 |
+
const dim3 blocks_num(nrows, nchannels, nsamples);
|
| 334 |
+
if (ncols < 1024) {
|
| 335 |
+
const dim3 block_dims(WARP_SIZE, 1, 1);
|
| 336 |
+
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
| 337 |
+
} else {
|
| 338 |
+
const dim3 block_dims(1024, 1, 1);
|
| 339 |
+
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 344 |
const ggml_tensor * src0 = dst->src[0];
|
| 345 |
const float * src0_d = (const float *) src0->data;
|
|
|
|
| 432 |
|
| 433 |
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
| 434 |
}
|
| 435 |
+
|
| 436 |
+
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 437 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 438 |
+
const float * src0_d = (const float *) src0->data;
|
| 439 |
+
float * dst_d = (float *) dst->data;
|
| 440 |
+
cudaStream_t stream = ctx.stream();
|
| 441 |
+
|
| 442 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 443 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 444 |
+
|
| 445 |
+
GGML_TENSOR_UNARY_OP_LOCALS;
|
| 446 |
+
|
| 447 |
+
float eps;
|
| 448 |
+
memcpy(&eps, dst->op_params, sizeof(float));
|
| 449 |
+
GGML_ASSERT(eps >= 0.0f);
|
| 450 |
+
|
| 451 |
+
const size_t ts0 = ggml_type_size(src0->type);
|
| 452 |
+
GGML_ASSERT(nb00 == ts0);
|
| 453 |
+
const int64_t s01 = nb01 / ts0;
|
| 454 |
+
const int64_t s02 = nb02 / ts0;
|
| 455 |
+
const int64_t s03 = nb03 / ts0;
|
| 456 |
+
|
| 457 |
+
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
| 458 |
+
}
|
ggml/src/ggml-cuda/norm.cuh
CHANGED
|
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
| 7 |
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 8 |
|
| 9 |
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
| 7 |
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 8 |
|
| 9 |
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 10 |
+
|
| 11 |
+
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml/src/ggml-cuda/wkv.cu
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
#include "wkv.cuh"
|
| 3 |
+
|
| 4 |
+
template <int block_size>
|
| 5 |
+
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
| 6 |
+
const int tid = threadIdx.x;
|
| 7 |
+
const int bid = blockIdx.x;
|
| 8 |
+
|
| 9 |
+
const int head_size = block_size;
|
| 10 |
+
const int batch_i = bid / H;
|
| 11 |
+
const int head_i = bid % H;
|
| 12 |
+
const int state_size = C * head_size;
|
| 13 |
+
const int n_seq_tokens = T / B;
|
| 14 |
+
|
| 15 |
+
float state[head_size];
|
| 16 |
+
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
| 17 |
+
|
| 18 |
+
#pragma unroll
|
| 19 |
+
for (int i = 0; i < head_size; i++) {
|
| 20 |
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
__syncthreads();
|
| 24 |
+
_tf[tid] = tf[head_i * head_size + tid];
|
| 25 |
+
__syncthreads();
|
| 26 |
+
|
| 27 |
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
| 28 |
+
__syncthreads();
|
| 29 |
+
_k[tid] = k[t];
|
| 30 |
+
_r[tid] = r[t];
|
| 31 |
+
_td[tid] = td[t];
|
| 32 |
+
__syncthreads();
|
| 33 |
+
|
| 34 |
+
const float _v = v[t];
|
| 35 |
+
float y = 0;
|
| 36 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 37 |
+
const float4& k = (float4&)(_k[j]);
|
| 38 |
+
const float4& r = (float4&)(_r[j]);
|
| 39 |
+
const float4& tf = (float4&)(_tf[j]);
|
| 40 |
+
const float4& td = (float4&)(_td[j]);
|
| 41 |
+
float4& s = (float4&)(state[j]);
|
| 42 |
+
float4 kv;
|
| 43 |
+
|
| 44 |
+
kv.x = k.x * _v;
|
| 45 |
+
kv.y = k.y * _v;
|
| 46 |
+
kv.z = k.z * _v;
|
| 47 |
+
kv.w = k.w * _v;
|
| 48 |
+
|
| 49 |
+
y += r.x * (tf.x * kv.x + s.x);
|
| 50 |
+
y += r.y * (tf.y * kv.y + s.y);
|
| 51 |
+
y += r.z * (tf.z * kv.z + s.z);
|
| 52 |
+
y += r.w * (tf.w * kv.w + s.w);
|
| 53 |
+
|
| 54 |
+
s.x = s.x * td.x + kv.x;
|
| 55 |
+
s.y = s.y * td.y + kv.y;
|
| 56 |
+
s.z = s.z * td.z + kv.z;
|
| 57 |
+
s.w = s.w * td.w + kv.w;
|
| 58 |
+
}
|
| 59 |
+
dst[t] = y;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
#pragma unroll
|
| 63 |
+
for (int i = 0; i < head_size; i++) {
|
| 64 |
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
template <int block_size>
|
| 69 |
+
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
|
| 70 |
+
const int tid = threadIdx.x;
|
| 71 |
+
const int bid = blockIdx.x;
|
| 72 |
+
|
| 73 |
+
const int head_size = block_size;
|
| 74 |
+
const int batch_i = bid / H;
|
| 75 |
+
const int head_i = bid % H;
|
| 76 |
+
const int state_size = C * head_size;
|
| 77 |
+
const int n_seq_tokens = T / B;
|
| 78 |
+
|
| 79 |
+
float state[head_size];
|
| 80 |
+
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
|
| 81 |
+
|
| 82 |
+
#ifndef GGML_USE_MUSA
|
| 83 |
+
#pragma unroll
|
| 84 |
+
#endif
|
| 85 |
+
for (int i = 0; i < head_size; i++) {
|
| 86 |
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
| 90 |
+
__syncthreads();
|
| 91 |
+
_r[tid] = r[t];
|
| 92 |
+
_w[tid] = w[t];
|
| 93 |
+
_k[tid] = k[t];
|
| 94 |
+
_a[tid] = a[t];
|
| 95 |
+
_b[tid] = b[t];
|
| 96 |
+
__syncthreads();
|
| 97 |
+
|
| 98 |
+
float sa = 0;
|
| 99 |
+
#pragma unroll
|
| 100 |
+
for (int j = 0; j < head_size; j += 4)
|
| 101 |
+
{
|
| 102 |
+
const float4& a = (float4&)(_a[j]);
|
| 103 |
+
const float4& s = (float4&)(state[j]);
|
| 104 |
+
sa += a.x * s.x;
|
| 105 |
+
sa += a.y * s.y;
|
| 106 |
+
sa += a.z * s.z;
|
| 107 |
+
sa += a.w * s.w;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
const float _v = v[t];
|
| 111 |
+
float y = 0;
|
| 112 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 113 |
+
const float4& r = (float4&)(_r[j]);
|
| 114 |
+
const float4& w = (float4&)(_w[j]);
|
| 115 |
+
const float4& k = (float4&)(_k[j]);
|
| 116 |
+
const float4& b = (float4&)(_b[j]);
|
| 117 |
+
float4& s = (float4&)(state[j]);
|
| 118 |
+
float4 kv;
|
| 119 |
+
|
| 120 |
+
kv.x = k.x * _v;
|
| 121 |
+
kv.y = k.y * _v;
|
| 122 |
+
kv.z = k.z * _v;
|
| 123 |
+
kv.w = k.w * _v;
|
| 124 |
+
|
| 125 |
+
s.x = s.x * w.x + kv.x + sa * b.x;
|
| 126 |
+
s.y = s.y * w.y + kv.y + sa * b.y;
|
| 127 |
+
s.z = s.z * w.z + kv.z + sa * b.z;
|
| 128 |
+
s.w = s.w * w.w + kv.w + sa * b.w;
|
| 129 |
+
|
| 130 |
+
y += s.x * r.x;
|
| 131 |
+
y += s.y * r.y;
|
| 132 |
+
y += s.z * r.z;
|
| 133 |
+
y += s.w * r.w;
|
| 134 |
+
}
|
| 135 |
+
dst[t] = y;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
#pragma unroll
|
| 139 |
+
for (int i = 0; i < head_size; i++) {
|
| 140 |
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 145 |
+
const float * k_d = (const float *)dst->src[0]->data;
|
| 146 |
+
const float * v_d = (const float *)dst->src[1]->data;
|
| 147 |
+
const float * r_d = (const float *)dst->src[2]->data;
|
| 148 |
+
const float * tf_d = (const float *)dst->src[3]->data;
|
| 149 |
+
const float * td_d = (const float *)dst->src[4]->data;
|
| 150 |
+
const float * s_d = (const float *)dst->src[5]->data;
|
| 151 |
+
|
| 152 |
+
const int64_t B = dst->src[5]->ne[1];
|
| 153 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 154 |
+
const int64_t C = dst->ne[0];
|
| 155 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 156 |
+
|
| 157 |
+
float * dst_d = (float *)dst->data;
|
| 158 |
+
|
| 159 |
+
cudaStream_t stream = ctx.stream();
|
| 160 |
+
|
| 161 |
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
| 162 |
+
GGML_ASSERT(C % H == 0);
|
| 163 |
+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
| 164 |
+
|
| 165 |
+
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
| 166 |
+
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
| 167 |
+
} else {
|
| 168 |
+
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 173 |
+
const float * r_d = (const float *)dst->src[0]->data;
|
| 174 |
+
const float * w_d = (const float *)dst->src[1]->data;
|
| 175 |
+
const float * k_d = (const float *)dst->src[2]->data;
|
| 176 |
+
const float * v_d = (const float *)dst->src[3]->data;
|
| 177 |
+
const float * a_d = (const float *)dst->src[4]->data;
|
| 178 |
+
const float * b_d = (const float *)dst->src[5]->data;
|
| 179 |
+
const float * s_d = (const float *)dst->src[6]->data;
|
| 180 |
+
|
| 181 |
+
const int64_t B = dst->src[6]->ne[1];
|
| 182 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 183 |
+
const int64_t C = dst->ne[0];
|
| 184 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 185 |
+
|
| 186 |
+
float * dst_d = (float *)dst->data;
|
| 187 |
+
|
| 188 |
+
cudaStream_t stream = ctx.stream();
|
| 189 |
+
|
| 190 |
+
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
| 191 |
+
GGML_ASSERT(C % H == 0);
|
| 192 |
+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
| 193 |
+
|
| 194 |
+
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
| 195 |
+
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
| 196 |
+
} else {
|
| 197 |
+
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
| 198 |
+
}
|
| 199 |
+
}
|
ggml/src/ggml-cuda/wkv.cuh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
#define CUDA_WKV_BLOCK_SIZE 64
|
| 4 |
+
|
| 5 |
+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 6 |
+
|
| 7 |
+
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml/src/ggml-metal/ggml-metal-impl.h
CHANGED
|
@@ -285,6 +285,13 @@ typedef struct {
|
|
| 285 |
float eps;
|
| 286 |
} ggml_metal_kargs_rms_norm;
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
typedef struct {
|
| 289 |
int64_t ne00;
|
| 290 |
int64_t ne01;
|
|
|
|
| 285 |
float eps;
|
| 286 |
} ggml_metal_kargs_rms_norm;
|
| 287 |
|
| 288 |
+
typedef struct {
|
| 289 |
+
int32_t ne00;
|
| 290 |
+
int32_t ne00_4;
|
| 291 |
+
uint64_t nb01;
|
| 292 |
+
float eps;
|
| 293 |
+
} ggml_metal_kargs_l2_norm;
|
| 294 |
+
|
| 295 |
typedef struct {
|
| 296 |
int64_t ne00;
|
| 297 |
int64_t ne01;
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
|
|
| 184 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
| 185 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
| 186 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
|
|
| 187 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
| 188 |
GGML_METAL_KERNEL_TYPE_NORM,
|
| 189 |
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 190 |
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
|
|
|
|
|
|
| 191 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
| 192 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
| 193 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
@@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 810 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 811 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 812 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
|
|
| 813 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
| 814 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 815 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 816 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
|
|
|
|
|
| 817 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
| 818 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
| 819 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
@@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1251 |
case GGML_OP_GROUP_NORM:
|
| 1252 |
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
| 1253 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 1254 |
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
| 1255 |
case GGML_OP_ARGMAX:
|
| 1256 |
return true;
|
|
@@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1288 |
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 1289 |
case GGML_OP_SSM_CONV:
|
| 1290 |
case GGML_OP_SSM_SCAN:
|
|
|
|
|
|
|
| 1291 |
return true;
|
| 1292 |
case GGML_OP_MUL_MAT:
|
| 1293 |
case GGML_OP_MUL_MAT_ID:
|
|
@@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
|
|
| 2216 |
|
| 2217 |
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2218 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2219 |
case GGML_OP_MUL_MAT:
|
| 2220 |
{
|
| 2221 |
GGML_ASSERT(ne00 == ne10);
|
|
@@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
|
|
| 3122 |
|
| 3123 |
const int64_t nrows = ggml_nrows(src0);
|
| 3124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3125 |
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3126 |
} break;
|
| 3127 |
case GGML_OP_GROUP_NORM:
|
|
|
|
| 184 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
| 185 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
| 186 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 187 |
+
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
| 188 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
| 189 |
GGML_METAL_KERNEL_TYPE_NORM,
|
| 190 |
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 191 |
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
| 192 |
+
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
| 193 |
+
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
| 194 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
| 195 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
| 196 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
|
|
| 813 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 814 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 815 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
| 816 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
| 817 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
| 818 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 819 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 820 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
| 821 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
| 822 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
| 823 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
| 824 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
| 825 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
|
|
| 1257 |
case GGML_OP_GROUP_NORM:
|
| 1258 |
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
| 1259 |
case GGML_OP_RMS_NORM:
|
| 1260 |
+
case GGML_OP_L2_NORM:
|
| 1261 |
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
| 1262 |
case GGML_OP_ARGMAX:
|
| 1263 |
return true;
|
|
|
|
| 1295 |
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 1296 |
case GGML_OP_SSM_CONV:
|
| 1297 |
case GGML_OP_SSM_SCAN:
|
| 1298 |
+
case GGML_OP_RWKV_WKV6:
|
| 1299 |
+
case GGML_OP_RWKV_WKV7:
|
| 1300 |
return true;
|
| 1301 |
case GGML_OP_MUL_MAT:
|
| 1302 |
case GGML_OP_MUL_MAT_ID:
|
|
|
|
| 2225 |
|
| 2226 |
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2227 |
} break;
|
| 2228 |
+
case GGML_OP_RWKV_WKV6:
|
| 2229 |
+
{
|
| 2230 |
+
const int64_t B = dst->src[5]->ne[1];
|
| 2231 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 2232 |
+
const int64_t C = dst->ne[0];
|
| 2233 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 2234 |
+
|
| 2235 |
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
| 2236 |
+
GGML_ASSERT(C % H == 0);
|
| 2237 |
+
GGML_ASSERT(C / H == 64);
|
| 2238 |
+
|
| 2239 |
+
size_t offs_src3 = 0;
|
| 2240 |
+
size_t offs_src4 = 0;
|
| 2241 |
+
size_t offs_src5 = 0;
|
| 2242 |
+
|
| 2243 |
+
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
| 2244 |
+
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
| 2245 |
+
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
| 2246 |
+
|
| 2247 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
| 2248 |
+
|
| 2249 |
+
[encoder setComputePipelineState:pipeline];
|
| 2250 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2251 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2252 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
| 2253 |
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
| 2254 |
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
| 2255 |
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
| 2256 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
| 2257 |
+
|
| 2258 |
+
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
| 2259 |
+
[encoder setBytes:&T length:sizeof(T) atIndex:8];
|
| 2260 |
+
[encoder setBytes:&C length:sizeof(C) atIndex:9];
|
| 2261 |
+
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
| 2262 |
+
|
| 2263 |
+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
| 2264 |
+
} break;
|
| 2265 |
+
case GGML_OP_RWKV_WKV7:
|
| 2266 |
+
{
|
| 2267 |
+
const int64_t B = dst->src[6]->ne[1];
|
| 2268 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 2269 |
+
const int64_t C = dst->ne[0];
|
| 2270 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 2271 |
+
|
| 2272 |
+
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
| 2273 |
+
GGML_ASSERT(C % H == 0);
|
| 2274 |
+
GGML_ASSERT(C / H == 64);
|
| 2275 |
+
|
| 2276 |
+
size_t offs_src3 = 0;
|
| 2277 |
+
size_t offs_src4 = 0;
|
| 2278 |
+
size_t offs_src5 = 0;
|
| 2279 |
+
size_t offs_src6 = 0;
|
| 2280 |
+
|
| 2281 |
+
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
| 2282 |
+
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
| 2283 |
+
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
| 2284 |
+
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
|
| 2285 |
+
|
| 2286 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
|
| 2287 |
+
|
| 2288 |
+
[encoder setComputePipelineState:pipeline];
|
| 2289 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2290 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2291 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
| 2292 |
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
| 2293 |
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
| 2294 |
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
| 2295 |
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
| 2296 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
| 2297 |
+
|
| 2298 |
+
[encoder setBytes:&B length:sizeof(B) atIndex:8];
|
| 2299 |
+
[encoder setBytes:&T length:sizeof(T) atIndex:9];
|
| 2300 |
+
[encoder setBytes:&C length:sizeof(C) atIndex:10];
|
| 2301 |
+
[encoder setBytes:&H length:sizeof(H) atIndex:11];
|
| 2302 |
+
|
| 2303 |
+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
| 2304 |
+
} break;
|
| 2305 |
case GGML_OP_MUL_MAT:
|
| 2306 |
{
|
| 2307 |
GGML_ASSERT(ne00 == ne10);
|
|
|
|
| 3208 |
|
| 3209 |
const int64_t nrows = ggml_nrows(src0);
|
| 3210 |
|
| 3211 |
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3212 |
+
} break;
|
| 3213 |
+
case GGML_OP_L2_NORM:
|
| 3214 |
+
{
|
| 3215 |
+
GGML_ASSERT(ne00 % 4 == 0);
|
| 3216 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3217 |
+
|
| 3218 |
+
float eps;
|
| 3219 |
+
memcpy(&eps, dst->op_params, sizeof(float));
|
| 3220 |
+
|
| 3221 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
| 3222 |
+
|
| 3223 |
+
int nth = 32; // SIMD width
|
| 3224 |
+
|
| 3225 |
+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
| 3226 |
+
nth *= 2;
|
| 3227 |
+
}
|
| 3228 |
+
|
| 3229 |
+
nth = MIN(nth, ne00/4);
|
| 3230 |
+
|
| 3231 |
+
ggml_metal_kargs_l2_norm args = {
|
| 3232 |
+
/*.ne00 =*/ ne00,
|
| 3233 |
+
/*.ne00_4 =*/ ne00/4,
|
| 3234 |
+
/*.nb01 =*/ nb01,
|
| 3235 |
+
/*.eps =*/ eps,
|
| 3236 |
+
};
|
| 3237 |
+
|
| 3238 |
+
[encoder setComputePipelineState:pipeline];
|
| 3239 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 3240 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 3241 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 3242 |
+
|
| 3243 |
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 3244 |
+
|
| 3245 |
+
const int64_t nrows = ggml_nrows(src0);
|
| 3246 |
+
|
| 3247 |
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3248 |
} break;
|
| 3249 |
case GGML_OP_GROUP_NORM:
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
|
|
| 1295 |
}
|
| 1296 |
}
|
| 1297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1298 |
kernel void kernel_argmax(
|
| 1299 |
device const void * x,
|
| 1300 |
device int32_t * dst,
|
|
@@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
|
|
| 1463 |
}
|
| 1464 |
}
|
| 1465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1466 |
kernel void kernel_group_norm(
|
| 1467 |
device const float * src0,
|
| 1468 |
device float * dst,
|
|
|
|
| 1295 |
}
|
| 1296 |
}
|
| 1297 |
|
| 1298 |
+
kernel void kernel_rwkv_wkv6_f32(
|
| 1299 |
+
device const float * k,
|
| 1300 |
+
device const float * v,
|
| 1301 |
+
device const float * r,
|
| 1302 |
+
device const float * tf,
|
| 1303 |
+
device const float * td,
|
| 1304 |
+
device const float * state_in,
|
| 1305 |
+
device float * dst,
|
| 1306 |
+
constant uint & B,
|
| 1307 |
+
constant uint & T,
|
| 1308 |
+
constant uint & C,
|
| 1309 |
+
constant uint & H,
|
| 1310 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1311 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1312 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1313 |
+
|
| 1314 |
+
const uint head_size = 64; // TODO: support head_size = 128
|
| 1315 |
+
const uint batch_id = tgpig.x / H;
|
| 1316 |
+
const uint head_id = tgpig.x % H;
|
| 1317 |
+
const uint tid = tpitg.x;
|
| 1318 |
+
|
| 1319 |
+
if (batch_id >= B || head_id >= H) {
|
| 1320 |
+
return;
|
| 1321 |
+
}
|
| 1322 |
+
|
| 1323 |
+
const uint state_size = C * head_size;
|
| 1324 |
+
const uint n_seq_tokens = T / B;
|
| 1325 |
+
|
| 1326 |
+
threadgroup float _k[head_size];
|
| 1327 |
+
threadgroup float _r[head_size];
|
| 1328 |
+
threadgroup float _tf[head_size];
|
| 1329 |
+
threadgroup float _td[head_size];
|
| 1330 |
+
|
| 1331 |
+
float state[head_size];
|
| 1332 |
+
|
| 1333 |
+
for (uint i = 0; i < head_size; i++) {
|
| 1334 |
+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
| 1335 |
+
+ i * head_size + tid];
|
| 1336 |
+
}
|
| 1337 |
+
|
| 1338 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1339 |
+
_tf[tid] = tf[head_id * head_size + tid];
|
| 1340 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1341 |
+
|
| 1342 |
+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
| 1343 |
+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
| 1344 |
+
|
| 1345 |
+
for (uint t = start_t; t < end_t; t += C) {
|
| 1346 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1347 |
+
_k[tid] = k[t];
|
| 1348 |
+
_r[tid] = r[t];
|
| 1349 |
+
_td[tid] = td[t];
|
| 1350 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1351 |
+
|
| 1352 |
+
const float v_val = v[t];
|
| 1353 |
+
float y = 0.0;
|
| 1354 |
+
|
| 1355 |
+
for (uint j = 0; j < head_size; j += 4) {
|
| 1356 |
+
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
| 1357 |
+
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
| 1358 |
+
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
| 1359 |
+
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
| 1360 |
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 1361 |
+
|
| 1362 |
+
float4 kv = k_vec * v_val;
|
| 1363 |
+
|
| 1364 |
+
float4 temp = tf_vec * kv + s_vec;
|
| 1365 |
+
y += dot(r_vec, temp);
|
| 1366 |
+
|
| 1367 |
+
s_vec = s_vec * td_vec + kv;
|
| 1368 |
+
state[j] = s_vec[0];
|
| 1369 |
+
state[j+1] = s_vec[1];
|
| 1370 |
+
state[j+2] = s_vec[2];
|
| 1371 |
+
state[j+3] = s_vec[3];
|
| 1372 |
+
}
|
| 1373 |
+
|
| 1374 |
+
dst[t] = y;
|
| 1375 |
+
}
|
| 1376 |
+
|
| 1377 |
+
for (uint i = 0; i < head_size; i++) {
|
| 1378 |
+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
| 1379 |
+
+ i * head_size + tid] = state[i];
|
| 1380 |
+
}
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
kernel void kernel_rwkv_wkv7_f32(
|
| 1384 |
+
device const float * r,
|
| 1385 |
+
device const float * w,
|
| 1386 |
+
device const float * k,
|
| 1387 |
+
device const float * v,
|
| 1388 |
+
device const float * a,
|
| 1389 |
+
device const float * b,
|
| 1390 |
+
device const float * state_in,
|
| 1391 |
+
device float * dst,
|
| 1392 |
+
constant uint & B,
|
| 1393 |
+
constant uint & T,
|
| 1394 |
+
constant uint & C,
|
| 1395 |
+
constant uint & H,
|
| 1396 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1397 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1398 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1399 |
+
|
| 1400 |
+
const uint head_size = 64; // TODO: support head_size = 128
|
| 1401 |
+
const uint batch_id = tgpig.x / H;
|
| 1402 |
+
const uint head_id = tgpig.x % H;
|
| 1403 |
+
const uint tid = tpitg.x;
|
| 1404 |
+
|
| 1405 |
+
if (batch_id >= B || head_id >= H) {
|
| 1406 |
+
return;
|
| 1407 |
+
}
|
| 1408 |
+
|
| 1409 |
+
const uint state_size = C * head_size;
|
| 1410 |
+
const uint n_seq_tokens = T / B;
|
| 1411 |
+
|
| 1412 |
+
threadgroup float _r[head_size];
|
| 1413 |
+
threadgroup float _w[head_size];
|
| 1414 |
+
threadgroup float _k[head_size];
|
| 1415 |
+
threadgroup float _a[head_size];
|
| 1416 |
+
threadgroup float _b[head_size];
|
| 1417 |
+
|
| 1418 |
+
float state[head_size];
|
| 1419 |
+
|
| 1420 |
+
for (uint i = 0; i < head_size; i++) {
|
| 1421 |
+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
| 1422 |
+
+ tid * head_size + i];
|
| 1423 |
+
}
|
| 1424 |
+
|
| 1425 |
+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
| 1426 |
+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
| 1427 |
+
|
| 1428 |
+
for (uint t = start_t; t < end_t; t += C) {
|
| 1429 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1430 |
+
_r[tid] = r[t];
|
| 1431 |
+
_w[tid] = w[t];
|
| 1432 |
+
_k[tid] = k[t];
|
| 1433 |
+
_a[tid] = a[t];
|
| 1434 |
+
_b[tid] = b[t];
|
| 1435 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1436 |
+
|
| 1437 |
+
const float v_val = v[t];
|
| 1438 |
+
float y = 0.0, sa = 0.0;
|
| 1439 |
+
|
| 1440 |
+
float4 sa_vec(0.0);
|
| 1441 |
+
|
| 1442 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 1443 |
+
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
| 1444 |
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 1445 |
+
sa_vec += a_vec * s_vec;
|
| 1446 |
+
}
|
| 1447 |
+
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
| 1448 |
+
|
| 1449 |
+
for (uint j = 0; j < head_size; j += 4) {
|
| 1450 |
+
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
| 1451 |
+
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
| 1452 |
+
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
| 1453 |
+
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
| 1454 |
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 1455 |
+
|
| 1456 |
+
float4 kv = k_vec * v_val;
|
| 1457 |
+
|
| 1458 |
+
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
| 1459 |
+
y += dot(s_vec, r_vec);
|
| 1460 |
+
|
| 1461 |
+
state[j] = s_vec[0];
|
| 1462 |
+
state[j+1] = s_vec[1];
|
| 1463 |
+
state[j+2] = s_vec[2];
|
| 1464 |
+
state[j+3] = s_vec[3];
|
| 1465 |
+
}
|
| 1466 |
+
|
| 1467 |
+
dst[t] = y;
|
| 1468 |
+
}
|
| 1469 |
+
|
| 1470 |
+
for (uint i = 0; i < head_size; i++) {
|
| 1471 |
+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
| 1472 |
+
+ tid * head_size + i] = state[i];
|
| 1473 |
+
}
|
| 1474 |
+
}
|
| 1475 |
+
|
| 1476 |
kernel void kernel_argmax(
|
| 1477 |
device const void * x,
|
| 1478 |
device int32_t * dst,
|
|
|
|
| 1641 |
}
|
| 1642 |
}
|
| 1643 |
|
| 1644 |
+
kernel void kernel_l2_norm(
|
| 1645 |
+
constant ggml_metal_kargs_l2_norm & args,
|
| 1646 |
+
device const char * src0,
|
| 1647 |
+
device char * dst,
|
| 1648 |
+
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
| 1649 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1650 |
+
ushort tpitg[[thread_position_in_threadgroup]],
|
| 1651 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
| 1652 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1653 |
+
ushort ntg[[threads_per_threadgroup]]) {
|
| 1654 |
+
if (sgitg == 0) {
|
| 1655 |
+
shmem_f32[tiisg] = 0.0f;
|
| 1656 |
+
}
|
| 1657 |
+
|
| 1658 |
+
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
| 1659 |
+
|
| 1660 |
+
float sumf = 0.0f;
|
| 1661 |
+
|
| 1662 |
+
// parallel sum
|
| 1663 |
+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
| 1664 |
+
sumf += dot(x[i00], x[i00]);
|
| 1665 |
+
}
|
| 1666 |
+
sumf = simd_sum(sumf);
|
| 1667 |
+
|
| 1668 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1669 |
+
|
| 1670 |
+
if (tiisg == 0) {
|
| 1671 |
+
shmem_f32[sgitg] = sumf;
|
| 1672 |
+
}
|
| 1673 |
+
|
| 1674 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1675 |
+
|
| 1676 |
+
sumf = shmem_f32[tiisg];
|
| 1677 |
+
sumf = simd_sum(sumf);
|
| 1678 |
+
|
| 1679 |
+
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
| 1680 |
+
|
| 1681 |
+
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
| 1682 |
+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
| 1683 |
+
y[i00] = x[i00] * scale;
|
| 1684 |
+
}
|
| 1685 |
+
}
|
| 1686 |
+
|
| 1687 |
kernel void kernel_group_norm(
|
| 1688 |
device const float * src0,
|
| 1689 |
device float * dst,
|
ggml/src/ggml-sycl/backend.hpp
CHANGED
|
@@ -26,7 +26,7 @@
|
|
| 26 |
#include "softmax.hpp"
|
| 27 |
#include "tsembd.hpp"
|
| 28 |
#include "im2col.hpp"
|
| 29 |
-
#include "
|
| 30 |
#include "outprod.hpp"
|
| 31 |
#include "element_wise.hpp"
|
| 32 |
#include "cpy.hpp"
|
|
|
|
| 26 |
#include "softmax.hpp"
|
| 27 |
#include "tsembd.hpp"
|
| 28 |
#include "im2col.hpp"
|
| 29 |
+
#include "wkv.hpp"
|
| 30 |
#include "outprod.hpp"
|
| 31 |
#include "element_wise.hpp"
|
| 32 |
#include "cpy.hpp"
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -2696,6 +2696,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|
| 2696 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 2697 |
}
|
| 2698 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2699 |
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 2700 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 2701 |
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
|
@@ -3410,6 +3416,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 3410 |
case GGML_OP_RMS_NORM:
|
| 3411 |
ggml_sycl_rms_norm(ctx, dst);
|
| 3412 |
break;
|
|
|
|
|
|
|
|
|
|
| 3413 |
case GGML_OP_MUL_MAT:
|
| 3414 |
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
| 3415 |
return false;
|
|
@@ -3487,6 +3496,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 3487 |
case GGML_OP_RWKV_WKV6:
|
| 3488 |
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
| 3489 |
break;
|
|
|
|
|
|
|
|
|
|
| 3490 |
case GGML_OP_GATED_LINEAR_ATTN:
|
| 3491 |
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
| 3492 |
break;
|
|
@@ -4012,6 +4024,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4012 |
return (op->src[0]->type == GGML_TYPE_F32);
|
| 4013 |
case GGML_OP_NORM:
|
| 4014 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 4015 |
case GGML_OP_GROUP_NORM:
|
| 4016 |
return ggml_is_contiguous(op->src[0]);
|
| 4017 |
case GGML_OP_SCALE:
|
|
@@ -4045,6 +4058,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4045 |
case GGML_OP_LEAKY_RELU:
|
| 4046 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 4047 |
case GGML_OP_RWKV_WKV6:
|
|
|
|
| 4048 |
case GGML_OP_GATED_LINEAR_ATTN:
|
| 4049 |
return true;
|
| 4050 |
default:
|
|
|
|
| 2696 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 2697 |
}
|
| 2698 |
|
| 2699 |
+
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 2700 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 2701 |
+
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
|
| 2702 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 2703 |
+
}
|
| 2704 |
+
|
| 2705 |
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 2706 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 2707 |
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
|
|
|
| 3416 |
case GGML_OP_RMS_NORM:
|
| 3417 |
ggml_sycl_rms_norm(ctx, dst);
|
| 3418 |
break;
|
| 3419 |
+
case GGML_OP_L2_NORM:
|
| 3420 |
+
ggml_sycl_l2_norm(ctx, dst);
|
| 3421 |
+
break;
|
| 3422 |
case GGML_OP_MUL_MAT:
|
| 3423 |
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
| 3424 |
return false;
|
|
|
|
| 3496 |
case GGML_OP_RWKV_WKV6:
|
| 3497 |
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
| 3498 |
break;
|
| 3499 |
+
case GGML_OP_RWKV_WKV7:
|
| 3500 |
+
ggml_sycl_op_rwkv_wkv7(ctx, dst);
|
| 3501 |
+
break;
|
| 3502 |
case GGML_OP_GATED_LINEAR_ATTN:
|
| 3503 |
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
| 3504 |
break;
|
|
|
|
| 4024 |
return (op->src[0]->type == GGML_TYPE_F32);
|
| 4025 |
case GGML_OP_NORM:
|
| 4026 |
case GGML_OP_RMS_NORM:
|
| 4027 |
+
case GGML_OP_L2_NORM:
|
| 4028 |
case GGML_OP_GROUP_NORM:
|
| 4029 |
return ggml_is_contiguous(op->src[0]);
|
| 4030 |
case GGML_OP_SCALE:
|
|
|
|
| 4058 |
case GGML_OP_LEAKY_RELU:
|
| 4059 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 4060 |
case GGML_OP_RWKV_WKV6:
|
| 4061 |
+
case GGML_OP_RWKV_WKV7:
|
| 4062 |
case GGML_OP_GATED_LINEAR_ATTN:
|
| 4063 |
return true;
|
| 4064 |
default:
|
ggml/src/ggml-sycl/norm.cpp
CHANGED
|
@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|
| 180 |
}
|
| 181 |
}
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
| 184 |
const int nrows, const float eps,
|
| 185 |
queue_ptr stream, int device) {
|
|
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
| 311 |
}
|
| 312 |
}
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
| 315 |
ggml_tensor* dst, const float* src0_dd,
|
| 316 |
const float* src1_dd, float* dst_dd,
|
|
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
|
| 376 |
(void)dst;
|
| 377 |
(void)src1_dd;
|
| 378 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
}
|
| 181 |
}
|
| 182 |
|
| 183 |
+
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
| 184 |
+
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
| 185 |
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
| 186 |
+
item_ct1.get_local_id(1);
|
| 187 |
+
const int tid = item_ct1.get_local_id(2);
|
| 188 |
+
const int nthreads = item_ct1.get_local_range(2);
|
| 189 |
+
const int nwarps = nthreads / WARP_SIZE;
|
| 190 |
+
float tmp = 0.0f; // partial sum for thread in warp
|
| 191 |
+
|
| 192 |
+
for (int col = tid; col < ncols; col += block_size) {
|
| 193 |
+
const float xi = x[row * ncols + col];
|
| 194 |
+
tmp += xi * xi;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// sum up partial sums
|
| 198 |
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 199 |
+
if (block_size > WARP_SIZE) {
|
| 200 |
+
|
| 201 |
+
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
| 202 |
+
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
| 203 |
+
if (lane_id == 0) {
|
| 204 |
+
s_sum[warp_id] = tmp;
|
| 205 |
+
}
|
| 206 |
+
/*
|
| 207 |
+
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
| 208 |
+
converged control flow. You may need to adjust the code.
|
| 209 |
+
*/
|
| 210 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 211 |
+
size_t nreduce = nwarps / WARP_SIZE;
|
| 212 |
+
tmp = 0.f;
|
| 213 |
+
for (size_t i = 0; i < nreduce; i += 1)
|
| 214 |
+
{
|
| 215 |
+
tmp += s_sum[lane_id + i * WARP_SIZE];
|
| 216 |
+
}
|
| 217 |
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
| 221 |
+
|
| 222 |
+
for (int col = tid; col < ncols; col += block_size) {
|
| 223 |
+
dst[row * ncols + col] = scale * x[row * ncols + col];
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
| 228 |
const int nrows, const float eps,
|
| 229 |
queue_ptr stream, int device) {
|
|
|
|
| 355 |
}
|
| 356 |
}
|
| 357 |
|
| 358 |
+
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
| 359 |
+
const int nrows, const float eps,
|
| 360 |
+
queue_ptr stream, int device) {
|
| 361 |
+
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 362 |
+
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
| 363 |
+
if (ncols < 1024) {
|
| 364 |
+
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
| 365 |
+
stream->submit([&](sycl::handler& cgh) {
|
| 366 |
+
cgh.parallel_for(
|
| 367 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
| 368 |
+
block_dims),
|
| 369 |
+
[=](sycl::nd_item<3> item_ct1)
|
| 370 |
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
| 371 |
+
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
| 372 |
+
nullptr, WARP_SIZE);
|
| 373 |
+
});
|
| 374 |
+
});
|
| 375 |
+
}
|
| 376 |
+
else {
|
| 377 |
+
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
| 378 |
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
| 379 |
+
const sycl::range<3> block_dims(1, 1, work_group_size);
|
| 380 |
+
/*
|
| 381 |
+
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
| 382 |
+
the limit. To get the device limit, query
|
| 383 |
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
| 384 |
+
*/
|
| 385 |
+
stream->submit([&](sycl::handler& cgh) {
|
| 386 |
+
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
| 387 |
+
cgh);
|
| 388 |
+
cgh.parallel_for(
|
| 389 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
| 390 |
+
block_dims),
|
| 391 |
+
[=](sycl::nd_item<3> item_ct1)
|
| 392 |
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
| 393 |
+
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
| 394 |
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
| 395 |
+
});
|
| 396 |
+
});
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
| 401 |
ggml_tensor* dst, const float* src0_dd,
|
| 402 |
const float* src1_dd, float* dst_dd,
|
|
|
|
| 462 |
(void)dst;
|
| 463 |
(void)src1_dd;
|
| 464 |
}
|
| 465 |
+
|
| 466 |
+
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
| 467 |
+
const ggml_tensor* src1, ggml_tensor* dst,
|
| 468 |
+
const float* src0_dd, const float* src1_dd,
|
| 469 |
+
float* dst_dd,
|
| 470 |
+
const queue_ptr& main_stream) {
|
| 471 |
+
|
| 472 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 473 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 474 |
+
|
| 475 |
+
const int64_t ne00 = src0->ne[0];
|
| 476 |
+
const int64_t nrows = ggml_nrows(src0);
|
| 477 |
+
|
| 478 |
+
float eps;
|
| 479 |
+
memcpy(&eps, dst->op_params, sizeof(float));
|
| 480 |
+
|
| 481 |
+
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
| 482 |
+
|
| 483 |
+
(void)src1;
|
| 484 |
+
(void)dst;
|
| 485 |
+
(void)src1_dd;
|
| 486 |
+
}
|
ggml/src/ggml-sycl/norm.hpp
CHANGED
|
@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
|
| 32 |
float* dst_dd,
|
| 33 |
const queue_ptr& main_stream);
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
#endif // GGML_SYCL_NORM_HPP
|
|
|
|
| 32 |
float* dst_dd,
|
| 33 |
const queue_ptr& main_stream);
|
| 34 |
|
| 35 |
+
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
| 36 |
+
const ggml_tensor* src1, ggml_tensor* dst,
|
| 37 |
+
const float* src0_dd, const float* src1_dd,
|
| 38 |
+
float* dst_dd,
|
| 39 |
+
const queue_ptr& main_stream);
|
| 40 |
+
|
| 41 |
#endif // GGML_SYCL_NORM_HPP
|
ggml/src/ggml-sycl/wkv.cpp
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sycl/sycl.hpp>
|
| 2 |
+
#include "wkv.hpp"
|
| 3 |
+
|
| 4 |
+
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
| 5 |
+
|
| 6 |
+
// Helper function for the main kernel
|
| 7 |
+
template <int block_size>
|
| 8 |
+
static void rwkv_wkv6_f32_kernel(
|
| 9 |
+
const int B, const int T, const int C, const int H,
|
| 10 |
+
const float* k, const float* v, const float* r,
|
| 11 |
+
const float* tf, const float* td, const float* s,
|
| 12 |
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
| 13 |
+
|
| 14 |
+
const int tid = item_ct1.get_local_id(2);
|
| 15 |
+
const int bid = item_ct1.get_group(2);
|
| 16 |
+
|
| 17 |
+
const int head_size = block_size;
|
| 18 |
+
const int batch_i = bid / H;
|
| 19 |
+
const int head_i = bid % H;
|
| 20 |
+
const int state_size = C * head_size;
|
| 21 |
+
const int n_seq_tokens = T / B;
|
| 22 |
+
|
| 23 |
+
// Set up shared memory pointers
|
| 24 |
+
float* _k = shared_mem;
|
| 25 |
+
float* _r = _k + head_size;
|
| 26 |
+
float* _tf = _r + head_size;
|
| 27 |
+
float* _td = _tf + head_size;
|
| 28 |
+
|
| 29 |
+
// Local state array
|
| 30 |
+
float state[block_size];
|
| 31 |
+
|
| 32 |
+
// Load initial state
|
| 33 |
+
#pragma unroll
|
| 34 |
+
for (int i = 0; i < head_size; i++) {
|
| 35 |
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// Sync threads before shared memory operations
|
| 39 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 40 |
+
|
| 41 |
+
// Load time-mixing parameters
|
| 42 |
+
_tf[tid] = tf[head_i * head_size + tid];
|
| 43 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 44 |
+
|
| 45 |
+
// Main sequence processing loop
|
| 46 |
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
| 47 |
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
| 48 |
+
t += C) {
|
| 49 |
+
|
| 50 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 51 |
+
|
| 52 |
+
// Load current timestep data to shared memory
|
| 53 |
+
_k[tid] = k[t];
|
| 54 |
+
_r[tid] = r[t];
|
| 55 |
+
_td[tid] = td[t];
|
| 56 |
+
|
| 57 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 58 |
+
|
| 59 |
+
const float _v = v[t];
|
| 60 |
+
float y = 0;
|
| 61 |
+
|
| 62 |
+
// Process in chunks of 4 for better vectorization
|
| 63 |
+
sycl::float4 k4, r4, tf4, td4, s4;
|
| 64 |
+
#pragma unroll
|
| 65 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 66 |
+
// Load data in vec4 chunks
|
| 67 |
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
| 68 |
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
| 69 |
+
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
| 70 |
+
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
| 71 |
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 72 |
+
|
| 73 |
+
// Compute key-value product
|
| 74 |
+
sycl::float4 kv4 = k4 * _v;
|
| 75 |
+
|
| 76 |
+
// Accumulate weighted sum
|
| 77 |
+
y += sycl::dot(r4, tf4 * kv4 + s4);
|
| 78 |
+
|
| 79 |
+
// Update state
|
| 80 |
+
s4 = s4 * td4 + kv4;
|
| 81 |
+
|
| 82 |
+
// Store updated state
|
| 83 |
+
state[j] = s4.x();
|
| 84 |
+
state[j+1] = s4.y();
|
| 85 |
+
state[j+2] = s4.z();
|
| 86 |
+
state[j+3] = s4.w();
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
dst[t] = y;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// Save final state
|
| 93 |
+
#pragma unroll
|
| 94 |
+
for (int i = 0; i < head_size; i++) {
|
| 95 |
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template <int block_size>
|
| 100 |
+
static void rwkv_wkv7_f32_kernel(
|
| 101 |
+
const int B, const int T, const int C, const int H,
|
| 102 |
+
const float* r, const float* w, const float* k, const float* v,
|
| 103 |
+
const float* a, const float* b, const float* s,
|
| 104 |
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
| 105 |
+
|
| 106 |
+
const int tid = item_ct1.get_local_id(2);
|
| 107 |
+
const int bid = item_ct1.get_group(2);
|
| 108 |
+
|
| 109 |
+
const int head_size = block_size;
|
| 110 |
+
const int batch_i = bid / H;
|
| 111 |
+
const int head_i = bid % H;
|
| 112 |
+
const int state_size = C * head_size;
|
| 113 |
+
const int n_seq_tokens = T / B;
|
| 114 |
+
|
| 115 |
+
float* _r = shared_mem;
|
| 116 |
+
float* _w = _r + head_size;
|
| 117 |
+
float* _k = _w + head_size;
|
| 118 |
+
float* _a = _k + head_size;
|
| 119 |
+
float* _b = _a + head_size;
|
| 120 |
+
|
| 121 |
+
float state[block_size];
|
| 122 |
+
|
| 123 |
+
#pragma unroll
|
| 124 |
+
for (int i = 0; i < head_size; i++) {
|
| 125 |
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
| 129 |
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
| 130 |
+
t += C) {
|
| 131 |
+
|
| 132 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 133 |
+
|
| 134 |
+
_r[tid] = r[t];
|
| 135 |
+
_w[tid] = w[t];
|
| 136 |
+
_k[tid] = k[t];
|
| 137 |
+
_a[tid] = a[t];
|
| 138 |
+
_b[tid] = b[t];
|
| 139 |
+
|
| 140 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 141 |
+
|
| 142 |
+
const float _v = v[t];
|
| 143 |
+
float y = 0, sa = 0;
|
| 144 |
+
sycl::float4 a4, s4;
|
| 145 |
+
|
| 146 |
+
#pragma unroll
|
| 147 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 148 |
+
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
| 149 |
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 150 |
+
sa += sycl::dot(a4, s4);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
sycl::float4 r4, w4, k4, b4;
|
| 154 |
+
#pragma unroll
|
| 155 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 156 |
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
| 157 |
+
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
| 158 |
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
| 159 |
+
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
| 160 |
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 161 |
+
|
| 162 |
+
sycl::float4 kv4 = k4 * _v;
|
| 163 |
+
|
| 164 |
+
s4 = s4 * w4 + kv4 + sa * b4;
|
| 165 |
+
y += sycl::dot(r4, s4);
|
| 166 |
+
|
| 167 |
+
state[j] = s4.x();
|
| 168 |
+
state[j+1] = s4.y();
|
| 169 |
+
state[j+2] = s4.z();
|
| 170 |
+
state[j+3] = s4.w();
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
dst[t] = y;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
#pragma unroll
|
| 177 |
+
for (int i = 0; i < head_size; i++) {
|
| 178 |
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
| 183 |
+
|
| 184 |
+
const ggml_tensor *src0 = dst->src[0];
|
| 185 |
+
const ggml_tensor *src1 = dst->src[1];
|
| 186 |
+
|
| 187 |
+
const float* k_d = (const float*)dst->src[0]->data;
|
| 188 |
+
const float* v_d = (const float*)dst->src[1]->data;
|
| 189 |
+
const float* r_d = (const float*)dst->src[2]->data;
|
| 190 |
+
const float* tf_d = (const float*)dst->src[3]->data;
|
| 191 |
+
const float* td_d = (const float*)dst->src[4]->data;
|
| 192 |
+
const float* s_d = (const float*)dst->src[5]->data;
|
| 193 |
+
float* dst_d = (float*)dst->data;
|
| 194 |
+
|
| 195 |
+
const int64_t B = dst->src[5]->ne[1];
|
| 196 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 197 |
+
const int64_t C = dst->ne[0];
|
| 198 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 199 |
+
|
| 200 |
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
| 201 |
+
GGML_ASSERT(C % H == 0);
|
| 202 |
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
| 203 |
+
|
| 204 |
+
dpct::queue_ptr stream = ctx.stream();
|
| 205 |
+
|
| 206 |
+
// Calculate execution configuration
|
| 207 |
+
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
|
| 208 |
+
sycl::range<3> block_dims(1, 1, C / H);
|
| 209 |
+
sycl::range<3> grid_dims(1, 1, B * H);
|
| 210 |
+
|
| 211 |
+
// Submit kernel
|
| 212 |
+
if (C / H == WKV_BLOCK_SIZE) {
|
| 213 |
+
stream->submit([&](sycl::handler& cgh) {
|
| 214 |
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
| 215 |
+
|
| 216 |
+
cgh.parallel_for(
|
| 217 |
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
| 218 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 219 |
+
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
| 220 |
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
| 221 |
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
| 222 |
+
);
|
| 223 |
+
});
|
| 224 |
+
});
|
| 225 |
+
} else {
|
| 226 |
+
stream->submit([&](sycl::handler& cgh) {
|
| 227 |
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
| 228 |
+
|
| 229 |
+
cgh.parallel_for(
|
| 230 |
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
| 231 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 232 |
+
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
| 233 |
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
| 234 |
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
| 235 |
+
);
|
| 236 |
+
});
|
| 237 |
+
});
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
GGML_UNUSED(src0);
|
| 241 |
+
GGML_UNUSED(src1);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
| 245 |
+
|
| 246 |
+
const ggml_tensor *src0 = dst->src[0];
|
| 247 |
+
const ggml_tensor *src1 = dst->src[1];
|
| 248 |
+
|
| 249 |
+
const float* r_d = (const float*)dst->src[0]->data;
|
| 250 |
+
const float* w_d = (const float*)dst->src[1]->data;
|
| 251 |
+
const float* k_d = (const float*)dst->src[2]->data;
|
| 252 |
+
const float* v_d = (const float*)dst->src[3]->data;
|
| 253 |
+
const float* a_d = (const float*)dst->src[4]->data;
|
| 254 |
+
const float* b_d = (const float*)dst->src[5]->data;
|
| 255 |
+
const float* s_d = (const float*)dst->src[6]->data;
|
| 256 |
+
float* dst_d = (float*)dst->data;
|
| 257 |
+
|
| 258 |
+
const int64_t B = dst->src[6]->ne[1];
|
| 259 |
+
const int64_t T = dst->src[0]->ne[2];
|
| 260 |
+
const int64_t C = dst->ne[0];
|
| 261 |
+
const int64_t H = dst->src[0]->ne[1];
|
| 262 |
+
|
| 263 |
+
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
| 264 |
+
GGML_ASSERT(C % H == 0);
|
| 265 |
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
|
| 266 |
+
|
| 267 |
+
dpct::queue_ptr stream = ctx.stream();
|
| 268 |
+
|
| 269 |
+
// Calculate execution configuration
|
| 270 |
+
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
|
| 271 |
+
sycl::range<3> block_dims(1, 1, C / H);
|
| 272 |
+
sycl::range<3> grid_dims(1, 1, B * H);
|
| 273 |
+
|
| 274 |
+
// Submit kernel
|
| 275 |
+
if (C / H == WKV_BLOCK_SIZE) {
|
| 276 |
+
stream->submit([&](sycl::handler& cgh) {
|
| 277 |
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
| 278 |
+
|
| 279 |
+
cgh.parallel_for(
|
| 280 |
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
| 281 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 282 |
+
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
| 283 |
+
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
| 284 |
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
| 285 |
+
);
|
| 286 |
+
});
|
| 287 |
+
});
|
| 288 |
+
} else {
|
| 289 |
+
stream->submit([&](sycl::handler& cgh) {
|
| 290 |
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
| 291 |
+
|
| 292 |
+
cgh.parallel_for(
|
| 293 |
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
| 294 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 295 |
+
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
| 296 |
+
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
| 297 |
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
| 298 |
+
);
|
| 299 |
+
});
|
| 300 |
+
});
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
GGML_UNUSED(src0);
|
| 304 |
+
GGML_UNUSED(src1);
|
| 305 |
+
}
|
ggml/src/ggml-sycl/wkv.hpp
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef GGML_SYCL_WKV_HPP
|
| 2 |
+
#define GGML_SYCL_WKV_HPP
|
| 3 |
+
|
| 4 |
+
#include "common.hpp"
|
| 5 |
+
|
| 6 |
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 7 |
+
|
| 8 |
+
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 9 |
+
|
| 10 |
+
#endif // GGML_SYCL_WKV_HPP
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -304,6 +304,7 @@ struct vk_device_struct {
|
|
| 304 |
vk_pipeline pipeline_group_norm_f32;
|
| 305 |
vk_pipeline pipeline_rms_norm_f32;
|
| 306 |
vk_pipeline pipeline_rms_norm_back_f32;
|
|
|
|
| 307 |
vk_pipeline pipeline_gelu_f32;
|
| 308 |
vk_pipeline pipeline_gelu_quick_f32;
|
| 309 |
vk_pipeline pipeline_silu_f32;
|
|
@@ -328,6 +329,7 @@ struct vk_device_struct {
|
|
| 328 |
vk_pipeline pipeline_timestep_embedding_f32;
|
| 329 |
vk_pipeline pipeline_pool2d_f32;
|
| 330 |
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
|
|
| 331 |
vk_pipeline pipeline_opt_step_adamw_f32;
|
| 332 |
|
| 333 |
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
@@ -629,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
|
|
| 629 |
uint32_t H;
|
| 630 |
};
|
| 631 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
// Allow pre-recording command buffers
|
| 633 |
struct vk_staging_memcpy {
|
| 634 |
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
@@ -2263,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2263 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2264 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2265 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
|
| 2266 |
|
| 2267 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2268 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -2374,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2374 |
|
| 2375 |
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
| 2376 |
|
|
|
|
|
|
|
| 2377 |
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2378 |
|
| 2379 |
for (auto &c : compiles) {
|
|
@@ -5473,6 +5485,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 5473 |
return ctx->device->pipeline_rms_norm_back_f32;
|
| 5474 |
}
|
| 5475 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5476 |
case GGML_OP_UNARY:
|
| 5477 |
switch (ggml_get_unary_op(dst)) {
|
| 5478 |
case GGML_UNARY_OP_SILU:
|
|
@@ -5612,6 +5629,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 5612 |
return ctx->device->pipeline_rwkv_wkv6_f32;
|
| 5613 |
}
|
| 5614 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5615 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 5616 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5617 |
return ctx->device->pipeline_opt_step_adamw_f32;
|
|
@@ -5859,6 +5881,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 5859 |
case GGML_OP_NORM:
|
| 5860 |
case GGML_OP_RMS_NORM:
|
| 5861 |
case GGML_OP_RMS_NORM_BACK:
|
|
|
|
| 5862 |
case GGML_OP_SOFT_MAX:
|
| 5863 |
case GGML_OP_SOFT_MAX_BACK:
|
| 5864 |
case GGML_OP_SUM_ROWS:
|
|
@@ -6108,23 +6131,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
| 6108 |
}, dryrun);
|
| 6109 |
}
|
| 6110 |
|
| 6111 |
-
static void
|
| 6112 |
-
|
| 6113 |
-
|
| 6114 |
-
|
| 6115 |
-
|
| 6116 |
-
|
| 6117 |
-
|
| 6118 |
-
|
| 6119 |
-
GGML_ASSERT(!ggml_is_quantized(k->type));
|
| 6120 |
-
GGML_ASSERT(!ggml_is_quantized(v->type));
|
| 6121 |
-
GGML_ASSERT(!ggml_is_quantized(r->type));
|
| 6122 |
-
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
| 6123 |
-
GGML_ASSERT(!ggml_is_quantized(td->type));
|
| 6124 |
-
GGML_ASSERT(!ggml_is_quantized(state->type));
|
| 6125 |
GGML_ASSERT(dst->buffer != nullptr);
|
| 6126 |
|
| 6127 |
-
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx,
|
| 6128 |
GGML_ASSERT(pipeline != nullptr);
|
| 6129 |
|
| 6130 |
if (dryrun) {
|
|
@@ -6133,89 +6150,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
|
| 6133 |
}
|
| 6134 |
|
| 6135 |
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
| 6136 |
-
ggml_backend_vk_buffer_context *
|
| 6137 |
-
|
| 6138 |
-
|
| 6139 |
-
|
| 6140 |
-
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
| 6141 |
-
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
| 6142 |
|
| 6143 |
ggml_vk_sync_buffers(subctx);
|
| 6144 |
|
| 6145 |
-
vk_buffer d_D = nullptr,
|
| 6146 |
-
size_t
|
| 6147 |
-
bool
|
| 6148 |
|
| 6149 |
if (ctx->device->uma) {
|
| 6150 |
-
|
| 6151 |
-
|
| 6152 |
-
|
| 6153 |
-
|
| 6154 |
-
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
| 6155 |
-
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
| 6156 |
-
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
| 6157 |
|
| 6158 |
-
|
| 6159 |
-
|
| 6160 |
-
R_uma = d_R != nullptr;
|
| 6161 |
-
TF_uma = d_TF != nullptr;
|
| 6162 |
-
TD_uma = d_TD != nullptr;
|
| 6163 |
-
STATE_uma = d_State != nullptr;
|
| 6164 |
-
DST_uma = d_D != nullptr;
|
| 6165 |
}
|
| 6166 |
|
| 6167 |
-
|
| 6168 |
-
|
| 6169 |
-
|
| 6170 |
-
|
| 6171 |
-
|
| 6172 |
-
|
| 6173 |
-
|
| 6174 |
-
}
|
| 6175 |
-
if (!R_uma) {
|
| 6176 |
-
d_R = r_buf_ctx->dev_buffer;
|
| 6177 |
-
r_offset = vk_tensor_offset(r) + r->view_offs;
|
| 6178 |
-
}
|
| 6179 |
-
if (!TF_uma) {
|
| 6180 |
-
d_TF = tf_buf_ctx->dev_buffer;
|
| 6181 |
-
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
| 6182 |
-
}
|
| 6183 |
-
if (!TD_uma) {
|
| 6184 |
-
d_TD = td_buf_ctx->dev_buffer;
|
| 6185 |
-
td_offset = vk_tensor_offset(td) + td->view_offs;
|
| 6186 |
-
}
|
| 6187 |
-
if (!STATE_uma) {
|
| 6188 |
-
d_State = state_buf_ctx->dev_buffer;
|
| 6189 |
-
state_offset = vk_tensor_offset(state) + state->view_offs;
|
| 6190 |
}
|
| 6191 |
-
|
|
|
|
|
|
|
| 6192 |
d_D = dst_buf_ctx->dev_buffer;
|
| 6193 |
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
| 6194 |
}
|
| 6195 |
|
| 6196 |
-
const uint64_t k_size = ggml_nbytes(k);
|
| 6197 |
-
const uint64_t v_size = ggml_nbytes(v);
|
| 6198 |
-
const uint64_t r_size = ggml_nbytes(r);
|
| 6199 |
-
const uint64_t tf_size = ggml_nbytes(tf);
|
| 6200 |
-
const uint64_t td_size = ggml_nbytes(td);
|
| 6201 |
-
const uint64_t state_size = ggml_nbytes(state);
|
| 6202 |
-
const uint64_t dst_size = ggml_nbytes(dst);
|
| 6203 |
-
|
| 6204 |
std::array<uint32_t, 3> elements = {
|
| 6205 |
(uint32_t)(pc.B * pc.H),
|
| 6206 |
1,
|
| 6207 |
1
|
| 6208 |
};
|
| 6209 |
|
| 6210 |
-
|
| 6211 |
-
|
| 6212 |
-
|
| 6213 |
-
|
| 6214 |
-
|
| 6215 |
-
|
| 6216 |
-
|
| 6217 |
-
|
| 6218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6219 |
}
|
| 6220 |
|
| 6221 |
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -6224,7 +6225,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 6224 |
const size_t n_heads = dst->src[0]->ne[1];
|
| 6225 |
const size_t n_seqs = dst->src[5]->ne[1];
|
| 6226 |
|
| 6227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6228 |
ctx, subctx, dst,
|
| 6229 |
{
|
| 6230 |
(uint32_t)n_seqs,
|
|
@@ -6232,6 +6252,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 6232 |
(uint32_t)n_embed,
|
| 6233 |
(uint32_t)n_heads,
|
| 6234 |
},
|
|
|
|
| 6235 |
dryrun
|
| 6236 |
);
|
| 6237 |
}
|
|
@@ -6533,6 +6554,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 6533 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
| 6534 |
}
|
| 6535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6536 |
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6537 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
| 6538 |
}
|
|
@@ -7528,6 +7554,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7528 |
case GGML_OP_GROUP_NORM:
|
| 7529 |
case GGML_OP_RMS_NORM:
|
| 7530 |
case GGML_OP_RMS_NORM_BACK:
|
|
|
|
| 7531 |
case GGML_OP_DIAG_MASK_INF:
|
| 7532 |
case GGML_OP_SOFT_MAX:
|
| 7533 |
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -7544,6 +7571,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7544 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 7545 |
case GGML_OP_POOL_2D:
|
| 7546 |
case GGML_OP_RWKV_WKV6:
|
|
|
|
| 7547 |
case GGML_OP_LEAKY_RELU:
|
| 7548 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 7549 |
case GGML_OP_OPT_STEP_ADAMW:
|
|
@@ -7590,6 +7618,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7590 |
case GGML_OP_GROUP_NORM:
|
| 7591 |
case GGML_OP_RMS_NORM:
|
| 7592 |
case GGML_OP_RMS_NORM_BACK:
|
|
|
|
| 7593 |
case GGML_OP_UNARY:
|
| 7594 |
case GGML_OP_DIAG_MASK_INF:
|
| 7595 |
case GGML_OP_SOFT_MAX:
|
|
@@ -7707,6 +7736,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7707 |
case GGML_OP_RMS_NORM_BACK:
|
| 7708 |
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 7709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7710 |
break;
|
| 7711 |
case GGML_OP_UNARY:
|
| 7712 |
switch (ggml_get_unary_op(node)) {
|
|
@@ -7797,6 +7830,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7797 |
|
| 7798 |
break;
|
| 7799 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7800 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 7801 |
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
| 7802 |
|
|
@@ -7870,6 +7908,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
| 7870 |
case GGML_OP_GROUP_NORM:
|
| 7871 |
case GGML_OP_RMS_NORM:
|
| 7872 |
case GGML_OP_RMS_NORM_BACK:
|
|
|
|
| 7873 |
case GGML_OP_DIAG_MASK_INF:
|
| 7874 |
case GGML_OP_SOFT_MAX:
|
| 7875 |
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -7889,6 +7928,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
| 7889 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 7890 |
case GGML_OP_POOL_2D:
|
| 7891 |
case GGML_OP_RWKV_WKV6:
|
|
|
|
| 7892 |
case GGML_OP_LEAKY_RELU:
|
| 7893 |
case GGML_OP_REPEAT:
|
| 7894 |
case GGML_OP_REPEAT_BACK:
|
|
@@ -8806,6 +8846,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 8806 |
case GGML_OP_NORM:
|
| 8807 |
case GGML_OP_GROUP_NORM:
|
| 8808 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 8809 |
return ggml_is_contiguous(op->src[0]);
|
| 8810 |
case GGML_OP_ADD:
|
| 8811 |
case GGML_OP_SUB:
|
|
@@ -8835,6 +8876,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 8835 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 8836 |
case GGML_OP_POOL_2D:
|
| 8837 |
case GGML_OP_RWKV_WKV6:
|
|
|
|
| 8838 |
case GGML_OP_LEAKY_RELU:
|
| 8839 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 8840 |
return true;
|
|
@@ -9219,6 +9261,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9219 |
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
| 9220 |
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
| 9221 |
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
|
|
|
|
|
|
|
|
|
| 9222 |
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
| 9223 |
if (src1 != nullptr) {
|
| 9224 |
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
@@ -9338,6 +9383,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9338 |
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
| 9339 |
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
| 9340 |
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
|
|
|
|
|
|
|
|
|
| 9341 |
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
| 9342 |
src_clone[0]->flags = src0->flags;
|
| 9343 |
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
|
|
|
| 304 |
vk_pipeline pipeline_group_norm_f32;
|
| 305 |
vk_pipeline pipeline_rms_norm_f32;
|
| 306 |
vk_pipeline pipeline_rms_norm_back_f32;
|
| 307 |
+
vk_pipeline pipeline_l2_norm_f32;
|
| 308 |
vk_pipeline pipeline_gelu_f32;
|
| 309 |
vk_pipeline pipeline_gelu_quick_f32;
|
| 310 |
vk_pipeline pipeline_silu_f32;
|
|
|
|
| 329 |
vk_pipeline pipeline_timestep_embedding_f32;
|
| 330 |
vk_pipeline pipeline_pool2d_f32;
|
| 331 |
vk_pipeline pipeline_rwkv_wkv6_f32;
|
| 332 |
+
vk_pipeline pipeline_rwkv_wkv7_f32;
|
| 333 |
vk_pipeline pipeline_opt_step_adamw_f32;
|
| 334 |
|
| 335 |
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
|
|
| 631 |
uint32_t H;
|
| 632 |
};
|
| 633 |
|
| 634 |
+
struct vk_op_rwkv_wkv7_push_constants {
|
| 635 |
+
uint32_t B;
|
| 636 |
+
uint32_t T;
|
| 637 |
+
uint32_t C;
|
| 638 |
+
uint32_t H;
|
| 639 |
+
};
|
| 640 |
+
|
| 641 |
// Allow pre-recording command buffers
|
| 642 |
struct vk_staging_memcpy {
|
| 643 |
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
|
|
| 2272 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2273 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2274 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2275 |
+
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2276 |
|
| 2277 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2278 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
| 2384 |
|
| 2385 |
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
| 2386 |
|
| 2387 |
+
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
| 2388 |
+
|
| 2389 |
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2390 |
|
| 2391 |
for (auto &c : compiles) {
|
|
|
|
| 5485 |
return ctx->device->pipeline_rms_norm_back_f32;
|
| 5486 |
}
|
| 5487 |
return nullptr;
|
| 5488 |
+
case GGML_OP_L2_NORM:
|
| 5489 |
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5490 |
+
return ctx->device->pipeline_l2_norm_f32;
|
| 5491 |
+
}
|
| 5492 |
+
return nullptr;
|
| 5493 |
case GGML_OP_UNARY:
|
| 5494 |
switch (ggml_get_unary_op(dst)) {
|
| 5495 |
case GGML_UNARY_OP_SILU:
|
|
|
|
| 5629 |
return ctx->device->pipeline_rwkv_wkv6_f32;
|
| 5630 |
}
|
| 5631 |
return nullptr;
|
| 5632 |
+
case GGML_OP_RWKV_WKV7:
|
| 5633 |
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5634 |
+
return ctx->device->pipeline_rwkv_wkv7_f32;
|
| 5635 |
+
}
|
| 5636 |
+
return nullptr;
|
| 5637 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 5638 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5639 |
return ctx->device->pipeline_opt_step_adamw_f32;
|
|
|
|
| 5881 |
case GGML_OP_NORM:
|
| 5882 |
case GGML_OP_RMS_NORM:
|
| 5883 |
case GGML_OP_RMS_NORM_BACK:
|
| 5884 |
+
case GGML_OP_L2_NORM:
|
| 5885 |
case GGML_OP_SOFT_MAX:
|
| 5886 |
case GGML_OP_SOFT_MAX_BACK:
|
| 5887 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 6131 |
}, dryrun);
|
| 6132 |
}
|
| 6133 |
|
| 6134 |
+
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
| 6135 |
+
GGML_ASSERT(version == 6 || version == 7);
|
| 6136 |
+
int num_srcs = version == 6 ? 6 : 7;
|
| 6137 |
+
|
| 6138 |
+
for (int i = 0; i < num_srcs; i++) {
|
| 6139 |
+
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
| 6140 |
+
}
|
| 6141 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6142 |
GGML_ASSERT(dst->buffer != nullptr);
|
| 6143 |
|
| 6144 |
+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
| 6145 |
GGML_ASSERT(pipeline != nullptr);
|
| 6146 |
|
| 6147 |
if (dryrun) {
|
|
|
|
| 6150 |
}
|
| 6151 |
|
| 6152 |
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
| 6153 |
+
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
| 6154 |
+
for (int i = 0; i < num_srcs; i++) {
|
| 6155 |
+
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
| 6156 |
+
}
|
|
|
|
|
|
|
| 6157 |
|
| 6158 |
ggml_vk_sync_buffers(subctx);
|
| 6159 |
|
| 6160 |
+
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
| 6161 |
+
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
| 6162 |
+
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
| 6163 |
|
| 6164 |
if (ctx->device->uma) {
|
| 6165 |
+
for (int i = 0; i < num_srcs; i++) {
|
| 6166 |
+
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
| 6167 |
+
srcs_uma[i] = d_srcs[i] != nullptr;
|
| 6168 |
+
}
|
|
|
|
|
|
|
|
|
|
| 6169 |
|
| 6170 |
+
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
| 6171 |
+
dst_uma = d_D != nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6172 |
}
|
| 6173 |
|
| 6174 |
+
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
| 6175 |
+
for (int i = 0; i < num_srcs; i++) {
|
| 6176 |
+
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
| 6177 |
+
if (!srcs_uma[i]) {
|
| 6178 |
+
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
| 6179 |
+
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
| 6180 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6181 |
}
|
| 6182 |
+
|
| 6183 |
+
const uint64_t dst_size = ggml_nbytes(dst);
|
| 6184 |
+
if (!dst_uma) {
|
| 6185 |
d_D = dst_buf_ctx->dev_buffer;
|
| 6186 |
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
| 6187 |
}
|
| 6188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6189 |
std::array<uint32_t, 3> elements = {
|
| 6190 |
(uint32_t)(pc.B * pc.H),
|
| 6191 |
1,
|
| 6192 |
1
|
| 6193 |
};
|
| 6194 |
|
| 6195 |
+
if (version == 6) {
|
| 6196 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
| 6197 |
+
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
| 6198 |
+
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
| 6199 |
+
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
| 6200 |
+
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
| 6201 |
+
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
| 6202 |
+
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
| 6203 |
+
vk_subbuffer{ d_D, dst_offset, dst_size }
|
| 6204 |
+
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
| 6205 |
+
} else if (version == 7) {
|
| 6206 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
| 6207 |
+
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
| 6208 |
+
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
| 6209 |
+
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
| 6210 |
+
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
| 6211 |
+
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
| 6212 |
+
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
| 6213 |
+
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
| 6214 |
+
vk_subbuffer{ d_D, dst_offset, dst_size }
|
| 6215 |
+
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
| 6216 |
+
} else {
|
| 6217 |
+
// shouldn't happen
|
| 6218 |
+
GGML_ASSERT(false);
|
| 6219 |
+
}
|
| 6220 |
}
|
| 6221 |
|
| 6222 |
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
|
|
|
| 6225 |
const size_t n_heads = dst->src[0]->ne[1];
|
| 6226 |
const size_t n_seqs = dst->src[5]->ne[1];
|
| 6227 |
|
| 6228 |
+
ggml_vk_op_f32_wkv(
|
| 6229 |
+
ctx, subctx, dst,
|
| 6230 |
+
{
|
| 6231 |
+
(uint32_t)n_seqs,
|
| 6232 |
+
(uint32_t)seq_length,
|
| 6233 |
+
(uint32_t)n_embed,
|
| 6234 |
+
(uint32_t)n_heads,
|
| 6235 |
+
},
|
| 6236 |
+
6,
|
| 6237 |
+
dryrun
|
| 6238 |
+
);
|
| 6239 |
+
}
|
| 6240 |
+
|
| 6241 |
+
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
| 6242 |
+
const size_t seq_length = dst->src[0]->ne[2];
|
| 6243 |
+
const size_t n_embed = dst->ne[0];
|
| 6244 |
+
const size_t n_heads = dst->src[0]->ne[1];
|
| 6245 |
+
const size_t n_seqs = dst->src[6]->ne[1];
|
| 6246 |
+
|
| 6247 |
+
ggml_vk_op_f32_wkv(
|
| 6248 |
ctx, subctx, dst,
|
| 6249 |
{
|
| 6250 |
(uint32_t)n_seqs,
|
|
|
|
| 6252 |
(uint32_t)n_embed,
|
| 6253 |
(uint32_t)n_heads,
|
| 6254 |
},
|
| 6255 |
+
7,
|
| 6256 |
dryrun
|
| 6257 |
);
|
| 6258 |
}
|
|
|
|
| 6554 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
| 6555 |
}
|
| 6556 |
|
| 6557 |
+
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6558 |
+
float * op_params = (float *)dst->op_params;
|
| 6559 |
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
| 6560 |
+
}
|
| 6561 |
+
|
| 6562 |
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6563 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
| 6564 |
}
|
|
|
|
| 7554 |
case GGML_OP_GROUP_NORM:
|
| 7555 |
case GGML_OP_RMS_NORM:
|
| 7556 |
case GGML_OP_RMS_NORM_BACK:
|
| 7557 |
+
case GGML_OP_L2_NORM:
|
| 7558 |
case GGML_OP_DIAG_MASK_INF:
|
| 7559 |
case GGML_OP_SOFT_MAX:
|
| 7560 |
case GGML_OP_SOFT_MAX_BACK:
|
|
|
|
| 7571 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 7572 |
case GGML_OP_POOL_2D:
|
| 7573 |
case GGML_OP_RWKV_WKV6:
|
| 7574 |
+
case GGML_OP_RWKV_WKV7:
|
| 7575 |
case GGML_OP_LEAKY_RELU:
|
| 7576 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 7577 |
case GGML_OP_OPT_STEP_ADAMW:
|
|
|
|
| 7618 |
case GGML_OP_GROUP_NORM:
|
| 7619 |
case GGML_OP_RMS_NORM:
|
| 7620 |
case GGML_OP_RMS_NORM_BACK:
|
| 7621 |
+
case GGML_OP_L2_NORM:
|
| 7622 |
case GGML_OP_UNARY:
|
| 7623 |
case GGML_OP_DIAG_MASK_INF:
|
| 7624 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 7736 |
case GGML_OP_RMS_NORM_BACK:
|
| 7737 |
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 7738 |
|
| 7739 |
+
break;
|
| 7740 |
+
case GGML_OP_L2_NORM:
|
| 7741 |
+
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
|
| 7742 |
+
|
| 7743 |
break;
|
| 7744 |
case GGML_OP_UNARY:
|
| 7745 |
switch (ggml_get_unary_op(node)) {
|
|
|
|
| 7830 |
|
| 7831 |
break;
|
| 7832 |
|
| 7833 |
+
case GGML_OP_RWKV_WKV7:
|
| 7834 |
+
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
|
| 7835 |
+
|
| 7836 |
+
break;
|
| 7837 |
+
|
| 7838 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 7839 |
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
| 7840 |
|
|
|
|
| 7908 |
case GGML_OP_GROUP_NORM:
|
| 7909 |
case GGML_OP_RMS_NORM:
|
| 7910 |
case GGML_OP_RMS_NORM_BACK:
|
| 7911 |
+
case GGML_OP_L2_NORM:
|
| 7912 |
case GGML_OP_DIAG_MASK_INF:
|
| 7913 |
case GGML_OP_SOFT_MAX:
|
| 7914 |
case GGML_OP_SOFT_MAX_BACK:
|
|
|
|
| 7928 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 7929 |
case GGML_OP_POOL_2D:
|
| 7930 |
case GGML_OP_RWKV_WKV6:
|
| 7931 |
+
case GGML_OP_RWKV_WKV7:
|
| 7932 |
case GGML_OP_LEAKY_RELU:
|
| 7933 |
case GGML_OP_REPEAT:
|
| 7934 |
case GGML_OP_REPEAT_BACK:
|
|
|
|
| 8846 |
case GGML_OP_NORM:
|
| 8847 |
case GGML_OP_GROUP_NORM:
|
| 8848 |
case GGML_OP_RMS_NORM:
|
| 8849 |
+
case GGML_OP_L2_NORM:
|
| 8850 |
return ggml_is_contiguous(op->src[0]);
|
| 8851 |
case GGML_OP_ADD:
|
| 8852 |
case GGML_OP_SUB:
|
|
|
|
| 8876 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 8877 |
case GGML_OP_POOL_2D:
|
| 8878 |
case GGML_OP_RWKV_WKV6:
|
| 8879 |
+
case GGML_OP_RWKV_WKV7:
|
| 8880 |
case GGML_OP_LEAKY_RELU:
|
| 8881 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 8882 |
return true;
|
|
|
|
| 9261 |
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
| 9262 |
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
| 9263 |
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
| 9264 |
+
} else if (tensor->op == GGML_OP_L2_NORM) {
|
| 9265 |
+
const float eps = ((float *) tensor->op_params)[0];
|
| 9266 |
+
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
| 9267 |
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
| 9268 |
if (src1 != nullptr) {
|
| 9269 |
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
|
|
| 9383 |
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
| 9384 |
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
| 9385 |
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
| 9386 |
+
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
| 9387 |
+
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
| 9388 |
+
src_clone[4], src_clone[5], src_clone[6]);
|
| 9389 |
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
| 9390 |
src_clone[0]->flags = src0->flags;
|
| 9391 |
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "generic_head.comp"
|
| 4 |
+
#include "types.comp"
|
| 5 |
+
|
| 6 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 7 |
+
#define BLOCK_SIZE 512
|
| 8 |
+
|
| 9 |
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
+
|
| 11 |
+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
| 12 |
+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
| 13 |
+
|
| 14 |
+
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
| 15 |
+
|
| 16 |
+
void main() {
|
| 17 |
+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
| 18 |
+
const uint tid = gl_LocalInvocationID.x;
|
| 19 |
+
|
| 20 |
+
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
| 21 |
+
|
| 22 |
+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
| 23 |
+
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
| 24 |
+
sum[tid] += xi * xi;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// sum up partial sums and write back result
|
| 28 |
+
barrier();
|
| 29 |
+
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
| 30 |
+
if (tid < s) {
|
| 31 |
+
sum[tid] += sum[tid + s];
|
| 32 |
+
}
|
| 33 |
+
barrier();
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
| 37 |
+
|
| 38 |
+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
| 39 |
+
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
| 40 |
+
}
|
| 41 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
CHANGED
|
@@ -434,6 +434,7 @@ void process_shaders() {
|
|
| 434 |
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 435 |
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 436 |
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
|
|
| 437 |
|
| 438 |
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 439 |
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
|
@@ -528,6 +529,8 @@ void process_shaders() {
|
|
| 528 |
|
| 529 |
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 530 |
|
|
|
|
|
|
|
| 531 |
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 532 |
|
| 533 |
for (auto &c : compiles) {
|
|
|
|
| 434 |
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 435 |
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 436 |
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 437 |
+
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 438 |
|
| 439 |
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 440 |
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
|
|
|
| 529 |
|
| 530 |
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 531 |
|
| 532 |
+
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 533 |
+
|
| 534 |
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 535 |
|
| 536 |
for (auto &c : compiles) {
|
ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : require
|
| 4 |
+
|
| 5 |
+
#define BLOCK_SIZE 64
|
| 6 |
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 7 |
+
|
| 8 |
+
layout(push_constant) uniform Parameters {
|
| 9 |
+
uint B;
|
| 10 |
+
uint T;
|
| 11 |
+
uint C;
|
| 12 |
+
uint H;
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
|
| 16 |
+
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
|
| 17 |
+
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
|
| 18 |
+
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
|
| 19 |
+
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
|
| 20 |
+
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
|
| 21 |
+
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
|
| 22 |
+
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
|
| 23 |
+
|
| 24 |
+
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
|
| 25 |
+
|
| 26 |
+
void main() {
|
| 27 |
+
const uint head_size = BLOCK_SIZE;
|
| 28 |
+
const uint batch_id = gl_WorkGroupID.x / H;
|
| 29 |
+
const uint head_id = gl_WorkGroupID.x % H;
|
| 30 |
+
const uint tid = gl_LocalInvocationID.x;
|
| 31 |
+
|
| 32 |
+
const uint state_size = C * head_size;
|
| 33 |
+
const uint n_seq_tokens = T / B;
|
| 34 |
+
|
| 35 |
+
if (batch_id >= B || head_id >= H) {
|
| 36 |
+
return;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
A_TYPE state[BLOCK_SIZE];
|
| 40 |
+
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
| 41 |
+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
| 42 |
+
+ tid * head_size + i];
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
| 46 |
+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
| 47 |
+
|
| 48 |
+
for (uint t = start_t; t < end_t; t += C) {
|
| 49 |
+
barrier();
|
| 50 |
+
_r[tid] = r[t];
|
| 51 |
+
_w[tid] = w[t];
|
| 52 |
+
_k[tid] = k[t];
|
| 53 |
+
_a[tid] = a[t];
|
| 54 |
+
_b[tid] = b[t];
|
| 55 |
+
barrier();
|
| 56 |
+
|
| 57 |
+
A_TYPE sa = 0.0;
|
| 58 |
+
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
| 59 |
+
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 60 |
+
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
| 61 |
+
sa += dot(s_vec, a_vec);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
const A_TYPE v_val = v[t];
|
| 65 |
+
A_TYPE y = 0.0;
|
| 66 |
+
|
| 67 |
+
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
| 68 |
+
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
| 69 |
+
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
| 70 |
+
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
| 71 |
+
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
| 72 |
+
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 73 |
+
|
| 74 |
+
vec4 kv = k_vec * v_val;
|
| 75 |
+
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
| 76 |
+
y += dot(r_vec, s_vec);
|
| 77 |
+
|
| 78 |
+
state[j] = s_vec.x;
|
| 79 |
+
state[j+1] = s_vec.y;
|
| 80 |
+
state[j+2] = s_vec.z;
|
| 81 |
+
state[j+3] = s_vec.w;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
dst[t] = y;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
| 88 |
+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
| 89 |
+
+ tid * head_size + i] = state[i];
|
| 90 |
+
}
|
| 91 |
+
}
|
ggml/src/ggml.c
CHANGED
|
@@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 929 |
"RMS_NORM",
|
| 930 |
"RMS_NORM_BACK",
|
| 931 |
"GROUP_NORM",
|
|
|
|
| 932 |
|
| 933 |
"MUL_MAT",
|
| 934 |
"MUL_MAT_ID",
|
|
@@ -977,6 +978,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 977 |
"ADD_REL_POS",
|
| 978 |
"RWKV_WKV6",
|
| 979 |
"GATED_LINEAR_ATTN",
|
|
|
|
| 980 |
|
| 981 |
"UNARY",
|
| 982 |
|
|
@@ -996,7 +998,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 996 |
"OPT_STEP_ADAMW",
|
| 997 |
};
|
| 998 |
|
| 999 |
-
static_assert(GGML_OP_COUNT ==
|
| 1000 |
|
| 1001 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1002 |
"none",
|
|
@@ -1026,6 +1028,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1026 |
"rms_norm(x)",
|
| 1027 |
"rms_norm_back(x)",
|
| 1028 |
"group_norm(x)",
|
|
|
|
| 1029 |
|
| 1030 |
"X*Y",
|
| 1031 |
"X[i]*Y",
|
|
@@ -1074,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1074 |
"add_rel_pos(x)",
|
| 1075 |
"rwkv_wkv6(k, v, r, tf, td, s)",
|
| 1076 |
"gated_linear_attn(k, v, q, gate, s)",
|
|
|
|
| 1077 |
|
| 1078 |
"unary(x)",
|
| 1079 |
|
|
@@ -1093,7 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1093 |
"adamw(x)",
|
| 1094 |
};
|
| 1095 |
|
| 1096 |
-
static_assert(GGML_OP_COUNT ==
|
| 1097 |
|
| 1098 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1099 |
|
|
@@ -2686,6 +2690,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
|
|
| 2686 |
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
| 2687 |
}
|
| 2688 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2689 |
// ggml_mul_mat
|
| 2690 |
|
| 2691 |
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
|
@@ -4720,6 +4755,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
|
|
| 4720 |
return result;
|
| 4721 |
}
|
| 4722 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4723 |
// ggml_unary
|
| 4724 |
|
| 4725 |
static struct ggml_tensor * ggml_unary_impl(
|
|
|
|
| 929 |
"RMS_NORM",
|
| 930 |
"RMS_NORM_BACK",
|
| 931 |
"GROUP_NORM",
|
| 932 |
+
"L2_NORM",
|
| 933 |
|
| 934 |
"MUL_MAT",
|
| 935 |
"MUL_MAT_ID",
|
|
|
|
| 978 |
"ADD_REL_POS",
|
| 979 |
"RWKV_WKV6",
|
| 980 |
"GATED_LINEAR_ATTN",
|
| 981 |
+
"RWKV_WKV7",
|
| 982 |
|
| 983 |
"UNARY",
|
| 984 |
|
|
|
|
| 998 |
"OPT_STEP_ADAMW",
|
| 999 |
};
|
| 1000 |
|
| 1001 |
+
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
| 1002 |
|
| 1003 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1004 |
"none",
|
|
|
|
| 1028 |
"rms_norm(x)",
|
| 1029 |
"rms_norm_back(x)",
|
| 1030 |
"group_norm(x)",
|
| 1031 |
+
"l2_norm(x)",
|
| 1032 |
|
| 1033 |
"X*Y",
|
| 1034 |
"X[i]*Y",
|
|
|
|
| 1077 |
"add_rel_pos(x)",
|
| 1078 |
"rwkv_wkv6(k, v, r, tf, td, s)",
|
| 1079 |
"gated_linear_attn(k, v, q, gate, s)",
|
| 1080 |
+
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
| 1081 |
|
| 1082 |
"unary(x)",
|
| 1083 |
|
|
|
|
| 1097 |
"adamw(x)",
|
| 1098 |
};
|
| 1099 |
|
| 1100 |
+
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
| 1101 |
|
| 1102 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1103 |
|
|
|
|
| 2690 |
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
| 2691 |
}
|
| 2692 |
|
| 2693 |
+
// ggml_l2_norm
|
| 2694 |
+
|
| 2695 |
+
static struct ggml_tensor * ggml_l2_norm_impl(
|
| 2696 |
+
struct ggml_context * ctx,
|
| 2697 |
+
struct ggml_tensor * a,
|
| 2698 |
+
float eps,
|
| 2699 |
+
bool inplace) {
|
| 2700 |
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 2701 |
+
|
| 2702 |
+
ggml_set_op_params_f32(result, 0, eps);
|
| 2703 |
+
|
| 2704 |
+
result->op = GGML_OP_L2_NORM;
|
| 2705 |
+
result->src[0] = a;
|
| 2706 |
+
|
| 2707 |
+
return result;
|
| 2708 |
+
}
|
| 2709 |
+
|
| 2710 |
+
struct ggml_tensor * ggml_l2_norm(
|
| 2711 |
+
struct ggml_context * ctx,
|
| 2712 |
+
struct ggml_tensor * a,
|
| 2713 |
+
float eps) {
|
| 2714 |
+
return ggml_l2_norm_impl(ctx, a, eps, false);
|
| 2715 |
+
}
|
| 2716 |
+
|
| 2717 |
+
struct ggml_tensor * ggml_l2_norm_inplace(
|
| 2718 |
+
struct ggml_context * ctx,
|
| 2719 |
+
struct ggml_tensor * a,
|
| 2720 |
+
float eps) {
|
| 2721 |
+
return ggml_l2_norm_impl(ctx, a, eps, true);
|
| 2722 |
+
}
|
| 2723 |
+
|
| 2724 |
// ggml_mul_mat
|
| 2725 |
|
| 2726 |
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
|
|
|
| 4755 |
return result;
|
| 4756 |
}
|
| 4757 |
|
| 4758 |
+
// ggml_rwkv_wkv7
|
| 4759 |
+
|
| 4760 |
+
struct ggml_tensor * ggml_rwkv_wkv7(
|
| 4761 |
+
struct ggml_context * ctx,
|
| 4762 |
+
struct ggml_tensor * r,
|
| 4763 |
+
struct ggml_tensor * w,
|
| 4764 |
+
struct ggml_tensor * k,
|
| 4765 |
+
struct ggml_tensor * v,
|
| 4766 |
+
struct ggml_tensor * a,
|
| 4767 |
+
struct ggml_tensor * b,
|
| 4768 |
+
struct ggml_tensor * state) {
|
| 4769 |
+
GGML_ASSERT(ggml_is_contiguous(r));
|
| 4770 |
+
GGML_ASSERT(ggml_is_contiguous(w));
|
| 4771 |
+
GGML_ASSERT(ggml_is_contiguous(k));
|
| 4772 |
+
GGML_ASSERT(ggml_is_contiguous(v));
|
| 4773 |
+
GGML_ASSERT(ggml_is_contiguous(a));
|
| 4774 |
+
GGML_ASSERT(ggml_is_contiguous(b));
|
| 4775 |
+
GGML_ASSERT(ggml_is_contiguous(state));
|
| 4776 |
+
|
| 4777 |
+
const int64_t S = k->ne[0];
|
| 4778 |
+
const int64_t H = k->ne[1];
|
| 4779 |
+
const int64_t n_tokens = k->ne[2];
|
| 4780 |
+
const int64_t n_seqs = state->ne[1];
|
| 4781 |
+
{
|
| 4782 |
+
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
|
| 4783 |
+
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
|
| 4784 |
+
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
| 4785 |
+
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
|
| 4786 |
+
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
|
| 4787 |
+
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
| 4788 |
+
}
|
| 4789 |
+
|
| 4790 |
+
// concat output and new_state
|
| 4791 |
+
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
| 4792 |
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4793 |
+
|
| 4794 |
+
result->op = GGML_OP_RWKV_WKV7;
|
| 4795 |
+
result->src[0] = r;
|
| 4796 |
+
result->src[1] = w;
|
| 4797 |
+
result->src[2] = k;
|
| 4798 |
+
result->src[3] = v;
|
| 4799 |
+
result->src[4] = a;
|
| 4800 |
+
result->src[5] = b;
|
| 4801 |
+
result->src[6] = state;
|
| 4802 |
+
|
| 4803 |
+
return result;
|
| 4804 |
+
}
|
| 4805 |
+
|
| 4806 |
// ggml_unary
|
| 4807 |
|
| 4808 |
static struct ggml_tensor * ggml_unary_impl(
|