Spaces:
Sleeping
Sleeping
metal : improve FA + improve MoE (llama/12612)
Browse files* ggml : FA with different K, V head sizes (CPU)
ggml-ci
* metal : add FA with HS=192
* metal : extend FA to support different K and V head sizes
ggml-ci
* metal : add FA vector kernels for heads K 192 and V 128
ggml-ci
* ggml : restrict op on other backends to equal head sizes
ggml-ci
* metal : optimize FA-vec kernel
ggml-ci
* metal : FA remove mq registers
* metal : improve MoE mul_mat_id condition
ggml-ci
* metal : fix comments + remove unnecessary addition
ggml-ci
* metal : avoid too much shared memory usage with mul_mat_id
ggml-ci
- ggml/include/ggml.h +5 -5
- ggml/src/ggml-cpu/ggml-cpu.c +25 -25
- ggml/src/ggml-cuda/ggml-cuda.cu +7 -0
- ggml/src/ggml-metal/ggml-metal-impl.h +6 -3
- ggml/src/ggml-metal/ggml-metal.m +0 -0
- ggml/src/ggml-metal/ggml-metal.metal +292 -232
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +4 -0
- ggml/src/ggml.c +1 -1
ggml/include/ggml.h
CHANGED
|
@@ -1791,11 +1791,11 @@ extern "C" {
|
|
| 1791 |
|
| 1792 |
#define GGML_KQ_MASK_PAD 64
|
| 1793 |
|
| 1794 |
-
// q: [
|
| 1795 |
-
// k: [
|
| 1796 |
-
// v: [
|
| 1797 |
-
// mask: [n_kv,
|
| 1798 |
-
// res: [
|
| 1799 |
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
| 1800 |
struct ggml_context * ctx,
|
| 1801 |
struct ggml_tensor * q,
|
|
|
|
| 1791 |
|
| 1792 |
#define GGML_KQ_MASK_PAD 64
|
| 1793 |
|
| 1794 |
+
// q: [n_embd_k, n_batch, n_head, 1]
|
| 1795 |
+
// k: [n_embd_k, n_kv, n_head_kv, 1]
|
| 1796 |
+
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
|
| 1797 |
+
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
| 1798 |
+
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
|
| 1799 |
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
| 1800 |
struct ggml_context * ctx,
|
| 1801 |
struct ggml_tensor * q,
|
ggml/src/ggml-cpu/ggml-cpu.c
CHANGED
|
@@ -12238,10 +12238,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 12238 |
const int ith = params->ith;
|
| 12239 |
const int nth = params->nth;
|
| 12240 |
|
| 12241 |
-
const int64_t
|
| 12242 |
-
const int64_t
|
|
|
|
| 12243 |
|
| 12244 |
-
GGML_ASSERT(ne0 ==
|
| 12245 |
GGML_ASSERT(ne2 == N);
|
| 12246 |
|
| 12247 |
// input tensor rows must be contiguous
|
|
@@ -12249,12 +12250,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 12249 |
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
| 12250 |
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
| 12251 |
|
| 12252 |
-
GGML_ASSERT(neq0 ==
|
| 12253 |
-
GGML_ASSERT(nek0 ==
|
| 12254 |
-
GGML_ASSERT(nev0 ==
|
| 12255 |
|
| 12256 |
GGML_ASSERT(neq1 == N);
|
| 12257 |
-
GGML_ASSERT(nev0 == D);
|
| 12258 |
|
| 12259 |
// dst cannot be transposed or permuted
|
| 12260 |
GGML_ASSERT(nb0 == sizeof(float));
|
|
@@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 12320 |
float S = 0.0f; // sum
|
| 12321 |
float M = -INFINITY; // maximum KQ value
|
| 12322 |
|
| 12323 |
-
float * VKQ32 = (float *) params->wdata + ith*(
|
| 12324 |
-
float * V32 = (VKQ32 + 1*
|
| 12325 |
-
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*
|
| 12326 |
-
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*
|
| 12327 |
|
| 12328 |
if (v->type == GGML_TYPE_F16) {
|
| 12329 |
-
memset(VKQ16, 0,
|
| 12330 |
} else {
|
| 12331 |
-
memset(VKQ32, 0,
|
| 12332 |
}
|
| 12333 |
|
| 12334 |
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
@@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 12342 |
const int iv2 = iq2 / rv2;
|
| 12343 |
|
| 12344 |
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
| 12345 |
-
q_to_vec_dot(pq, Q_q,
|
| 12346 |
|
| 12347 |
// online softmax / attention
|
| 12348 |
// loop over n_kv and n_head_kv
|
|
@@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 12356 |
float s; // KQ value
|
| 12357 |
|
| 12358 |
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
| 12359 |
-
kq_vec_dot(
|
| 12360 |
|
| 12361 |
s = s*scale; // scale KQ value
|
| 12362 |
|
|
@@ -12380,14 +12380,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 12380 |
ms = expf(Mold - M);
|
| 12381 |
|
| 12382 |
// V = V*expf(Mold - M)
|
| 12383 |
-
ggml_vec_scale_f16(
|
| 12384 |
} else {
|
| 12385 |
// no new maximum, ms == 1.0f, vs != 1.0f
|
| 12386 |
vs = expf(s - M);
|
| 12387 |
}
|
| 12388 |
|
| 12389 |
// V += v*expf(s - M)
|
| 12390 |
-
ggml_vec_mad_f16(
|
| 12391 |
} else {
|
| 12392 |
if (s > M) {
|
| 12393 |
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
|
@@ -12395,30 +12395,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 12395 |
ms = expf(Mold - M);
|
| 12396 |
|
| 12397 |
// V = V*expf(Mold - M)
|
| 12398 |
-
ggml_vec_scale_f32(
|
| 12399 |
} else {
|
| 12400 |
// no new maximum, ms == 1.0f, vs != 1.0f
|
| 12401 |
vs = expf(s - M);
|
| 12402 |
}
|
| 12403 |
|
| 12404 |
-
v_to_float(v_data, V32,
|
| 12405 |
|
| 12406 |
// V += v*expf(s - M)
|
| 12407 |
-
ggml_vec_mad_f32(
|
| 12408 |
}
|
| 12409 |
|
| 12410 |
S = S*ms + vs; // scale and increment sum with partial sum
|
| 12411 |
}
|
| 12412 |
|
| 12413 |
if (v->type == GGML_TYPE_F16) {
|
| 12414 |
-
for (int64_t d = 0; d <
|
| 12415 |
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
| 12416 |
}
|
| 12417 |
}
|
| 12418 |
|
| 12419 |
// V /= S
|
| 12420 |
const float S_inv = 1.0f/S;
|
| 12421 |
-
ggml_vec_scale_f32(
|
| 12422 |
|
| 12423 |
// dst indices
|
| 12424 |
const int i1 = iq1;
|
|
@@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
|
|
| 15277 |
size_t cur = 0;
|
| 15278 |
|
| 15279 |
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
|
| 15280 |
-
|
| 15281 |
switch (node->op) {
|
| 15282 |
case GGML_OP_CPY:
|
| 15283 |
case GGML_OP_DUP:
|
|
@@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
|
|
| 15386 |
} break;
|
| 15387 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 15388 |
{
|
| 15389 |
-
const int64_t
|
|
|
|
| 15390 |
|
| 15391 |
-
cur =
|
| 15392 |
} break;
|
| 15393 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 15394 |
{
|
|
|
|
| 12238 |
const int ith = params->ith;
|
| 12239 |
const int nth = params->nth;
|
| 12240 |
|
| 12241 |
+
const int64_t DK = nek0;
|
| 12242 |
+
const int64_t DV = nev0;
|
| 12243 |
+
const int64_t N = neq1;
|
| 12244 |
|
| 12245 |
+
GGML_ASSERT(ne0 == DV);
|
| 12246 |
GGML_ASSERT(ne2 == N);
|
| 12247 |
|
| 12248 |
// input tensor rows must be contiguous
|
|
|
|
| 12250 |
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
| 12251 |
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
| 12252 |
|
| 12253 |
+
GGML_ASSERT(neq0 == DK);
|
| 12254 |
+
GGML_ASSERT(nek0 == DK);
|
| 12255 |
+
GGML_ASSERT(nev0 == DV);
|
| 12256 |
|
| 12257 |
GGML_ASSERT(neq1 == N);
|
|
|
|
| 12258 |
|
| 12259 |
// dst cannot be transposed or permuted
|
| 12260 |
GGML_ASSERT(nb0 == sizeof(float));
|
|
|
|
| 12320 |
float S = 0.0f; // sum
|
| 12321 |
float M = -INFINITY; // maximum KQ value
|
| 12322 |
|
| 12323 |
+
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
| 12324 |
+
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
|
| 12325 |
+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
|
| 12326 |
+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
|
| 12327 |
|
| 12328 |
if (v->type == GGML_TYPE_F16) {
|
| 12329 |
+
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
|
| 12330 |
} else {
|
| 12331 |
+
memset(VKQ32, 0, DV*sizeof(float));
|
| 12332 |
}
|
| 12333 |
|
| 12334 |
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
|
|
| 12342 |
const int iv2 = iq2 / rv2;
|
| 12343 |
|
| 12344 |
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
| 12345 |
+
q_to_vec_dot(pq, Q_q, DK);
|
| 12346 |
|
| 12347 |
// online softmax / attention
|
| 12348 |
// loop over n_kv and n_head_kv
|
|
|
|
| 12356 |
float s; // KQ value
|
| 12357 |
|
| 12358 |
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
| 12359 |
+
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
|
| 12360 |
|
| 12361 |
s = s*scale; // scale KQ value
|
| 12362 |
|
|
|
|
| 12380 |
ms = expf(Mold - M);
|
| 12381 |
|
| 12382 |
// V = V*expf(Mold - M)
|
| 12383 |
+
ggml_vec_scale_f16(DV, VKQ16, ms);
|
| 12384 |
} else {
|
| 12385 |
// no new maximum, ms == 1.0f, vs != 1.0f
|
| 12386 |
vs = expf(s - M);
|
| 12387 |
}
|
| 12388 |
|
| 12389 |
// V += v*expf(s - M)
|
| 12390 |
+
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
| 12391 |
} else {
|
| 12392 |
if (s > M) {
|
| 12393 |
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
|
|
|
| 12395 |
ms = expf(Mold - M);
|
| 12396 |
|
| 12397 |
// V = V*expf(Mold - M)
|
| 12398 |
+
ggml_vec_scale_f32(DV, VKQ32, ms);
|
| 12399 |
} else {
|
| 12400 |
// no new maximum, ms == 1.0f, vs != 1.0f
|
| 12401 |
vs = expf(s - M);
|
| 12402 |
}
|
| 12403 |
|
| 12404 |
+
v_to_float(v_data, V32, DV);
|
| 12405 |
|
| 12406 |
// V += v*expf(s - M)
|
| 12407 |
+
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
| 12408 |
}
|
| 12409 |
|
| 12410 |
S = S*ms + vs; // scale and increment sum with partial sum
|
| 12411 |
}
|
| 12412 |
|
| 12413 |
if (v->type == GGML_TYPE_F16) {
|
| 12414 |
+
for (int64_t d = 0; d < DV; ++d) {
|
| 12415 |
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
| 12416 |
}
|
| 12417 |
}
|
| 12418 |
|
| 12419 |
// V /= S
|
| 12420 |
const float S_inv = 1.0f/S;
|
| 12421 |
+
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
| 12422 |
|
| 12423 |
// dst indices
|
| 12424 |
const int i1 = iq1;
|
|
|
|
| 15277 |
size_t cur = 0;
|
| 15278 |
|
| 15279 |
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
|
|
|
|
| 15280 |
switch (node->op) {
|
| 15281 |
case GGML_OP_CPY:
|
| 15282 |
case GGML_OP_DUP:
|
|
|
|
| 15385 |
} break;
|
| 15386 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 15387 |
{
|
| 15388 |
+
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
| 15389 |
+
const int64_t ne20 = node->src[2]->ne[0]; // DV
|
| 15390 |
|
| 15391 |
+
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
|
| 15392 |
} break;
|
| 15393 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 15394 |
{
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3232 |
#ifndef FLASH_ATTN_AVAILABLE
|
| 3233 |
return false;
|
| 3234 |
#endif // FLASH_ATTN_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3235 |
if (op->src[0]->ne[3] != 1) {
|
| 3236 |
return false;
|
| 3237 |
}
|
|
|
|
| 3232 |
#ifndef FLASH_ATTN_AVAILABLE
|
| 3233 |
return false;
|
| 3234 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3235 |
+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 3236 |
+
// different head sizes of K and V are not supported yet
|
| 3237 |
+
return false;
|
| 3238 |
+
}
|
| 3239 |
+
if (op->src[0]->ne[0] == 192) {
|
| 3240 |
+
return false;
|
| 3241 |
+
}
|
| 3242 |
if (op->src[0]->ne[3] != 1) {
|
| 3243 |
return false;
|
| 3244 |
}
|
ggml/src/ggml-metal/ggml-metal-impl.h
CHANGED
|
@@ -219,9 +219,12 @@ typedef struct {
|
|
| 219 |
int32_t ne11;
|
| 220 |
int32_t ne_12_2; // assume K and V are same shape
|
| 221 |
int32_t ne_12_3;
|
| 222 |
-
uint64_t
|
| 223 |
-
uint64_t
|
| 224 |
-
uint64_t
|
|
|
|
|
|
|
|
|
|
| 225 |
uint64_t nb31;
|
| 226 |
int32_t ne1;
|
| 227 |
int32_t ne2;
|
|
|
|
| 219 |
int32_t ne11;
|
| 220 |
int32_t ne_12_2; // assume K and V are same shape
|
| 221 |
int32_t ne_12_3;
|
| 222 |
+
uint64_t nb11;
|
| 223 |
+
uint64_t nb12;
|
| 224 |
+
uint64_t nb13;
|
| 225 |
+
uint64_t nb21;
|
| 226 |
+
uint64_t nb22;
|
| 227 |
+
uint64_t nb23;
|
| 228 |
uint64_t nb31;
|
| 229 |
int32_t ne1;
|
| 230 |
int32_t ne2;
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
| 48 |
|
| 49 |
template <typename type4>
|
| 50 |
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
| 51 |
-
reg = (type4)(*(src
|
| 52 |
}
|
| 53 |
|
| 54 |
#if defined(GGML_METAL_USE_BF16)
|
|
@@ -56,6 +56,11 @@ template <typename type4x4>
|
|
| 56 |
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
| 57 |
reg = (type4x4)(*src);
|
| 58 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
#endif
|
| 60 |
|
| 61 |
template <typename type4x4>
|
|
@@ -3100,7 +3105,8 @@ template<
|
|
| 3100 |
typename vd4x4_t, // key type in device memory
|
| 3101 |
short nl_v,
|
| 3102 |
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
| 3103 |
-
short
|
|
|
|
| 3104 |
short Q = 8, // queries per threadgroup
|
| 3105 |
short KV = 8, // key/value processed per each simdgroup
|
| 3106 |
short C = 32> // cache items per threadgroup
|
|
@@ -3122,20 +3128,23 @@ kernel void kernel_flash_attn_ext(
|
|
| 3122 |
const int iq2 = tgpig[1];
|
| 3123 |
const int iq1 = tgpig[0]*Q;
|
| 3124 |
|
| 3125 |
-
const short
|
| 3126 |
-
const short
|
| 3127 |
-
const short
|
|
|
|
|
|
|
|
|
|
| 3128 |
const short NW = N_SIMDWIDTH;
|
| 3129 |
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
| 3130 |
|
| 3131 |
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
| 3132 |
-
const short T =
|
| 3133 |
|
| 3134 |
-
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*
|
| 3135 |
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*
|
| 3136 |
-
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*
|
| 3137 |
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*
|
| 3138 |
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*
|
| 3139 |
|
| 3140 |
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
| 3141 |
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
|
@@ -3144,23 +3153,23 @@ kernel void kernel_flash_attn_ext(
|
|
| 3144 |
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
| 3145 |
|
| 3146 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3147 |
-
o8x8_t lo[
|
| 3148 |
|
| 3149 |
// load heads from Q to shared memory
|
| 3150 |
for (short j = sgitg; j < Q; j += nsg) {
|
| 3151 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
| 3152 |
|
| 3153 |
-
for (short i = tiisg; i <
|
| 3154 |
if (iq1 + j < args.ne01) {
|
| 3155 |
-
sq4[j*
|
| 3156 |
} else {
|
| 3157 |
-
sq4[j*
|
| 3158 |
}
|
| 3159 |
}
|
| 3160 |
}
|
| 3161 |
|
| 3162 |
// zero out lo
|
| 3163 |
-
for (short i = 0; i <
|
| 3164 |
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
|
| 3165 |
}
|
| 3166 |
|
|
@@ -3190,13 +3199,6 @@ kernel void kernel_flash_attn_ext(
|
|
| 3190 |
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
| 3191 |
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
| 3192 |
|
| 3193 |
-
// load the queries from shared memory into local memory
|
| 3194 |
-
q8x8_t mq[D8];
|
| 3195 |
-
|
| 3196 |
-
for (short i = 0; i < D8; ++i) {
|
| 3197 |
-
simdgroup_load(mq[i], sq + i*8, D);
|
| 3198 |
-
}
|
| 3199 |
-
|
| 3200 |
const bool has_mask = mask != q;
|
| 3201 |
|
| 3202 |
half slope = 1.0f;
|
|
@@ -3249,20 +3251,22 @@ kernel void kernel_flash_attn_ext(
|
|
| 3249 |
// this is compile-time check, so it does not have runtime overhead
|
| 3250 |
if (is_same<kd4x4_t, k4x4_t>::value) {
|
| 3251 |
// we can read directly from global memory
|
| 3252 |
-
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.
|
| 3253 |
|
| 3254 |
-
#pragma unroll(
|
| 3255 |
-
for (short i = 0; i <
|
| 3256 |
k8x8_t mk;
|
| 3257 |
-
simdgroup_load(mk, pk + i*8, args.
|
| 3258 |
|
| 3259 |
-
|
|
|
|
|
|
|
| 3260 |
}
|
| 3261 |
} else {
|
| 3262 |
-
for (short ii = 0; ii <
|
| 3263 |
-
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.
|
| 3264 |
|
| 3265 |
-
if (
|
| 3266 |
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
| 3267 |
{
|
| 3268 |
k4x4_t tmp;
|
|
@@ -3275,15 +3279,18 @@ kernel void kernel_flash_attn_ext(
|
|
| 3275 |
#pragma unroll(4)
|
| 3276 |
for (short k = 0; k < 4; ++k) {
|
| 3277 |
k8x8_t mk;
|
|
|
|
| 3278 |
|
| 3279 |
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 3280 |
-
|
|
|
|
| 3281 |
|
| 3282 |
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 3283 |
-
|
|
|
|
| 3284 |
}
|
| 3285 |
} else {
|
| 3286 |
-
if (ii + tx <
|
| 3287 |
k4x4_t tmp;
|
| 3288 |
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
| 3289 |
sk4x4[4*ty + tx] = tmp;
|
|
@@ -3291,14 +3298,17 @@ kernel void kernel_flash_attn_ext(
|
|
| 3291 |
|
| 3292 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3293 |
|
| 3294 |
-
for (short k = 0; k < 4 && ii + k <
|
| 3295 |
k8x8_t mk;
|
|
|
|
| 3296 |
|
| 3297 |
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 3298 |
-
|
|
|
|
| 3299 |
|
| 3300 |
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 3301 |
-
|
|
|
|
| 3302 |
}
|
| 3303 |
}
|
| 3304 |
}
|
|
@@ -3350,8 +3360,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3350 |
s8x8_t mm;
|
| 3351 |
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
| 3352 |
|
| 3353 |
-
#pragma unroll(
|
| 3354 |
-
for (short i = 0; i <
|
| 3355 |
simdgroup_multiply(lo[i], mm, lo[i]);
|
| 3356 |
}
|
| 3357 |
}
|
|
@@ -3364,20 +3374,20 @@ kernel void kernel_flash_attn_ext(
|
|
| 3364 |
|
| 3365 |
if (is_same<vd4x4_t, v4x4_t>::value) {
|
| 3366 |
// we can read directly from global memory
|
| 3367 |
-
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.
|
| 3368 |
|
| 3369 |
-
#pragma unroll(
|
| 3370 |
-
for (short i = 0; i <
|
| 3371 |
v8x8_t mv;
|
| 3372 |
-
simdgroup_load(mv, pv + i*8, args.
|
| 3373 |
|
| 3374 |
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
| 3375 |
}
|
| 3376 |
} else {
|
| 3377 |
-
for (short ii = 0; ii <
|
| 3378 |
-
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.
|
| 3379 |
|
| 3380 |
-
if (
|
| 3381 |
// no need for bound checks
|
| 3382 |
{
|
| 3383 |
v4x4_t tmp;
|
|
@@ -3398,7 +3408,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3398 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3399 |
}
|
| 3400 |
} else {
|
| 3401 |
-
if (ii + tx <
|
| 3402 |
v4x4_t tmp;
|
| 3403 |
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
| 3404 |
sv4x4[4*ty + tx] = tmp;
|
|
@@ -3406,7 +3416,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3406 |
|
| 3407 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3408 |
|
| 3409 |
-
for (short k = 0; k < 4 && ii + k <
|
| 3410 |
v8x8_t mv;
|
| 3411 |
|
| 3412 |
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
|
@@ -3440,8 +3450,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3440 |
|
| 3441 |
// each simdgroup stores its output to shared memory, reusing sq
|
| 3442 |
if (sgitg == sg) {
|
| 3443 |
-
for (short i = 0; i <
|
| 3444 |
-
simdgroup_store(lo[i], so + i*8,
|
| 3445 |
}
|
| 3446 |
}
|
| 3447 |
|
|
@@ -3480,11 +3490,11 @@ kernel void kernel_flash_attn_ext(
|
|
| 3480 |
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
| 3481 |
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
| 3482 |
|
| 3483 |
-
#pragma unroll(
|
| 3484 |
-
for (short i = 0; i <
|
| 3485 |
o8x8_t t;
|
| 3486 |
|
| 3487 |
-
simdgroup_load (t, so + i*8,
|
| 3488 |
simdgroup_multiply(t, ms1, t);
|
| 3489 |
|
| 3490 |
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
|
@@ -3495,8 +3505,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3495 |
|
| 3496 |
// store result to shared memory (reuse sq)
|
| 3497 |
if (sgitg == 0) {
|
| 3498 |
-
for (short i = 0; i <
|
| 3499 |
-
simdgroup_store(lo[i], so + i*8,
|
| 3500 |
}
|
| 3501 |
}
|
| 3502 |
|
|
@@ -3507,8 +3517,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3507 |
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
| 3508 |
const float S = ss[j*TS + 0];
|
| 3509 |
|
| 3510 |
-
for (short i = tiisg; i <
|
| 3511 |
-
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*
|
| 3512 |
}
|
| 3513 |
}
|
| 3514 |
}
|
|
@@ -3525,80 +3535,94 @@ kernel void kernel_flash_attn_ext(
|
|
| 3525 |
float, simdgroup_float8x8, \
|
| 3526 |
half, half4, simdgroup_half8x8
|
| 3527 |
|
| 3528 |
-
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
| 3529 |
|
| 3530 |
-
template [[host_name("kernel_flash_attn_ext_f16_h64" )]]
|
| 3531 |
-
template [[host_name("kernel_flash_attn_ext_f16_h80" )]]
|
| 3532 |
-
template [[host_name("kernel_flash_attn_ext_f16_h96" )]]
|
| 3533 |
-
template [[host_name("kernel_flash_attn_ext_f16_h112")]]
|
| 3534 |
-
template [[host_name("kernel_flash_attn_ext_f16_h128")]]
|
| 3535 |
-
template [[host_name("
|
|
|
|
|
|
|
| 3536 |
|
| 3537 |
#if defined(GGML_METAL_USE_BF16)
|
| 3538 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]
|
| 3539 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]
|
| 3540 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]
|
| 3541 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h112")]]
|
| 3542 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h128")]]
|
| 3543 |
-
template [[host_name("
|
|
|
|
|
|
|
| 3544 |
#endif
|
| 3545 |
|
| 3546 |
-
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]
|
| 3547 |
-
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]]
|
| 3548 |
-
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]]
|
| 3549 |
-
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]]
|
| 3550 |
-
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]
|
| 3551 |
-
template [[host_name("
|
| 3552 |
-
|
| 3553 |
-
template [[host_name("
|
| 3554 |
-
|
| 3555 |
-
template [[host_name("
|
| 3556 |
-
template [[host_name("
|
| 3557 |
-
template [[host_name("
|
| 3558 |
-
template [[host_name("
|
| 3559 |
-
|
| 3560 |
-
template [[host_name("
|
| 3561 |
-
template [[host_name("
|
| 3562 |
-
template [[host_name("
|
| 3563 |
-
|
| 3564 |
-
template [[host_name("
|
| 3565 |
-
template [[host_name("
|
| 3566 |
-
|
| 3567 |
-
template [[host_name("
|
| 3568 |
-
template [[host_name("
|
| 3569 |
-
template [[host_name("
|
| 3570 |
-
template [[host_name("
|
| 3571 |
-
template [[host_name("
|
| 3572 |
-
|
| 3573 |
-
|
| 3574 |
-
template [[host_name("
|
| 3575 |
-
template [[host_name("
|
| 3576 |
-
template [[host_name("
|
| 3577 |
-
template [[host_name("
|
| 3578 |
-
template [[host_name("
|
| 3579 |
-
template [[host_name("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3580 |
|
| 3581 |
#undef FA_TYPES
|
| 3582 |
|
| 3583 |
template<
|
| 3584 |
-
typename q4_t,
|
| 3585 |
-
typename
|
| 3586 |
-
typename
|
| 3587 |
-
typename
|
| 3588 |
-
typename
|
| 3589 |
-
typename s_t, // soft-max types
|
| 3590 |
typename s4_t,
|
| 3591 |
-
typename
|
| 3592 |
-
typename
|
| 3593 |
-
typename kd4x4_t, // key type in device memory
|
| 3594 |
short nl_k,
|
| 3595 |
-
void (*
|
| 3596 |
-
typename
|
| 3597 |
short nl_v,
|
| 3598 |
-
void (*
|
| 3599 |
-
short
|
| 3600 |
-
short
|
| 3601 |
-
short
|
|
|
|
|
|
|
| 3602 |
kernel void kernel_flash_attn_ext_vec(
|
| 3603 |
constant ggml_metal_kargs_flash_attn_ext & args,
|
| 3604 |
device const char * q,
|
|
@@ -3617,29 +3641,28 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3617 |
const int iq2 = tgpig[1];
|
| 3618 |
const int iq1 = tgpig[0];
|
| 3619 |
|
| 3620 |
-
const short
|
| 3621 |
-
const short
|
| 3622 |
const short NW = N_SIMDWIDTH;
|
| 3623 |
-
const short NL = NW/
|
| 3624 |
-
const short SH = 2*C;
|
| 3625 |
|
| 3626 |
-
const short T =
|
| 3627 |
|
| 3628 |
-
//threadgroup q_t
|
| 3629 |
-
threadgroup q4_t
|
| 3630 |
-
threadgroup
|
| 3631 |
-
threadgroup
|
| 3632 |
-
threadgroup
|
| 3633 |
-
threadgroup
|
| 3634 |
-
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
|
| 3635 |
|
| 3636 |
-
// store the result for all queries in local memory
|
| 3637 |
-
|
| 3638 |
|
| 3639 |
// load heads from Q to shared memory
|
| 3640 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
| 3641 |
|
| 3642 |
-
for (short i = tiisg; i <
|
| 3643 |
if (iq1 < args.ne01) {
|
| 3644 |
sq4[i] = (q4_t) q4[i];
|
| 3645 |
} else {
|
|
@@ -3648,8 +3671,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3648 |
}
|
| 3649 |
|
| 3650 |
// zero out lo
|
| 3651 |
-
for (short i = 0; i <
|
| 3652 |
-
lo[i] = (
|
| 3653 |
}
|
| 3654 |
|
| 3655 |
// zero out shared memory SH
|
|
@@ -3674,14 +3697,6 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3674 |
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
| 3675 |
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
| 3676 |
|
| 3677 |
-
// load the queries from shared memory into local memory
|
| 3678 |
-
q4x4_t mq[D16/NL];
|
| 3679 |
-
|
| 3680 |
-
#pragma unroll(D16/NL)
|
| 3681 |
-
for (short ii = 0; ii < D16; ii += NL) {
|
| 3682 |
-
mq[ii/NL] = sq4x4[ii + tx];
|
| 3683 |
-
}
|
| 3684 |
-
|
| 3685 |
const bool has_mask = mask != q;
|
| 3686 |
|
| 3687 |
// pointer to the mask
|
|
@@ -3713,43 +3728,56 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3713 |
|
| 3714 |
// Q*K^T
|
| 3715 |
{
|
| 3716 |
-
// each simdgroup processes 1 query and
|
| 3717 |
-
for (short cc = 0; cc < C/
|
| 3718 |
-
qk_t
|
| 3719 |
|
| 3720 |
-
device const
|
| 3721 |
|
| 3722 |
-
#pragma unroll(
|
| 3723 |
-
for (short ii = 0; ii <
|
| 3724 |
const short i = ii + tx;
|
| 3725 |
|
| 3726 |
-
|
| 3727 |
-
|
| 3728 |
|
| 3729 |
// note: this is less precise than the version below
|
| 3730 |
-
//mqka[0] += dot(mq[
|
| 3731 |
-
//mqka[1] += dot(mq[
|
| 3732 |
-
//mqka[2] += dot(mq[
|
| 3733 |
-
//mqka[3] += dot(mq[
|
| 3734 |
-
|
| 3735 |
-
|
| 3736 |
-
mqka[
|
| 3737 |
-
mqka[
|
| 3738 |
-
mqka[
|
|
|
|
|
|
|
|
|
|
| 3739 |
}
|
| 3740 |
|
| 3741 |
-
|
| 3742 |
|
| 3743 |
-
// simdgroup reduce
|
| 3744 |
// [ 0 .. 7] -> [ 0]
|
| 3745 |
// [ 8 .. 15] -> [ 8]
|
| 3746 |
// [16 .. 23] -> [16]
|
| 3747 |
// [24 .. 31] -> [24]
|
| 3748 |
-
|
| 3749 |
-
|
| 3750 |
-
|
| 3751 |
-
|
| 3752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3753 |
|
| 3754 |
// mqk = mqk*scale + mask*slope
|
| 3755 |
if (tx == 0) {
|
|
@@ -3759,9 +3787,9 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3759 |
mqk = args.logit_softcap*precise::tanh(mqk);
|
| 3760 |
}
|
| 3761 |
|
| 3762 |
-
mqk += sm[
|
| 3763 |
|
| 3764 |
-
ss[
|
| 3765 |
}
|
| 3766 |
}
|
| 3767 |
}
|
|
@@ -3784,8 +3812,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3784 |
ss[tiisg] = vs;
|
| 3785 |
|
| 3786 |
// O = diag(ms)*O
|
| 3787 |
-
#pragma unroll(
|
| 3788 |
-
for (short ii = 0; ii <
|
| 3789 |
lo[ii/NL] *= ms;
|
| 3790 |
}
|
| 3791 |
}
|
|
@@ -3794,17 +3822,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3794 |
|
| 3795 |
// O = O + (Q*K^T)*V
|
| 3796 |
{
|
| 3797 |
-
|
| 3798 |
-
|
|
|
|
| 3799 |
|
| 3800 |
-
const
|
| 3801 |
|
| 3802 |
-
#pragma unroll(
|
| 3803 |
-
for (short ii = 0; ii <
|
| 3804 |
const short i = ii + tx;
|
| 3805 |
|
| 3806 |
-
|
| 3807 |
-
|
| 3808 |
|
| 3809 |
lo[ii/NL] += mv*ms;
|
| 3810 |
}
|
|
@@ -3819,7 +3848,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3819 |
}
|
| 3820 |
}
|
| 3821 |
|
| 3822 |
-
// simdgroup reduce
|
| 3823 |
// [ 0, 8, 16, 24] -> [ 0]
|
| 3824 |
// [ 1, 9, 17, 25] -> [ 1]
|
| 3825 |
// [ 2, 10, 18, 26] -> [ 2]
|
|
@@ -3828,37 +3857,48 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3828 |
// [ 5, 13, 21, 29] -> [ 5]
|
| 3829 |
// [ 6, 14, 22, 30] -> [ 6]
|
| 3830 |
// [ 7, 15, 23, 31] -> [ 7]
|
| 3831 |
-
for (short ii = 0; ii <
|
| 3832 |
-
|
| 3833 |
-
|
| 3834 |
-
|
| 3835 |
-
|
| 3836 |
-
|
| 3837 |
-
|
| 3838 |
-
|
| 3839 |
-
|
| 3840 |
-
|
| 3841 |
-
|
| 3842 |
-
|
| 3843 |
-
|
| 3844 |
-
|
| 3845 |
-
|
| 3846 |
-
|
| 3847 |
-
|
| 3848 |
-
|
| 3849 |
-
|
| 3850 |
-
|
| 3851 |
-
|
| 3852 |
-
|
| 3853 |
-
|
| 3854 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3855 |
}
|
| 3856 |
|
| 3857 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3858 |
|
| 3859 |
// store results to shared memory
|
| 3860 |
-
for (short i = tiisg; i <
|
| 3861 |
-
|
| 3862 |
}
|
| 3863 |
|
| 3864 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
@@ -3885,22 +3925,22 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3885 |
}
|
| 3886 |
|
| 3887 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3888 |
-
for (short i = tiisg; i <
|
| 3889 |
-
|
| 3890 |
}
|
| 3891 |
}
|
| 3892 |
|
| 3893 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3894 |
}
|
| 3895 |
|
| 3896 |
-
device
|
| 3897 |
|
| 3898 |
// final rescale with 1/S and store to global memory
|
| 3899 |
if (sgitg == 0) {
|
| 3900 |
const float S = ss[0];
|
| 3901 |
|
| 3902 |
-
for (short i = tiisg; i <
|
| 3903 |
-
|
| 3904 |
}
|
| 3905 |
}
|
| 3906 |
}
|
|
@@ -3909,34 +3949,54 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3909 |
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
| 3910 |
//
|
| 3911 |
#define FA_TYPES \
|
| 3912 |
-
half4,
|
| 3913 |
-
|
| 3914 |
-
|
| 3915 |
-
float,
|
| 3916 |
-
half, half4,
|
| 3917 |
-
|
| 3918 |
|
| 3919 |
-
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3920 |
|
| 3921 |
-
template [[host_name("
|
| 3922 |
#if defined(GGML_METAL_USE_BF16)
|
| 3923 |
-
template [[host_name("
|
| 3924 |
#endif
|
| 3925 |
-
template [[host_name("
|
| 3926 |
-
template [[host_name("
|
| 3927 |
-
template [[host_name("
|
| 3928 |
-
template [[host_name("
|
| 3929 |
-
template [[host_name("
|
| 3930 |
|
| 3931 |
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
|
| 3932 |
#if defined(GGML_METAL_USE_BF16)
|
| 3933 |
-
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
|
| 3934 |
#endif
|
| 3935 |
-
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,
|
| 3936 |
-
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,
|
| 3937 |
-
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,
|
| 3938 |
-
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,
|
| 3939 |
-
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,
|
| 3940 |
|
| 3941 |
#undef FA_TYPES
|
| 3942 |
|
|
|
|
| 48 |
|
| 49 |
template <typename type4>
|
| 50 |
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
| 51 |
+
reg = (type4)(*(src));
|
| 52 |
}
|
| 53 |
|
| 54 |
#if defined(GGML_METAL_USE_BF16)
|
|
|
|
| 56 |
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
| 57 |
reg = (type4x4)(*src);
|
| 58 |
}
|
| 59 |
+
|
| 60 |
+
template <typename type4>
|
| 61 |
+
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
|
| 62 |
+
reg = (type4)(*(src));
|
| 63 |
+
}
|
| 64 |
#endif
|
| 65 |
|
| 66 |
template <typename type4x4>
|
|
|
|
| 3105 |
typename vd4x4_t, // key type in device memory
|
| 3106 |
short nl_v,
|
| 3107 |
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
| 3108 |
+
short DK, // K head size
|
| 3109 |
+
short DV, // V head size
|
| 3110 |
short Q = 8, // queries per threadgroup
|
| 3111 |
short KV = 8, // key/value processed per each simdgroup
|
| 3112 |
short C = 32> // cache items per threadgroup
|
|
|
|
| 3128 |
const int iq2 = tgpig[1];
|
| 3129 |
const int iq1 = tgpig[0]*Q;
|
| 3130 |
|
| 3131 |
+
const short DK4 = DK/4;
|
| 3132 |
+
const short DK8 = DK/8;
|
| 3133 |
+
const short DK16 = DK/16;
|
| 3134 |
+
const short DV4 = DV/4;
|
| 3135 |
+
const short DV8 = DV/8;
|
| 3136 |
+
const short DV16 = DV/16;
|
| 3137 |
const short NW = N_SIMDWIDTH;
|
| 3138 |
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
| 3139 |
|
| 3140 |
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
| 3141 |
+
const short T = DK + 2*TS; // shared memory size per query in (half)
|
| 3142 |
|
| 3143 |
+
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
| 3144 |
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
| 3145 |
+
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
| 3146 |
+
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
| 3147 |
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
| 3148 |
|
| 3149 |
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
| 3150 |
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
|
|
|
| 3153 |
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
| 3154 |
|
| 3155 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3156 |
+
o8x8_t lo[DV8];
|
| 3157 |
|
| 3158 |
// load heads from Q to shared memory
|
| 3159 |
for (short j = sgitg; j < Q; j += nsg) {
|
| 3160 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
| 3161 |
|
| 3162 |
+
for (short i = tiisg; i < DK4; i += NW) {
|
| 3163 |
if (iq1 + j < args.ne01) {
|
| 3164 |
+
sq4[j*DK4 + i] = (q4_t) q4[i];
|
| 3165 |
} else {
|
| 3166 |
+
sq4[j*DK4 + i] = (q4_t) 0.0f;
|
| 3167 |
}
|
| 3168 |
}
|
| 3169 |
}
|
| 3170 |
|
| 3171 |
// zero out lo
|
| 3172 |
+
for (short i = 0; i < DV8; ++i) {
|
| 3173 |
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
|
| 3174 |
}
|
| 3175 |
|
|
|
|
| 3199 |
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
| 3200 |
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
| 3201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3202 |
const bool has_mask = mask != q;
|
| 3203 |
|
| 3204 |
half slope = 1.0f;
|
|
|
|
| 3251 |
// this is compile-time check, so it does not have runtime overhead
|
| 3252 |
if (is_same<kd4x4_t, k4x4_t>::value) {
|
| 3253 |
// we can read directly from global memory
|
| 3254 |
+
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
| 3255 |
|
| 3256 |
+
#pragma unroll(DK8)
|
| 3257 |
+
for (short i = 0; i < DK8; ++i) {
|
| 3258 |
k8x8_t mk;
|
| 3259 |
+
simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
| 3260 |
|
| 3261 |
+
q8x8_t mq;
|
| 3262 |
+
simdgroup_load(mq, sq + i*8, DK);
|
| 3263 |
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
| 3264 |
}
|
| 3265 |
} else {
|
| 3266 |
+
for (short ii = 0; ii < DK16; ii += 4) {
|
| 3267 |
+
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
| 3268 |
|
| 3269 |
+
if (DK16%4 == 0) {
|
| 3270 |
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
| 3271 |
{
|
| 3272 |
k4x4_t tmp;
|
|
|
|
| 3279 |
#pragma unroll(4)
|
| 3280 |
for (short k = 0; k < 4; ++k) {
|
| 3281 |
k8x8_t mk;
|
| 3282 |
+
q8x8_t mq;
|
| 3283 |
|
| 3284 |
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 3285 |
+
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
|
| 3286 |
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
| 3287 |
|
| 3288 |
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 3289 |
+
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
|
| 3290 |
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
| 3291 |
}
|
| 3292 |
} else {
|
| 3293 |
+
if (ii + tx < DK16) {
|
| 3294 |
k4x4_t tmp;
|
| 3295 |
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
| 3296 |
sk4x4[4*ty + tx] = tmp;
|
|
|
|
| 3298 |
|
| 3299 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3300 |
|
| 3301 |
+
for (short k = 0; k < 4 && ii + k < DK16; ++k) {
|
| 3302 |
k8x8_t mk;
|
| 3303 |
+
q8x8_t mq;
|
| 3304 |
|
| 3305 |
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 3306 |
+
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
|
| 3307 |
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
| 3308 |
|
| 3309 |
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 3310 |
+
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
|
| 3311 |
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
| 3312 |
}
|
| 3313 |
}
|
| 3314 |
}
|
|
|
|
| 3360 |
s8x8_t mm;
|
| 3361 |
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
| 3362 |
|
| 3363 |
+
#pragma unroll(DV8)
|
| 3364 |
+
for (short i = 0; i < DV8; ++i) {
|
| 3365 |
simdgroup_multiply(lo[i], mm, lo[i]);
|
| 3366 |
}
|
| 3367 |
}
|
|
|
|
| 3374 |
|
| 3375 |
if (is_same<vd4x4_t, v4x4_t>::value) {
|
| 3376 |
// we can read directly from global memory
|
| 3377 |
+
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
| 3378 |
|
| 3379 |
+
#pragma unroll(DV8)
|
| 3380 |
+
for (short i = 0; i < DV8; ++i) {
|
| 3381 |
v8x8_t mv;
|
| 3382 |
+
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
| 3383 |
|
| 3384 |
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
| 3385 |
}
|
| 3386 |
} else {
|
| 3387 |
+
for (short ii = 0; ii < DV16; ii += 4) {
|
| 3388 |
+
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
| 3389 |
|
| 3390 |
+
if (DV16%4 == 0) {
|
| 3391 |
// no need for bound checks
|
| 3392 |
{
|
| 3393 |
v4x4_t tmp;
|
|
|
|
| 3408 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3409 |
}
|
| 3410 |
} else {
|
| 3411 |
+
if (ii + tx < DV16) {
|
| 3412 |
v4x4_t tmp;
|
| 3413 |
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
| 3414 |
sv4x4[4*ty + tx] = tmp;
|
|
|
|
| 3416 |
|
| 3417 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3418 |
|
| 3419 |
+
for (short k = 0; k < 4 && ii + k < DV16; ++k) {
|
| 3420 |
v8x8_t mv;
|
| 3421 |
|
| 3422 |
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
|
|
|
| 3450 |
|
| 3451 |
// each simdgroup stores its output to shared memory, reusing sq
|
| 3452 |
if (sgitg == sg) {
|
| 3453 |
+
for (short i = 0; i < DV8; ++i) {
|
| 3454 |
+
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
| 3455 |
}
|
| 3456 |
}
|
| 3457 |
|
|
|
|
| 3490 |
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
| 3491 |
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
| 3492 |
|
| 3493 |
+
#pragma unroll(DV8)
|
| 3494 |
+
for (short i = 0; i < DV8; ++i) {
|
| 3495 |
o8x8_t t;
|
| 3496 |
|
| 3497 |
+
simdgroup_load (t, so + i*8, DV, 0, false);
|
| 3498 |
simdgroup_multiply(t, ms1, t);
|
| 3499 |
|
| 3500 |
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
|
|
|
| 3505 |
|
| 3506 |
// store result to shared memory (reuse sq)
|
| 3507 |
if (sgitg == 0) {
|
| 3508 |
+
for (short i = 0; i < DV8; ++i) {
|
| 3509 |
+
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
| 3510 |
}
|
| 3511 |
}
|
| 3512 |
|
|
|
|
| 3517 |
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
| 3518 |
const float S = ss[j*TS + 0];
|
| 3519 |
|
| 3520 |
+
for (short i = tiisg; i < DV4; i += NW) {
|
| 3521 |
+
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
|
| 3522 |
}
|
| 3523 |
}
|
| 3524 |
}
|
|
|
|
| 3535 |
float, simdgroup_float8x8, \
|
| 3536 |
half, half4, simdgroup_half8x8
|
| 3537 |
|
| 3538 |
+
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
| 3539 |
|
| 3540 |
+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
| 3541 |
+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
| 3542 |
+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
| 3543 |
+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
| 3544 |
+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
|
| 3545 |
+
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
| 3546 |
+
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
| 3547 |
+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
| 3548 |
|
| 3549 |
#if defined(GGML_METAL_USE_BF16)
|
| 3550 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
| 3551 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
| 3552 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
| 3553 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
| 3554 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
| 3555 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
| 3556 |
+
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
| 3557 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
| 3558 |
#endif
|
| 3559 |
|
| 3560 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
| 3561 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
| 3562 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
| 3563 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
| 3564 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
|
| 3565 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
| 3566 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
| 3567 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
| 3568 |
+
|
| 3569 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
| 3570 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
| 3571 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
| 3572 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
| 3573 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
|
| 3574 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
| 3575 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
| 3576 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
| 3577 |
+
|
| 3578 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
| 3579 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
| 3580 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
| 3581 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
| 3582 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
|
| 3583 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
| 3584 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
| 3585 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
| 3586 |
+
|
| 3587 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
| 3588 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
| 3589 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
| 3590 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
| 3591 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
|
| 3592 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
| 3593 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
| 3594 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
| 3595 |
+
|
| 3596 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
| 3597 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
| 3598 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
| 3599 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
| 3600 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
|
| 3601 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
| 3602 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
| 3603 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
| 3604 |
|
| 3605 |
#undef FA_TYPES
|
| 3606 |
|
| 3607 |
template<
|
| 3608 |
+
typename q4_t, // query types in shared memory
|
| 3609 |
+
typename k4_t, // key types in shared memory
|
| 3610 |
+
typename v4_t, // value types in shared memory
|
| 3611 |
+
typename qk_t, // Q*K types
|
| 3612 |
+
typename s_t, // soft-max types
|
|
|
|
| 3613 |
typename s4_t,
|
| 3614 |
+
typename o4_t, // attention accumulation types
|
| 3615 |
+
typename kd4_t, // key type in device memory
|
|
|
|
| 3616 |
short nl_k,
|
| 3617 |
+
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
| 3618 |
+
typename vd4_t, // key type in device memory
|
| 3619 |
short nl_v,
|
| 3620 |
+
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
| 3621 |
+
short DK, // K head size
|
| 3622 |
+
short DV, // V head size
|
| 3623 |
+
short NE = 4, // head elements per thread
|
| 3624 |
+
short Q = 1, // queries per threadgroup
|
| 3625 |
+
short C = 32> // cache items per threadgroup
|
| 3626 |
kernel void kernel_flash_attn_ext_vec(
|
| 3627 |
constant ggml_metal_kargs_flash_attn_ext & args,
|
| 3628 |
device const char * q,
|
|
|
|
| 3641 |
const int iq2 = tgpig[1];
|
| 3642 |
const int iq1 = tgpig[0];
|
| 3643 |
|
| 3644 |
+
const short DK4 = DK/4;
|
| 3645 |
+
const short DV4 = DV/4;
|
| 3646 |
const short NW = N_SIMDWIDTH;
|
| 3647 |
+
const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
| 3648 |
+
const short SH = 2*C; // shared memory per simdgroup
|
| 3649 |
|
| 3650 |
+
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
| 3651 |
|
| 3652 |
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
| 3653 |
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
| 3654 |
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
| 3655 |
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
| 3656 |
+
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
|
| 3657 |
+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
|
|
|
| 3658 |
|
| 3659 |
+
// store the result for all queries in local memory (the O matrix from the paper)
|
| 3660 |
+
o4_t lo[DV4/NL];
|
| 3661 |
|
| 3662 |
// load heads from Q to shared memory
|
| 3663 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
| 3664 |
|
| 3665 |
+
for (short i = tiisg; i < DK4; i += NW) {
|
| 3666 |
if (iq1 < args.ne01) {
|
| 3667 |
sq4[i] = (q4_t) q4[i];
|
| 3668 |
} else {
|
|
|
|
| 3671 |
}
|
| 3672 |
|
| 3673 |
// zero out lo
|
| 3674 |
+
for (short i = 0; i < DV4/NL; ++i) {
|
| 3675 |
+
lo[i] = (o4_t) 0.0f;
|
| 3676 |
}
|
| 3677 |
|
| 3678 |
// zero out shared memory SH
|
|
|
|
| 3697 |
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
| 3698 |
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
| 3699 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3700 |
const bool has_mask = mask != q;
|
| 3701 |
|
| 3702 |
// pointer to the mask
|
|
|
|
| 3728 |
|
| 3729 |
// Q*K^T
|
| 3730 |
{
|
| 3731 |
+
// each simdgroup processes 1 query and NE (NW/NL) head elements
|
| 3732 |
+
for (short cc = 0; cc < C/NE; ++cc) {
|
| 3733 |
+
qk_t mqk = 0.0f;
|
| 3734 |
|
| 3735 |
+
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
| 3736 |
|
| 3737 |
+
#pragma unroll(DK4/NL)
|
| 3738 |
+
for (short ii = 0; ii < DK4; ii += NL) {
|
| 3739 |
const short i = ii + tx;
|
| 3740 |
|
| 3741 |
+
k4_t mk;
|
| 3742 |
+
deq_k_t4(pk + i/nl_k, i%nl_k, mk);
|
| 3743 |
|
| 3744 |
// note: this is less precise than the version below
|
| 3745 |
+
//mqka[0] += dot(mq[0], mk[0]);
|
| 3746 |
+
//mqka[1] += dot(mq[1], mk[1]);
|
| 3747 |
+
//mqka[2] += dot(mq[2], mk[2]);
|
| 3748 |
+
//mqka[3] += dot(mq[3], mk[3]);
|
| 3749 |
+
|
| 3750 |
+
//q4x4_t mq = sq4x4[i];
|
| 3751 |
+
//mqka[0] += dot((float4) mq[0], (float4) mk[0]);
|
| 3752 |
+
//mqka[1] += dot((float4) mq[1], (float4) mk[1]);
|
| 3753 |
+
//mqka[2] += dot((float4) mq[2], (float4) mk[2]);
|
| 3754 |
+
//mqka[3] += dot((float4) mq[3], (float4) mk[3]);
|
| 3755 |
+
|
| 3756 |
+
mqk += dot((float4) mk, (float4) sq4[i]);
|
| 3757 |
}
|
| 3758 |
|
| 3759 |
+
static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
|
| 3760 |
|
| 3761 |
+
// simdgroup reduce (NE = 4)
|
| 3762 |
// [ 0 .. 7] -> [ 0]
|
| 3763 |
// [ 8 .. 15] -> [ 8]
|
| 3764 |
// [16 .. 23] -> [16]
|
| 3765 |
// [24 .. 31] -> [24]
|
| 3766 |
+
if (NE <= 1) {
|
| 3767 |
+
mqk += simd_shuffle_down(mqk, 16);
|
| 3768 |
+
}
|
| 3769 |
+
if (NE <= 2) {
|
| 3770 |
+
mqk += simd_shuffle_down(mqk, 8);
|
| 3771 |
+
}
|
| 3772 |
+
if (NE <= 4) {
|
| 3773 |
+
mqk += simd_shuffle_down(mqk, 4);
|
| 3774 |
+
}
|
| 3775 |
+
if (NE <= 8) {
|
| 3776 |
+
mqk += simd_shuffle_down(mqk, 2);
|
| 3777 |
+
}
|
| 3778 |
+
if (NE <= 16) {
|
| 3779 |
+
mqk += simd_shuffle_down(mqk, 1);
|
| 3780 |
+
}
|
| 3781 |
|
| 3782 |
// mqk = mqk*scale + mask*slope
|
| 3783 |
if (tx == 0) {
|
|
|
|
| 3787 |
mqk = args.logit_softcap*precise::tanh(mqk);
|
| 3788 |
}
|
| 3789 |
|
| 3790 |
+
mqk += sm[NE*cc + ty]*slope;
|
| 3791 |
|
| 3792 |
+
ss[NE*cc + ty] = mqk;
|
| 3793 |
}
|
| 3794 |
}
|
| 3795 |
}
|
|
|
|
| 3812 |
ss[tiisg] = vs;
|
| 3813 |
|
| 3814 |
// O = diag(ms)*O
|
| 3815 |
+
#pragma unroll(DV4/NL)
|
| 3816 |
+
for (short ii = 0; ii < DV4; ii += NL) {
|
| 3817 |
lo[ii/NL] *= ms;
|
| 3818 |
}
|
| 3819 |
}
|
|
|
|
| 3822 |
|
| 3823 |
// O = O + (Q*K^T)*V
|
| 3824 |
{
|
| 3825 |
+
//#pragma unroll(C/NE)
|
| 3826 |
+
for (short cc = 0; cc < C/NE; ++cc) {
|
| 3827 |
+
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
| 3828 |
|
| 3829 |
+
const s4_t ms(ss[NE*cc + ty]);
|
| 3830 |
|
| 3831 |
+
#pragma unroll(DV4/NL)
|
| 3832 |
+
for (short ii = 0; ii < DV4; ii += NL) {
|
| 3833 |
const short i = ii + tx;
|
| 3834 |
|
| 3835 |
+
v4_t mv;
|
| 3836 |
+
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
|
| 3837 |
|
| 3838 |
lo[ii/NL] += mv*ms;
|
| 3839 |
}
|
|
|
|
| 3848 |
}
|
| 3849 |
}
|
| 3850 |
|
| 3851 |
+
// simdgroup reduce (NE = 4)
|
| 3852 |
// [ 0, 8, 16, 24] -> [ 0]
|
| 3853 |
// [ 1, 9, 17, 25] -> [ 1]
|
| 3854 |
// [ 2, 10, 18, 26] -> [ 2]
|
|
|
|
| 3857 |
// [ 5, 13, 21, 29] -> [ 5]
|
| 3858 |
// [ 6, 14, 22, 30] -> [ 6]
|
| 3859 |
// [ 7, 15, 23, 31] -> [ 7]
|
| 3860 |
+
for (short ii = 0; ii < DV4; ii += NL) {
|
| 3861 |
+
if (NE > 1) {
|
| 3862 |
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
|
| 3863 |
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
|
| 3864 |
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
|
| 3865 |
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
|
| 3866 |
+
}
|
| 3867 |
+
|
| 3868 |
+
if (NE > 2) {
|
| 3869 |
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
|
| 3870 |
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
|
| 3871 |
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
|
| 3872 |
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
|
| 3873 |
+
}
|
| 3874 |
+
|
| 3875 |
+
if (NE > 4) {
|
| 3876 |
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
|
| 3877 |
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
|
| 3878 |
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
|
| 3879 |
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
|
| 3880 |
+
}
|
| 3881 |
+
|
| 3882 |
+
if (NE > 8) {
|
| 3883 |
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
|
| 3884 |
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
|
| 3885 |
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
|
| 3886 |
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
|
| 3887 |
+
}
|
| 3888 |
+
|
| 3889 |
+
if (NE > 16) {
|
| 3890 |
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
|
| 3891 |
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
|
| 3892 |
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
|
| 3893 |
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
|
| 3894 |
+
}
|
| 3895 |
}
|
| 3896 |
|
| 3897 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3898 |
|
| 3899 |
// store results to shared memory
|
| 3900 |
+
for (short i = tiisg; i < DV4; i += NL) {
|
| 3901 |
+
sr4[i] = lo[i/NL];
|
| 3902 |
}
|
| 3903 |
|
| 3904 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
| 3925 |
}
|
| 3926 |
|
| 3927 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3928 |
+
for (short i = tiisg; i < DV4; i += NW) {
|
| 3929 |
+
sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
|
| 3930 |
}
|
| 3931 |
}
|
| 3932 |
|
| 3933 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3934 |
}
|
| 3935 |
|
| 3936 |
+
device float4 * dst4 = (device float4 *) dst;
|
| 3937 |
|
| 3938 |
// final rescale with 1/S and store to global memory
|
| 3939 |
if (sgitg == 0) {
|
| 3940 |
const float S = ss[0];
|
| 3941 |
|
| 3942 |
+
for (short i = tiisg; i < DV4; i += NW) {
|
| 3943 |
+
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
|
| 3944 |
}
|
| 3945 |
}
|
| 3946 |
}
|
|
|
|
| 3949 |
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
| 3950 |
//
|
| 3951 |
#define FA_TYPES \
|
| 3952 |
+
half4, \
|
| 3953 |
+
half4, \
|
| 3954 |
+
half4, \
|
| 3955 |
+
float, \
|
| 3956 |
+
half, half4, \
|
| 3957 |
+
half4
|
| 3958 |
|
| 3959 |
+
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
|
| 3960 |
+
|
| 3961 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
|
| 3962 |
+
#if defined(GGML_METAL_USE_BF16)
|
| 3963 |
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
|
| 3964 |
+
#endif
|
| 3965 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
|
| 3966 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
|
| 3967 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
|
| 3968 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
|
| 3969 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
|
| 3970 |
+
|
| 3971 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
|
| 3972 |
+
#if defined(GGML_METAL_USE_BF16)
|
| 3973 |
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
|
| 3974 |
+
#endif
|
| 3975 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
|
| 3976 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
|
| 3977 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
|
| 3978 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
|
| 3979 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
|
| 3980 |
|
| 3981 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
|
| 3982 |
#if defined(GGML_METAL_USE_BF16)
|
| 3983 |
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
|
| 3984 |
#endif
|
| 3985 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
|
| 3986 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
|
| 3987 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
|
| 3988 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
|
| 3989 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
|
| 3990 |
|
| 3991 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
|
| 3992 |
#if defined(GGML_METAL_USE_BF16)
|
| 3993 |
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
|
| 3994 |
#endif
|
| 3995 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
|
| 3996 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
|
| 3997 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
|
| 3998 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
|
| 3999 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
|
| 4000 |
|
| 4001 |
#undef FA_TYPES
|
| 4002 |
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -8764,6 +8764,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 8764 |
default:
|
| 8765 |
return false;
|
| 8766 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8767 |
if (op->src[0]->type != GGML_TYPE_F32) {
|
| 8768 |
return false;
|
| 8769 |
}
|
|
|
|
| 8764 |
default:
|
| 8765 |
return false;
|
| 8766 |
}
|
| 8767 |
+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 8768 |
+
// different head sizes of K and V are not supported yet
|
| 8769 |
+
return false;
|
| 8770 |
+
}
|
| 8771 |
if (op->src[0]->type != GGML_TYPE_F32) {
|
| 8772 |
return false;
|
| 8773 |
}
|
ggml/src/ggml.c
CHANGED
|
@@ -4369,7 +4369,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|
| 4369 |
}
|
| 4370 |
|
| 4371 |
// permute(0, 2, 1, 3)
|
| 4372 |
-
int64_t ne[4] = {
|
| 4373 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4374 |
|
| 4375 |
float params[] = { scale, max_bias, logit_softcap };
|
|
|
|
| 4369 |
}
|
| 4370 |
|
| 4371 |
// permute(0, 2, 1, 3)
|
| 4372 |
+
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
| 4373 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4374 |
|
| 4375 |
float params[] = { scale, max_bias, logit_softcap };
|