JohannesGaessler commited on
Commit
a3fe534
·
1 Parent(s): 6dbe297

CUDA: optimize and refactor MMQ (llama/8416)

Browse files

* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation

ggml/src/ggml-cuda/mma.cuh CHANGED
@@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
70
  }
71
  #endif // defined(INT8_MMA_AVAILABLE)
72
  }
 
 
 
 
73
  };
74
 
75
  struct mma_int_B_J8K4 {
 
70
  }
71
  #endif // defined(INT8_MMA_AVAILABLE)
72
  }
73
+
74
+ __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
75
+ ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
76
+ }
77
  };
78
 
79
  struct mma_int_B_J8K4 {
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -8,18 +8,70 @@
8
  #include <cstdint>
9
 
10
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
 
 
11
 
12
  typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
13
- typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
14
  typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
15
 
 
 
 
 
 
 
16
  struct block_q8_1_mmq {
17
- half2 ds[4];
18
- int8_t qs[4*QK8_1];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  };
20
  static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
21
  static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  struct tile_x_sizes {
24
  int qs;
25
  int dm;
@@ -79,49 +131,46 @@ static constexpr __device__ int get_mmq_y_device() {
79
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
80
  }
81
 
82
- #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
83
- #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
84
- #define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
85
- #define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
86
- #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
87
- #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
88
- #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
89
- #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
90
- #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
91
- #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
92
 
93
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
94
  return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
95
  type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
96
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
97
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
98
  type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
99
  type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
100
  type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
101
  type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
102
  type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
103
  type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
104
- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 :
105
- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 :
106
  tile_x_sizes{0, 0, 0};
107
  }
108
 
109
- #define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
110
- #define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
111
- #define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4)
112
- #define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4)
113
- #define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0)
114
- #define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4)
115
- #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2)
116
- #define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
117
- #define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
118
- #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
119
 
120
  static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
121
  static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
122
- static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
123
- static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
124
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
 
125
  static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
126
  static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
127
  static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
@@ -131,21 +180,20 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
131
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
132
  return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
133
  type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
134
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
135
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
136
  type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
137
  type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
138
  type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
139
  type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
140
  type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
141
  type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
142
- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 :
143
- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 :
144
  0;
145
  }
146
 
147
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
148
- #define MMQ_NWARPS 8
149
 
150
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
151
  return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
@@ -218,7 +266,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
218
 
219
  template <int mmq_x, int mmq_y, int nwarps>
220
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
221
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
222
 
223
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
224
  const int * x_qs = (const int *) x;
@@ -226,34 +274,39 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
226
  const int * y_qs = (const int *) y + 4;
227
  const half2 * y_ds = (const half2 *) y;
228
 
 
 
 
 
229
  #pragma unroll
230
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
231
- const int j = j0 + threadIdx.y;
232
 
233
  #pragma unroll
234
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
235
- const int i = i0 + threadIdx.x;
236
 
237
- const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
238
 
239
- int u[2*VDR_Q4_0_Q8_1_MMQ];
240
 
241
  #pragma unroll
242
- for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
243
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
244
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
245
- }
246
 
247
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
248
- (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
249
- y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
 
250
  }
251
  }
252
  }
253
 
254
  template <int mmq_x, int mmq_y, int nwarps>
255
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
256
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
257
  #ifdef INT8_MMA_AVAILABLE
258
 
259
  typedef mma_int_A_I16K8 mma_A;
@@ -271,52 +324,60 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
271
  const int * y_qs = (const int *) y + 4;
272
  const half2 * y_ds = (const half2 *) y;
273
 
274
- mma_A A[ntx];
275
- float dA[ntx][mma_C::ne/2];
276
 
277
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
278
 
279
  #pragma unroll
280
  for (int n = 0; n < ntx; ++n) {
281
  #pragma unroll
282
- for (int l = 0; l < mma_A::ne; ++l) {
283
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
284
- const int k = k0 + mma_A::get_k(l) % QI4_0;
285
- const int shift = 4*(mma_A::get_k(l) / QI4_0);
286
 
287
- A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
288
- }
 
 
 
 
 
 
289
 
290
  #pragma unroll
291
- for (int l = 0; l < mma_C::ne/2; ++l) {
292
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
293
 
294
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0];
 
295
  }
296
  }
297
 
298
  #pragma unroll
299
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
300
- mma_B B;
301
- float dB[mma_C::ne/2];
 
 
302
 
303
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
304
 
305
  #pragma unroll
306
- for (int l = 0; l < mma_C::ne/2; ++l) {
307
- const int j = j0 + mma_C::get_j(l);
308
 
309
- dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
310
- }
311
 
312
  #pragma unroll
313
- for (int n = 0; n < ntx; ++n) {
314
- mma_C C;
315
- C.mma_K8(A[n], B);
316
 
317
  #pragma unroll
318
- for (int l = 0; l < mma_C::ne; ++l) {
319
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
 
320
  }
321
  }
322
  }
@@ -381,7 +442,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
381
 
382
  template <int mmq_x, int mmq_y, int nwarps>
383
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
384
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
385
 
386
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
387
  const int * x_qs = (const int *) x;
@@ -389,34 +450,39 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
389
  const int * y_qs = (const int *) y + 4;
390
  const half2 * y_ds = (const half2 *) y;
391
 
 
 
 
 
392
  #pragma unroll
393
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
394
- const int j = j0 + threadIdx.y;
395
 
396
  #pragma unroll
397
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
398
- const int i = i0 + threadIdx.x;
399
 
400
- const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
401
 
402
- int u[2*VDR_Q4_1_Q8_1_MMQ];
403
 
404
  #pragma unroll
405
- for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
406
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
407
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
408
- }
409
 
410
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
411
- (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
412
- y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
 
413
  }
414
  }
415
  }
416
 
417
  template <int mmq_x, int mmq_y, int nwarps>
418
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
419
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
420
  #ifdef INT8_MMA_AVAILABLE
421
 
422
  typedef mma_int_A_I16K8 mma_A;
@@ -435,50 +501,58 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
435
  const int * y_qs = (const int *) y + 4;
436
  const half2 * y_ds = (const half2 *) y;
437
 
438
- mma_A A[ntx];
439
- half2 dmA[ntx][mma_C::ne/2];
440
 
441
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
442
 
443
  #pragma unroll
444
  for (int n = 0; n < ntx; ++n) {
445
- ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1);
446
- A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F;
447
- A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F;
448
- A[n].x[0] &= 0x0F0F0F0F;
449
- A[n].x[1] &= 0x0F0F0F0F;
 
 
 
 
450
 
451
  #pragma unroll
452
- for (int l = 0; l < mma_C::ne/2; ++l) {
453
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
454
 
455
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1];
 
456
  }
457
  }
458
 
459
  #pragma unroll
460
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
461
- mma_B B;
462
- half2 dsB[mma_C::ne/2];
 
 
463
 
464
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
465
 
466
  #pragma unroll
467
- for (int l = 0; l < mma_C::ne/2; ++l) {
468
- const int j = j0 + mma_C::get_j(l);
469
 
470
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
471
- }
472
 
473
  #pragma unroll
474
- for (int n = 0; n < ntx; ++n) {
475
- mma_C C;
476
- C.mma_K8(A[n], B);
477
 
478
  #pragma unroll
479
- for (int l = 0; l < mma_C::ne; ++l) {
480
- const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
481
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
 
482
  }
483
  }
484
  }
@@ -531,8 +605,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
531
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
532
 
533
  #ifdef INT8_MMA_AVAILABLE
534
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536
  #else
537
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
538
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
@@ -553,106 +627,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
553
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
554
 
555
  #ifdef INT8_MMA_AVAILABLE
556
- x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d;
557
  #else
558
  x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
559
  #endif // INT8_MMA_AVAILABLE
560
  }
561
  }
562
 
563
- template <int mmq_x, int mmq_y, int nwarps>
564
- static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
565
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
566
-
567
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
568
- const int * x_qs = (const int *) x;
569
- const float * x_df = (const float *) x_qs + txs.qs;
570
- const int * y_qs = (const int *) y + 4;
571
- const float * y_df = (const float *) y;
572
-
573
- #pragma unroll
574
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
575
- const int j = j0 + threadIdx.y;
576
-
577
- #pragma unroll
578
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
579
- const int i = i0 + threadIdx.x;
580
-
581
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
582
- (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
583
- x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
584
- }
585
- }
586
- }
587
-
588
- template <int mmq_x, int mmq_y, int nwarps>
589
- static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
590
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
591
- #ifdef INT8_MMA_AVAILABLE
592
-
593
- typedef mma_int_A_I16K8 mma_A;
594
- typedef mma_int_B_J8K8 mma_B;
595
- typedef mma_int_C_I16J8 mma_C;
596
-
597
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
598
- constexpr int rows_per_warp = 2 * granularity;
599
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
600
-
601
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
602
-
603
- const int * x_qs = (const int *) x;
604
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
605
- const int * y_qs = (const int *) y + 4;
606
- const float * y_df = (const float *) y;
607
-
608
- mma_A A[ntx];
609
- float dA[ntx][mma_C::ne/2];
610
-
611
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
612
-
613
- #pragma unroll
614
- for (int n = 0; n < ntx; ++n) {
615
- A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
616
-
617
- #pragma unroll
618
- for (int l = 0; l < mma_C::ne/2; ++l) {
619
- const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
620
-
621
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
622
- }
623
- }
624
-
625
- #pragma unroll
626
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
627
- mma_B B;
628
- float dB[mma_C::ne/2];
629
-
630
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
631
-
632
- #pragma unroll
633
- for (int l = 0; l < mma_C::ne/2; ++l) {
634
- const int j = j0 + mma_C::get_j(l);
635
-
636
- dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
637
- }
638
-
639
- #pragma unroll
640
- for (int n = 0; n < ntx; ++n) {
641
- mma_C C;
642
- C.mma_K8(A[n], B);
643
-
644
- #pragma unroll
645
- for (int l = 0; l < mma_C::ne; ++l) {
646
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
647
- }
648
- }
649
- }
650
- #else
651
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
652
- NO_DEVICE_CODE;
653
- #endif // INT8_MMA_AVAILABLE
654
- }
655
-
656
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
657
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
658
 
@@ -694,8 +675,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
694
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
695
 
696
  #ifdef INT8_MMA_AVAILABLE
697
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
698
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
699
  #else
700
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
701
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
@@ -716,41 +697,101 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
716
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
717
 
718
  #ifdef INT8_MMA_AVAILABLE
719
- x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm;
720
  #else
721
  x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
722
  #endif // INT8_MMA_AVAILABLE
723
  }
724
  }
725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  template <int mmq_x, int mmq_y, int nwarps>
727
- static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
728
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
729
 
730
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
731
  const int * x_qs = (const int *) x;
732
- const half2 * x_dm = (const half2 *) x_qs + txs.qs;
733
  const int * y_qs = (const int *) y + 4;
734
- const half2 * y_ds = (const half2 *) y;
 
 
 
 
735
 
736
  #pragma unroll
737
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
738
- const int j = j0 + threadIdx.y;
739
 
740
  #pragma unroll
741
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
742
- const int i = i0 + threadIdx.x;
743
 
744
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
745
- (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
746
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
 
747
  }
748
  }
749
  }
750
 
751
  template <int mmq_x, int mmq_y, int nwarps>
752
- static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
753
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
754
  #ifdef INT8_MMA_AVAILABLE
755
 
756
  typedef mma_int_A_I16K8 mma_A;
@@ -764,140 +805,106 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
764
  y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
765
 
766
  const int * x_qs = (const int *) x;
767
- const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
768
  const int * y_qs = (const int *) y + 4;
769
- const half2 * y_ds = (const half2 *) y;
770
 
771
- mma_A A[ntx];
772
- half2 dmA[ntx][mma_C::ne/2];
773
 
774
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
775
 
776
  #pragma unroll
777
  for (int n = 0; n < ntx; ++n) {
778
- A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
779
-
780
  #pragma unroll
781
- for (int l = 0; l < mma_C::ne/2; ++l) {
782
- const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
783
 
784
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1];
785
  }
786
- }
787
-
788
- #pragma unroll
789
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
790
- mma_B B;
791
- half2 dsB[mma_C::ne/2];
792
-
793
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
794
 
795
  #pragma unroll
796
  for (int l = 0; l < mma_C::ne/2; ++l) {
797
- const int j = j0 + mma_C::get_j(l);
798
-
799
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
800
- }
801
 
802
  #pragma unroll
803
- for (int n = 0; n < ntx; ++n) {
804
- mma_C C;
805
- C.mma_K8(A[n], B);
806
 
807
- #pragma unroll
808
- for (int l = 0; l < mma_C::ne; ++l) {
809
- const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
810
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
811
  }
812
  }
813
  }
814
- #else
815
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
816
- NO_DEVICE_CODE;
817
- #endif // INT8_MMA_AVAILABLE
818
- }
819
 
820
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
821
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
 
 
822
 
823
- #ifdef INT8_MMA_AVAILABLE
824
- int * x_qs = (int *) x_tile;
825
- float * x_df = (float *) (x_tile + WARP_SIZE);
826
- #else
827
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
828
- int * x_qs = (int *) x_tile;
829
- float * x_df = (float *) (x_qs + txs.qs);
830
- #endif // INT8_MMA_AVAILABLE
831
 
832
- const int kbx = threadIdx.x / QI8_0;
833
- const int kqsx = threadIdx.x % QI8_0;
834
 
835
  #pragma unroll
836
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
837
- int i = i0 + threadIdx.y;
838
-
839
- if (need_check) {
840
- i = min(i, i_max);
841
- }
842
-
843
- const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
844
 
845
- #ifdef INT8_MMA_AVAILABLE
846
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
847
- #else
848
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
849
- #endif // INT8_MMA_AVAILABLE
850
- }
851
 
852
- const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
853
- const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 
 
854
 
855
  #pragma unroll
856
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
857
- int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
858
-
859
- if (need_check) {
860
- i = min(i, i_max);
861
  }
862
-
863
- const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
864
-
865
- #ifdef INT8_MMA_AVAILABLE
866
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
867
  #else
868
- x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
 
869
  #endif // INT8_MMA_AVAILABLE
870
- }
871
  }
872
 
873
  template <int mmq_x, int mmq_y, int nwarps>
874
- static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
875
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
876
 
877
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
878
  const int * x_qs = (const int *) x;
879
- const float * x_df = (const float *) x_qs + txs.qs;
880
  const int * y_qs = (const int *) y + 4;
881
- const float * y_df = (const float *) y;
 
 
 
 
882
 
883
  #pragma unroll
884
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
885
- const int j = j0 + threadIdx.y;
886
 
887
  #pragma unroll
888
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
889
- const int i = i0 + threadIdx.x;
890
 
891
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
892
- (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
893
- y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
 
894
  }
895
  }
896
  }
897
 
898
  template <int mmq_x, int mmq_y, int nwarps>
899
- static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
900
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
901
  #ifdef INT8_MMA_AVAILABLE
902
 
903
  typedef mma_int_A_I16K8 mma_A;
@@ -911,49 +918,65 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
911
  y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
912
 
913
  const int * x_qs = (const int *) x;
914
- const float * x_df = (const float *) x_qs + WARP_SIZE;
915
  const int * y_qs = (const int *) y + 4;
916
- const float * y_df = (const float *) y;
917
 
918
- mma_A A[ntx];
919
- float dA[ntx][mma_C::ne/2];
920
 
921
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
922
 
923
  #pragma unroll
924
  for (int n = 0; n < ntx; ++n) {
925
- A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
 
 
 
 
 
926
 
927
  #pragma unroll
928
  for (int l = 0; l < mma_C::ne/2; ++l) {
929
  const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
930
 
931
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
 
 
 
 
 
932
  }
933
  }
934
 
935
  #pragma unroll
936
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
937
- mma_B B;
938
- float dB[mma_C::ne/2];
 
 
 
 
939
 
940
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
941
 
942
  #pragma unroll
943
- for (int l = 0; l < mma_C::ne/2; ++l) {
944
- const int j = j0 + mma_C::get_j(l);
945
 
946
- dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
947
- }
948
 
949
  #pragma unroll
950
- for (int n = 0; n < ntx; ++n) {
951
- mma_C C;
952
- C.mma_K8(A[n], B);
953
 
954
  #pragma unroll
955
- for (int l = 0; l < mma_C::ne; ++l) {
956
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
 
 
957
  }
958
  }
959
  }
@@ -968,44 +991,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
968
 
969
  #ifdef INT8_MMA_AVAILABLE
970
  int * x_qs = (int *) x_tile;
971
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
972
  #else
973
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
974
  int * x_qs = (int *) x_tile;
975
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
976
  #endif // INT8_MMA_AVAILABLE
977
 
978
- const int kbx = threadIdx.x / QI2_K;
979
  const int kqsx = threadIdx.x % QI2_K;
980
 
981
  #pragma unroll
982
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
983
- int i = i0 + threadIdx.y;
984
 
985
  if (need_check) {
986
  i = min(i, i_max);
987
  }
988
 
989
- const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
990
 
991
  const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
992
 
993
  #pragma unroll
994
  for (int l = 0; l < QR2_K; ++l) {
995
- const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
996
 
997
- int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
998
- x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
999
- x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
1000
-
1001
- if (kqsx % QR2_K != 0) {
1002
- continue;
1003
- }
1004
 
1005
  #ifdef INT8_MMA_AVAILABLE
1006
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1007
  #else
1008
- x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
1009
  #endif // INT8_MMA_AVAILABLE
1010
  }
1011
 
@@ -1018,44 +1034,68 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1018
  #endif // FAST_FP16_AVAILABLE
1019
 
1020
  #ifdef INT8_MMA_AVAILABLE
1021
- x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik;
1022
  #else
1023
- x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
1024
  #endif // INT8_MMA_AVAILABLE
1025
  }
1026
  }
1027
 
1028
  template <int mmq_x, int mmq_y, int nwarps>
1029
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1030
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1031
 
1032
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1033
  const int * x_qs = (const int *) x;
1034
  const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1035
  const int * y_qs = (const int *) y + 4;
1036
- const float * y_df = (const float *) y;
1037
 
 
1038
  #pragma unroll
1039
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1040
  const int j = j0 + threadIdx.y;
1041
 
 
 
 
1042
  #pragma unroll
1043
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1044
- const int i = i0 + threadIdx.x;
 
 
 
 
 
 
 
 
1045
 
1046
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
1047
- &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
1048
- &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
 
 
 
 
 
 
 
 
 
 
 
1049
  }
1050
  }
1051
  }
1052
 
1053
  template <int mmq_x, int mmq_y, int nwarps>
1054
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1055
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1056
  #ifdef INT8_MMA_AVAILABLE
1057
 
1058
  typedef mma_int_A_I16K4 mma_A;
 
1059
  typedef mma_int_B_J8K4 mma_B;
1060
  typedef mma_int_C_I16J8 mma_C;
1061
 
@@ -1066,74 +1106,107 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1066
  y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
1067
 
1068
  const int * x_qs = (const int *) x;
1069
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
1070
  const int * y_qs = (const int *) y + 4;
1071
- const float * y_df = (const float *) y;
1072
 
1073
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1074
 
1075
- mma_A A[ntx][2];
1076
- float dA[ntx][mma_C::ne/2][2];
1077
- float mA[ntx][mma_C::ne/2][2];
1078
 
1079
  #pragma unroll
1080
  for (int n = 0; n < ntx; ++n) {
1081
  #pragma unroll
1082
- for (int l = 0; l < mma_A::ne; ++l) {
1083
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
1084
- const int shift = 2*mma_A::get_k(l);
1085
 
1086
- A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303;
1087
- A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
1088
  }
 
1089
 
 
 
1090
  #pragma unroll
1091
  for (int l = 0; l < mma_C::ne/2; ++l) {
1092
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1093
 
1094
  #pragma unroll
1095
- for (int kdm = 0; kdm < 2; ++kdm) {
1096
- const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]);
1097
 
1098
- dA[n][l][kdm] = dm.x;
1099
- mA[n][l][kdm] = dm.y;
 
 
1100
  }
1101
  }
1102
  }
1103
 
1104
  #pragma unroll
1105
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1106
- mma_B B[2];
1107
- float dB[mma_C::ne/2];
1108
-
1109
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
1110
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
1111
 
1112
  #pragma unroll
1113
  for (int l = 0; l < mma_C::ne/2; ++l) {
1114
  const int j = j0 + mma_C::get_j(l);
1115
 
1116
- dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
1117
  }
1118
 
1119
- mma_C Cm[2];
1120
- mma_A A1;
1121
- A1.x[0] = 0x01010101;
1122
- A1.x[1] = 0x01010101;
1123
- Cm[0].mma_K4(A1, B[0]);
1124
- Cm[1].mma_K4(A1, B[1]);
 
 
 
 
 
 
 
 
 
1125
 
1126
  #pragma unroll
1127
- for (int n = 0; n < ntx; ++n) {
1128
- mma_C Cd[2];
1129
 
1130
- Cd[0].mma_K4(A[n][0], B[0]);
1131
- Cd[1].mma_K4(A[n][1], B[1]);
1132
 
1133
  #pragma unroll
1134
- for (int l = 0; l < mma_C::ne; ++l) {
1135
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += (
1136
- Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1137
  }
1138
  }
1139
  }
@@ -1149,7 +1222,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1149
  #ifdef INT8_MMA_AVAILABLE
1150
  int * x_qs = (int *) x_tile;
1151
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1152
- int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
1153
  #else
1154
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1155
  int * x_qs = (int *) x_tile;
@@ -1157,75 +1230,66 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1157
  int * x_sc = (int *) (x_df + txs.dm);
1158
  #endif // INT8_MMA_AVAILABLE
1159
 
1160
- const int kbx = threadIdx.x / QI3_K;
1161
  const int kqsx = threadIdx.x % QI3_K;
1162
 
1163
  #pragma unroll
1164
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1165
- int i = i0 + threadIdx.y;
1166
 
1167
  if (need_check) {
1168
  i = min(i, i_max);
1169
  }
1170
 
1171
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
1172
 
1173
  const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1174
  const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1175
 
1176
  #pragma unroll
1177
  for (int l = 0; l < QR3_K; ++l) {
1178
- const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
1179
 
1180
  const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
1181
  const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
1182
 
1183
- int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
1184
- x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
1185
-
1186
- if (kqsx % 2 != 0) {
1187
- continue;
1188
- }
1189
 
1190
  #ifdef INT8_MMA_AVAILABLE
1191
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k;
1192
  #else
1193
- x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
1194
  #endif // INT8_MMA_AVAILABLE
1195
  }
1196
  }
1197
 
1198
- const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
1199
- const int kbxd = threadIdx.x % blocks_per_tile_x_row;
1200
-
1201
  #pragma unroll
1202
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
1203
- int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
1204
 
1205
  if (need_check) {
1206
  i = min(i, i_max);
1207
  }
1208
 
1209
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
1210
 
1211
  #ifdef INT8_MMA_AVAILABLE
1212
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d;
1213
  #else
1214
- x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d;
1215
  #endif // INT8_MMA_AVAILABLE
1216
  }
1217
 
1218
  #pragma unroll
1219
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
1220
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
1221
 
1222
  if (need_check) {
1223
  i = min(i, i_max);
1224
  }
1225
 
1226
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
1227
 
1228
- const int ksc = threadIdx.x % (QI3_K/4);
1229
 
1230
  const int ksc_low = ksc % (QI3_K/8);
1231
  const int shift_low = 4 * (ksc / (QI3_K/8));
@@ -1238,16 +1302,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1238
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1239
 
1240
  #ifdef INT8_MMA_AVAILABLE
1241
- x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc;
1242
  #else
1243
- x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc;
1244
  #endif // INT8_MMA_AVAILABLE
1245
  }
1246
  }
1247
 
1248
  template <int mmq_x, int mmq_y, int nwarps>
1249
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1250
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1251
 
1252
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1253
  const int * x_qs = (const int *) x;
@@ -1256,32 +1320,35 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1256
  const int * y_qs = (const int *) y + 4;
1257
  const float * y_df = (const float *) y;
1258
 
1259
- #pragma unroll
1260
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1261
- const int j = j0 + threadIdx.y;
1262
 
1263
  #pragma unroll
1264
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1265
- const int i = i0 + threadIdx.x;
1266
 
1267
- const int kbx = k0 / QI3_K;
1268
- const int ky = (k0 % QI3_K) * QR3_K;
 
1269
 
1270
- const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
1271
 
1272
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
1273
- &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
1274
- x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
 
1275
  }
1276
  }
1277
  }
1278
 
1279
  template <int mmq_x, int mmq_y, int nwarps>
1280
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
1281
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1282
  #ifdef INT8_MMA_AVAILABLE
1283
 
1284
  typedef mma_int_A_I16K4 mma_A;
 
1285
  typedef mma_int_B_J8K4 mma_B;
1286
  typedef mma_int_C_I16J8 mma_C;
1287
 
@@ -1293,73 +1360,74 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
1293
 
1294
  const int * x_qs = (const int *) x;
1295
  const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1296
- const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K;
1297
  const int * y_qs = (const int *) y + 4;
1298
  const float * y_df = (const float *) y;
1299
 
1300
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1301
 
1302
- mma_A A[ntx][2];
1303
- int scA[ntx][mma_C::ne/2][2];
1304
  float dA[ntx][mma_C::ne/2];
1305
 
1306
  #pragma unroll
1307
  for (int n = 0; n < ntx; ++n) {
1308
  #pragma unroll
1309
- for (int l = 0; l < mma_A::ne; ++l) {
1310
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
1311
- const int k = QR3_K*k0 + mma_A::get_k(l);
1312
 
1313
- A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
1314
- A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
1315
- A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
1316
- A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
1317
  }
1318
 
1319
  #pragma unroll
1320
  for (int l = 0; l < mma_C::ne/2; ++l) {
1321
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1322
 
1323
- const int kbx = k0 / QI3_K;
1324
- const int ky = (k0 % QI3_K) * QR3_K;
1325
- const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4;
1326
 
1327
- scA[n][l][0] = sc[0];
1328
- scA[n][l][1] = sc[1];
1329
- }
1330
 
1331
  #pragma unroll
1332
- for (int l = 0; l < mma_C::ne/2; ++l) {
1333
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
 
1334
 
1335
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K];
1336
  }
1337
  }
1338
 
1339
  #pragma unroll
1340
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1341
- mma_B B[2];
1342
- float dB[mma_C::ne/2];
 
 
1343
 
1344
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
1345
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
1346
 
1347
  #pragma unroll
1348
- for (int l = 0; l < mma_C::ne/2; ++l) {
1349
- const int j = j0 + mma_C::get_j(l);
1350
 
1351
- dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
1352
- }
1353
 
1354
  #pragma unroll
1355
- for (int n = 0; n < ntx; ++n) {
1356
- mma_C C[2];
1357
- C[0].mma_K4(A[n][0], B[0]);
1358
- C[1].mma_K4(A[n][1], B[1]);
1359
 
1360
  #pragma unroll
1361
- for (int l = 0; l < mma_C::ne; ++l) {
1362
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2];
 
 
1363
  }
1364
  }
1365
  }
@@ -1451,7 +1519,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1451
 
1452
  template <int mmq_x, int mmq_y, int nwarps>
1453
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1454
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1455
 
1456
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1457
  const int * x_qs = (const int *) x;
@@ -1460,26 +1528,31 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1460
  const int * y_qs = (const int *) y + 4;
1461
  const half2 * y_ds = (const half2 *) y;
1462
 
 
 
 
 
1463
  #pragma unroll
1464
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1465
- const int j = j0 + threadIdx.y;
1466
 
1467
  #pragma unroll
1468
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1469
- const int i = i0 + threadIdx.x;
1470
 
1471
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
1472
 
1473
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1474
- &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
1475
- x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
 
1476
  }
1477
  }
1478
  }
1479
 
1480
  template <int mmq_x, int mmq_y, int nwarps>
1481
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1482
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1483
  #ifdef INT8_MMA_AVAILABLE
1484
 
1485
  typedef mma_int_A_I16K8 mma_A;
@@ -1500,35 +1573,40 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1500
 
1501
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1502
 
1503
- mma_A A[ntx][2];
1504
- int scA[ntx][mma_C::ne/2][2];
1505
- int mA[ntx][mma_C::ne/2][2];
1506
  half2 dmA[ntx][mma_C::ne/2];
1507
 
1508
  #pragma unroll
1509
  for (int n = 0; n < ntx; ++n) {
1510
  #pragma unroll
1511
- for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) {
1512
- A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K);
 
 
1513
 
1514
  #pragma unroll
1515
  for (int l = 0; l < mma_A::ne; ++l) {
1516
- A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F;
1517
- A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F;
1518
  }
1519
  }
1520
 
1521
  #pragma unroll
1522
- for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
1523
- #pragma unroll
1524
- for (int l = 0; l < mma_C::ne/2; ++l) {
1525
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
1526
 
1527
- const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8);
1528
- const uint8_t * m = sc + 8;
1529
 
1530
- scA[n][l][kvdr/4] = sc[kvdr/4];
1531
- mA[n][l][kvdr/4] = m[kvdr/4];
 
 
1532
  }
1533
  }
1534
 
@@ -1536,7 +1614,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1536
  for (int l = 0; l < mma_C::ne/2; ++l) {
1537
  const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
1538
 
1539
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K];
1540
  }
1541
  }
1542
 
@@ -1546,28 +1624,28 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1546
  float tmpm[ntx][mma_C::ne] = {{0.0f}};
1547
 
1548
  #pragma unroll
1549
- for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
1550
  mma_B B;
1551
  half2 dsB[mma_C::ne/2];
1552
 
1553
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
1554
 
1555
  #pragma unroll
1556
  for (int l = 0; l < mma_C::ne/2; ++l) {
1557
  const int j = j0 + mma_C::get_j(l);
1558
 
1559
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
1560
  }
1561
 
1562
  #pragma unroll
1563
  for (int n = 0; n < ntx; ++n) {
1564
  mma_C C;
1565
- C.mma_K8(A[n][kvdr/4], B);
1566
 
1567
  #pragma unroll
1568
  for (int l = 0; l < mma_C::ne; ++l) {
1569
- tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
1570
- tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
1571
  }
1572
  }
1573
  }
@@ -1682,7 +1760,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1682
 
1683
  template <int mmq_x, int mmq_y, int nwarps>
1684
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1685
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1686
 
1687
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1688
  const int * x_qs = (const int *) x;
@@ -1691,26 +1769,31 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1691
  const int * y_qs = (const int *) y + 4;
1692
  const half2 * y_ds = (const half2 *) y;
1693
 
 
 
 
 
1694
  #pragma unroll
1695
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1696
- const int j = j0 + threadIdx.y;
1697
 
1698
  #pragma unroll
1699
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1700
- const int i = i0 + threadIdx.x;
1701
 
1702
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
1703
 
1704
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1705
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
1706
- x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
 
1707
  }
1708
  }
1709
  }
1710
 
1711
  template <int mmq_x, int mmq_y, int nwarps>
1712
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1713
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1714
  #ifdef INT8_MMA_AVAILABLE
1715
 
1716
  typedef mma_int_A_I16K8 mma_A;
@@ -1731,26 +1814,34 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1731
 
1732
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1733
 
1734
- mma_A A[ntx][2];
1735
- int scA[ntx][mma_C::ne/2][2];
1736
- int mA[ntx][mma_C::ne/2][2];
1737
  half2 dmA[ntx][mma_C::ne/2];
1738
 
1739
  #pragma unroll
1740
  for (int n = 0; n < ntx; ++n) {
1741
  #pragma unroll
1742
- for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
1743
- A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K);
 
 
 
1744
 
1745
  #pragma unroll
1746
- for (int l = 0; l < mma_C::ne/2; ++l) {
1747
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
 
 
1748
 
1749
- const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8);
1750
- const uint8_t * m = sc + 8;
1751
 
1752
- scA[n][l][kvdr/4] = sc[kvdr/4];
1753
- mA[n][l][kvdr/4] = m[kvdr/4];
 
 
1754
  }
1755
  }
1756
 
@@ -1758,7 +1849,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1758
  for (int l = 0; l < mma_C::ne/2; ++l) {
1759
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1760
 
1761
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K];
1762
  }
1763
  }
1764
 
@@ -1768,28 +1859,30 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1768
  float tmpm[ntx][mma_C::ne] = {{0.0f}};
1769
 
1770
  #pragma unroll
1771
- for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
 
 
1772
  mma_B B;
1773
  half2 dsB[mma_C::ne/2];
1774
 
1775
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
1776
 
1777
  #pragma unroll
1778
  for (int l = 0; l < mma_C::ne/2; ++l) {
1779
  const int j = j0 + mma_C::get_j(l);
1780
 
1781
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
1782
  }
1783
 
1784
  #pragma unroll
1785
  for (int n = 0; n < ntx; ++n) {
1786
  mma_C C;
1787
- C.mma_K8(A[n][kvdr/4], B);
1788
 
1789
  #pragma unroll
1790
  for (int l = 0; l < mma_C::ne; ++l) {
1791
- tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
1792
- tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
1793
  }
1794
  }
1795
  }
@@ -1896,7 +1989,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1896
 
1897
  template <int mmq_x, int mmq_y, int nwarps>
1898
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1899
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1900
 
1901
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1902
  const int * x_qs = (const int *) x;
@@ -1905,26 +1998,31 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1905
  const int * y_qs = (const int *) y + 4;
1906
  const float * y_df = (const float *) y;
1907
 
 
 
 
 
1908
  #pragma unroll
1909
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1910
- const int j = j0 + threadIdx.y;
1911
 
1912
  #pragma unroll
1913
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1914
- const int i = i0 + threadIdx.x;
1915
 
1916
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
1917
 
1918
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
1919
- &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
1920
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
 
1921
  }
1922
  }
1923
  }
1924
 
1925
  template <int mmq_x, int mmq_y, int nwarps>
1926
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1927
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1928
  #ifdef INT8_MMA_AVAILABLE
1929
 
1930
  typedef mma_int_A_I16K4 mma_A;
@@ -1945,25 +2043,35 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1945
 
1946
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1947
 
1948
- mma_A A[ntx][4];
1949
- int scA[ntx][mma_C::ne/2][4];
1950
  float dA[ntx][mma_C::ne/2];
1951
 
1952
  #pragma unroll
1953
  for (int n = 0; n < ntx; ++n) {
1954
  #pragma unroll
1955
- for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
1956
- A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K);
1957
- A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
 
 
 
 
 
 
 
1958
 
1959
  #pragma unroll
1960
  for (int l = 0; l < mma_C::ne/2; ++l) {
1961
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1962
 
1963
- const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]);
 
1964
 
1965
- scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0];
1966
- scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1];
 
 
1967
  }
1968
  }
1969
 
@@ -1971,7 +2079,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1971
  for (int l = 0; l < mma_C::ne/2; ++l) {
1972
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1973
 
1974
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K];
1975
  }
1976
  }
1977
 
@@ -1980,30 +2088,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1980
  float tmp[ntx][mma_C::ne] = {{0.0f}};
1981
 
1982
  #pragma unroll
1983
- for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
1984
  mma_B B[2];
1985
  float dB[mma_C::ne/2];
1986
 
1987
- const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE;
1988
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K);
1989
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
1990
 
1991
  #pragma unroll
1992
  for (int l = 0; l < mma_C::ne/2; ++l) {
1993
  const int j = j0 + mma_C::get_j(l);
1994
 
1995
- dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
1996
  }
1997
 
1998
  #pragma unroll
1999
  for (int n = 0; n < ntx; ++n) {
2000
  mma_C C[2];
2001
- C[0].mma_K4(A[n][kvdr/2 + 0], B[0]);
2002
- C[1].mma_K4(A[n][kvdr/2 + 1], B[1]);
2003
 
2004
  #pragma unroll
2005
  for (int l = 0; l < mma_C::ne; ++l) {
2006
- tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2];
2007
  }
2008
  }
2009
  }
@@ -2051,8 +2158,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2051
  const int2 v = get_int_from_table_16(aux_q4);
2052
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2053
  #ifdef INT8_MMA_AVAILABLE
2054
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
2055
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
2056
  #else
2057
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2058
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
@@ -2073,7 +2180,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2073
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2074
 
2075
  #ifdef INT8_MMA_AVAILABLE
2076
- x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
2077
  #else
2078
  x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
2079
  #endif // INT8_MMA_AVAILABLE
@@ -2109,8 +2216,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2109
  const int2 v = get_int_from_table_16(aux_q4);
2110
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2111
  #ifdef INT8_MMA_AVAILABLE
2112
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
2113
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
2114
  #else
2115
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2116
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
@@ -2133,7 +2240,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2133
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2134
 
2135
  #ifdef INT8_MMA_AVAILABLE
2136
- x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
2137
  #else
2138
  x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2139
  #endif // INT8_MMA_AVAILABLE
@@ -2229,16 +2336,16 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2229
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
2230
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2231
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
2232
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2233
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2234
  };
2235
 
2236
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2237
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
2238
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
2239
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
2240
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2241
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2242
  };
2243
 
2244
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -2293,45 +2400,18 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2293
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
2294
  static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2295
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2296
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2297
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2298
  };
2299
 
2300
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2301
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
2302
  static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2303
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2304
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2305
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2306
  };
2307
 
2308
- static bool mmq_need_sum(const ggml_type type_x) {
2309
- switch (type_x) {
2310
- case GGML_TYPE_Q4_0:
2311
- case GGML_TYPE_Q4_1:
2312
- return true;
2313
- case GGML_TYPE_Q5_0:
2314
- return false;
2315
- case GGML_TYPE_Q5_1:
2316
- return true;
2317
- case GGML_TYPE_Q8_0:
2318
- case GGML_TYPE_Q2_K:
2319
- case GGML_TYPE_Q3_K:
2320
- return false;
2321
- case GGML_TYPE_Q4_K:
2322
- case GGML_TYPE_Q5_K:
2323
- return true;
2324
- case GGML_TYPE_Q6_K:
2325
- case GGML_TYPE_IQ4_XS:
2326
- case GGML_TYPE_IQ4_NL:
2327
- return false;
2328
- default:
2329
- GGML_ASSERT(false);
2330
- break;
2331
- }
2332
- return false;
2333
- }
2334
-
2335
  template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
2336
  static __device__ void mul_mat_q_process_tile(
2337
  const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@@ -2339,10 +2419,7 @@ static __device__ void mul_mat_q_process_tile(
2339
  const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
2340
 
2341
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2342
- constexpr int qr = ggml_cuda_type_traits<type>::qr;
2343
- constexpr int qi = ggml_cuda_type_traits<type>::qi;
2344
  constexpr int mmq_y = get_mmq_y_device();
2345
- constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
2346
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
2347
 
2348
  extern __shared__ char data_mul_mat_q[];
@@ -2357,7 +2434,7 @@ static __device__ void mul_mat_q_process_tile(
2357
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
2358
  #endif // INT8_MMA_AVAILABLE
2359
 
2360
- constexpr int blocks_per_warp = WARP_SIZE / qi;
2361
 
2362
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2363
 
@@ -2366,29 +2443,40 @@ static __device__ void mul_mat_q_process_tile(
2366
 
2367
  const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
2368
 
2369
- for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
2370
-
2371
  load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
2372
 
2373
- #pragma unroll
2374
- for (int kr = 0; kr < qr; ++kr) {
2375
- const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
2376
  #pragma unroll
2377
  for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2378
  int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
2379
 
2380
  tile_y[l] = by0[l];
2381
  }
 
2382
 
2383
- __syncthreads();
2384
 
2385
- // #pragma unroll // unrolling this loop causes too much register pressure
2386
- for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
2387
- vec_dot(tile_x, tile_y, sum, k0);
2388
- }
 
 
 
 
 
2389
 
2390
- __syncthreads();
 
2391
  }
 
 
 
 
 
 
2392
  }
2393
 
2394
  if (fixup) {
@@ -2424,7 +2512,6 @@ static __global__ void mul_mat_q(
2424
  }
2425
 
2426
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2427
- constexpr int qi = ggml_cuda_type_traits<type>::qi;
2428
  constexpr int mmq_y = get_mmq_y_device();
2429
 
2430
  // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
@@ -2439,7 +2526,7 @@ static __global__ void mul_mat_q(
2439
  #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
2440
 
2441
  const int64_t blocks_per_ne00 = ne00 / qk;
2442
- constexpr int blocks_per_warp = WARP_SIZE / qi;
2443
 
2444
  const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
2445
  const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
@@ -2448,8 +2535,8 @@ static __global__ void mul_mat_q(
2448
  int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
2449
  int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
2450
 
2451
- kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
2452
- kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
2453
 
2454
  // kb0 == k index when doing the matrix multiplication for an output tile.
2455
  int kb0_start = kbc % blocks_per_ne00;
@@ -2490,8 +2577,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
2490
 
2491
  constexpr int mmq_y = get_mmq_y_device();
2492
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2493
- constexpr int qi = ggml_cuda_type_traits<type>::qi;
2494
- constexpr int blocks_per_warp = WARP_SIZE / qi;
2495
  const int64_t blocks_per_ne00 = ne00 / qk;
2496
 
2497
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
@@ -2501,15 +2587,18 @@ static __global__ void mul_mat_q_stream_k_fixup(
2501
 
2502
  bool any_fixup = false;
2503
 
2504
- const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
2505
- const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
 
 
 
2506
 
2507
  for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
2508
- int64_t kbc = (int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq;
2509
- int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
2510
 
2511
- kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
2512
- kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
2513
 
2514
  // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
2515
  if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
 
8
  #include <cstdint>
9
 
10
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
11
+ #define MMQ_ITER_K 256
12
+ #define MMQ_NWARPS 8
13
 
14
  typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
15
+ typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
16
  typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
17
 
18
+ enum mmq_q8_1_ds_layout {
19
+ MMQ_Q8_1_DS_LAYOUT_D4,
20
+ MMQ_Q8_1_DS_LAYOUT_DS4,
21
+ MMQ_Q8_1_DS_LAYOUT_D2S6,
22
+ };
23
+
24
  struct block_q8_1_mmq {
25
+ // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
26
+ // The y float data is first grouped as blocks of 128 values.
27
+ // These blocks are then treated as individual data values and transposed.
28
+ //
29
+ // To avoid shared memory bank conflicts each block is padded with 16 bytes.
30
+ // This padding is also used to store block scales/partial sums.
31
+ // The scales multiplied with the quantized data are equal to the unquantized values.
32
+ // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
33
+ // and are only needed for performance reasons.
34
+ //
35
+ // The exact data stored depends on the x data type.
36
+ union {
37
+ float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
38
+ half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
39
+ half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
40
+ // stored as d0,d1,s1,s2,s3,s4,s5
41
+ };
42
+ int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
43
  };
44
  static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
45
  static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
46
 
47
+ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
48
+ switch (type_x) {
49
+ case GGML_TYPE_Q4_0:
50
+ case GGML_TYPE_Q4_1:
51
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
52
+ case GGML_TYPE_Q5_0:
53
+ return MMQ_Q8_1_DS_LAYOUT_D4;
54
+ case GGML_TYPE_Q5_1:
55
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
56
+ case GGML_TYPE_Q8_0:
57
+ return MMQ_Q8_1_DS_LAYOUT_D4;
58
+ case GGML_TYPE_Q2_K:
59
+ return MMQ_Q8_1_DS_LAYOUT_D2S6;
60
+ case GGML_TYPE_Q3_K:
61
+ return MMQ_Q8_1_DS_LAYOUT_D4;
62
+ case GGML_TYPE_Q4_K:
63
+ case GGML_TYPE_Q5_K:
64
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
65
+ case GGML_TYPE_Q6_K:
66
+ case GGML_TYPE_IQ4_XS:
67
+ case GGML_TYPE_IQ4_NL:
68
+ return MMQ_Q8_1_DS_LAYOUT_D4;
69
+ default:
70
+ GGML_ASSERT(false);
71
+ break;
72
+ }
73
+ }
74
+
75
  struct tile_x_sizes {
76
  int qs;
77
  int dm;
 
131
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
132
  }
133
 
134
+ #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
135
+ #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
136
+ #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
137
+ #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
138
+ #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
139
+ #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
140
+ #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
141
+ #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
142
+ #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
 
143
 
144
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
145
  return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
146
  type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
147
+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
148
+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
149
  type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
150
  type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
151
  type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
152
  type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
153
  type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
154
  type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
155
+ type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
156
+ type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
157
  tile_x_sizes{0, 0, 0};
158
  }
159
 
160
+ #define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
161
+ #define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
162
+ #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
163
+ #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
164
+ #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
165
+ #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7)
166
+ #define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
167
+ #define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
168
+ #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
 
169
 
170
  static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
171
  static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
 
 
172
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
173
+ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
174
  static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
175
  static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
176
  static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
 
180
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
181
  return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
182
  type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
183
+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
184
+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
185
  type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
186
  type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
187
  type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
188
  type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
189
  type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
190
  type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
191
+ type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
192
+ type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
193
  0;
194
  }
195
 
196
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
 
197
 
198
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
199
  return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
 
266
 
267
  template <int mmq_x, int mmq_y, int nwarps>
268
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
269
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
270
 
271
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
272
  const int * x_qs = (const int *) x;
 
274
  const int * y_qs = (const int *) y + 4;
275
  const half2 * y_ds = (const half2 *) y;
276
 
277
+ // #pragma unroll
278
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
279
+ const int k0 = k00 + k01;
280
+
281
  #pragma unroll
282
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
283
+ const int j = j0 + threadIdx.y;
284
 
285
  #pragma unroll
286
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
287
+ const int i = i0 + threadIdx.x;
288
 
289
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
290
 
291
+ int u[2*VDR_Q4_0_Q8_1_MMQ];
292
 
293
  #pragma unroll
294
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
295
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
296
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
297
+ }
298
 
299
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
300
+ (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
301
+ x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
302
+ }
303
  }
304
  }
305
  }
306
 
307
  template <int mmq_x, int mmq_y, int nwarps>
308
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
309
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
310
  #ifdef INT8_MMA_AVAILABLE
311
 
312
  typedef mma_int_A_I16K8 mma_A;
 
324
  const int * y_qs = (const int *) y + 4;
325
  const half2 * y_ds = (const half2 *) y;
326
 
327
+ mma_A A[ntx][4];
328
+ float dA[ntx][mma_C::ne/2][4];
329
 
330
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
331
 
332
  #pragma unroll
333
  for (int n = 0; n < ntx; ++n) {
334
  #pragma unroll
335
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
336
+ const int k0 = k00 + k01;
 
 
337
 
338
+ #pragma unroll
339
+ for (int l = 0; l < mma_A::ne; ++l) {
340
+ const int i = i0 + n*mma_A::I + mma_A::get_i(l);
341
+ const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0;
342
+ const int shift = 4*(mma_A::get_k(l) / QI4_0);
343
+
344
+ A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
345
+ }
346
 
347
  #pragma unroll
348
+ for (int l = 0; l < mma_C::ne/2; ++l) {
349
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
350
 
351
+ dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)];
352
+ }
353
  }
354
  }
355
 
356
  #pragma unroll
357
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
358
+ #pragma unroll
359
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
360
+ mma_B B;
361
+ float dB[mma_C::ne/2];
362
 
363
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
364
 
365
  #pragma unroll
366
+ for (int l = 0; l < mma_C::ne/2; ++l) {
367
+ const int j = j0 + mma_C::get_j(l);
368
 
369
+ dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
370
+ }
371
 
372
  #pragma unroll
373
+ for (int n = 0; n < ntx; ++n) {
374
+ mma_C C;
375
+ C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B);
376
 
377
  #pragma unroll
378
+ for (int l = 0; l < mma_C::ne; ++l) {
379
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l];
380
+ }
381
  }
382
  }
383
  }
 
442
 
443
  template <int mmq_x, int mmq_y, int nwarps>
444
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
445
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
446
 
447
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
448
  const int * x_qs = (const int *) x;
 
450
  const int * y_qs = (const int *) y + 4;
451
  const half2 * y_ds = (const half2 *) y;
452
 
453
+ // #pragma unroll
454
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
455
+ const int k0 = k00 + k01;
456
+
457
  #pragma unroll
458
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
459
+ const int j = j0 + threadIdx.y;
460
 
461
  #pragma unroll
462
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
463
+ const int i = i0 + threadIdx.x;
464
 
465
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
466
 
467
+ int u[2*VDR_Q4_1_Q8_1_MMQ];
468
 
469
  #pragma unroll
470
+ for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
471
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
472
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
473
+ }
474
 
475
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
476
+ (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
477
+ x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
478
+ }
479
  }
480
  }
481
  }
482
 
483
  template <int mmq_x, int mmq_y, int nwarps>
484
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
485
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
486
  #ifdef INT8_MMA_AVAILABLE
487
 
488
  typedef mma_int_A_I16K8 mma_A;
 
501
  const int * y_qs = (const int *) y + 4;
502
  const half2 * y_ds = (const half2 *) y;
503
 
504
+ mma_A A[ntx][4];
505
+ half2 dmA[ntx][mma_C::ne/2][4];
506
 
507
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
508
 
509
  #pragma unroll
510
  for (int n = 0; n < ntx; ++n) {
511
+ #pragma unroll
512
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
513
+ const int k0 = k00 + k01;
514
+
515
+ A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1);
516
+ A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F;
517
+ A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F;
518
+ A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F;
519
+ A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F;
520
 
521
  #pragma unroll
522
+ for (int l = 0; l < mma_C::ne/2; ++l) {
523
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
524
 
525
+ dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)];
526
+ }
527
  }
528
  }
529
 
530
  #pragma unroll
531
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
532
+ #pragma unroll
533
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
534
+ mma_B B;
535
+ half2 dsB[mma_C::ne/2];
536
 
537
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
538
 
539
  #pragma unroll
540
+ for (int l = 0; l < mma_C::ne/2; ++l) {
541
+ const int j = j0 + mma_C::get_j(l);
542
 
543
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
544
+ }
545
 
546
  #pragma unroll
547
+ for (int n = 0; n < ntx; ++n) {
548
+ mma_C C;
549
+ C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B);
550
 
551
  #pragma unroll
552
+ for (int l = 0; l < mma_C::ne; ++l) {
553
+ const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2];
554
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
555
+ }
556
  }
557
  }
558
  }
 
605
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
606
 
607
  #ifdef INT8_MMA_AVAILABLE
608
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
609
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
610
  #else
611
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
612
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 
627
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
628
 
629
  #ifdef INT8_MMA_AVAILABLE
630
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
631
  #else
632
  x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
633
  #endif // INT8_MMA_AVAILABLE
634
  }
635
  }
636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
638
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
639
 
 
675
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
676
 
677
  #ifdef INT8_MMA_AVAILABLE
678
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
679
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
680
  #else
681
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
682
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 
697
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
698
 
699
  #ifdef INT8_MMA_AVAILABLE
700
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
701
  #else
702
  x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
703
  #endif // INT8_MMA_AVAILABLE
704
  }
705
  }
706
 
707
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
708
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
709
+
710
+ #ifdef INT8_MMA_AVAILABLE
711
+ int * x_qs = (int *) x_tile;
712
+ float * x_df = (float *) (x_tile + 2*WARP_SIZE);
713
+ #else
714
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
715
+ int * x_qs = (int *) x_tile;
716
+ float * x_df = (float *) (x_qs + txs.qs);
717
+ #endif // INT8_MMA_AVAILABLE
718
+
719
+ const int kbx = threadIdx.x / QI8_0;
720
+ const int kqsx = threadIdx.x % QI8_0;
721
+
722
+ #pragma unroll
723
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
724
+ int i = i0 + threadIdx.y;
725
+
726
+ if (need_check) {
727
+ i = min(i, i_max);
728
+ }
729
+
730
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
731
+
732
+ #ifdef INT8_MMA_AVAILABLE
733
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
734
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
735
+ #else
736
+ x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
737
+ x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
738
+ #endif // INT8_MMA_AVAILABLE
739
+ }
740
+
741
+ const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
742
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
743
+
744
+ #pragma unroll
745
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
746
+ int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
747
+
748
+ if (need_check) {
749
+ i = min(i, i_max);
750
+ }
751
+
752
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
753
+
754
+ #ifdef INT8_MMA_AVAILABLE
755
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
756
+ #else
757
+ x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
758
+ #endif // INT8_MMA_AVAILABLE
759
+ }
760
+ }
761
+
762
  template <int mmq_x, int mmq_y, int nwarps>
763
+ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
764
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
765
 
766
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
767
  const int * x_qs = (const int *) x;
768
+ const float * x_df = (const float *) x_qs + txs.qs;
769
  const int * y_qs = (const int *) y + 4;
770
+ const float * y_df = (const float *) y;
771
+
772
+ // #pragma unroll
773
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
774
+ const int k0 = k00 + k01;
775
 
776
  #pragma unroll
777
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
778
+ const int j = j0 + threadIdx.y;
779
 
780
  #pragma unroll
781
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
782
+ const int i = i0 + threadIdx.x;
783
 
784
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
785
+ (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
786
+ x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
787
+ }
788
  }
789
  }
790
  }
791
 
792
  template <int mmq_x, int mmq_y, int nwarps>
793
+ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
794
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
795
  #ifdef INT8_MMA_AVAILABLE
796
 
797
  typedef mma_int_A_I16K8 mma_A;
 
805
  y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
806
 
807
  const int * x_qs = (const int *) x;
808
+ const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
809
  const int * y_qs = (const int *) y + 4;
810
+ const float * y_df = (const float *) y;
811
 
812
+ mma_A A[ntx][WARP_SIZE/QI8_0];
813
+ float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
814
 
815
+ const int i0 = (threadIdx.y/ntx)*rows_per_warp;
816
 
817
  #pragma unroll
818
  for (int n = 0; n < ntx; ++n) {
 
 
819
  #pragma unroll
820
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
821
+ const int k0 = k00 + k01;
822
 
823
+ A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
824
  }
 
 
 
 
 
 
 
 
825
 
826
  #pragma unroll
827
  for (int l = 0; l < mma_C::ne/2; ++l) {
828
+ const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
 
 
829
 
830
  #pragma unroll
831
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
832
+ const int k0 = k00 + k01;
 
833
 
834
+ dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
 
 
 
835
  }
836
  }
837
  }
 
 
 
 
 
838
 
839
+ #pragma unroll
840
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
841
+ #pragma unroll
842
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
843
+ const int k0 = k00 + k01;
844
 
845
+ mma_B B;
846
+ float dB[mma_C::ne/2];
 
 
 
 
 
 
847
 
848
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
 
849
 
850
  #pragma unroll
851
+ for (int l = 0; l < mma_C::ne/2; ++l) {
852
+ const int j = j0 + mma_C::get_j(l);
 
 
 
 
 
 
853
 
854
+ dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
855
+ }
 
 
 
 
856
 
857
+ #pragma unroll
858
+ for (int n = 0; n < ntx; ++n) {
859
+ mma_C C;
860
+ C.mma_K8(A[n][k01/QI8_0], B);
861
 
862
  #pragma unroll
863
+ for (int l = 0; l < mma_C::ne; ++l) {
864
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
865
+ }
866
+ }
 
867
  }
868
+ }
 
 
 
 
869
  #else
870
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
871
+ NO_DEVICE_CODE;
872
  #endif // INT8_MMA_AVAILABLE
 
873
  }
874
 
875
  template <int mmq_x, int mmq_y, int nwarps>
876
+ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
877
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
878
 
879
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
880
  const int * x_qs = (const int *) x;
881
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
882
  const int * y_qs = (const int *) y + 4;
883
+ const half2 * y_ds = (const half2 *) y;
884
+
885
+ // #pragma unroll
886
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
887
+ const int k0 = k00 + k01;
888
 
889
  #pragma unroll
890
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
891
+ const int j = j0 + threadIdx.y;
892
 
893
  #pragma unroll
894
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
895
+ const int i = i0 + threadIdx.x;
896
 
897
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
898
+ (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
899
+ x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
900
+ }
901
  }
902
  }
903
  }
904
 
905
  template <int mmq_x, int mmq_y, int nwarps>
906
+ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
907
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
908
  #ifdef INT8_MMA_AVAILABLE
909
 
910
  typedef mma_int_A_I16K8 mma_A;
 
918
  y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
919
 
920
  const int * x_qs = (const int *) x;
921
+ const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
922
  const int * y_qs = (const int *) y + 4;
923
+ const half2 * y_dm = (const half2 *) y;
924
 
925
+ mma_A A[ntx][WARP_SIZE/QI8_1];
926
+ half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
927
 
928
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
929
 
930
  #pragma unroll
931
  for (int n = 0; n < ntx; ++n) {
932
+ #pragma unroll
933
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
934
+ const int k0 = k00 + k01;
935
+
936
+ A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
937
+ }
938
 
939
  #pragma unroll
940
  for (int l = 0; l < mma_C::ne/2; ++l) {
941
  const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
942
 
943
+ #pragma unroll
944
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
945
+ const int k0 = k00 + k01;
946
+
947
+ dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1];
948
+ }
949
  }
950
  }
951
 
952
  #pragma unroll
953
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
954
+ #pragma unroll
955
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
956
+ const int k0 = k00 + k01;
957
+
958
+ mma_B B;
959
+ half2 dsB[mma_C::ne/2];
960
 
961
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
962
 
963
  #pragma unroll
964
+ for (int l = 0; l < mma_C::ne/2; ++l) {
965
+ const int j = j0 + mma_C::get_j(l);
966
 
967
+ dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
968
+ }
969
 
970
  #pragma unroll
971
+ for (int n = 0; n < ntx; ++n) {
972
+ mma_C C;
973
+ C.mma_K8(A[n][k01/QI8_1], B);
974
 
975
  #pragma unroll
976
+ for (int l = 0; l < mma_C::ne; ++l) {
977
+ const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2];
978
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
979
+ }
980
  }
981
  }
982
  }
 
991
 
992
  #ifdef INT8_MMA_AVAILABLE
993
  int * x_qs = (int *) x_tile;
994
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
995
  #else
996
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
997
  int * x_qs = (int *) x_tile;
998
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
999
  #endif // INT8_MMA_AVAILABLE
1000
 
 
1001
  const int kqsx = threadIdx.x % QI2_K;
1002
 
1003
  #pragma unroll
1004
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
1005
+ int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
1006
 
1007
  if (need_check) {
1008
  i = min(i, i_max);
1009
  }
1010
 
1011
+ const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
1012
 
1013
  const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1014
 
1015
  #pragma unroll
1016
  for (int l = 0; l < QR2_K; ++l) {
1017
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1018
 
1019
+ const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
 
 
 
 
 
1020
 
1021
  #ifdef INT8_MMA_AVAILABLE
1022
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1023
  #else
1024
+ x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
1025
  #endif // INT8_MMA_AVAILABLE
1026
  }
1027
 
 
1034
  #endif // FAST_FP16_AVAILABLE
1035
 
1036
  #ifdef INT8_MMA_AVAILABLE
1037
+ x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1038
  #else
1039
+ x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
1040
  #endif // INT8_MMA_AVAILABLE
1041
  }
1042
  }
1043
 
1044
  template <int mmq_x, int mmq_y, int nwarps>
1045
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1046
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1047
 
1048
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1049
  const int * x_qs = (const int *) x;
1050
  const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1051
  const int * y_qs = (const int *) y + 4;
1052
+ const half2 * y_ds = (const half2 *) y;
1053
 
1054
+ float2 y_df[mmq_x/nwarps];
1055
  #pragma unroll
1056
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1057
  const int j = j0 + threadIdx.y;
1058
 
1059
+ y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1060
+ }
1061
+
1062
  #pragma unroll
1063
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1064
+ const int k0 = k00 + k01;
1065
+
1066
+ #pragma unroll
1067
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1068
+ const int j = j0 + threadIdx.y;
1069
+
1070
+ #pragma unroll
1071
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1072
+ const int i = i0 + threadIdx.x;
1073
 
1074
+ if (k01 < WARP_SIZE/2) {
1075
+ constexpr int ns = 2;
1076
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1077
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1078
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1079
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1080
+ } else {
1081
+ constexpr int ns = 1;
1082
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1083
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1084
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1085
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1086
+ }
1087
+ }
1088
  }
1089
  }
1090
  }
1091
 
1092
  template <int mmq_x, int mmq_y, int nwarps>
1093
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1094
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1095
  #ifdef INT8_MMA_AVAILABLE
1096
 
1097
  typedef mma_int_A_I16K4 mma_A;
1098
+ typedef mma_int_A_I16K8 mma_A_K8;
1099
  typedef mma_int_B_J8K4 mma_B;
1100
  typedef mma_int_C_I16J8 mma_C;
1101
 
 
1106
  y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
1107
 
1108
  const int * x_qs = (const int *) x;
1109
+ const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
1110
  const int * y_qs = (const int *) y + 4;
1111
+ const half2 * y_ds = (const half2 *) y;
1112
 
1113
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1114
 
1115
+ mma_A A[ntx][8];
1116
+ float dA[ntx][mma_C::ne/2][8];
1117
+ float mA[ntx][mma_C::ne/2][8];
1118
 
1119
  #pragma unroll
1120
  for (int n = 0; n < ntx; ++n) {
1121
  #pragma unroll
1122
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1123
+ const int k0 = k00 + k01;
 
1124
 
1125
+ ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
 
1126
  }
1127
+ }
1128
 
1129
+ #pragma unroll
1130
+ for (int n = 0; n < ntx; ++n) {
1131
  #pragma unroll
1132
  for (int l = 0; l < mma_C::ne/2; ++l) {
1133
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1134
 
1135
  #pragma unroll
1136
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
1137
+ const int k0 = k00 + k01;
1138
 
1139
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
1140
+
1141
+ dA[n][l][k01/(QI8_1/2)] = dm.x;
1142
+ mA[n][l][k01/(QI8_1/2)] = dm.y;
1143
  }
1144
  }
1145
  }
1146
 
1147
  #pragma unroll
1148
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1149
+ float2 dB[mma_C::ne/2];
 
 
 
 
1150
 
1151
  #pragma unroll
1152
  for (int l = 0; l < mma_C::ne/2; ++l) {
1153
  const int j = j0 + mma_C::get_j(l);
1154
 
1155
+ dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1156
  }
1157
 
1158
+ #pragma unroll
1159
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1160
+ mma_B B[2];
1161
+
1162
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1163
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
1164
+
1165
+ mma_C Cm[2];
1166
+ if (k01 >= WARP_SIZE * 3/4) {
1167
+ mma_A A1;
1168
+ A1.x[0] = 0x01010101;
1169
+ A1.x[1] = 0x01010101;
1170
+ Cm[0].mma_K4(A1, B[0]);
1171
+ Cm[1].mma_K4(A1, B[1]);
1172
+ }
1173
 
1174
  #pragma unroll
1175
+ for (int n = 0; n < ntx; ++n) {
1176
+ mma_C Cd[2];
1177
 
1178
+ Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
1179
+ Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
1180
 
1181
  #pragma unroll
1182
+ for (int l = 0; l < mma_C::ne; ++l) {
1183
+ float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1184
+ if (k01 >= WARP_SIZE * 3/4) {
1185
+ tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1186
+ }
1187
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
1188
+ }
1189
+ }
1190
+ }
1191
+
1192
+ #pragma unroll
1193
+ for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
1194
+ float2 sB[mma_C::ne/2];
1195
+
1196
+ #pragma unroll
1197
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1198
+ const int j = j0 + mma_C::get_j(l);
1199
+
1200
+ sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1201
+ }
1202
+
1203
+ #pragma unroll
1204
+ for (int n = 0; n < ntx; ++n) {
1205
+ #pragma unroll
1206
+ for (int l = 0; l < mma_C::ne; ++l) {
1207
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
1208
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
1209
+ }
1210
  }
1211
  }
1212
  }
 
1222
  #ifdef INT8_MMA_AVAILABLE
1223
  int * x_qs = (int *) x_tile;
1224
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1225
+ int * x_sc = (int *) (x_df + 1);
1226
  #else
1227
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1228
  int * x_qs = (int *) x_tile;
 
1230
  int * x_sc = (int *) (x_df + txs.dm);
1231
  #endif // INT8_MMA_AVAILABLE
1232
 
 
1233
  const int kqsx = threadIdx.x % QI3_K;
1234
 
1235
  #pragma unroll
1236
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
1237
+ int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
1238
 
1239
  if (need_check) {
1240
  i = min(i, i_max);
1241
  }
1242
 
1243
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1244
 
1245
  const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1246
  const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1247
 
1248
  #pragma unroll
1249
  for (int l = 0; l < QR3_K; ++l) {
1250
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1251
 
1252
  const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
1253
  const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
1254
 
1255
+ const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
 
 
 
 
1256
 
1257
  #ifdef INT8_MMA_AVAILABLE
1258
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1259
  #else
1260
+ x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
1261
  #endif // INT8_MMA_AVAILABLE
1262
  }
1263
  }
1264
 
 
 
 
1265
  #pragma unroll
1266
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1267
+ int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
1268
 
1269
  if (need_check) {
1270
  i = min(i, i_max);
1271
  }
1272
 
1273
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1274
 
1275
  #ifdef INT8_MMA_AVAILABLE
1276
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d;
1277
  #else
1278
+ x_df[i] = bxi->d;
1279
  #endif // INT8_MMA_AVAILABLE
1280
  }
1281
 
1282
  #pragma unroll
1283
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1284
+ int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
1285
 
1286
  if (need_check) {
1287
  i = min(i, i_max);
1288
  }
1289
 
1290
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1291
 
1292
+ const int ksc = threadIdx.x % (WARP_SIZE/8);
1293
 
1294
  const int ksc_low = ksc % (QI3_K/8);
1295
  const int shift_low = 4 * (ksc / (QI3_K/8));
 
1302
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1303
 
1304
  #ifdef INT8_MMA_AVAILABLE
1305
+ x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc;
1306
  #else
1307
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1308
  #endif // INT8_MMA_AVAILABLE
1309
  }
1310
  }
1311
 
1312
  template <int mmq_x, int mmq_y, int nwarps>
1313
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1314
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1315
 
1316
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1317
  const int * x_qs = (const int *) x;
 
1320
  const int * y_qs = (const int *) y + 4;
1321
  const float * y_df = (const float *) y;
1322
 
1323
+ // #pragma unroll
1324
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1325
+ const int k0 = k00 + k01;
1326
 
1327
  #pragma unroll
1328
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1329
+ const int j = j0 + threadIdx.y;
1330
 
1331
+ #pragma unroll
1332
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1333
+ const int i = i0 + threadIdx.x;
1334
 
1335
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
1336
 
1337
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
1338
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1339
+ x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1340
+ }
1341
  }
1342
  }
1343
  }
1344
 
1345
  template <int mmq_x, int mmq_y, int nwarps>
1346
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
1347
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1348
  #ifdef INT8_MMA_AVAILABLE
1349
 
1350
  typedef mma_int_A_I16K4 mma_A;
1351
+ typedef mma_int_A_I16K8 mma_A_K8;
1352
  typedef mma_int_B_J8K4 mma_B;
1353
  typedef mma_int_C_I16J8 mma_C;
1354
 
 
1360
 
1361
  const int * x_qs = (const int *) x;
1362
  const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1363
+ const int * x_sc = (const int *) x_df + 1;
1364
  const int * y_qs = (const int *) y + 4;
1365
  const float * y_df = (const float *) y;
1366
 
1367
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1368
 
1369
+ mma_A A[ntx][8];
1370
+ int scA[ntx][mma_C::ne/2][8];
1371
  float dA[ntx][mma_C::ne/2];
1372
 
1373
  #pragma unroll
1374
  for (int n = 0; n < ntx; ++n) {
1375
  #pragma unroll
1376
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1377
+ const int k0 = k00 + k01;
 
1378
 
1379
+ ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
 
 
 
1380
  }
1381
 
1382
  #pragma unroll
1383
  for (int l = 0; l < mma_C::ne/2; ++l) {
1384
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1385
 
1386
+ #pragma unroll
1387
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
1388
+ const int k0 = k00 + k01;
1389
 
1390
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16];
1391
+ const int8_t * sc = (const int8_t *) &sc_packed;
 
1392
 
1393
  #pragma unroll
1394
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
1395
+ scA[n][l][k01/4 + ksc] = sc[ksc];
1396
+ }
1397
+ }
1398
 
1399
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K];
1400
  }
1401
  }
1402
 
1403
  #pragma unroll
1404
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1405
+ #pragma unroll
1406
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1407
+ mma_B B[2];
1408
+ float dB[mma_C::ne/2];
1409
 
1410
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1411
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
1412
 
1413
  #pragma unroll
1414
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1415
+ const int j = j0 + mma_C::get_j(l);
1416
 
1417
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1418
+ }
1419
 
1420
  #pragma unroll
1421
+ for (int n = 0; n < ntx; ++n) {
1422
+ mma_C C[2];
1423
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
1424
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
1425
 
1426
  #pragma unroll
1427
+ for (int l = 0; l < mma_C::ne; ++l) {
1428
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*
1429
+ (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]);
1430
+ }
1431
  }
1432
  }
1433
  }
 
1519
 
1520
  template <int mmq_x, int mmq_y, int nwarps>
1521
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1522
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1523
 
1524
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1525
  const int * x_qs = (const int *) x;
 
1528
  const int * y_qs = (const int *) y + 4;
1529
  const half2 * y_ds = (const half2 *) y;
1530
 
1531
+ // #pragma unroll
1532
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
1533
+ const int k0 = k00 + k01;
1534
+
1535
  #pragma unroll
1536
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1537
+ const int j = j0 + threadIdx.y;
1538
 
1539
  #pragma unroll
1540
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1541
+ const int i = i0 + threadIdx.x;
1542
 
1543
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
1544
 
1545
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1546
+ &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1547
+ x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1548
+ }
1549
  }
1550
  }
1551
  }
1552
 
1553
  template <int mmq_x, int mmq_y, int nwarps>
1554
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1555
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1556
  #ifdef INT8_MMA_AVAILABLE
1557
 
1558
  typedef mma_int_A_I16K8 mma_A;
 
1573
 
1574
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1575
 
1576
+ mma_A A[ntx][4];
1577
+ int scA[ntx][mma_C::ne/2][4];
1578
+ int mA[ntx][mma_C::ne/2][4];
1579
  half2 dmA[ntx][mma_C::ne/2];
1580
 
1581
  #pragma unroll
1582
  for (int n = 0; n < ntx; ++n) {
1583
  #pragma unroll
1584
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
1585
+ const int k0 = k00 + k01;
1586
+
1587
+ A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K);
1588
 
1589
  #pragma unroll
1590
  for (int l = 0; l < mma_A::ne; ++l) {
1591
+ A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F;
1592
+ A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F;
1593
  }
1594
  }
1595
 
1596
  #pragma unroll
1597
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1598
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1599
+
1600
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)];
1601
+ const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)];
1602
 
1603
+ const uint8_t * sc = (const uint8_t *) &sc_packed;
1604
+ const uint8_t * m = (const uint8_t *) &m_packed;
1605
 
1606
+ #pragma unroll
1607
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
1608
+ scA[n][l][ksc] = sc[ksc];
1609
+ mA[n][l][ksc] = m[ksc];
1610
  }
1611
  }
1612
 
 
1614
  for (int l = 0; l < mma_C::ne/2; ++l) {
1615
  const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
1616
 
1617
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K];
1618
  }
1619
  }
1620
 
 
1624
  float tmpm[ntx][mma_C::ne] = {{0.0f}};
1625
 
1626
  #pragma unroll
1627
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1628
  mma_B B;
1629
  half2 dsB[mma_C::ne/2];
1630
 
1631
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1632
 
1633
  #pragma unroll
1634
  for (int l = 0; l < mma_C::ne/2; ++l) {
1635
  const int j = j0 + mma_C::get_j(l);
1636
 
1637
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
1638
  }
1639
 
1640
  #pragma unroll
1641
  for (int n = 0; n < ntx; ++n) {
1642
  mma_C C;
1643
+ C.mma_K8(A[n][k01/8], B);
1644
 
1645
  #pragma unroll
1646
  for (int l = 0; l < mma_C::ne; ++l) {
1647
+ tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
1648
+ tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
1649
  }
1650
  }
1651
  }
 
1760
 
1761
  template <int mmq_x, int mmq_y, int nwarps>
1762
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1763
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1764
 
1765
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1766
  const int * x_qs = (const int *) x;
 
1769
  const int * y_qs = (const int *) y + 4;
1770
  const half2 * y_ds = (const half2 *) y;
1771
 
1772
+ // #pragma unroll
1773
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
1774
+ const int k0 = k00 + k01;
1775
+
1776
  #pragma unroll
1777
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1778
+ const int j = j0 + threadIdx.y;
1779
 
1780
  #pragma unroll
1781
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1782
+ const int i = i0 + threadIdx.x;
1783
 
1784
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
1785
 
1786
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1787
+ &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1788
+ x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1789
+ }
1790
  }
1791
  }
1792
  }
1793
 
1794
  template <int mmq_x, int mmq_y, int nwarps>
1795
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1796
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1797
  #ifdef INT8_MMA_AVAILABLE
1798
 
1799
  typedef mma_int_A_I16K8 mma_A;
 
1814
 
1815
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1816
 
1817
+ mma_A A[ntx][4];
1818
+ int scA[ntx][mma_C::ne/2][4];
1819
+ int mA[ntx][mma_C::ne/2][4];
1820
  half2 dmA[ntx][mma_C::ne/2];
1821
 
1822
  #pragma unroll
1823
  for (int n = 0; n < ntx; ++n) {
1824
  #pragma unroll
1825
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1826
+ const int k0 = k00 + k01;
1827
+
1828
+ A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K);
1829
+ }
1830
 
1831
  #pragma unroll
1832
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1833
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1834
+
1835
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)];
1836
+ const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)];
1837
 
1838
+ const uint8_t * sc = (const uint8_t *) &sc_packed;
1839
+ const uint8_t * m = (const uint8_t *) &m_packed;
1840
 
1841
+ #pragma unroll
1842
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
1843
+ scA[n][l][ksc] = sc[ksc];
1844
+ mA[n][l][ksc] = m[ksc];
1845
  }
1846
  }
1847
 
 
1849
  for (int l = 0; l < mma_C::ne/2; ++l) {
1850
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1851
 
1852
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K];
1853
  }
1854
  }
1855
 
 
1859
  float tmpm[ntx][mma_C::ne] = {{0.0f}};
1860
 
1861
  #pragma unroll
1862
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1863
+ const int k0 = k00 + k01;
1864
+
1865
  mma_B B;
1866
  half2 dsB[mma_C::ne/2];
1867
 
1868
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
1869
 
1870
  #pragma unroll
1871
  for (int l = 0; l < mma_C::ne/2; ++l) {
1872
  const int j = j0 + mma_C::get_j(l);
1873
 
1874
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
1875
  }
1876
 
1877
  #pragma unroll
1878
  for (int n = 0; n < ntx; ++n) {
1879
  mma_C C;
1880
+ C.mma_K8(A[n][k01/8], B);
1881
 
1882
  #pragma unroll
1883
  for (int l = 0; l < mma_C::ne; ++l) {
1884
+ tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
1885
+ tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
1886
  }
1887
  }
1888
  }
 
1989
 
1990
  template <int mmq_x, int mmq_y, int nwarps>
1991
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1992
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1993
 
1994
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1995
  const int * x_qs = (const int *) x;
 
1998
  const int * y_qs = (const int *) y + 4;
1999
  const float * y_df = (const float *) y;
2000
 
2001
+ // #pragma unroll
2002
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
2003
+ const int k0 = k00 + k01;
2004
+
2005
  #pragma unroll
2006
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2007
+ const int j = j0 + threadIdx.y;
2008
 
2009
  #pragma unroll
2010
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2011
+ const int i = i0 + threadIdx.x;
2012
 
2013
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
2014
 
2015
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
2016
+ &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
2017
+ x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
2018
+ }
2019
  }
2020
  }
2021
  }
2022
 
2023
  template <int mmq_x, int mmq_y, int nwarps>
2024
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2025
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
2026
  #ifdef INT8_MMA_AVAILABLE
2027
 
2028
  typedef mma_int_A_I16K4 mma_A;
 
2043
 
2044
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
2045
 
2046
+ mma_A A[ntx][8];
2047
+ int scA[ntx][mma_C::ne/2][8];
2048
  float dA[ntx][mma_C::ne/2];
2049
 
2050
  #pragma unroll
2051
  for (int n = 0; n < ntx; ++n) {
2052
  #pragma unroll
2053
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2054
+ const int k0 = k00 + k01;
2055
+
2056
+ A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
2057
+ A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
2058
+ }
2059
+
2060
+ #pragma unroll
2061
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
2062
+ const int k0 = k00 + k01;
2063
 
2064
  #pragma unroll
2065
  for (int l = 0; l < mma_C::ne/2; ++l) {
2066
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
2067
 
2068
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
2069
+ const int8_t * sc = (const int8_t *) &sc_packed;
2070
 
2071
+ #pragma unroll
2072
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
2073
+ scA[n][l][k01/4 + ksc] = sc[ksc];
2074
+ }
2075
  }
2076
  }
2077
 
 
2079
  for (int l = 0; l < mma_C::ne/2; ++l) {
2080
  const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
2081
 
2082
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2083
  }
2084
  }
2085
 
 
2088
  float tmp[ntx][mma_C::ne] = {{0.0f}};
2089
 
2090
  #pragma unroll
2091
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2092
  mma_B B[2];
2093
  float dB[mma_C::ne/2];
2094
 
2095
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
2096
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
 
2097
 
2098
  #pragma unroll
2099
  for (int l = 0; l < mma_C::ne/2; ++l) {
2100
  const int j = j0 + mma_C::get_j(l);
2101
 
2102
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2103
  }
2104
 
2105
  #pragma unroll
2106
  for (int n = 0; n < ntx; ++n) {
2107
  mma_C C[2];
2108
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
2109
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
2110
 
2111
  #pragma unroll
2112
  for (int l = 0; l < mma_C::ne; ++l) {
2113
+ tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
2114
  }
2115
  }
2116
  }
 
2158
  const int2 v = get_int_from_table_16(aux_q4);
2159
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2160
  #ifdef INT8_MMA_AVAILABLE
2161
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2162
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2163
  #else
2164
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2165
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
 
2180
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2181
 
2182
  #ifdef INT8_MMA_AVAILABLE
2183
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2184
  #else
2185
  x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
2186
  #endif // INT8_MMA_AVAILABLE
 
2216
  const int2 v = get_int_from_table_16(aux_q4);
2217
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2218
  #ifdef INT8_MMA_AVAILABLE
2219
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2220
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2221
  #else
2222
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2223
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
 
2240
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2241
 
2242
  #ifdef INT8_MMA_AVAILABLE
2243
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2244
  #else
2245
  x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2246
  #endif // INT8_MMA_AVAILABLE
 
2336
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
2337
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2338
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
2339
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2340
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2341
  };
2342
 
2343
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2344
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
2345
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
2346
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
2347
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2348
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2349
  };
2350
 
2351
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 
2400
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
2401
  static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2402
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2403
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2404
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2405
  };
2406
 
2407
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2408
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
2409
  static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2410
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2411
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2412
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2413
  };
2414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2415
  template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
2416
  static __device__ void mul_mat_q_process_tile(
2417
  const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
 
2419
  const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
2420
 
2421
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
 
 
2422
  constexpr int mmq_y = get_mmq_y_device();
 
2423
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
2424
 
2425
  extern __shared__ char data_mul_mat_q[];
 
2434
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
2435
  #endif // INT8_MMA_AVAILABLE
2436
 
2437
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2438
 
2439
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2440
 
 
2443
 
2444
  const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
2445
 
2446
+ for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
 
2447
  load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
2448
 
2449
+ {
2450
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
 
2451
  #pragma unroll
2452
  for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2453
  int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
2454
 
2455
  tile_y[l] = by0[l];
2456
  }
2457
+ }
2458
 
2459
+ __syncthreads();
2460
 
2461
+ vec_dot(tile_x, tile_y, sum, 0);
2462
+
2463
+ __syncthreads();
2464
+
2465
+ {
2466
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
2467
+ #pragma unroll
2468
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2469
+ int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
2470
 
2471
+ tile_y[l] = by0[l];
2472
+ }
2473
  }
2474
+
2475
+ __syncthreads();
2476
+
2477
+ vec_dot(tile_x, tile_y, sum, WARP_SIZE);
2478
+
2479
+ __syncthreads();
2480
  }
2481
 
2482
  if (fixup) {
 
2512
  }
2513
 
2514
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
 
2515
  constexpr int mmq_y = get_mmq_y_device();
2516
 
2517
  // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
 
2526
  #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
2527
 
2528
  const int64_t blocks_per_ne00 = ne00 / qk;
2529
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2530
 
2531
  const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
2532
  const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
 
2535
  int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
2536
  int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
2537
 
2538
+ kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
2539
+ kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
2540
 
2541
  // kb0 == k index when doing the matrix multiplication for an output tile.
2542
  int kb0_start = kbc % blocks_per_ne00;
 
2577
 
2578
  constexpr int mmq_y = get_mmq_y_device();
2579
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2580
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
2581
  const int64_t blocks_per_ne00 = ne00 / qk;
2582
 
2583
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
2587
 
2588
  bool any_fixup = false;
2589
 
2590
+ const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x);
2591
+ const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
2592
+
2593
+ int64_t kbc_0;
2594
+ int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
2595
 
2596
  for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
2597
+ kbc_0 = kbc_stop_0;
2598
+ kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
2599
 
2600
+ const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter;
2601
+ const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
2602
 
2603
  // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
2604
  if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
ggml/src/ggml-cuda/quantize.cu CHANGED
@@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
37
  reinterpret_cast<half&>(y[ib].ds.y) = sum;
38
  }
39
 
40
- template <bool need_sum>
41
  static __global__ void quantize_mmq_q8_1(
42
  const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
43
 
44
- const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
 
 
45
 
46
  if (ix0 >= kx0_padded) {
47
  return;
48
  }
49
 
 
 
50
  const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
51
 
52
  block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
53
 
54
- const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
55
- const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
56
- const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
57
-
58
- const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
59
- float amax = fabsf(xi);
60
-
61
- amax = warp_reduce_max(amax);
 
 
 
 
 
 
 
 
62
 
63
  float sum;
64
- if (need_sum) {
65
- sum = warp_reduce_sum(xi);
 
 
 
 
 
 
66
  }
67
 
68
- const float d = amax / 127;
69
- const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
 
 
 
 
70
 
71
- y[ib].qs[iqs] = q;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- if (iqs % QK8_1 != 0) {
74
  return;
75
  }
76
 
77
- if (need_sum) {
78
- y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
 
 
 
 
 
 
79
  } else {
80
- ((float *) y[ib].ds)[iqs/QK8_1] = d;
81
  }
82
  }
83
 
@@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
101
 
102
  GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
103
 
104
- const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
105
  const dim3 num_blocks(block_num_x, kx1, channels);
106
- const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
107
- if (mmq_need_sum(type_x)) {
108
- quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
109
- } else {
110
- quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
112
  }
 
37
  reinterpret_cast<half&>(y[ib].ds.y) = sum;
38
  }
39
 
40
+ template <mmq_q8_1_ds_layout ds_layout>
41
  static __global__ void quantize_mmq_q8_1(
42
  const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
43
 
44
+ constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
45
+ constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
46
+
47
+ const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
48
 
49
  if (ix0 >= kx0_padded) {
50
  return;
51
  }
52
 
53
+ const float4 * x4 = (const float4 *) x;
54
+
55
  const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
56
 
57
  block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
58
 
59
+ const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
60
+ const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
61
+ const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
62
+
63
+ // Load 4 floats per thread and calculate max. abs. value between them:
64
+ const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
65
+ float amax = fabsf(xi.x);
66
+ amax = fmaxf(amax, fabsf(xi.y));
67
+ amax = fmaxf(amax, fabsf(xi.z));
68
+ amax = fmaxf(amax, fabsf(xi.w));
69
+
70
+ // Exchange max. abs. value between vals_per_scale/4 threads.
71
+ #pragma unroll
72
+ for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
73
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
74
+ }
75
 
76
  float sum;
77
+ if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
78
+ sum = xi.x + xi.y + xi.z + xi.w;
79
+
80
+ // Exchange calculate sum across vals_per_sum/4 threads.
81
+ #pragma unroll
82
+ for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
83
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
84
+ }
85
  }
86
 
87
+ const float d_inv = 127.0f / amax;
88
+ char4 q;
89
+ q.x = roundf(xi.x*d_inv);
90
+ q.y = roundf(xi.y*d_inv);
91
+ q.z = roundf(xi.z*d_inv);
92
+ q.w = roundf(xi.w*d_inv);
93
 
94
+ // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
95
+ char4 * yqs4 = (char4 *) y[ib].qs;
96
+ yqs4[iqs/4] = q;
97
+
98
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
99
+ if (iqs % 16 != 0 || iqs >= 96) {
100
+ return;
101
+ }
102
+
103
+ y[ib].d2s6[2 + iqs/16] = sum;
104
+
105
+ if (iqs % 64 != 0) {
106
+ return;
107
+ }
108
+
109
+ const float d = 1.0f / d_inv;
110
+
111
+ y[ib].d2s6[iqs/64] = d;
112
 
 
113
  return;
114
  }
115
 
116
+ if (iqs % 32 != 0) {
117
+ return;
118
+ }
119
+
120
+ const float d = 1.0f / d_inv;
121
+
122
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
123
+ y[ib].ds4[iqs/32] = make_half2(d, sum);
124
  } else {
125
+ y[ib].d4[iqs/32] = d;
126
  }
127
  }
128
 
 
146
 
147
  GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
148
 
149
+ const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
150
  const dim3 num_blocks(block_num_x, kx1, channels);
151
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
152
+ switch (mmq_get_q8_1_ds_layout(type_x)) {
153
+ case MMQ_Q8_1_DS_LAYOUT_D4:
154
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
155
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
156
+ break;
157
+ case MMQ_Q8_1_DS_LAYOUT_DS4:
158
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
159
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
160
+ break;
161
+ case MMQ_Q8_1_DS_LAYOUT_D2S6:
162
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
163
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
164
+ break;
165
+ default:
166
+ GGML_ASSERT(false);
167
+ break;
168
  }
169
  }
ggml/src/ggml-cuda/quantize.cuh CHANGED
@@ -5,7 +5,11 @@
5
 
6
  #include <cstdint>
7
 
8
- #define CUDA_QUANTIZE_BLOCK_SIZE 256
 
 
 
 
9
 
10
  typedef void (*quantize_cuda_t)(
11
  const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
 
5
 
6
  #include <cstdint>
7
 
8
+ #define CUDA_QUANTIZE_BLOCK_SIZE 256
9
+ #define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
10
+
11
+ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
12
+ static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
13
 
14
  typedef void (*quantize_cuda_t)(
15
  const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
ggml/src/ggml-cuda/vecdotq.cuh CHANGED
@@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
189
  }
190
 
191
  #define VDR_Q2_K_Q8_1_MMVQ 1
192
- #define VDR_Q2_K_Q8_1_MMQ 2
193
 
194
  // contiguous v/x values
195
  static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
@@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
219
  return dm2f.x*sumf_d - dm2f.y*sumf_m;
220
  }
221
 
222
- // contiguous u/y values
 
223
  static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
224
- const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
225
 
226
- float sumf_d = 0.0f;
227
- float sumf_m = 0.0f;
228
 
229
  #pragma unroll
230
- for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
231
- const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
232
- int sumi_d = 0;
233
- int sumi_m = 0;
 
 
234
 
235
- const int vi0 = v[i0/(QI8_1/2)];
236
  #pragma unroll
237
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
238
- const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
239
- sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product
240
- sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
 
 
 
 
241
  }
 
242
 
243
- sumf_d += dm2f.x * sumi_d;
244
- sumf_m += dm2f.y * sumi_m;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  }
246
 
247
- return d8*(sumf_d - sumf_m);
248
  }
249
 
250
  #define VDR_Q3_K_Q8_1_MMVQ 1
@@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
283
  return d3 * sumf;
284
  }
285
 
286
- // contiguous u/y values
287
  static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
288
  const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
289
  const float & d3, const float & d8) {
@@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
296
 
297
  #pragma unroll
298
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
299
- const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
300
- sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
301
  }
302
 
303
  sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
334
  return dm4f.x*sumf_d - dm4f.y*sumf_m;
335
  }
336
 
337
- // contiguous u/y values
338
  static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
339
  const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
340
  const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
397
  return dm5f.x*sumf_d - dm5f.y*sumf_m;
398
  }
399
 
400
- // contiguous u/y values
401
  static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
402
  const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
403
  const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
451
  return d*sumf;
452
  }
453
 
454
- // contiguous u/y values
455
  static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
456
  const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
457
  const float & d6, const float * __restrict__ d8) {
458
 
459
  float sumf_d = 0.0f;
460
 
 
 
 
461
  #pragma unroll
462
  for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
463
  int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
@@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
471
  sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
472
  }
473
 
474
- sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
475
  }
476
 
477
  return d6 * sumf_d;
 
189
  }
190
 
191
  #define VDR_Q2_K_Q8_1_MMVQ 1
192
+ #define VDR_Q2_K_Q8_1_MMQ 4
193
 
194
  // contiguous v/x values
195
  static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
 
219
  return dm2f.x*sumf_d - dm2f.y*sumf_m;
220
  }
221
 
222
+ // contiguous v/x + u/y values
223
+ template <int ns8>
224
  static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
225
+ const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
226
 
227
+ float sumf = 0.0f;
228
+ float sumf_d8 = 0.0f;
229
 
230
  #pragma unroll
231
+ for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
232
+ const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
233
+ int sumi_d0 = 0;
234
+
235
+ const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
236
+ int sumi_d1 = 0;
237
 
 
238
  #pragma unroll
239
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
240
+ sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
241
+ }
242
+ sumf_d8 += dm2f0.x * sumi_d0;
243
+
244
+ #pragma unroll
245
+ for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
246
+ sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
247
  }
248
+ sumf_d8 += dm2f1.x * sumi_d1;
249
 
250
+ if (i0/QI8_1 < ns8) {
251
+ const float2 s8f = __half22float2(s8[i0/QI8_1]);
252
+ sumf -= dm2f0.y*s8f.x;
253
+ sumf -= dm2f1.y*s8f.y;
254
+ } else {
255
+ int sumi_m0 = 0;
256
+ #pragma unroll
257
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
258
+ sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
259
+ }
260
+ sumf_d8 -= dm2f0.y * sumi_m0;
261
+
262
+ int sumi_m1 = 0;
263
+ #pragma unroll
264
+ for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
265
+ sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
266
+ }
267
+ sumf_d8 -= dm2f1.y * sumi_m1;
268
+ }
269
  }
270
 
271
+ return sumf + d8*sumf_d8;
272
  }
273
 
274
  #define VDR_Q3_K_Q8_1_MMVQ 1
 
307
  return d3 * sumf;
308
  }
309
 
310
+ // contiguous v/x + u/y values
311
  static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
312
  const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
313
  const float & d3, const float & d8) {
 
320
 
321
  #pragma unroll
322
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
323
+ sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
 
324
  }
325
 
326
  sumi += sumi_sc * scales[i0 / (QI8_1/2)];
 
357
  return dm4f.x*sumf_d - dm4f.y*sumf_m;
358
  }
359
 
360
+ // contiguous v/x + u/y values
361
  static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
362
  const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
363
  const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
 
420
  return dm5f.x*sumf_d - dm5f.y*sumf_m;
421
  }
422
 
423
+ // contiguous v/x + u/y values
424
  static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
425
  const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
426
  const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
 
474
  return d*sumf;
475
  }
476
 
477
+ // contiguous v/x + u/y values
478
  static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
479
  const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
480
  const float & d6, const float * __restrict__ d8) {
481
 
482
  float sumf_d = 0.0f;
483
 
484
+ const int sc_packed = get_int_b4(sc, 0);
485
+ const int8_t * sc_reg = (const int8_t *) &sc_packed;
486
+
487
  #pragma unroll
488
  for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
489
  int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
 
497
  sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
498
  }
499
 
500
+ sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
501
  }
502
 
503
  return d6 * sumf_d;