ggerganov commited on
Commit
04a3389
·
1 Parent(s): 7585c4a

metal : improve FA + improve MoE (llama/12612)

Browse files

* ggml : FA with different K, V head sizes (CPU)

ggml-ci

* metal : add FA with HS=192

* metal : extend FA to support different K and V head sizes

ggml-ci

* metal : add FA vector kernels for heads K 192 and V 128

ggml-ci

* ggml : restrict op on other backends to equal head sizes

ggml-ci

* metal : optimize FA-vec kernel

ggml-ci

* metal : FA remove mq registers

* metal : improve MoE mul_mat_id condition

ggml-ci

* metal : fix comments + remove unnecessary addition

ggml-ci

* metal : avoid too much shared memory usage with mul_mat_id

ggml-ci

ggml/include/ggml.h CHANGED
@@ -1791,11 +1791,11 @@ extern "C" {
1791
 
1792
  #define GGML_KQ_MASK_PAD 64
1793
 
1794
- // q: [n_embd, n_batch, n_head, 1]
1795
- // k: [n_embd, n_kv, n_head_kv, 1]
1796
- // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1797
- // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1798
- // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1799
  GGML_API struct ggml_tensor * ggml_flash_attn_ext(
1800
  struct ggml_context * ctx,
1801
  struct ggml_tensor * q,
 
1791
 
1792
  #define GGML_KQ_MASK_PAD 64
1793
 
1794
+ // q: [n_embd_k, n_batch, n_head, 1]
1795
+ // k: [n_embd_k, n_kv, n_head_kv, 1]
1796
+ // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1797
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1798
+ // res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
1799
  GGML_API struct ggml_tensor * ggml_flash_attn_ext(
1800
  struct ggml_context * ctx,
1801
  struct ggml_tensor * q,
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -12238,10 +12238,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12238
  const int ith = params->ith;
12239
  const int nth = params->nth;
12240
 
12241
- const int64_t D = neq0;
12242
- const int64_t N = neq1;
 
12243
 
12244
- GGML_ASSERT(ne0 == D);
12245
  GGML_ASSERT(ne2 == N);
12246
 
12247
  // input tensor rows must be contiguous
@@ -12249,12 +12250,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12249
  GGML_ASSERT(nbk0 == ggml_type_size(k->type));
12250
  GGML_ASSERT(nbv0 == ggml_type_size(v->type));
12251
 
12252
- GGML_ASSERT(neq0 == D);
12253
- GGML_ASSERT(nek0 == D);
12254
- GGML_ASSERT(nev0 == D);
12255
 
12256
  GGML_ASSERT(neq1 == N);
12257
- GGML_ASSERT(nev0 == D);
12258
 
12259
  // dst cannot be transposed or permuted
12260
  GGML_ASSERT(nb0 == sizeof(float));
@@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12320
  float S = 0.0f; // sum
12321
  float M = -INFINITY; // maximum KQ value
12322
 
12323
- float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324
- float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
12325
- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
12326
- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
12327
 
12328
  if (v->type == GGML_TYPE_F16) {
12329
- memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
12330
  } else {
12331
- memset(VKQ32, 0, D*sizeof(float));
12332
  }
12333
 
12334
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12342
  const int iv2 = iq2 / rv2;
12343
 
12344
  const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
12345
- q_to_vec_dot(pq, Q_q, D);
12346
 
12347
  // online softmax / attention
12348
  // loop over n_kv and n_head_kv
@@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12356
  float s; // KQ value
12357
 
12358
  const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
12359
- kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
12360
 
12361
  s = s*scale; // scale KQ value
12362
 
@@ -12380,14 +12380,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12380
  ms = expf(Mold - M);
12381
 
12382
  // V = V*expf(Mold - M)
12383
- ggml_vec_scale_f16(D, VKQ16, ms);
12384
  } else {
12385
  // no new maximum, ms == 1.0f, vs != 1.0f
12386
  vs = expf(s - M);
12387
  }
12388
 
12389
  // V += v*expf(s - M)
12390
- ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
12391
  } else {
12392
  if (s > M) {
12393
  // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
@@ -12395,30 +12395,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12395
  ms = expf(Mold - M);
12396
 
12397
  // V = V*expf(Mold - M)
12398
- ggml_vec_scale_f32(D, VKQ32, ms);
12399
  } else {
12400
  // no new maximum, ms == 1.0f, vs != 1.0f
12401
  vs = expf(s - M);
12402
  }
12403
 
12404
- v_to_float(v_data, V32, D);
12405
 
12406
  // V += v*expf(s - M)
12407
- ggml_vec_mad_f32(D, VKQ32, V32, vs);
12408
  }
12409
 
12410
  S = S*ms + vs; // scale and increment sum with partial sum
12411
  }
12412
 
12413
  if (v->type == GGML_TYPE_F16) {
12414
- for (int64_t d = 0; d < D; ++d) {
12415
  VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
12416
  }
12417
  }
12418
 
12419
  // V /= S
12420
  const float S_inv = 1.0f/S;
12421
- ggml_vec_scale_f32(D, VKQ32, S_inv);
12422
 
12423
  // dst indices
12424
  const int i1 = iq1;
@@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
15277
  size_t cur = 0;
15278
 
15279
  if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
15280
-
15281
  switch (node->op) {
15282
  case GGML_OP_CPY:
15283
  case GGML_OP_DUP:
@@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
15386
  } break;
15387
  case GGML_OP_FLASH_ATTN_EXT:
15388
  {
15389
- const int64_t ne00 = node->src[0]->ne[0]; // D
 
15390
 
15391
- cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
15392
  } break;
15393
  case GGML_OP_FLASH_ATTN_BACK:
15394
  {
 
12238
  const int ith = params->ith;
12239
  const int nth = params->nth;
12240
 
12241
+ const int64_t DK = nek0;
12242
+ const int64_t DV = nev0;
12243
+ const int64_t N = neq1;
12244
 
12245
+ GGML_ASSERT(ne0 == DV);
12246
  GGML_ASSERT(ne2 == N);
12247
 
12248
  // input tensor rows must be contiguous
 
12250
  GGML_ASSERT(nbk0 == ggml_type_size(k->type));
12251
  GGML_ASSERT(nbv0 == ggml_type_size(v->type));
12252
 
12253
+ GGML_ASSERT(neq0 == DK);
12254
+ GGML_ASSERT(nek0 == DK);
12255
+ GGML_ASSERT(nev0 == DV);
12256
 
12257
  GGML_ASSERT(neq1 == N);
 
12258
 
12259
  // dst cannot be transposed or permuted
12260
  GGML_ASSERT(nb0 == sizeof(float));
 
12320
  float S = 0.0f; // sum
12321
  float M = -INFINITY; // maximum KQ value
12322
 
12323
+ float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324
+ float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
12325
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
12326
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
12327
 
12328
  if (v->type == GGML_TYPE_F16) {
12329
+ memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
12330
  } else {
12331
+ memset(VKQ32, 0, DV*sizeof(float));
12332
  }
12333
 
12334
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
 
12342
  const int iv2 = iq2 / rv2;
12343
 
12344
  const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
12345
+ q_to_vec_dot(pq, Q_q, DK);
12346
 
12347
  // online softmax / attention
12348
  // loop over n_kv and n_head_kv
 
12356
  float s; // KQ value
12357
 
12358
  const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
12359
+ kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
12360
 
12361
  s = s*scale; // scale KQ value
12362
 
 
12380
  ms = expf(Mold - M);
12381
 
12382
  // V = V*expf(Mold - M)
12383
+ ggml_vec_scale_f16(DV, VKQ16, ms);
12384
  } else {
12385
  // no new maximum, ms == 1.0f, vs != 1.0f
12386
  vs = expf(s - M);
12387
  }
12388
 
12389
  // V += v*expf(s - M)
12390
+ ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
12391
  } else {
12392
  if (s > M) {
12393
  // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
 
12395
  ms = expf(Mold - M);
12396
 
12397
  // V = V*expf(Mold - M)
12398
+ ggml_vec_scale_f32(DV, VKQ32, ms);
12399
  } else {
12400
  // no new maximum, ms == 1.0f, vs != 1.0f
12401
  vs = expf(s - M);
12402
  }
12403
 
12404
+ v_to_float(v_data, V32, DV);
12405
 
12406
  // V += v*expf(s - M)
12407
+ ggml_vec_mad_f32(DV, VKQ32, V32, vs);
12408
  }
12409
 
12410
  S = S*ms + vs; // scale and increment sum with partial sum
12411
  }
12412
 
12413
  if (v->type == GGML_TYPE_F16) {
12414
+ for (int64_t d = 0; d < DV; ++d) {
12415
  VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
12416
  }
12417
  }
12418
 
12419
  // V /= S
12420
  const float S_inv = 1.0f/S;
12421
+ ggml_vec_scale_f32(DV, VKQ32, S_inv);
12422
 
12423
  // dst indices
12424
  const int i1 = iq1;
 
15277
  size_t cur = 0;
15278
 
15279
  if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
 
15280
  switch (node->op) {
15281
  case GGML_OP_CPY:
15282
  case GGML_OP_DUP:
 
15385
  } break;
15386
  case GGML_OP_FLASH_ATTN_EXT:
15387
  {
15388
+ const int64_t ne10 = node->src[1]->ne[0]; // DK
15389
+ const int64_t ne20 = node->src[2]->ne[0]; // DV
15390
 
15391
+ cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
15392
  } break;
15393
  case GGML_OP_FLASH_ATTN_BACK:
15394
  {
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3232
  #ifndef FLASH_ATTN_AVAILABLE
3233
  return false;
3234
  #endif // FLASH_ATTN_AVAILABLE
 
 
 
 
 
 
 
3235
  if (op->src[0]->ne[3] != 1) {
3236
  return false;
3237
  }
 
3232
  #ifndef FLASH_ATTN_AVAILABLE
3233
  return false;
3234
  #endif // FLASH_ATTN_AVAILABLE
3235
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3236
+ // different head sizes of K and V are not supported yet
3237
+ return false;
3238
+ }
3239
+ if (op->src[0]->ne[0] == 192) {
3240
+ return false;
3241
+ }
3242
  if (op->src[0]->ne[3] != 1) {
3243
  return false;
3244
  }
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -219,9 +219,12 @@ typedef struct {
219
  int32_t ne11;
220
  int32_t ne_12_2; // assume K and V are same shape
221
  int32_t ne_12_3;
222
- uint64_t nb_12_1;
223
- uint64_t nb_12_2;
224
- uint64_t nb_12_3;
 
 
 
225
  uint64_t nb31;
226
  int32_t ne1;
227
  int32_t ne2;
 
219
  int32_t ne11;
220
  int32_t ne_12_2; // assume K and V are same shape
221
  int32_t ne_12_3;
222
+ uint64_t nb11;
223
+ uint64_t nb12;
224
+ uint64_t nb13;
225
+ uint64_t nb21;
226
+ uint64_t nb22;
227
+ uint64_t nb23;
228
  uint64_t nb31;
229
  int32_t ne1;
230
  int32_t ne2;
ggml/src/ggml-metal/ggml-metal.m CHANGED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
48
 
49
  template <typename type4>
50
  void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
51
- reg = (type4)(*(src + il));
52
  }
53
 
54
  #if defined(GGML_METAL_USE_BF16)
@@ -56,6 +56,11 @@ template <typename type4x4>
56
  void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
57
  reg = (type4x4)(*src);
58
  }
 
 
 
 
 
59
  #endif
60
 
61
  template <typename type4x4>
@@ -3100,7 +3105,8 @@ template<
3100
  typename vd4x4_t, // key type in device memory
3101
  short nl_v,
3102
  void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
3103
- short D, // head size
 
3104
  short Q = 8, // queries per threadgroup
3105
  short KV = 8, // key/value processed per each simdgroup
3106
  short C = 32> // cache items per threadgroup
@@ -3122,20 +3128,23 @@ kernel void kernel_flash_attn_ext(
3122
  const int iq2 = tgpig[1];
3123
  const int iq1 = tgpig[0]*Q;
3124
 
3125
- const short D4 = D/4;
3126
- const short D8 = D/8;
3127
- const short D16 = D/16;
 
 
 
3128
  const short NW = N_SIMDWIDTH;
3129
  const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3130
 
3131
  const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3132
- const short T = D + 2*TS; // shared memory size per query in (half)
3133
 
3134
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
3135
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
3136
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
3137
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
3138
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
3139
 
3140
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3141
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3144,23 +3153,23 @@ kernel void kernel_flash_attn_ext(
3144
  threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
3145
 
3146
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3147
- o8x8_t lo[D8];
3148
 
3149
  // load heads from Q to shared memory
3150
  for (short j = sgitg; j < Q; j += nsg) {
3151
  device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
3152
 
3153
- for (short i = tiisg; i < D4; i += NW) {
3154
  if (iq1 + j < args.ne01) {
3155
- sq4[j*D4 + i] = (q4_t) q4[i];
3156
  } else {
3157
- sq4[j*D4 + i] = (q4_t) 0.0f;
3158
  }
3159
  }
3160
  }
3161
 
3162
  // zero out lo
3163
- for (short i = 0; i < D8; ++i) {
3164
  lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
3165
  }
3166
 
@@ -3190,13 +3199,6 @@ kernel void kernel_flash_attn_ext(
3190
  const short ikv2 = iq2/(args.ne02/args.ne_12_2);
3191
  const short ikv3 = iq3/(args.ne03/args.ne_12_3);
3192
 
3193
- // load the queries from shared memory into local memory
3194
- q8x8_t mq[D8];
3195
-
3196
- for (short i = 0; i < D8; ++i) {
3197
- simdgroup_load(mq[i], sq + i*8, D);
3198
- }
3199
-
3200
  const bool has_mask = mask != q;
3201
 
3202
  half slope = 1.0f;
@@ -3249,20 +3251,22 @@ kernel void kernel_flash_attn_ext(
3249
  // this is compile-time check, so it does not have runtime overhead
3250
  if (is_same<kd4x4_t, k4x4_t>::value) {
3251
  // we can read directly from global memory
3252
- device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3253
 
3254
- #pragma unroll(D8)
3255
- for (short i = 0; i < D8; ++i) {
3256
  k8x8_t mk;
3257
- simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
3258
 
3259
- simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
 
 
3260
  }
3261
  } else {
3262
- for (short ii = 0; ii < D16; ii += 4) {
3263
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3264
 
3265
- if (D16%4 == 0) {
3266
  // the head is evenly divisible by 4*16 = 64, so no need for bound checks
3267
  {
3268
  k4x4_t tmp;
@@ -3275,15 +3279,18 @@ kernel void kernel_flash_attn_ext(
3275
  #pragma unroll(4)
3276
  for (short k = 0; k < 4; ++k) {
3277
  k8x8_t mk;
 
3278
 
3279
  simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3280
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
 
3281
 
3282
  simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3283
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
 
3284
  }
3285
  } else {
3286
- if (ii + tx < D16) {
3287
  k4x4_t tmp;
3288
  deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
3289
  sk4x4[4*ty + tx] = tmp;
@@ -3291,14 +3298,17 @@ kernel void kernel_flash_attn_ext(
3291
 
3292
  simdgroup_barrier(mem_flags::mem_threadgroup);
3293
 
3294
- for (short k = 0; k < 4 && ii + k < D16; ++k) {
3295
  k8x8_t mk;
 
3296
 
3297
  simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3298
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
 
3299
 
3300
  simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3301
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
 
3302
  }
3303
  }
3304
  }
@@ -3350,8 +3360,8 @@ kernel void kernel_flash_attn_ext(
3350
  s8x8_t mm;
3351
  simdgroup_load(mm, ss + 2*C, TS, 0, false);
3352
 
3353
- #pragma unroll(D8)
3354
- for (short i = 0; i < D8; ++i) {
3355
  simdgroup_multiply(lo[i], mm, lo[i]);
3356
  }
3357
  }
@@ -3364,20 +3374,20 @@ kernel void kernel_flash_attn_ext(
3364
 
3365
  if (is_same<vd4x4_t, v4x4_t>::value) {
3366
  // we can read directly from global memory
3367
- device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3368
 
3369
- #pragma unroll(D8)
3370
- for (short i = 0; i < D8; ++i) {
3371
  v8x8_t mv;
3372
- simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
3373
 
3374
  simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3375
  }
3376
  } else {
3377
- for (short ii = 0; ii < D16; ii += 4) {
3378
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3379
 
3380
- if (D16%4 == 0) {
3381
  // no need for bound checks
3382
  {
3383
  v4x4_t tmp;
@@ -3398,7 +3408,7 @@ kernel void kernel_flash_attn_ext(
3398
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3399
  }
3400
  } else {
3401
- if (ii + tx < D16) {
3402
  v4x4_t tmp;
3403
  deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
3404
  sv4x4[4*ty + tx] = tmp;
@@ -3406,7 +3416,7 @@ kernel void kernel_flash_attn_ext(
3406
 
3407
  simdgroup_barrier(mem_flags::mem_threadgroup);
3408
 
3409
- for (short k = 0; k < 4 && ii + k < D16; ++k) {
3410
  v8x8_t mv;
3411
 
3412
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
@@ -3440,8 +3450,8 @@ kernel void kernel_flash_attn_ext(
3440
 
3441
  // each simdgroup stores its output to shared memory, reusing sq
3442
  if (sgitg == sg) {
3443
- for (short i = 0; i < D8; ++i) {
3444
- simdgroup_store(lo[i], so + i*8, D, 0, false);
3445
  }
3446
  }
3447
 
@@ -3480,11 +3490,11 @@ kernel void kernel_flash_attn_ext(
3480
  simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3481
  simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3482
 
3483
- #pragma unroll(D8)
3484
- for (short i = 0; i < D8; ++i) {
3485
  o8x8_t t;
3486
 
3487
- simdgroup_load (t, so + i*8, D, 0, false);
3488
  simdgroup_multiply(t, ms1, t);
3489
 
3490
  simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -3495,8 +3505,8 @@ kernel void kernel_flash_attn_ext(
3495
 
3496
  // store result to shared memory (reuse sq)
3497
  if (sgitg == 0) {
3498
- for (short i = 0; i < D8; ++i) {
3499
- simdgroup_store(lo[i], so + i*8, D, 0, false);
3500
  }
3501
  }
3502
 
@@ -3507,8 +3517,8 @@ kernel void kernel_flash_attn_ext(
3507
  for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3508
  const float S = ss[j*TS + 0];
3509
 
3510
- for (short i = tiisg; i < D4; i += NW) {
3511
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
3512
  }
3513
  }
3514
  }
@@ -3525,80 +3535,94 @@ kernel void kernel_flash_attn_ext(
3525
  float, simdgroup_float8x8, \
3526
  half, half4, simdgroup_half8x8
3527
 
3528
- typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
3529
 
3530
- template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
3531
- template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
3532
- template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
3533
- template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
3534
- template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
3535
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
 
 
3536
 
3537
  #if defined(GGML_METAL_USE_BF16)
3538
- template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
3539
- template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
3540
- template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
3541
- template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
3542
- template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
3543
- template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
 
 
3544
  #endif
3545
 
3546
- template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
3547
- template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
3548
- template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
3549
- template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
3550
- template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
3551
- template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
3552
-
3553
- template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
3554
- template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
3555
- template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
3556
- template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
3557
- template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
3558
- template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
3559
-
3560
- template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
3561
- template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
3562
- template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
3563
- template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
3564
- template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
3565
- template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
3566
-
3567
- template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
3568
- template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
3569
- template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
3570
- template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
3571
- template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
3572
- template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
3573
-
3574
- template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
3575
- template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
3576
- template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
3577
- template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
3578
- template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
3579
- template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
 
 
 
 
 
 
 
 
 
 
3580
 
3581
  #undef FA_TYPES
3582
 
3583
  template<
3584
- typename q4_t, // query types in shared memory
3585
- typename q4x4_t,
3586
- typename k4x4_t, // key types in shared memory
3587
- typename v4x4_t, // value types in shared memory
3588
- typename qk_t, // Q*K types
3589
- typename s_t, // soft-max types
3590
  typename s4_t,
3591
- typename s4x4_t,
3592
- typename o4x4_t, // attention accumulation types
3593
- typename kd4x4_t, // key type in device memory
3594
  short nl_k,
3595
- void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
3596
- typename vd4x4_t, // key type in device memory
3597
  short nl_v,
3598
- void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
3599
- short D, // head size
3600
- short Q = 1, // queries per threadgroup
3601
- short C = 32> // cache items per threadgroup
 
 
3602
  kernel void kernel_flash_attn_ext_vec(
3603
  constant ggml_metal_kargs_flash_attn_ext & args,
3604
  device const char * q,
@@ -3617,29 +3641,28 @@ kernel void kernel_flash_attn_ext_vec(
3617
  const int iq2 = tgpig[1];
3618
  const int iq1 = tgpig[0];
3619
 
3620
- const short D4 = D/4;
3621
- const short D16 = D/16;
3622
  const short NW = N_SIMDWIDTH;
3623
- const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
3624
- const short SH = 2*C; // shared memory per simdgroup
3625
 
3626
- const short T = D + nsg*SH; // shared memory size per query in (half)
3627
 
3628
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
3629
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
3630
- threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
3631
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
3632
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
3633
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
3634
- threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
3635
 
3636
- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3637
- o4x4_t lo[D16/NL];
3638
 
3639
  // load heads from Q to shared memory
3640
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
3641
 
3642
- for (short i = tiisg; i < D4; i += NW) {
3643
  if (iq1 < args.ne01) {
3644
  sq4[i] = (q4_t) q4[i];
3645
  } else {
@@ -3648,8 +3671,8 @@ kernel void kernel_flash_attn_ext_vec(
3648
  }
3649
 
3650
  // zero out lo
3651
- for (short i = 0; i < D16/NL; ++i) {
3652
- lo[i] = (o4x4_t) 0.0f;
3653
  }
3654
 
3655
  // zero out shared memory SH
@@ -3674,14 +3697,6 @@ kernel void kernel_flash_attn_ext_vec(
3674
  const short ikv2 = iq2/(args.ne02/args.ne_12_2);
3675
  const short ikv3 = iq3/(args.ne03/args.ne_12_3);
3676
 
3677
- // load the queries from shared memory into local memory
3678
- q4x4_t mq[D16/NL];
3679
-
3680
- #pragma unroll(D16/NL)
3681
- for (short ii = 0; ii < D16; ii += NL) {
3682
- mq[ii/NL] = sq4x4[ii + tx];
3683
- }
3684
-
3685
  const bool has_mask = mask != q;
3686
 
3687
  // pointer to the mask
@@ -3713,43 +3728,56 @@ kernel void kernel_flash_attn_ext_vec(
3713
 
3714
  // Q*K^T
3715
  {
3716
- // each simdgroup processes 1 query and 4 (NW/NL) keys
3717
- for (short cc = 0; cc < C/4; ++cc) {
3718
- qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
3719
 
3720
- device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3721
 
3722
- #pragma unroll(D16/NL)
3723
- for (short ii = 0; ii < D16; ii += NL) {
3724
  const short i = ii + tx;
3725
 
3726
- k4x4_t mk;
3727
- deq_k(pk + i/nl_k, i%nl_k, mk);
3728
 
3729
  // note: this is less precise than the version below
3730
- //mqka[0] += dot(mq[ii/NL][0], mk[0]);
3731
- //mqka[1] += dot(mq[ii/NL][1], mk[1]);
3732
- //mqka[2] += dot(mq[ii/NL][2], mk[2]);
3733
- //mqka[3] += dot(mq[ii/NL][3], mk[3]);
3734
-
3735
- mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
3736
- mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
3737
- mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
3738
- mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
 
 
 
3739
  }
3740
 
3741
- qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
3742
 
3743
- // simdgroup reduce
3744
  // [ 0 .. 7] -> [ 0]
3745
  // [ 8 .. 15] -> [ 8]
3746
  // [16 .. 23] -> [16]
3747
  // [24 .. 31] -> [24]
3748
- //mqk += simd_shuffle_down(mqk, 16);
3749
- //mqk += simd_shuffle_down(mqk, 8);
3750
- mqk += simd_shuffle_down(mqk, 4);
3751
- mqk += simd_shuffle_down(mqk, 2);
3752
- mqk += simd_shuffle_down(mqk, 1);
 
 
 
 
 
 
 
 
 
 
3753
 
3754
  // mqk = mqk*scale + mask*slope
3755
  if (tx == 0) {
@@ -3759,9 +3787,9 @@ kernel void kernel_flash_attn_ext_vec(
3759
  mqk = args.logit_softcap*precise::tanh(mqk);
3760
  }
3761
 
3762
- mqk += sm[4*cc + ty]*slope;
3763
 
3764
- ss[4*cc + ty] = mqk;
3765
  }
3766
  }
3767
  }
@@ -3784,8 +3812,8 @@ kernel void kernel_flash_attn_ext_vec(
3784
  ss[tiisg] = vs;
3785
 
3786
  // O = diag(ms)*O
3787
- #pragma unroll(D16/NL)
3788
- for (short ii = 0; ii < D16; ii += NL) {
3789
  lo[ii/NL] *= ms;
3790
  }
3791
  }
@@ -3794,17 +3822,18 @@ kernel void kernel_flash_attn_ext_vec(
3794
 
3795
  // O = O + (Q*K^T)*V
3796
  {
3797
- for (short cc = 0; cc < C/4; ++cc) {
3798
- device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
 
3799
 
3800
- const s4x4_t ms(ss[4*cc + ty]);
3801
 
3802
- #pragma unroll(D16/NL)
3803
- for (short ii = 0; ii < D16; ii += NL) {
3804
  const short i = ii + tx;
3805
 
3806
- v4x4_t mv;
3807
- deq_v(pv4 + i/nl_v, i%nl_v, mv);
3808
 
3809
  lo[ii/NL] += mv*ms;
3810
  }
@@ -3819,7 +3848,7 @@ kernel void kernel_flash_attn_ext_vec(
3819
  }
3820
  }
3821
 
3822
- // simdgroup reduce
3823
  // [ 0, 8, 16, 24] -> [ 0]
3824
  // [ 1, 9, 17, 25] -> [ 1]
3825
  // [ 2, 10, 18, 26] -> [ 2]
@@ -3828,37 +3857,48 @@ kernel void kernel_flash_attn_ext_vec(
3828
  // [ 5, 13, 21, 29] -> [ 5]
3829
  // [ 6, 14, 22, 30] -> [ 6]
3830
  // [ 7, 15, 23, 31] -> [ 7]
3831
- for (short ii = 0; ii < D16; ii += NL) {
3832
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
3833
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
3834
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
3835
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
3836
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
3837
-
3838
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
3839
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
3840
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
3841
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
3842
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
3843
-
3844
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
3845
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
3846
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
3847
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
3848
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
3849
-
3850
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
3851
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
3852
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
3853
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
3854
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
 
 
 
 
 
 
 
 
 
 
 
3855
  }
3856
 
3857
  threadgroup_barrier(mem_flags::mem_threadgroup);
3858
 
3859
  // store results to shared memory
3860
- for (short i = tiisg; i < D16; i += NL) {
3861
- sr4x4[i] = lo[i/NL];
3862
  }
3863
 
3864
  threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3885,22 +3925,22 @@ kernel void kernel_flash_attn_ext_vec(
3885
  }
3886
 
3887
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3888
- for (short i = tiisg; i < D16; i += NW) {
3889
- sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
3890
  }
3891
  }
3892
 
3893
  threadgroup_barrier(mem_flags::mem_threadgroup);
3894
  }
3895
 
3896
- device float4x4 * dst44 = (device float4x4 *) dst;
3897
 
3898
  // final rescale with 1/S and store to global memory
3899
  if (sgitg == 0) {
3900
  const float S = ss[0];
3901
 
3902
- for (short i = tiisg; i < D16; i += NW) {
3903
- dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
3904
  }
3905
  }
3906
  }
@@ -3909,34 +3949,54 @@ kernel void kernel_flash_attn_ext_vec(
3909
  // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3910
  //
3911
  #define FA_TYPES \
3912
- half4, half4x4, \
3913
- half4x4, \
3914
- half4x4, \
3915
- float, \
3916
- half, half4, half4x4, \
3917
- half4x4
3918
 
3919
- typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3920
 
3921
- template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
3922
  #if defined(GGML_METAL_USE_BF16)
3923
- template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
3924
  #endif
3925
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
3926
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
3927
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
3928
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
3929
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
3930
 
3931
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
3932
  #if defined(GGML_METAL_USE_BF16)
3933
- template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
3934
  #endif
3935
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
3936
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
3937
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
3938
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
3939
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
3940
 
3941
  #undef FA_TYPES
3942
 
 
48
 
49
  template <typename type4>
50
  void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
51
+ reg = (type4)(*(src));
52
  }
53
 
54
  #if defined(GGML_METAL_USE_BF16)
 
56
  void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
57
  reg = (type4x4)(*src);
58
  }
59
+
60
+ template <typename type4>
61
+ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
62
+ reg = (type4)(*(src));
63
+ }
64
  #endif
65
 
66
  template <typename type4x4>
 
3105
  typename vd4x4_t, // key type in device memory
3106
  short nl_v,
3107
  void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
3108
+ short DK, // K head size
3109
+ short DV, // V head size
3110
  short Q = 8, // queries per threadgroup
3111
  short KV = 8, // key/value processed per each simdgroup
3112
  short C = 32> // cache items per threadgroup
 
3128
  const int iq2 = tgpig[1];
3129
  const int iq1 = tgpig[0]*Q;
3130
 
3131
+ const short DK4 = DK/4;
3132
+ const short DK8 = DK/8;
3133
+ const short DK16 = DK/16;
3134
+ const short DV4 = DV/4;
3135
+ const short DV8 = DV/8;
3136
+ const short DV16 = DV/16;
3137
  const short NW = N_SIMDWIDTH;
3138
  const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3139
 
3140
  const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3141
+ const short T = DK + 2*TS; // shared memory size per query in (half)
3142
 
3143
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3144
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3145
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3146
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3147
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3148
 
3149
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3150
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
 
3153
  threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
3154
 
3155
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3156
+ o8x8_t lo[DV8];
3157
 
3158
  // load heads from Q to shared memory
3159
  for (short j = sgitg; j < Q; j += nsg) {
3160
  device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
3161
 
3162
+ for (short i = tiisg; i < DK4; i += NW) {
3163
  if (iq1 + j < args.ne01) {
3164
+ sq4[j*DK4 + i] = (q4_t) q4[i];
3165
  } else {
3166
+ sq4[j*DK4 + i] = (q4_t) 0.0f;
3167
  }
3168
  }
3169
  }
3170
 
3171
  // zero out lo
3172
+ for (short i = 0; i < DV8; ++i) {
3173
  lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
3174
  }
3175
 
 
3199
  const short ikv2 = iq2/(args.ne02/args.ne_12_2);
3200
  const short ikv3 = iq3/(args.ne03/args.ne_12_3);
3201
 
 
 
 
 
 
 
 
3202
  const bool has_mask = mask != q;
3203
 
3204
  half slope = 1.0f;
 
3251
  // this is compile-time check, so it does not have runtime overhead
3252
  if (is_same<kd4x4_t, k4x4_t>::value) {
3253
  // we can read directly from global memory
3254
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
3255
 
3256
+ #pragma unroll(DK8)
3257
+ for (short i = 0; i < DK8; ++i) {
3258
  k8x8_t mk;
3259
+ simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
3260
 
3261
+ q8x8_t mq;
3262
+ simdgroup_load(mq, sq + i*8, DK);
3263
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3264
  }
3265
  } else {
3266
+ for (short ii = 0; ii < DK16; ii += 4) {
3267
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
3268
 
3269
+ if (DK16%4 == 0) {
3270
  // the head is evenly divisible by 4*16 = 64, so no need for bound checks
3271
  {
3272
  k4x4_t tmp;
 
3279
  #pragma unroll(4)
3280
  for (short k = 0; k < 4; ++k) {
3281
  k8x8_t mk;
3282
+ q8x8_t mq;
3283
 
3284
  simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3285
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
3286
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3287
 
3288
  simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3289
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
3290
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3291
  }
3292
  } else {
3293
+ if (ii + tx < DK16) {
3294
  k4x4_t tmp;
3295
  deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
3296
  sk4x4[4*ty + tx] = tmp;
 
3298
 
3299
  simdgroup_barrier(mem_flags::mem_threadgroup);
3300
 
3301
+ for (short k = 0; k < 4 && ii + k < DK16; ++k) {
3302
  k8x8_t mk;
3303
+ q8x8_t mq;
3304
 
3305
  simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3306
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
3307
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3308
 
3309
  simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3310
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
3311
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3312
  }
3313
  }
3314
  }
 
3360
  s8x8_t mm;
3361
  simdgroup_load(mm, ss + 2*C, TS, 0, false);
3362
 
3363
+ #pragma unroll(DV8)
3364
+ for (short i = 0; i < DV8; ++i) {
3365
  simdgroup_multiply(lo[i], mm, lo[i]);
3366
  }
3367
  }
 
3374
 
3375
  if (is_same<vd4x4_t, v4x4_t>::value) {
3376
  // we can read directly from global memory
3377
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
3378
 
3379
+ #pragma unroll(DV8)
3380
+ for (short i = 0; i < DV8; ++i) {
3381
  v8x8_t mv;
3382
+ simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
3383
 
3384
  simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3385
  }
3386
  } else {
3387
+ for (short ii = 0; ii < DV16; ii += 4) {
3388
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
3389
 
3390
+ if (DV16%4 == 0) {
3391
  // no need for bound checks
3392
  {
3393
  v4x4_t tmp;
 
3408
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3409
  }
3410
  } else {
3411
+ if (ii + tx < DV16) {
3412
  v4x4_t tmp;
3413
  deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
3414
  sv4x4[4*ty + tx] = tmp;
 
3416
 
3417
  simdgroup_barrier(mem_flags::mem_threadgroup);
3418
 
3419
+ for (short k = 0; k < 4 && ii + k < DV16; ++k) {
3420
  v8x8_t mv;
3421
 
3422
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
 
3450
 
3451
  // each simdgroup stores its output to shared memory, reusing sq
3452
  if (sgitg == sg) {
3453
+ for (short i = 0; i < DV8; ++i) {
3454
+ simdgroup_store(lo[i], so + i*8, DV, 0, false);
3455
  }
3456
  }
3457
 
 
3490
  simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3491
  simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3492
 
3493
+ #pragma unroll(DV8)
3494
+ for (short i = 0; i < DV8; ++i) {
3495
  o8x8_t t;
3496
 
3497
+ simdgroup_load (t, so + i*8, DV, 0, false);
3498
  simdgroup_multiply(t, ms1, t);
3499
 
3500
  simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
 
3505
 
3506
  // store result to shared memory (reuse sq)
3507
  if (sgitg == 0) {
3508
+ for (short i = 0; i < DV8; ++i) {
3509
+ simdgroup_store(lo[i], so + i*8, DV, 0, false);
3510
  }
3511
  }
3512
 
 
3517
  for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3518
  const float S = ss[j*TS + 0];
3519
 
3520
+ for (short i = tiisg; i < DV4; i += NW) {
3521
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3522
  }
3523
  }
3524
  }
 
3535
  float, simdgroup_float8x8, \
3536
  half, half4, simdgroup_half8x8
3537
 
3538
+ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
3539
 
3540
+ template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
3541
+ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
3542
+ template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
3543
+ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
3544
+ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
3545
+ template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
3546
+ template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
3547
+ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
3548
 
3549
  #if defined(GGML_METAL_USE_BF16)
3550
+ template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3551
+ template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3552
+ template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3553
+ template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3554
+ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3555
+ template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3556
+ template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3557
+ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3558
  #endif
3559
 
3560
+ template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
3561
+ template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
3562
+ template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
3563
+ template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
3564
+ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
3565
+ template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
3566
+ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
3567
+ template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
3568
+
3569
+ template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
3570
+ template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
3571
+ template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
3572
+ template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
3573
+ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
3574
+ template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
3575
+ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
3576
+ template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
3577
+
3578
+ template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
3579
+ template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
3580
+ template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
3581
+ template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
3582
+ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
3583
+ template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
3584
+ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
3585
+ template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
3586
+
3587
+ template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
3588
+ template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
3589
+ template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
3590
+ template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
3591
+ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
3592
+ template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
3593
+ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
3594
+ template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
3595
+
3596
+ template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
3597
+ template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
3598
+ template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
3599
+ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
3600
+ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
3601
+ template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
3602
+ template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
3603
+ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
3604
 
3605
  #undef FA_TYPES
3606
 
3607
  template<
3608
+ typename q4_t, // query types in shared memory
3609
+ typename k4_t, // key types in shared memory
3610
+ typename v4_t, // value types in shared memory
3611
+ typename qk_t, // Q*K types
3612
+ typename s_t, // soft-max types
 
3613
  typename s4_t,
3614
+ typename o4_t, // attention accumulation types
3615
+ typename kd4_t, // key type in device memory
 
3616
  short nl_k,
3617
+ void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
3618
+ typename vd4_t, // key type in device memory
3619
  short nl_v,
3620
+ void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
3621
+ short DK, // K head size
3622
+ short DV, // V head size
3623
+ short NE = 4, // head elements per thread
3624
+ short Q = 1, // queries per threadgroup
3625
+ short C = 32> // cache items per threadgroup
3626
  kernel void kernel_flash_attn_ext_vec(
3627
  constant ggml_metal_kargs_flash_attn_ext & args,
3628
  device const char * q,
 
3641
  const int iq2 = tgpig[1];
3642
  const int iq1 = tgpig[0];
3643
 
3644
+ const short DK4 = DK/4;
3645
+ const short DV4 = DV/4;
3646
  const short NW = N_SIMDWIDTH;
3647
+ const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3648
+ const short SH = 2*C; // shared memory per simdgroup
3649
 
3650
+ const short T = DK + nsg*SH; // shared memory size per query in (half)
3651
 
3652
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3653
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3654
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3655
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3656
+ threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
3657
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
 
3658
 
3659
+ // store the result for all queries in local memory (the O matrix from the paper)
3660
+ o4_t lo[DV4/NL];
3661
 
3662
  // load heads from Q to shared memory
3663
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
3664
 
3665
+ for (short i = tiisg; i < DK4; i += NW) {
3666
  if (iq1 < args.ne01) {
3667
  sq4[i] = (q4_t) q4[i];
3668
  } else {
 
3671
  }
3672
 
3673
  // zero out lo
3674
+ for (short i = 0; i < DV4/NL; ++i) {
3675
+ lo[i] = (o4_t) 0.0f;
3676
  }
3677
 
3678
  // zero out shared memory SH
 
3697
  const short ikv2 = iq2/(args.ne02/args.ne_12_2);
3698
  const short ikv3 = iq3/(args.ne03/args.ne_12_3);
3699
 
 
 
 
 
 
 
 
 
3700
  const bool has_mask = mask != q;
3701
 
3702
  // pointer to the mask
 
3728
 
3729
  // Q*K^T
3730
  {
3731
+ // each simdgroup processes 1 query and NE (NW/NL) head elements
3732
+ for (short cc = 0; cc < C/NE; ++cc) {
3733
+ qk_t mqk = 0.0f;
3734
 
3735
+ device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
3736
 
3737
+ #pragma unroll(DK4/NL)
3738
+ for (short ii = 0; ii < DK4; ii += NL) {
3739
  const short i = ii + tx;
3740
 
3741
+ k4_t mk;
3742
+ deq_k_t4(pk + i/nl_k, i%nl_k, mk);
3743
 
3744
  // note: this is less precise than the version below
3745
+ //mqka[0] += dot(mq[0], mk[0]);
3746
+ //mqka[1] += dot(mq[1], mk[1]);
3747
+ //mqka[2] += dot(mq[2], mk[2]);
3748
+ //mqka[3] += dot(mq[3], mk[3]);
3749
+
3750
+ //q4x4_t mq = sq4x4[i];
3751
+ //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
3752
+ //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
3753
+ //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
3754
+ //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
3755
+
3756
+ mqk += dot((float4) mk, (float4) sq4[i]);
3757
  }
3758
 
3759
+ static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
3760
 
3761
+ // simdgroup reduce (NE = 4)
3762
  // [ 0 .. 7] -> [ 0]
3763
  // [ 8 .. 15] -> [ 8]
3764
  // [16 .. 23] -> [16]
3765
  // [24 .. 31] -> [24]
3766
+ if (NE <= 1) {
3767
+ mqk += simd_shuffle_down(mqk, 16);
3768
+ }
3769
+ if (NE <= 2) {
3770
+ mqk += simd_shuffle_down(mqk, 8);
3771
+ }
3772
+ if (NE <= 4) {
3773
+ mqk += simd_shuffle_down(mqk, 4);
3774
+ }
3775
+ if (NE <= 8) {
3776
+ mqk += simd_shuffle_down(mqk, 2);
3777
+ }
3778
+ if (NE <= 16) {
3779
+ mqk += simd_shuffle_down(mqk, 1);
3780
+ }
3781
 
3782
  // mqk = mqk*scale + mask*slope
3783
  if (tx == 0) {
 
3787
  mqk = args.logit_softcap*precise::tanh(mqk);
3788
  }
3789
 
3790
+ mqk += sm[NE*cc + ty]*slope;
3791
 
3792
+ ss[NE*cc + ty] = mqk;
3793
  }
3794
  }
3795
  }
 
3812
  ss[tiisg] = vs;
3813
 
3814
  // O = diag(ms)*O
3815
+ #pragma unroll(DV4/NL)
3816
+ for (short ii = 0; ii < DV4; ii += NL) {
3817
  lo[ii/NL] *= ms;
3818
  }
3819
  }
 
3822
 
3823
  // O = O + (Q*K^T)*V
3824
  {
3825
+ //#pragma unroll(C/NE)
3826
+ for (short cc = 0; cc < C/NE; ++cc) {
3827
+ device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
3828
 
3829
+ const s4_t ms(ss[NE*cc + ty]);
3830
 
3831
+ #pragma unroll(DV4/NL)
3832
+ for (short ii = 0; ii < DV4; ii += NL) {
3833
  const short i = ii + tx;
3834
 
3835
+ v4_t mv;
3836
+ deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
3837
 
3838
  lo[ii/NL] += mv*ms;
3839
  }
 
3848
  }
3849
  }
3850
 
3851
+ // simdgroup reduce (NE = 4)
3852
  // [ 0, 8, 16, 24] -> [ 0]
3853
  // [ 1, 9, 17, 25] -> [ 1]
3854
  // [ 2, 10, 18, 26] -> [ 2]
 
3857
  // [ 5, 13, 21, 29] -> [ 5]
3858
  // [ 6, 14, 22, 30] -> [ 6]
3859
  // [ 7, 15, 23, 31] -> [ 7]
3860
+ for (short ii = 0; ii < DV4; ii += NL) {
3861
+ if (NE > 1) {
3862
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
3863
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
3864
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
3865
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
3866
+ }
3867
+
3868
+ if (NE > 2) {
3869
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
3870
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
3871
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
3872
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
3873
+ }
3874
+
3875
+ if (NE > 4) {
3876
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
3877
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
3878
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
3879
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
3880
+ }
3881
+
3882
+ if (NE > 8) {
3883
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
3884
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
3885
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
3886
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
3887
+ }
3888
+
3889
+ if (NE > 16) {
3890
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
3891
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
3892
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
3893
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
3894
+ }
3895
  }
3896
 
3897
  threadgroup_barrier(mem_flags::mem_threadgroup);
3898
 
3899
  // store results to shared memory
3900
+ for (short i = tiisg; i < DV4; i += NL) {
3901
+ sr4[i] = lo[i/NL];
3902
  }
3903
 
3904
  threadgroup_barrier(mem_flags::mem_threadgroup);
 
3925
  }
3926
 
3927
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3928
+ for (short i = tiisg; i < DV4; i += NW) {
3929
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
3930
  }
3931
  }
3932
 
3933
  threadgroup_barrier(mem_flags::mem_threadgroup);
3934
  }
3935
 
3936
+ device float4 * dst4 = (device float4 *) dst;
3937
 
3938
  // final rescale with 1/S and store to global memory
3939
  if (sgitg == 0) {
3940
  const float S = ss[0];
3941
 
3942
+ for (short i = tiisg; i < DV4; i += NW) {
3943
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
3944
  }
3945
  }
3946
  }
 
3949
  // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3950
  //
3951
  #define FA_TYPES \
3952
+ half4, \
3953
+ half4, \
3954
+ half4, \
3955
+ float, \
3956
+ half, half4, \
3957
+ half4
3958
 
3959
+ typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
3960
+
3961
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
3962
+ #if defined(GGML_METAL_USE_BF16)
3963
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
3964
+ #endif
3965
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
3966
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
3967
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
3968
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
3969
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
3970
+
3971
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
3972
+ #if defined(GGML_METAL_USE_BF16)
3973
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
3974
+ #endif
3975
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
3976
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
3977
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
3978
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
3979
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
3980
 
3981
+ template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
3982
  #if defined(GGML_METAL_USE_BF16)
3983
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
3984
  #endif
3985
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
3986
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
3987
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
3988
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
3989
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
3990
 
3991
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
3992
  #if defined(GGML_METAL_USE_BF16)
3993
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
3994
  #endif
3995
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
3996
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
3997
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
3998
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
3999
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
4000
 
4001
  #undef FA_TYPES
4002
 
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -8764,6 +8764,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8764
  default:
8765
  return false;
8766
  }
 
 
 
 
8767
  if (op->src[0]->type != GGML_TYPE_F32) {
8768
  return false;
8769
  }
 
8764
  default:
8765
  return false;
8766
  }
8767
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
8768
+ // different head sizes of K and V are not supported yet
8769
+ return false;
8770
+ }
8771
  if (op->src[0]->type != GGML_TYPE_F32) {
8772
  return false;
8773
  }
ggml/src/ggml.c CHANGED
@@ -4369,7 +4369,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
4369
  }
4370
 
4371
  // permute(0, 2, 1, 3)
4372
- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4373
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4374
 
4375
  float params[] = { scale, max_bias, logit_softcap };
 
4369
  }
4370
 
4371
  // permute(0, 2, 1, 3)
4372
+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4373
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4374
 
4375
  float params[] = { scale, max_bias, logit_softcap };