Spaces:
Running
Running
Commit
·
a3fe534
1
Parent(s):
6dbe297
CUDA: optimize and refactor MMQ (llama/8416)
Browse files* CUDA: optimize and refactor MMQ
* explicit q8_1 memory layouts, add documentation
- ggml/src/ggml-cuda/mma.cuh +4 -0
- ggml/src/ggml-cuda/mmq.cuh +704 -615
- ggml/src/ggml-cuda/quantize.cu +82 -25
- ggml/src/ggml-cuda/quantize.cuh +5 -1
- ggml/src/ggml-cuda/vecdotq.cuh +49 -23
ggml/src/ggml-cuda/mma.cuh
CHANGED
|
@@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
|
|
| 70 |
}
|
| 71 |
#endif // defined(INT8_MMA_AVAILABLE)
|
| 72 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
};
|
| 74 |
|
| 75 |
struct mma_int_B_J8K4 {
|
|
|
|
| 70 |
}
|
| 71 |
#endif // defined(INT8_MMA_AVAILABLE)
|
| 72 |
}
|
| 73 |
+
|
| 74 |
+
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
|
| 75 |
+
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
|
| 76 |
+
}
|
| 77 |
};
|
| 78 |
|
| 79 |
struct mma_int_B_J8K4 {
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -8,18 +8,70 @@
|
|
| 8 |
#include <cstdint>
|
| 9 |
|
| 10 |
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
|
|
|
|
|
|
| 11 |
|
| 12 |
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
|
| 13 |
-
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 14 |
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
struct block_q8_1_mmq {
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
};
|
| 20 |
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
| 21 |
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
struct tile_x_sizes {
|
| 24 |
int qs;
|
| 25 |
int dm;
|
|
@@ -79,49 +131,46 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
| 79 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 80 |
}
|
| 81 |
|
| 82 |
-
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0
|
| 83 |
-
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1
|
| 84 |
-
#define
|
| 85 |
-
#define
|
| 86 |
-
#define
|
| 87 |
-
#define
|
| 88 |
-
#define
|
| 89 |
-
#define
|
| 90 |
-
#define
|
| 91 |
-
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 92 |
|
| 93 |
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
| 94 |
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
|
| 95 |
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
| 96 |
-
type == GGML_TYPE_Q5_0 ?
|
| 97 |
-
type == GGML_TYPE_Q5_1 ?
|
| 98 |
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
| 99 |
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
| 100 |
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
| 101 |
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
| 102 |
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
| 103 |
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
| 104 |
-
type == GGML_TYPE_IQ4_XS ?
|
| 105 |
-
type == GGML_TYPE_IQ4_NL ?
|
| 106 |
tile_x_sizes{0, 0, 0};
|
| 107 |
}
|
| 108 |
|
| 109 |
-
#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0
|
| 110 |
-
#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1
|
| 111 |
-
#define
|
| 112 |
-
#define
|
| 113 |
-
#define
|
| 114 |
-
#define
|
| 115 |
-
#define
|
| 116 |
-
#define
|
| 117 |
-
#define
|
| 118 |
-
#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
|
| 119 |
|
| 120 |
static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
|
| 121 |
static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
|
| 122 |
-
static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
|
| 123 |
-
static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
|
| 124 |
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
|
|
|
| 125 |
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
|
| 126 |
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
| 127 |
static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
|
|
@@ -131,21 +180,20 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
| 131 |
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
| 132 |
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
|
| 133 |
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
|
| 134 |
-
type == GGML_TYPE_Q5_0 ?
|
| 135 |
-
type == GGML_TYPE_Q5_1 ?
|
| 136 |
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
| 137 |
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
| 138 |
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
| 139 |
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
|
| 140 |
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
|
| 141 |
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
| 142 |
-
type == GGML_TYPE_IQ4_XS ?
|
| 143 |
-
type == GGML_TYPE_IQ4_NL ?
|
| 144 |
0;
|
| 145 |
}
|
| 146 |
|
| 147 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
| 148 |
-
#define MMQ_NWARPS 8
|
| 149 |
|
| 150 |
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
| 151 |
return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
|
|
@@ -218,7 +266,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 218 |
|
| 219 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 220 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
| 221 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 222 |
|
| 223 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 224 |
const int * x_qs = (const int *) x;
|
|
@@ -226,34 +274,39 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
| 226 |
const int * y_qs = (const int *) y + 4;
|
| 227 |
const half2 * y_ds = (const half2 *) y;
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
#pragma unroll
|
| 230 |
-
|
| 231 |
-
|
| 232 |
|
| 233 |
#pragma unroll
|
| 234 |
-
|
| 235 |
-
|
| 236 |
|
| 237 |
-
|
| 238 |
|
| 239 |
-
|
| 240 |
|
| 241 |
#pragma unroll
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
| 250 |
}
|
| 251 |
}
|
| 252 |
}
|
| 253 |
|
| 254 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 255 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
| 256 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 257 |
#ifdef INT8_MMA_AVAILABLE
|
| 258 |
|
| 259 |
typedef mma_int_A_I16K8 mma_A;
|
|
@@ -271,52 +324,60 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
|
| 271 |
const int * y_qs = (const int *) y + 4;
|
| 272 |
const half2 * y_ds = (const half2 *) y;
|
| 273 |
|
| 274 |
-
mma_A A[ntx];
|
| 275 |
-
float dA[ntx][mma_C::ne/2];
|
| 276 |
|
| 277 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 278 |
|
| 279 |
#pragma unroll
|
| 280 |
for (int n = 0; n < ntx; ++n) {
|
| 281 |
#pragma unroll
|
| 282 |
-
for (int
|
| 283 |
-
const int
|
| 284 |
-
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
| 285 |
-
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
#pragma unroll
|
| 291 |
-
|
| 292 |
-
|
| 293 |
|
| 294 |
-
|
|
|
|
| 295 |
}
|
| 296 |
}
|
| 297 |
|
| 298 |
#pragma unroll
|
| 299 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
| 302 |
|
| 303 |
-
|
| 304 |
|
| 305 |
#pragma unroll
|
| 306 |
-
|
| 307 |
-
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
|
| 312 |
#pragma unroll
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
|
| 317 |
#pragma unroll
|
| 318 |
-
|
| 319 |
-
|
|
|
|
| 320 |
}
|
| 321 |
}
|
| 322 |
}
|
|
@@ -381,7 +442,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 381 |
|
| 382 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 383 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
| 384 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 385 |
|
| 386 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 387 |
const int * x_qs = (const int *) x;
|
|
@@ -389,34 +450,39 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
| 389 |
const int * y_qs = (const int *) y + 4;
|
| 390 |
const half2 * y_ds = (const half2 *) y;
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
#pragma unroll
|
| 393 |
-
|
| 394 |
-
|
| 395 |
|
| 396 |
#pragma unroll
|
| 397 |
-
|
| 398 |
-
|
| 399 |
|
| 400 |
-
|
| 401 |
|
| 402 |
-
|
| 403 |
|
| 404 |
#pragma unroll
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
| 413 |
}
|
| 414 |
}
|
| 415 |
}
|
| 416 |
|
| 417 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 418 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
| 419 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 420 |
#ifdef INT8_MMA_AVAILABLE
|
| 421 |
|
| 422 |
typedef mma_int_A_I16K8 mma_A;
|
|
@@ -435,50 +501,58 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
|
| 435 |
const int * y_qs = (const int *) y + 4;
|
| 436 |
const half2 * y_ds = (const half2 *) y;
|
| 437 |
|
| 438 |
-
mma_A A[ntx];
|
| 439 |
-
half2 dmA[ntx][mma_C::ne/2];
|
| 440 |
|
| 441 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 442 |
|
| 443 |
#pragma unroll
|
| 444 |
for (int n = 0; n < ntx; ++n) {
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
#pragma unroll
|
| 452 |
-
|
| 453 |
-
|
| 454 |
|
| 455 |
-
|
|
|
|
| 456 |
}
|
| 457 |
}
|
| 458 |
|
| 459 |
#pragma unroll
|
| 460 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
| 463 |
|
| 464 |
-
|
| 465 |
|
| 466 |
#pragma unroll
|
| 467 |
-
|
| 468 |
-
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
|
| 473 |
#pragma unroll
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
|
| 478 |
#pragma unroll
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
|
|
|
| 482 |
}
|
| 483 |
}
|
| 484 |
}
|
|
@@ -531,8 +605,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 531 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 532 |
|
| 533 |
#ifdef INT8_MMA_AVAILABLE
|
| 534 |
-
x_qs[i*
|
| 535 |
-
x_qs[i*
|
| 536 |
#else
|
| 537 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 538 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
@@ -553,106 +627,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 553 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
| 554 |
|
| 555 |
#ifdef INT8_MMA_AVAILABLE
|
| 556 |
-
x_df[i*
|
| 557 |
#else
|
| 558 |
x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
| 559 |
#endif // INT8_MMA_AVAILABLE
|
| 560 |
}
|
| 561 |
}
|
| 562 |
|
| 563 |
-
template <int mmq_x, int mmq_y, int nwarps>
|
| 564 |
-
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
|
| 565 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 566 |
-
|
| 567 |
-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
| 568 |
-
const int * x_qs = (const int *) x;
|
| 569 |
-
const float * x_df = (const float *) x_qs + txs.qs;
|
| 570 |
-
const int * y_qs = (const int *) y + 4;
|
| 571 |
-
const float * y_df = (const float *) y;
|
| 572 |
-
|
| 573 |
-
#pragma unroll
|
| 574 |
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 575 |
-
const int j = j0 + threadIdx.y;
|
| 576 |
-
|
| 577 |
-
#pragma unroll
|
| 578 |
-
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 579 |
-
const int i = i0 + threadIdx.x;
|
| 580 |
-
|
| 581 |
-
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
| 582 |
-
(&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
|
| 583 |
-
x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 584 |
-
}
|
| 585 |
-
}
|
| 586 |
-
}
|
| 587 |
-
|
| 588 |
-
template <int mmq_x, int mmq_y, int nwarps>
|
| 589 |
-
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
|
| 590 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 591 |
-
#ifdef INT8_MMA_AVAILABLE
|
| 592 |
-
|
| 593 |
-
typedef mma_int_A_I16K8 mma_A;
|
| 594 |
-
typedef mma_int_B_J8K8 mma_B;
|
| 595 |
-
typedef mma_int_C_I16J8 mma_C;
|
| 596 |
-
|
| 597 |
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 598 |
-
constexpr int rows_per_warp = 2 * granularity;
|
| 599 |
-
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
| 600 |
-
|
| 601 |
-
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
| 602 |
-
|
| 603 |
-
const int * x_qs = (const int *) x;
|
| 604 |
-
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
| 605 |
-
const int * y_qs = (const int *) y + 4;
|
| 606 |
-
const float * y_df = (const float *) y;
|
| 607 |
-
|
| 608 |
-
mma_A A[ntx];
|
| 609 |
-
float dA[ntx][mma_C::ne/2];
|
| 610 |
-
|
| 611 |
-
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 612 |
-
|
| 613 |
-
#pragma unroll
|
| 614 |
-
for (int n = 0; n < ntx; ++n) {
|
| 615 |
-
A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
|
| 616 |
-
|
| 617 |
-
#pragma unroll
|
| 618 |
-
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 619 |
-
const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
|
| 620 |
-
|
| 621 |
-
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
|
| 622 |
-
}
|
| 623 |
-
}
|
| 624 |
-
|
| 625 |
-
#pragma unroll
|
| 626 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 627 |
-
mma_B B;
|
| 628 |
-
float dB[mma_C::ne/2];
|
| 629 |
-
|
| 630 |
-
B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
|
| 631 |
-
|
| 632 |
-
#pragma unroll
|
| 633 |
-
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 634 |
-
const int j = j0 + mma_C::get_j(l);
|
| 635 |
-
|
| 636 |
-
dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 637 |
-
}
|
| 638 |
-
|
| 639 |
-
#pragma unroll
|
| 640 |
-
for (int n = 0; n < ntx; ++n) {
|
| 641 |
-
mma_C C;
|
| 642 |
-
C.mma_K8(A[n], B);
|
| 643 |
-
|
| 644 |
-
#pragma unroll
|
| 645 |
-
for (int l = 0; l < mma_C::ne; ++l) {
|
| 646 |
-
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
|
| 647 |
-
}
|
| 648 |
-
}
|
| 649 |
-
}
|
| 650 |
-
#else
|
| 651 |
-
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 652 |
-
NO_DEVICE_CODE;
|
| 653 |
-
#endif // INT8_MMA_AVAILABLE
|
| 654 |
-
}
|
| 655 |
-
|
| 656 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 657 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 658 |
|
|
@@ -694,8 +675,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 694 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 695 |
|
| 696 |
#ifdef INT8_MMA_AVAILABLE
|
| 697 |
-
x_qs[i*
|
| 698 |
-
x_qs[i*
|
| 699 |
#else
|
| 700 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 701 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
@@ -716,41 +697,101 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 716 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
| 717 |
|
| 718 |
#ifdef INT8_MMA_AVAILABLE
|
| 719 |
-
x_dm[i*
|
| 720 |
#else
|
| 721 |
x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
| 722 |
#endif // INT8_MMA_AVAILABLE
|
| 723 |
}
|
| 724 |
}
|
| 725 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 727 |
-
static __device__ __forceinline__ void
|
| 728 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 729 |
|
| 730 |
-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(
|
| 731 |
const int * x_qs = (const int *) x;
|
| 732 |
-
const
|
| 733 |
const int * y_qs = (const int *) y + 4;
|
| 734 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
|
| 736 |
#pragma unroll
|
| 737 |
-
|
| 738 |
-
|
| 739 |
|
| 740 |
#pragma unroll
|
| 741 |
-
|
| 742 |
-
|
| 743 |
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
|
|
|
| 747 |
}
|
| 748 |
}
|
| 749 |
}
|
| 750 |
|
| 751 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 752 |
-
static __device__ __forceinline__ void
|
| 753 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 754 |
#ifdef INT8_MMA_AVAILABLE
|
| 755 |
|
| 756 |
typedef mma_int_A_I16K8 mma_A;
|
|
@@ -764,140 +805,106 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
|
|
| 764 |
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
| 765 |
|
| 766 |
const int * x_qs = (const int *) x;
|
| 767 |
-
const
|
| 768 |
const int * y_qs = (const int *) y + 4;
|
| 769 |
-
const
|
| 770 |
|
| 771 |
-
mma_A A[ntx];
|
| 772 |
-
|
| 773 |
|
| 774 |
-
const int i0 = (threadIdx.y
|
| 775 |
|
| 776 |
#pragma unroll
|
| 777 |
for (int n = 0; n < ntx; ++n) {
|
| 778 |
-
A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
|
| 779 |
-
|
| 780 |
#pragma unroll
|
| 781 |
-
for (int
|
| 782 |
-
const int
|
| 783 |
|
| 784 |
-
|
| 785 |
}
|
| 786 |
-
}
|
| 787 |
-
|
| 788 |
-
#pragma unroll
|
| 789 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 790 |
-
mma_B B;
|
| 791 |
-
half2 dsB[mma_C::ne/2];
|
| 792 |
-
|
| 793 |
-
B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
|
| 794 |
|
| 795 |
#pragma unroll
|
| 796 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 797 |
-
const int
|
| 798 |
-
|
| 799 |
-
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 800 |
-
}
|
| 801 |
|
| 802 |
#pragma unroll
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
C.mma_K8(A[n], B);
|
| 806 |
|
| 807 |
-
|
| 808 |
-
for (int l = 0; l < mma_C::ne; ++l) {
|
| 809 |
-
const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
|
| 810 |
-
sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
| 811 |
}
|
| 812 |
}
|
| 813 |
}
|
| 814 |
-
#else
|
| 815 |
-
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 816 |
-
NO_DEVICE_CODE;
|
| 817 |
-
#endif // INT8_MMA_AVAILABLE
|
| 818 |
-
}
|
| 819 |
|
| 820 |
-
|
| 821 |
-
|
|
|
|
|
|
|
|
|
|
| 822 |
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
float * x_df = (float *) (x_tile + WARP_SIZE);
|
| 826 |
-
#else
|
| 827 |
-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 828 |
-
int * x_qs = (int *) x_tile;
|
| 829 |
-
float * x_df = (float *) (x_qs + txs.qs);
|
| 830 |
-
#endif // INT8_MMA_AVAILABLE
|
| 831 |
|
| 832 |
-
|
| 833 |
-
const int kqsx = threadIdx.x % QI8_0;
|
| 834 |
|
| 835 |
#pragma unroll
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
if (need_check) {
|
| 840 |
-
i = min(i, i_max);
|
| 841 |
-
}
|
| 842 |
-
|
| 843 |
-
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 844 |
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
#else
|
| 848 |
-
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
| 849 |
-
#endif // INT8_MMA_AVAILABLE
|
| 850 |
-
}
|
| 851 |
|
| 852 |
-
|
| 853 |
-
|
|
|
|
|
|
|
| 854 |
|
| 855 |
#pragma unroll
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
i = min(i, i_max);
|
| 861 |
}
|
| 862 |
-
|
| 863 |
-
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
| 864 |
-
|
| 865 |
-
#ifdef INT8_MMA_AVAILABLE
|
| 866 |
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 867 |
#else
|
| 868 |
-
|
|
|
|
| 869 |
#endif // INT8_MMA_AVAILABLE
|
| 870 |
-
}
|
| 871 |
}
|
| 872 |
|
| 873 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 874 |
-
static __device__ __forceinline__ void
|
| 875 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 876 |
|
| 877 |
-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(
|
| 878 |
const int * x_qs = (const int *) x;
|
| 879 |
-
const
|
| 880 |
const int * y_qs = (const int *) y + 4;
|
| 881 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
|
| 883 |
#pragma unroll
|
| 884 |
-
|
| 885 |
-
|
| 886 |
|
| 887 |
#pragma unroll
|
| 888 |
-
|
| 889 |
-
|
| 890 |
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
|
|
|
| 894 |
}
|
| 895 |
}
|
| 896 |
}
|
| 897 |
|
| 898 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 899 |
-
static __device__ __forceinline__ void
|
| 900 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 901 |
#ifdef INT8_MMA_AVAILABLE
|
| 902 |
|
| 903 |
typedef mma_int_A_I16K8 mma_A;
|
|
@@ -911,49 +918,65 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 911 |
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
| 912 |
|
| 913 |
const int * x_qs = (const int *) x;
|
| 914 |
-
const
|
| 915 |
const int * y_qs = (const int *) y + 4;
|
| 916 |
-
const
|
| 917 |
|
| 918 |
-
mma_A
|
| 919 |
-
|
| 920 |
|
| 921 |
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
| 922 |
|
| 923 |
#pragma unroll
|
| 924 |
for (int n = 0; n < ntx; ++n) {
|
| 925 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 926 |
|
| 927 |
#pragma unroll
|
| 928 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 929 |
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
| 930 |
|
| 931 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 932 |
}
|
| 933 |
}
|
| 934 |
|
| 935 |
#pragma unroll
|
| 936 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 937 |
-
|
| 938 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
|
| 940 |
-
|
| 941 |
|
| 942 |
#pragma unroll
|
| 943 |
-
|
| 944 |
-
|
| 945 |
|
| 946 |
-
|
| 947 |
-
|
| 948 |
|
| 949 |
#pragma unroll
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
|
| 954 |
#pragma unroll
|
| 955 |
-
|
| 956 |
-
|
|
|
|
|
|
|
| 957 |
}
|
| 958 |
}
|
| 959 |
}
|
|
@@ -968,44 +991,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 968 |
|
| 969 |
#ifdef INT8_MMA_AVAILABLE
|
| 970 |
int * x_qs = (int *) x_tile;
|
| 971 |
-
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
|
| 972 |
#else
|
| 973 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 974 |
int * x_qs = (int *) x_tile;
|
| 975 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 976 |
#endif // INT8_MMA_AVAILABLE
|
| 977 |
|
| 978 |
-
const int kbx = threadIdx.x / QI2_K;
|
| 979 |
const int kqsx = threadIdx.x % QI2_K;
|
| 980 |
|
| 981 |
#pragma unroll
|
| 982 |
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
| 983 |
-
int i = i0 + threadIdx.y;
|
| 984 |
|
| 985 |
if (need_check) {
|
| 986 |
i = min(i, i_max);
|
| 987 |
}
|
| 988 |
|
| 989 |
-
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride
|
| 990 |
|
| 991 |
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
| 992 |
|
| 993 |
#pragma unroll
|
| 994 |
for (int l = 0; l < QR2_K; ++l) {
|
| 995 |
-
const int k =
|
| 996 |
|
| 997 |
-
int x_qs_k = (
|
| 998 |
-
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
|
| 999 |
-
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
|
| 1000 |
-
|
| 1001 |
-
if (kqsx % QR2_K != 0) {
|
| 1002 |
-
continue;
|
| 1003 |
-
}
|
| 1004 |
|
| 1005 |
#ifdef INT8_MMA_AVAILABLE
|
| 1006 |
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
| 1007 |
#else
|
| 1008 |
-
x_qs[i*(WARP_SIZE + 1)
|
| 1009 |
#endif // INT8_MMA_AVAILABLE
|
| 1010 |
}
|
| 1011 |
|
|
@@ -1018,44 +1034,68 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1018 |
#endif // FAST_FP16_AVAILABLE
|
| 1019 |
|
| 1020 |
#ifdef INT8_MMA_AVAILABLE
|
| 1021 |
-
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K +
|
| 1022 |
#else
|
| 1023 |
-
x_dm[i*(WARP_SIZE + 1) +
|
| 1024 |
#endif // INT8_MMA_AVAILABLE
|
| 1025 |
}
|
| 1026 |
}
|
| 1027 |
|
| 1028 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1029 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
| 1030 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1031 |
|
| 1032 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 1033 |
const int * x_qs = (const int *) x;
|
| 1034 |
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
| 1035 |
const int * y_qs = (const int *) y + 4;
|
| 1036 |
-
const
|
| 1037 |
|
|
|
|
| 1038 |
#pragma unroll
|
| 1039 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 1040 |
const int j = j0 + threadIdx.y;
|
| 1041 |
|
|
|
|
|
|
|
|
|
|
| 1042 |
#pragma unroll
|
| 1043 |
-
|
| 1044 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1049 |
}
|
| 1050 |
}
|
| 1051 |
}
|
| 1052 |
|
| 1053 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1054 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
| 1055 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1056 |
#ifdef INT8_MMA_AVAILABLE
|
| 1057 |
|
| 1058 |
typedef mma_int_A_I16K4 mma_A;
|
|
|
|
| 1059 |
typedef mma_int_B_J8K4 mma_B;
|
| 1060 |
typedef mma_int_C_I16J8 mma_C;
|
| 1061 |
|
|
@@ -1066,74 +1106,107 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1066 |
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
| 1067 |
|
| 1068 |
const int * x_qs = (const int *) x;
|
| 1069 |
-
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
|
| 1070 |
const int * y_qs = (const int *) y + 4;
|
| 1071 |
-
const
|
| 1072 |
|
| 1073 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1074 |
|
| 1075 |
-
mma_A A[ntx][
|
| 1076 |
-
float dA[ntx][mma_C::ne/2][
|
| 1077 |
-
float mA[ntx][mma_C::ne/2][
|
| 1078 |
|
| 1079 |
#pragma unroll
|
| 1080 |
for (int n = 0; n < ntx; ++n) {
|
| 1081 |
#pragma unroll
|
| 1082 |
-
for (int
|
| 1083 |
-
const int
|
| 1084 |
-
const int shift = 2*mma_A::get_k(l);
|
| 1085 |
|
| 1086 |
-
A[n][
|
| 1087 |
-
A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
|
| 1088 |
}
|
|
|
|
| 1089 |
|
|
|
|
|
|
|
| 1090 |
#pragma unroll
|
| 1091 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1092 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1093 |
|
| 1094 |
#pragma unroll
|
| 1095 |
-
for (int
|
| 1096 |
-
const
|
| 1097 |
|
| 1098 |
-
|
| 1099 |
-
|
|
|
|
|
|
|
| 1100 |
}
|
| 1101 |
}
|
| 1102 |
}
|
| 1103 |
|
| 1104 |
#pragma unroll
|
| 1105 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 1106 |
-
|
| 1107 |
-
float dB[mma_C::ne/2];
|
| 1108 |
-
|
| 1109 |
-
B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
|
| 1110 |
-
B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
|
| 1111 |
|
| 1112 |
#pragma unroll
|
| 1113 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1114 |
const int j = j0 + mma_C::get_j(l);
|
| 1115 |
|
| 1116 |
-
dB[l] =
|
| 1117 |
}
|
| 1118 |
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
|
| 1126 |
#pragma unroll
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
|
| 1133 |
#pragma unroll
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1137 |
}
|
| 1138 |
}
|
| 1139 |
}
|
|
@@ -1149,7 +1222,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1149 |
#ifdef INT8_MMA_AVAILABLE
|
| 1150 |
int * x_qs = (int *) x_tile;
|
| 1151 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1152 |
-
int * x_sc = (int *) (x_df +
|
| 1153 |
#else
|
| 1154 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
| 1155 |
int * x_qs = (int *) x_tile;
|
|
@@ -1157,75 +1230,66 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1157 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1158 |
#endif // INT8_MMA_AVAILABLE
|
| 1159 |
|
| 1160 |
-
const int kbx = threadIdx.x / QI3_K;
|
| 1161 |
const int kqsx = threadIdx.x % QI3_K;
|
| 1162 |
|
| 1163 |
#pragma unroll
|
| 1164 |
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
| 1165 |
-
int i = i0 + threadIdx.y;
|
| 1166 |
|
| 1167 |
if (need_check) {
|
| 1168 |
i = min(i, i_max);
|
| 1169 |
}
|
| 1170 |
|
| 1171 |
-
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride
|
| 1172 |
|
| 1173 |
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
| 1174 |
const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
|
| 1175 |
|
| 1176 |
#pragma unroll
|
| 1177 |
for (int l = 0; l < QR3_K; ++l) {
|
| 1178 |
-
const int k =
|
| 1179 |
|
| 1180 |
const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
| 1181 |
const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
|
| 1182 |
|
| 1183 |
-
int x_qs_k = (x_ql_k | x_qh_k
|
| 1184 |
-
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
|
| 1185 |
-
|
| 1186 |
-
if (kqsx % 2 != 0) {
|
| 1187 |
-
continue;
|
| 1188 |
-
}
|
| 1189 |
|
| 1190 |
#ifdef INT8_MMA_AVAILABLE
|
| 1191 |
-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k
|
| 1192 |
#else
|
| 1193 |
-
x_qs[i*(2*WARP_SIZE + 1) + k
|
| 1194 |
#endif // INT8_MMA_AVAILABLE
|
| 1195 |
}
|
| 1196 |
}
|
| 1197 |
|
| 1198 |
-
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
|
| 1199 |
-
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
| 1200 |
-
|
| 1201 |
#pragma unroll
|
| 1202 |
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
| 1203 |
-
int i = (i0 + threadIdx.y
|
| 1204 |
|
| 1205 |
if (need_check) {
|
| 1206 |
i = min(i, i_max);
|
| 1207 |
}
|
| 1208 |
|
| 1209 |
-
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride
|
| 1210 |
|
| 1211 |
#ifdef INT8_MMA_AVAILABLE
|
| 1212 |
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
| 1213 |
#else
|
| 1214 |
-
x_df[i
|
| 1215 |
#endif // INT8_MMA_AVAILABLE
|
| 1216 |
}
|
| 1217 |
|
| 1218 |
#pragma unroll
|
| 1219 |
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
| 1220 |
-
int i = i0 + threadIdx.y
|
| 1221 |
|
| 1222 |
if (need_check) {
|
| 1223 |
i = min(i, i_max);
|
| 1224 |
}
|
| 1225 |
|
| 1226 |
-
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride
|
| 1227 |
|
| 1228 |
-
const int ksc = threadIdx.x % (
|
| 1229 |
|
| 1230 |
const int ksc_low = ksc % (QI3_K/8);
|
| 1231 |
const int shift_low = 4 * (ksc / (QI3_K/8));
|
|
@@ -1238,16 +1302,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1238 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1239 |
|
| 1240 |
#ifdef INT8_MMA_AVAILABLE
|
| 1241 |
-
x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/
|
| 1242 |
#else
|
| 1243 |
-
x_sc[i*(WARP_SIZE/
|
| 1244 |
#endif // INT8_MMA_AVAILABLE
|
| 1245 |
}
|
| 1246 |
}
|
| 1247 |
|
| 1248 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1249 |
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
| 1250 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1251 |
|
| 1252 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
| 1253 |
const int * x_qs = (const int *) x;
|
|
@@ -1256,32 +1320,35 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
| 1256 |
const int * y_qs = (const int *) y + 4;
|
| 1257 |
const float * y_df = (const float *) y;
|
| 1258 |
|
| 1259 |
-
#pragma unroll
|
| 1260 |
-
for (int
|
| 1261 |
-
const int
|
| 1262 |
|
| 1263 |
#pragma unroll
|
| 1264 |
-
for (int
|
| 1265 |
-
const int
|
| 1266 |
|
| 1267 |
-
|
| 1268 |
-
|
|
|
|
| 1269 |
|
| 1270 |
-
|
| 1271 |
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
|
|
|
| 1275 |
}
|
| 1276 |
}
|
| 1277 |
}
|
| 1278 |
|
| 1279 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1280 |
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
|
| 1281 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1282 |
#ifdef INT8_MMA_AVAILABLE
|
| 1283 |
|
| 1284 |
typedef mma_int_A_I16K4 mma_A;
|
|
|
|
| 1285 |
typedef mma_int_B_J8K4 mma_B;
|
| 1286 |
typedef mma_int_C_I16J8 mma_C;
|
| 1287 |
|
|
@@ -1293,73 +1360,74 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
|
|
| 1293 |
|
| 1294 |
const int * x_qs = (const int *) x;
|
| 1295 |
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
| 1296 |
-
const int * x_sc = (const int *) x_df +
|
| 1297 |
const int * y_qs = (const int *) y + 4;
|
| 1298 |
const float * y_df = (const float *) y;
|
| 1299 |
|
| 1300 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1301 |
|
| 1302 |
-
mma_A A[ntx][
|
| 1303 |
-
int scA[ntx][mma_C::ne/2][
|
| 1304 |
float dA[ntx][mma_C::ne/2];
|
| 1305 |
|
| 1306 |
#pragma unroll
|
| 1307 |
for (int n = 0; n < ntx; ++n) {
|
| 1308 |
#pragma unroll
|
| 1309 |
-
for (int
|
| 1310 |
-
const int
|
| 1311 |
-
const int k = QR3_K*k0 + mma_A::get_k(l);
|
| 1312 |
|
| 1313 |
-
A[n][
|
| 1314 |
-
A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
|
| 1315 |
-
A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
|
| 1316 |
-
A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
|
| 1317 |
}
|
| 1318 |
|
| 1319 |
#pragma unroll
|
| 1320 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1321 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1322 |
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
}
|
| 1330 |
|
| 1331 |
#pragma unroll
|
| 1332 |
-
|
| 1333 |
-
|
|
|
|
|
|
|
| 1334 |
|
| 1335 |
-
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
| 1336 |
}
|
| 1337 |
}
|
| 1338 |
|
| 1339 |
#pragma unroll
|
| 1340 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 1341 |
-
|
| 1342 |
-
|
|
|
|
|
|
|
| 1343 |
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
|
| 1347 |
#pragma unroll
|
| 1348 |
-
|
| 1349 |
-
|
| 1350 |
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
|
| 1354 |
#pragma unroll
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
|
| 1360 |
#pragma unroll
|
| 1361 |
-
|
| 1362 |
-
|
|
|
|
|
|
|
| 1363 |
}
|
| 1364 |
}
|
| 1365 |
}
|
|
@@ -1451,7 +1519,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1451 |
|
| 1452 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1453 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
| 1454 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1455 |
|
| 1456 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
| 1457 |
const int * x_qs = (const int *) x;
|
|
@@ -1460,26 +1528,31 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
| 1460 |
const int * y_qs = (const int *) y + 4;
|
| 1461 |
const half2 * y_ds = (const half2 *) y;
|
| 1462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1463 |
#pragma unroll
|
| 1464 |
-
|
| 1465 |
-
|
| 1466 |
|
| 1467 |
#pragma unroll
|
| 1468 |
-
|
| 1469 |
-
|
| 1470 |
|
| 1471 |
-
|
| 1472 |
|
| 1473 |
-
|
| 1474 |
-
|
| 1475 |
-
|
|
|
|
| 1476 |
}
|
| 1477 |
}
|
| 1478 |
}
|
| 1479 |
|
| 1480 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1481 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
| 1482 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1483 |
#ifdef INT8_MMA_AVAILABLE
|
| 1484 |
|
| 1485 |
typedef mma_int_A_I16K8 mma_A;
|
|
@@ -1500,35 +1573,40 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
|
| 1500 |
|
| 1501 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1502 |
|
| 1503 |
-
mma_A A[ntx][
|
| 1504 |
-
int scA[ntx][mma_C::ne/2][
|
| 1505 |
-
int mA[ntx][mma_C::ne/2][
|
| 1506 |
half2 dmA[ntx][mma_C::ne/2];
|
| 1507 |
|
| 1508 |
#pragma unroll
|
| 1509 |
for (int n = 0; n < ntx; ++n) {
|
| 1510 |
#pragma unroll
|
| 1511 |
-
for (int
|
| 1512 |
-
|
|
|
|
|
|
|
| 1513 |
|
| 1514 |
#pragma unroll
|
| 1515 |
for (int l = 0; l < mma_A::ne; ++l) {
|
| 1516 |
-
A[n][
|
| 1517 |
-
A[n][
|
| 1518 |
}
|
| 1519 |
}
|
| 1520 |
|
| 1521 |
#pragma unroll
|
| 1522 |
-
for (int
|
| 1523 |
-
|
| 1524 |
-
|
| 1525 |
-
|
|
|
|
| 1526 |
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
|
| 1530 |
-
|
| 1531 |
-
|
|
|
|
|
|
|
| 1532 |
}
|
| 1533 |
}
|
| 1534 |
|
|
@@ -1536,7 +1614,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
|
| 1536 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1537 |
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
| 1538 |
|
| 1539 |
-
dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K
|
| 1540 |
}
|
| 1541 |
}
|
| 1542 |
|
|
@@ -1546,28 +1624,28 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
|
| 1546 |
float tmpm[ntx][mma_C::ne] = {{0.0f}};
|
| 1547 |
|
| 1548 |
#pragma unroll
|
| 1549 |
-
for (int
|
| 1550 |
mma_B B;
|
| 1551 |
half2 dsB[mma_C::ne/2];
|
| 1552 |
|
| 1553 |
-
B.load(y_qs + j0*MMQ_TILE_Y_K +
|
| 1554 |
|
| 1555 |
#pragma unroll
|
| 1556 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1557 |
const int j = j0 + mma_C::get_j(l);
|
| 1558 |
|
| 1559 |
-
dsB[l] = y_ds[j*MMQ_TILE_Y_K +
|
| 1560 |
}
|
| 1561 |
|
| 1562 |
#pragma unroll
|
| 1563 |
for (int n = 0; n < ntx; ++n) {
|
| 1564 |
mma_C C;
|
| 1565 |
-
C.mma_K8(A[n][
|
| 1566 |
|
| 1567 |
#pragma unroll
|
| 1568 |
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1569 |
-
tmpd[n][l] += (C.x[l]*scA[n][l/2][
|
| 1570 |
-
tmpm[n][l] += mA[n][l/2][
|
| 1571 |
}
|
| 1572 |
}
|
| 1573 |
}
|
|
@@ -1682,7 +1760,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1682 |
|
| 1683 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1684 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
| 1685 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1686 |
|
| 1687 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
| 1688 |
const int * x_qs = (const int *) x;
|
|
@@ -1691,26 +1769,31 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
| 1691 |
const int * y_qs = (const int *) y + 4;
|
| 1692 |
const half2 * y_ds = (const half2 *) y;
|
| 1693 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1694 |
#pragma unroll
|
| 1695 |
-
|
| 1696 |
-
|
| 1697 |
|
| 1698 |
#pragma unroll
|
| 1699 |
-
|
| 1700 |
-
|
| 1701 |
|
| 1702 |
-
|
| 1703 |
|
| 1704 |
-
|
| 1705 |
-
|
| 1706 |
-
|
|
|
|
| 1707 |
}
|
| 1708 |
}
|
| 1709 |
}
|
| 1710 |
|
| 1711 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1712 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
| 1713 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1714 |
#ifdef INT8_MMA_AVAILABLE
|
| 1715 |
|
| 1716 |
typedef mma_int_A_I16K8 mma_A;
|
|
@@ -1731,26 +1814,34 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
|
| 1731 |
|
| 1732 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1733 |
|
| 1734 |
-
mma_A A[ntx][
|
| 1735 |
-
int scA[ntx][mma_C::ne/2][
|
| 1736 |
-
int mA[ntx][mma_C::ne/2][
|
| 1737 |
half2 dmA[ntx][mma_C::ne/2];
|
| 1738 |
|
| 1739 |
#pragma unroll
|
| 1740 |
for (int n = 0; n < ntx; ++n) {
|
| 1741 |
#pragma unroll
|
| 1742 |
-
for (int
|
| 1743 |
-
|
|
|
|
|
|
|
|
|
|
| 1744 |
|
| 1745 |
#pragma unroll
|
| 1746 |
-
|
| 1747 |
-
|
|
|
|
|
|
|
|
|
|
| 1748 |
|
| 1749 |
-
|
| 1750 |
-
|
| 1751 |
|
| 1752 |
-
|
| 1753 |
-
|
|
|
|
|
|
|
| 1754 |
}
|
| 1755 |
}
|
| 1756 |
|
|
@@ -1758,7 +1849,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
|
| 1758 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1759 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1760 |
|
| 1761 |
-
dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K
|
| 1762 |
}
|
| 1763 |
}
|
| 1764 |
|
|
@@ -1768,28 +1859,30 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
|
| 1768 |
float tmpm[ntx][mma_C::ne] = {{0.0f}};
|
| 1769 |
|
| 1770 |
#pragma unroll
|
| 1771 |
-
for (int
|
|
|
|
|
|
|
| 1772 |
mma_B B;
|
| 1773 |
half2 dsB[mma_C::ne/2];
|
| 1774 |
|
| 1775 |
-
B.load(y_qs + j0*MMQ_TILE_Y_K +
|
| 1776 |
|
| 1777 |
#pragma unroll
|
| 1778 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1779 |
const int j = j0 + mma_C::get_j(l);
|
| 1780 |
|
| 1781 |
-
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (
|
| 1782 |
}
|
| 1783 |
|
| 1784 |
#pragma unroll
|
| 1785 |
for (int n = 0; n < ntx; ++n) {
|
| 1786 |
mma_C C;
|
| 1787 |
-
C.mma_K8(A[n][
|
| 1788 |
|
| 1789 |
#pragma unroll
|
| 1790 |
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1791 |
-
tmpd[n][l] += (C.x[l]*scA[n][l/2][
|
| 1792 |
-
tmpm[n][l] += mA[n][l/2][
|
| 1793 |
}
|
| 1794 |
}
|
| 1795 |
}
|
|
@@ -1896,7 +1989,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1896 |
|
| 1897 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1898 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
| 1899 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1900 |
|
| 1901 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
| 1902 |
const int * x_qs = (const int *) x;
|
|
@@ -1905,26 +1998,31 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
| 1905 |
const int * y_qs = (const int *) y + 4;
|
| 1906 |
const float * y_df = (const float *) y;
|
| 1907 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1908 |
#pragma unroll
|
| 1909 |
-
|
| 1910 |
-
|
| 1911 |
|
| 1912 |
#pragma unroll
|
| 1913 |
-
|
| 1914 |
-
|
| 1915 |
|
| 1916 |
-
|
| 1917 |
|
| 1918 |
-
|
| 1919 |
-
|
| 1920 |
-
|
|
|
|
| 1921 |
}
|
| 1922 |
}
|
| 1923 |
}
|
| 1924 |
|
| 1925 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1926 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1927 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int &
|
| 1928 |
#ifdef INT8_MMA_AVAILABLE
|
| 1929 |
|
| 1930 |
typedef mma_int_A_I16K4 mma_A;
|
|
@@ -1945,25 +2043,35 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1945 |
|
| 1946 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1947 |
|
| 1948 |
-
mma_A A[ntx][
|
| 1949 |
-
int scA[ntx][mma_C::ne/2][
|
| 1950 |
float dA[ntx][mma_C::ne/2];
|
| 1951 |
|
| 1952 |
#pragma unroll
|
| 1953 |
for (int n = 0; n < ntx; ++n) {
|
| 1954 |
#pragma unroll
|
| 1955 |
-
for (int
|
| 1956 |
-
|
| 1957 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1958 |
|
| 1959 |
#pragma unroll
|
| 1960 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1961 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1962 |
|
| 1963 |
-
const
|
|
|
|
| 1964 |
|
| 1965 |
-
|
| 1966 |
-
|
|
|
|
|
|
|
| 1967 |
}
|
| 1968 |
}
|
| 1969 |
|
|
@@ -1971,7 +2079,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1971 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1972 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1973 |
|
| 1974 |
-
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K
|
| 1975 |
}
|
| 1976 |
}
|
| 1977 |
|
|
@@ -1980,30 +2088,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1980 |
float tmp[ntx][mma_C::ne] = {{0.0f}};
|
| 1981 |
|
| 1982 |
#pragma unroll
|
| 1983 |
-
for (int
|
| 1984 |
mma_B B[2];
|
| 1985 |
float dB[mma_C::ne/2];
|
| 1986 |
|
| 1987 |
-
|
| 1988 |
-
B[
|
| 1989 |
-
B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
|
| 1990 |
|
| 1991 |
#pragma unroll
|
| 1992 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1993 |
const int j = j0 + mma_C::get_j(l);
|
| 1994 |
|
| 1995 |
-
dB[l] = y_df[j*MMQ_TILE_Y_K +
|
| 1996 |
}
|
| 1997 |
|
| 1998 |
#pragma unroll
|
| 1999 |
for (int n = 0; n < ntx; ++n) {
|
| 2000 |
mma_C C[2];
|
| 2001 |
-
C[0].mma_K4(A[n][
|
| 2002 |
-
C[1].mma_K4(A[n][
|
| 2003 |
|
| 2004 |
#pragma unroll
|
| 2005 |
for (int l = 0; l < mma_C::ne; ++l) {
|
| 2006 |
-
tmp[n][l] += (C[0].x[l]*scA[n][l/2][
|
| 2007 |
}
|
| 2008 |
}
|
| 2009 |
}
|
|
@@ -2051,8 +2158,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2051 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 2052 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2053 |
#ifdef INT8_MMA_AVAILABLE
|
| 2054 |
-
x_qs[i*
|
| 2055 |
-
x_qs[i*
|
| 2056 |
#else
|
| 2057 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2058 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
|
@@ -2073,7 +2180,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2073 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
| 2074 |
|
| 2075 |
#ifdef INT8_MMA_AVAILABLE
|
| 2076 |
-
x_df[i*
|
| 2077 |
#else
|
| 2078 |
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
|
| 2079 |
#endif // INT8_MMA_AVAILABLE
|
|
@@ -2109,8 +2216,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2109 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 2110 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2111 |
#ifdef INT8_MMA_AVAILABLE
|
| 2112 |
-
x_qs[i*
|
| 2113 |
-
x_qs[i*
|
| 2114 |
#else
|
| 2115 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2116 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
|
@@ -2133,7 +2240,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2133 |
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
| 2134 |
|
| 2135 |
#ifdef INT8_MMA_AVAILABLE
|
| 2136 |
-
x_df[i*
|
| 2137 |
#else
|
| 2138 |
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
| 2139 |
#endif // INT8_MMA_AVAILABLE
|
|
@@ -2229,16 +2336,16 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
| 2229 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
| 2230 |
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
| 2231 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
| 2232 |
-
static constexpr vec_dot_mmq_t vec_dot_mma =
|
| 2233 |
-
static constexpr vec_dot_mmq_t vec_dot_dp4a =
|
| 2234 |
};
|
| 2235 |
|
| 2236 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2237 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
| 2238 |
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
| 2239 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
| 2240 |
-
static constexpr vec_dot_mmq_t vec_dot_mma =
|
| 2241 |
-
static constexpr vec_dot_mmq_t vec_dot_dp4a =
|
| 2242 |
};
|
| 2243 |
|
| 2244 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
@@ -2293,45 +2400,18 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
| 2293 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
|
| 2294 |
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
| 2295 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
|
| 2296 |
-
static constexpr vec_dot_mmq_t vec_dot_mma =
|
| 2297 |
-
static constexpr vec_dot_mmq_t vec_dot_dp4a =
|
| 2298 |
};
|
| 2299 |
|
| 2300 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2301 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
|
| 2302 |
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
| 2303 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
|
| 2304 |
-
static constexpr vec_dot_mmq_t vec_dot_mma =
|
| 2305 |
-
static constexpr vec_dot_mmq_t vec_dot_dp4a =
|
| 2306 |
};
|
| 2307 |
|
| 2308 |
-
static bool mmq_need_sum(const ggml_type type_x) {
|
| 2309 |
-
switch (type_x) {
|
| 2310 |
-
case GGML_TYPE_Q4_0:
|
| 2311 |
-
case GGML_TYPE_Q4_1:
|
| 2312 |
-
return true;
|
| 2313 |
-
case GGML_TYPE_Q5_0:
|
| 2314 |
-
return false;
|
| 2315 |
-
case GGML_TYPE_Q5_1:
|
| 2316 |
-
return true;
|
| 2317 |
-
case GGML_TYPE_Q8_0:
|
| 2318 |
-
case GGML_TYPE_Q2_K:
|
| 2319 |
-
case GGML_TYPE_Q3_K:
|
| 2320 |
-
return false;
|
| 2321 |
-
case GGML_TYPE_Q4_K:
|
| 2322 |
-
case GGML_TYPE_Q5_K:
|
| 2323 |
-
return true;
|
| 2324 |
-
case GGML_TYPE_Q6_K:
|
| 2325 |
-
case GGML_TYPE_IQ4_XS:
|
| 2326 |
-
case GGML_TYPE_IQ4_NL:
|
| 2327 |
-
return false;
|
| 2328 |
-
default:
|
| 2329 |
-
GGML_ASSERT(false);
|
| 2330 |
-
break;
|
| 2331 |
-
}
|
| 2332 |
-
return false;
|
| 2333 |
-
}
|
| 2334 |
-
|
| 2335 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
|
| 2336 |
static __device__ void mul_mat_q_process_tile(
|
| 2337 |
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
|
@@ -2339,10 +2419,7 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2339 |
const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
|
| 2340 |
|
| 2341 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2342 |
-
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
| 2343 |
-
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
| 2344 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2345 |
-
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
| 2346 |
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
| 2347 |
|
| 2348 |
extern __shared__ char data_mul_mat_q[];
|
|
@@ -2357,7 +2434,7 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2357 |
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 2358 |
#endif // INT8_MMA_AVAILABLE
|
| 2359 |
|
| 2360 |
-
constexpr int
|
| 2361 |
|
| 2362 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2363 |
|
|
@@ -2366,29 +2443,40 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2366 |
|
| 2367 |
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2368 |
|
| 2369 |
-
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 +=
|
| 2370 |
-
|
| 2371 |
load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
|
| 2372 |
|
| 2373 |
-
|
| 2374 |
-
|
| 2375 |
-
const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2376 |
#pragma unroll
|
| 2377 |
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
| 2378 |
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
| 2379 |
|
| 2380 |
tile_y[l] = by0[l];
|
| 2381 |
}
|
|
|
|
| 2382 |
|
| 2383 |
-
|
| 2384 |
|
| 2385 |
-
|
| 2386 |
-
|
| 2387 |
-
|
| 2388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2389 |
|
| 2390 |
-
|
|
|
|
| 2391 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2392 |
}
|
| 2393 |
|
| 2394 |
if (fixup) {
|
|
@@ -2424,7 +2512,6 @@ static __global__ void mul_mat_q(
|
|
| 2424 |
}
|
| 2425 |
|
| 2426 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2427 |
-
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
| 2428 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2429 |
|
| 2430 |
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
|
@@ -2439,7 +2526,7 @@ static __global__ void mul_mat_q(
|
|
| 2439 |
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
| 2440 |
|
| 2441 |
const int64_t blocks_per_ne00 = ne00 / qk;
|
| 2442 |
-
constexpr int
|
| 2443 |
|
| 2444 |
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
|
| 2445 |
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
|
|
@@ -2448,8 +2535,8 @@ static __global__ void mul_mat_q(
|
|
| 2448 |
int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
|
| 2449 |
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
|
| 2450 |
|
| 2451 |
-
kbc -= (kbc % blocks_per_ne00) %
|
| 2452 |
-
kbc_stop -= (kbc_stop % blocks_per_ne00) %
|
| 2453 |
|
| 2454 |
// kb0 == k index when doing the matrix multiplication for an output tile.
|
| 2455 |
int kb0_start = kbc % blocks_per_ne00;
|
|
@@ -2490,8 +2577,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
| 2490 |
|
| 2491 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2492 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2493 |
-
constexpr int
|
| 2494 |
-
constexpr int blocks_per_warp = WARP_SIZE / qi;
|
| 2495 |
const int64_t blocks_per_ne00 = ne00 / qk;
|
| 2496 |
|
| 2497 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
|
@@ -2501,15 +2587,18 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
| 2501 |
|
| 2502 |
bool any_fixup = false;
|
| 2503 |
|
| 2504 |
-
const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq
|
| 2505 |
-
const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x)
|
|
|
|
|
|
|
|
|
|
| 2506 |
|
| 2507 |
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
|
| 2508 |
-
|
| 2509 |
-
|
| 2510 |
|
| 2511 |
-
kbc
|
| 2512 |
-
kbc_stop
|
| 2513 |
|
| 2514 |
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
|
| 2515 |
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
|
|
|
|
| 8 |
#include <cstdint>
|
| 9 |
|
| 10 |
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
| 11 |
+
#define MMQ_ITER_K 256
|
| 12 |
+
#define MMQ_NWARPS 8
|
| 13 |
|
| 14 |
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
|
| 15 |
+
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
|
| 16 |
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
|
| 17 |
|
| 18 |
+
enum mmq_q8_1_ds_layout {
|
| 19 |
+
MMQ_Q8_1_DS_LAYOUT_D4,
|
| 20 |
+
MMQ_Q8_1_DS_LAYOUT_DS4,
|
| 21 |
+
MMQ_Q8_1_DS_LAYOUT_D2S6,
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
struct block_q8_1_mmq {
|
| 25 |
+
// The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
|
| 26 |
+
// The y float data is first grouped as blocks of 128 values.
|
| 27 |
+
// These blocks are then treated as individual data values and transposed.
|
| 28 |
+
//
|
| 29 |
+
// To avoid shared memory bank conflicts each block is padded with 16 bytes.
|
| 30 |
+
// This padding is also used to store block scales/partial sums.
|
| 31 |
+
// The scales multiplied with the quantized data are equal to the unquantized values.
|
| 32 |
+
// The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
|
| 33 |
+
// and are only needed for performance reasons.
|
| 34 |
+
//
|
| 35 |
+
// The exact data stored depends on the x data type.
|
| 36 |
+
union {
|
| 37 |
+
float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
|
| 38 |
+
half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
|
| 39 |
+
half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
|
| 40 |
+
// stored as d0,d1,s1,s2,s3,s4,s5
|
| 41 |
+
};
|
| 42 |
+
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
| 43 |
};
|
| 44 |
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
| 45 |
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
| 46 |
|
| 47 |
+
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
| 48 |
+
switch (type_x) {
|
| 49 |
+
case GGML_TYPE_Q4_0:
|
| 50 |
+
case GGML_TYPE_Q4_1:
|
| 51 |
+
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
| 52 |
+
case GGML_TYPE_Q5_0:
|
| 53 |
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
| 54 |
+
case GGML_TYPE_Q5_1:
|
| 55 |
+
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
| 56 |
+
case GGML_TYPE_Q8_0:
|
| 57 |
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
| 58 |
+
case GGML_TYPE_Q2_K:
|
| 59 |
+
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
| 60 |
+
case GGML_TYPE_Q3_K:
|
| 61 |
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
| 62 |
+
case GGML_TYPE_Q4_K:
|
| 63 |
+
case GGML_TYPE_Q5_K:
|
| 64 |
+
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
| 65 |
+
case GGML_TYPE_Q6_K:
|
| 66 |
+
case GGML_TYPE_IQ4_XS:
|
| 67 |
+
case GGML_TYPE_IQ4_NL:
|
| 68 |
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
| 69 |
+
default:
|
| 70 |
+
GGML_ASSERT(false);
|
| 71 |
+
break;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
struct tile_x_sizes {
|
| 76 |
int qs;
|
| 77 |
int dm;
|
|
|
|
| 131 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 132 |
}
|
| 133 |
|
| 134 |
+
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
| 135 |
+
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
|
| 136 |
+
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
|
| 137 |
+
#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
|
| 138 |
+
#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
|
| 139 |
+
#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 140 |
+
#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 141 |
+
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 142 |
+
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
|
|
|
| 143 |
|
| 144 |
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
| 145 |
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
|
| 146 |
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
| 147 |
+
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
|
| 148 |
+
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
|
| 149 |
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
| 150 |
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
| 151 |
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
| 152 |
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
| 153 |
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
| 154 |
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
| 155 |
+
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
|
| 156 |
+
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
|
| 157 |
tile_x_sizes{0, 0, 0};
|
| 158 |
}
|
| 159 |
|
| 160 |
+
#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
|
| 161 |
+
#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
|
| 162 |
+
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
|
| 163 |
+
#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
|
| 164 |
+
#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
|
| 165 |
+
#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7)
|
| 166 |
+
#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
|
| 167 |
+
#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
|
| 168 |
+
#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
|
|
|
|
| 169 |
|
| 170 |
static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
|
| 171 |
static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
|
|
|
|
|
|
|
| 172 |
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
| 173 |
+
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
| 174 |
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
|
| 175 |
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
| 176 |
static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
|
|
|
|
| 180 |
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
| 181 |
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
|
| 182 |
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
|
| 183 |
+
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
| 184 |
+
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
|
| 185 |
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
| 186 |
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
| 187 |
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
| 188 |
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
|
| 189 |
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
|
| 190 |
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
| 191 |
+
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
|
| 192 |
+
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
|
| 193 |
0;
|
| 194 |
}
|
| 195 |
|
| 196 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
|
|
|
| 197 |
|
| 198 |
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
| 199 |
return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
|
|
|
|
| 266 |
|
| 267 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 268 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
| 269 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 270 |
|
| 271 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 272 |
const int * x_qs = (const int *) x;
|
|
|
|
| 274 |
const int * y_qs = (const int *) y + 4;
|
| 275 |
const half2 * y_ds = (const half2 *) y;
|
| 276 |
|
| 277 |
+
// #pragma unroll
|
| 278 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
|
| 279 |
+
const int k0 = k00 + k01;
|
| 280 |
+
|
| 281 |
#pragma unroll
|
| 282 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 283 |
+
const int j = j0 + threadIdx.y;
|
| 284 |
|
| 285 |
#pragma unroll
|
| 286 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 287 |
+
const int i = i0 + threadIdx.x;
|
| 288 |
|
| 289 |
+
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
| 290 |
|
| 291 |
+
int u[2*VDR_Q4_0_Q8_1_MMQ];
|
| 292 |
|
| 293 |
#pragma unroll
|
| 294 |
+
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
|
| 295 |
+
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
|
| 296 |
+
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
|
| 297 |
+
}
|
| 298 |
|
| 299 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
| 300 |
+
(&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
|
| 301 |
+
x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 302 |
+
}
|
| 303 |
}
|
| 304 |
}
|
| 305 |
}
|
| 306 |
|
| 307 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 308 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
| 309 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 310 |
#ifdef INT8_MMA_AVAILABLE
|
| 311 |
|
| 312 |
typedef mma_int_A_I16K8 mma_A;
|
|
|
|
| 324 |
const int * y_qs = (const int *) y + 4;
|
| 325 |
const half2 * y_ds = (const half2 *) y;
|
| 326 |
|
| 327 |
+
mma_A A[ntx][4];
|
| 328 |
+
float dA[ntx][mma_C::ne/2][4];
|
| 329 |
|
| 330 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 331 |
|
| 332 |
#pragma unroll
|
| 333 |
for (int n = 0; n < ntx; ++n) {
|
| 334 |
#pragma unroll
|
| 335 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
|
| 336 |
+
const int k0 = k00 + k01;
|
|
|
|
|
|
|
| 337 |
|
| 338 |
+
#pragma unroll
|
| 339 |
+
for (int l = 0; l < mma_A::ne; ++l) {
|
| 340 |
+
const int i = i0 + n*mma_A::I + mma_A::get_i(l);
|
| 341 |
+
const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0;
|
| 342 |
+
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
| 343 |
+
|
| 344 |
+
A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
|
| 345 |
+
}
|
| 346 |
|
| 347 |
#pragma unroll
|
| 348 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 349 |
+
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 350 |
|
| 351 |
+
dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)];
|
| 352 |
+
}
|
| 353 |
}
|
| 354 |
}
|
| 355 |
|
| 356 |
#pragma unroll
|
| 357 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 358 |
+
#pragma unroll
|
| 359 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
|
| 360 |
+
mma_B B;
|
| 361 |
+
float dB[mma_C::ne/2];
|
| 362 |
|
| 363 |
+
B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
| 364 |
|
| 365 |
#pragma unroll
|
| 366 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 367 |
+
const int j = j0 + mma_C::get_j(l);
|
| 368 |
|
| 369 |
+
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 370 |
+
}
|
| 371 |
|
| 372 |
#pragma unroll
|
| 373 |
+
for (int n = 0; n < ntx; ++n) {
|
| 374 |
+
mma_C C;
|
| 375 |
+
C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B);
|
| 376 |
|
| 377 |
#pragma unroll
|
| 378 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 379 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l];
|
| 380 |
+
}
|
| 381 |
}
|
| 382 |
}
|
| 383 |
}
|
|
|
|
| 442 |
|
| 443 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 444 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
| 445 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 446 |
|
| 447 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 448 |
const int * x_qs = (const int *) x;
|
|
|
|
| 450 |
const int * y_qs = (const int *) y + 4;
|
| 451 |
const half2 * y_ds = (const half2 *) y;
|
| 452 |
|
| 453 |
+
// #pragma unroll
|
| 454 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
|
| 455 |
+
const int k0 = k00 + k01;
|
| 456 |
+
|
| 457 |
#pragma unroll
|
| 458 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 459 |
+
const int j = j0 + threadIdx.y;
|
| 460 |
|
| 461 |
#pragma unroll
|
| 462 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 463 |
+
const int i = i0 + threadIdx.x;
|
| 464 |
|
| 465 |
+
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
| 466 |
|
| 467 |
+
int u[2*VDR_Q4_1_Q8_1_MMQ];
|
| 468 |
|
| 469 |
#pragma unroll
|
| 470 |
+
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
|
| 471 |
+
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
|
| 472 |
+
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
|
| 473 |
+
}
|
| 474 |
|
| 475 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
| 476 |
+
(&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
|
| 477 |
+
x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 478 |
+
}
|
| 479 |
}
|
| 480 |
}
|
| 481 |
}
|
| 482 |
|
| 483 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 484 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
| 485 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 486 |
#ifdef INT8_MMA_AVAILABLE
|
| 487 |
|
| 488 |
typedef mma_int_A_I16K8 mma_A;
|
|
|
|
| 501 |
const int * y_qs = (const int *) y + 4;
|
| 502 |
const half2 * y_ds = (const half2 *) y;
|
| 503 |
|
| 504 |
+
mma_A A[ntx][4];
|
| 505 |
+
half2 dmA[ntx][mma_C::ne/2][4];
|
| 506 |
|
| 507 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 508 |
|
| 509 |
#pragma unroll
|
| 510 |
for (int n = 0; n < ntx; ++n) {
|
| 511 |
+
#pragma unroll
|
| 512 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
|
| 513 |
+
const int k0 = k00 + k01;
|
| 514 |
+
|
| 515 |
+
A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1);
|
| 516 |
+
A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F;
|
| 517 |
+
A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F;
|
| 518 |
+
A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F;
|
| 519 |
+
A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F;
|
| 520 |
|
| 521 |
#pragma unroll
|
| 522 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 523 |
+
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 524 |
|
| 525 |
+
dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)];
|
| 526 |
+
}
|
| 527 |
}
|
| 528 |
}
|
| 529 |
|
| 530 |
#pragma unroll
|
| 531 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 532 |
+
#pragma unroll
|
| 533 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
|
| 534 |
+
mma_B B;
|
| 535 |
+
half2 dsB[mma_C::ne/2];
|
| 536 |
|
| 537 |
+
B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
| 538 |
|
| 539 |
#pragma unroll
|
| 540 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 541 |
+
const int j = j0 + mma_C::get_j(l);
|
| 542 |
|
| 543 |
+
dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 544 |
+
}
|
| 545 |
|
| 546 |
#pragma unroll
|
| 547 |
+
for (int n = 0; n < ntx; ++n) {
|
| 548 |
+
mma_C C;
|
| 549 |
+
C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B);
|
| 550 |
|
| 551 |
#pragma unroll
|
| 552 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 553 |
+
const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2];
|
| 554 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
| 555 |
+
}
|
| 556 |
}
|
| 557 |
}
|
| 558 |
}
|
|
|
|
| 605 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 606 |
|
| 607 |
#ifdef INT8_MMA_AVAILABLE
|
| 608 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 609 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 610 |
#else
|
| 611 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 612 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
|
|
| 627 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
| 628 |
|
| 629 |
#ifdef INT8_MMA_AVAILABLE
|
| 630 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 631 |
#else
|
| 632 |
x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
| 633 |
#endif // INT8_MMA_AVAILABLE
|
| 634 |
}
|
| 635 |
}
|
| 636 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 638 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 639 |
|
|
|
|
| 675 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 676 |
|
| 677 |
#ifdef INT8_MMA_AVAILABLE
|
| 678 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 679 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 680 |
#else
|
| 681 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 682 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
|
|
| 697 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
| 698 |
|
| 699 |
#ifdef INT8_MMA_AVAILABLE
|
| 700 |
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 701 |
#else
|
| 702 |
x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
| 703 |
#endif // INT8_MMA_AVAILABLE
|
| 704 |
}
|
| 705 |
}
|
| 706 |
|
| 707 |
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
| 708 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 709 |
+
|
| 710 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 711 |
+
int * x_qs = (int *) x_tile;
|
| 712 |
+
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
|
| 713 |
+
#else
|
| 714 |
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 715 |
+
int * x_qs = (int *) x_tile;
|
| 716 |
+
float * x_df = (float *) (x_qs + txs.qs);
|
| 717 |
+
#endif // INT8_MMA_AVAILABLE
|
| 718 |
+
|
| 719 |
+
const int kbx = threadIdx.x / QI8_0;
|
| 720 |
+
const int kqsx = threadIdx.x % QI8_0;
|
| 721 |
+
|
| 722 |
+
#pragma unroll
|
| 723 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
| 724 |
+
int i = i0 + threadIdx.y;
|
| 725 |
+
|
| 726 |
+
if (need_check) {
|
| 727 |
+
i = min(i, i_max);
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 731 |
+
|
| 732 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 733 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
|
| 734 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
|
| 735 |
+
#else
|
| 736 |
+
x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
|
| 737 |
+
x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
|
| 738 |
+
#endif // INT8_MMA_AVAILABLE
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
|
| 742 |
+
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
| 743 |
+
|
| 744 |
+
#pragma unroll
|
| 745 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
|
| 746 |
+
int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
|
| 747 |
+
|
| 748 |
+
if (need_check) {
|
| 749 |
+
i = min(i, i_max);
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
| 753 |
+
|
| 754 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 755 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 756 |
+
#else
|
| 757 |
+
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
| 758 |
+
#endif // INT8_MMA_AVAILABLE
|
| 759 |
+
}
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 763 |
+
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
| 764 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 765 |
|
| 766 |
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 767 |
const int * x_qs = (const int *) x;
|
| 768 |
+
const float * x_df = (const float *) x_qs + txs.qs;
|
| 769 |
const int * y_qs = (const int *) y + 4;
|
| 770 |
+
const float * y_df = (const float *) y;
|
| 771 |
+
|
| 772 |
+
// #pragma unroll
|
| 773 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
| 774 |
+
const int k0 = k00 + k01;
|
| 775 |
|
| 776 |
#pragma unroll
|
| 777 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 778 |
+
const int j = j0 + threadIdx.y;
|
| 779 |
|
| 780 |
#pragma unroll
|
| 781 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 782 |
+
const int i = i0 + threadIdx.x;
|
| 783 |
|
| 784 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
| 785 |
+
(&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
|
| 786 |
+
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 787 |
+
}
|
| 788 |
}
|
| 789 |
}
|
| 790 |
}
|
| 791 |
|
| 792 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 793 |
+
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 794 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 795 |
#ifdef INT8_MMA_AVAILABLE
|
| 796 |
|
| 797 |
typedef mma_int_A_I16K8 mma_A;
|
|
|
|
| 805 |
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
| 806 |
|
| 807 |
const int * x_qs = (const int *) x;
|
| 808 |
+
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
|
| 809 |
const int * y_qs = (const int *) y + 4;
|
| 810 |
+
const float * y_df = (const float *) y;
|
| 811 |
|
| 812 |
+
mma_A A[ntx][WARP_SIZE/QI8_0];
|
| 813 |
+
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
|
| 814 |
|
| 815 |
+
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
| 816 |
|
| 817 |
#pragma unroll
|
| 818 |
for (int n = 0; n < ntx; ++n) {
|
|
|
|
|
|
|
| 819 |
#pragma unroll
|
| 820 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 821 |
+
const int k0 = k00 + k01;
|
| 822 |
|
| 823 |
+
A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
| 824 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
|
| 826 |
#pragma unroll
|
| 827 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 828 |
+
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
#pragma unroll
|
| 831 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 832 |
+
const int k0 = k00 + k01;
|
|
|
|
| 833 |
|
| 834 |
+
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
|
|
|
|
|
|
|
|
|
| 835 |
}
|
| 836 |
}
|
| 837 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
|
| 839 |
+
#pragma unroll
|
| 840 |
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 841 |
+
#pragma unroll
|
| 842 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 843 |
+
const int k0 = k00 + k01;
|
| 844 |
|
| 845 |
+
mma_B B;
|
| 846 |
+
float dB[mma_C::ne/2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
+
B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
|
|
|
|
| 849 |
|
| 850 |
#pragma unroll
|
| 851 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 852 |
+
const int j = j0 + mma_C::get_j(l);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
|
| 854 |
+
dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 855 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
|
| 857 |
+
#pragma unroll
|
| 858 |
+
for (int n = 0; n < ntx; ++n) {
|
| 859 |
+
mma_C C;
|
| 860 |
+
C.mma_K8(A[n][k01/QI8_0], B);
|
| 861 |
|
| 862 |
#pragma unroll
|
| 863 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 864 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
| 865 |
+
}
|
| 866 |
+
}
|
|
|
|
| 867 |
}
|
| 868 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
#else
|
| 870 |
+
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 871 |
+
NO_DEVICE_CODE;
|
| 872 |
#endif // INT8_MMA_AVAILABLE
|
|
|
|
| 873 |
}
|
| 874 |
|
| 875 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 876 |
+
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
| 877 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 878 |
|
| 879 |
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
| 880 |
const int * x_qs = (const int *) x;
|
| 881 |
+
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
| 882 |
const int * y_qs = (const int *) y + 4;
|
| 883 |
+
const half2 * y_ds = (const half2 *) y;
|
| 884 |
+
|
| 885 |
+
// #pragma unroll
|
| 886 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
| 887 |
+
const int k0 = k00 + k01;
|
| 888 |
|
| 889 |
#pragma unroll
|
| 890 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 891 |
+
const int j = j0 + threadIdx.y;
|
| 892 |
|
| 893 |
#pragma unroll
|
| 894 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 895 |
+
const int i = i0 + threadIdx.x;
|
| 896 |
|
| 897 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
| 898 |
+
(&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
| 899 |
+
x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 900 |
+
}
|
| 901 |
}
|
| 902 |
}
|
| 903 |
}
|
| 904 |
|
| 905 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 906 |
+
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
| 907 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 908 |
#ifdef INT8_MMA_AVAILABLE
|
| 909 |
|
| 910 |
typedef mma_int_A_I16K8 mma_A;
|
|
|
|
| 918 |
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
| 919 |
|
| 920 |
const int * x_qs = (const int *) x;
|
| 921 |
+
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
|
| 922 |
const int * y_qs = (const int *) y + 4;
|
| 923 |
+
const half2 * y_dm = (const half2 *) y;
|
| 924 |
|
| 925 |
+
mma_A A[ntx][WARP_SIZE/QI8_1];
|
| 926 |
+
half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
|
| 927 |
|
| 928 |
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
| 929 |
|
| 930 |
#pragma unroll
|
| 931 |
for (int n = 0; n < ntx; ++n) {
|
| 932 |
+
#pragma unroll
|
| 933 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 934 |
+
const int k0 = k00 + k01;
|
| 935 |
+
|
| 936 |
+
A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
| 937 |
+
}
|
| 938 |
|
| 939 |
#pragma unroll
|
| 940 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 941 |
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
| 942 |
|
| 943 |
+
#pragma unroll
|
| 944 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 945 |
+
const int k0 = k00 + k01;
|
| 946 |
+
|
| 947 |
+
dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1];
|
| 948 |
+
}
|
| 949 |
}
|
| 950 |
}
|
| 951 |
|
| 952 |
#pragma unroll
|
| 953 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 954 |
+
#pragma unroll
|
| 955 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 956 |
+
const int k0 = k00 + k01;
|
| 957 |
+
|
| 958 |
+
mma_B B;
|
| 959 |
+
half2 dsB[mma_C::ne/2];
|
| 960 |
|
| 961 |
+
B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
|
| 962 |
|
| 963 |
#pragma unroll
|
| 964 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 965 |
+
const int j = j0 + mma_C::get_j(l);
|
| 966 |
|
| 967 |
+
dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 968 |
+
}
|
| 969 |
|
| 970 |
#pragma unroll
|
| 971 |
+
for (int n = 0; n < ntx; ++n) {
|
| 972 |
+
mma_C C;
|
| 973 |
+
C.mma_K8(A[n][k01/QI8_1], B);
|
| 974 |
|
| 975 |
#pragma unroll
|
| 976 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 977 |
+
const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2];
|
| 978 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
| 979 |
+
}
|
| 980 |
}
|
| 981 |
}
|
| 982 |
}
|
|
|
|
| 991 |
|
| 992 |
#ifdef INT8_MMA_AVAILABLE
|
| 993 |
int * x_qs = (int *) x_tile;
|
| 994 |
+
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 995 |
#else
|
| 996 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 997 |
int * x_qs = (int *) x_tile;
|
| 998 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 999 |
#endif // INT8_MMA_AVAILABLE
|
| 1000 |
|
|
|
|
| 1001 |
const int kqsx = threadIdx.x % QI2_K;
|
| 1002 |
|
| 1003 |
#pragma unroll
|
| 1004 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
|
| 1005 |
+
int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
|
| 1006 |
|
| 1007 |
if (need_check) {
|
| 1008 |
i = min(i, i_max);
|
| 1009 |
}
|
| 1010 |
|
| 1011 |
+
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
|
| 1012 |
|
| 1013 |
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
| 1014 |
|
| 1015 |
#pragma unroll
|
| 1016 |
for (int l = 0; l < QR2_K; ++l) {
|
| 1017 |
+
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
|
| 1018 |
|
| 1019 |
+
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1020 |
|
| 1021 |
#ifdef INT8_MMA_AVAILABLE
|
| 1022 |
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
| 1023 |
#else
|
| 1024 |
+
x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
|
| 1025 |
#endif // INT8_MMA_AVAILABLE
|
| 1026 |
}
|
| 1027 |
|
|
|
|
| 1034 |
#endif // FAST_FP16_AVAILABLE
|
| 1035 |
|
| 1036 |
#ifdef INT8_MMA_AVAILABLE
|
| 1037 |
+
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
| 1038 |
#else
|
| 1039 |
+
x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
|
| 1040 |
#endif // INT8_MMA_AVAILABLE
|
| 1041 |
}
|
| 1042 |
}
|
| 1043 |
|
| 1044 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1045 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
| 1046 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1047 |
|
| 1048 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 1049 |
const int * x_qs = (const int *) x;
|
| 1050 |
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
| 1051 |
const int * y_qs = (const int *) y + 4;
|
| 1052 |
+
const half2 * y_ds = (const half2 *) y;
|
| 1053 |
|
| 1054 |
+
float2 y_df[mmq_x/nwarps];
|
| 1055 |
#pragma unroll
|
| 1056 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 1057 |
const int j = j0 + threadIdx.y;
|
| 1058 |
|
| 1059 |
+
y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
| 1060 |
+
}
|
| 1061 |
+
|
| 1062 |
#pragma unroll
|
| 1063 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
| 1064 |
+
const int k0 = k00 + k01;
|
| 1065 |
+
|
| 1066 |
+
#pragma unroll
|
| 1067 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 1068 |
+
const int j = j0 + threadIdx.y;
|
| 1069 |
+
|
| 1070 |
+
#pragma unroll
|
| 1071 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 1072 |
+
const int i = i0 + threadIdx.x;
|
| 1073 |
|
| 1074 |
+
if (k01 < WARP_SIZE/2) {
|
| 1075 |
+
constexpr int ns = 2;
|
| 1076 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
| 1077 |
+
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
| 1078 |
+
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
| 1079 |
+
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
| 1080 |
+
} else {
|
| 1081 |
+
constexpr int ns = 1;
|
| 1082 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
| 1083 |
+
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
| 1084 |
+
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
| 1085 |
+
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
| 1086 |
+
}
|
| 1087 |
+
}
|
| 1088 |
}
|
| 1089 |
}
|
| 1090 |
}
|
| 1091 |
|
| 1092 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1093 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
| 1094 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1095 |
#ifdef INT8_MMA_AVAILABLE
|
| 1096 |
|
| 1097 |
typedef mma_int_A_I16K4 mma_A;
|
| 1098 |
+
typedef mma_int_A_I16K8 mma_A_K8;
|
| 1099 |
typedef mma_int_B_J8K4 mma_B;
|
| 1100 |
typedef mma_int_C_I16J8 mma_C;
|
| 1101 |
|
|
|
|
| 1106 |
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
| 1107 |
|
| 1108 |
const int * x_qs = (const int *) x;
|
| 1109 |
+
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
|
| 1110 |
const int * y_qs = (const int *) y + 4;
|
| 1111 |
+
const half2 * y_ds = (const half2 *) y;
|
| 1112 |
|
| 1113 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1114 |
|
| 1115 |
+
mma_A A[ntx][8];
|
| 1116 |
+
float dA[ntx][mma_C::ne/2][8];
|
| 1117 |
+
float mA[ntx][mma_C::ne/2][8];
|
| 1118 |
|
| 1119 |
#pragma unroll
|
| 1120 |
for (int n = 0; n < ntx; ++n) {
|
| 1121 |
#pragma unroll
|
| 1122 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1123 |
+
const int k0 = k00 + k01;
|
|
|
|
| 1124 |
|
| 1125 |
+
((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
|
|
| 1126 |
}
|
| 1127 |
+
}
|
| 1128 |
|
| 1129 |
+
#pragma unroll
|
| 1130 |
+
for (int n = 0; n < ntx; ++n) {
|
| 1131 |
#pragma unroll
|
| 1132 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1133 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1134 |
|
| 1135 |
#pragma unroll
|
| 1136 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
|
| 1137 |
+
const int k0 = k00 + k01;
|
| 1138 |
|
| 1139 |
+
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
|
| 1140 |
+
|
| 1141 |
+
dA[n][l][k01/(QI8_1/2)] = dm.x;
|
| 1142 |
+
mA[n][l][k01/(QI8_1/2)] = dm.y;
|
| 1143 |
}
|
| 1144 |
}
|
| 1145 |
}
|
| 1146 |
|
| 1147 |
#pragma unroll
|
| 1148 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 1149 |
+
float2 dB[mma_C::ne/2];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1150 |
|
| 1151 |
#pragma unroll
|
| 1152 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1153 |
const int j = j0 + mma_C::get_j(l);
|
| 1154 |
|
| 1155 |
+
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
| 1156 |
}
|
| 1157 |
|
| 1158 |
+
#pragma unroll
|
| 1159 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1160 |
+
mma_B B[2];
|
| 1161 |
+
|
| 1162 |
+
B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
| 1163 |
+
B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
| 1164 |
+
|
| 1165 |
+
mma_C Cm[2];
|
| 1166 |
+
if (k01 >= WARP_SIZE * 3/4) {
|
| 1167 |
+
mma_A A1;
|
| 1168 |
+
A1.x[0] = 0x01010101;
|
| 1169 |
+
A1.x[1] = 0x01010101;
|
| 1170 |
+
Cm[0].mma_K4(A1, B[0]);
|
| 1171 |
+
Cm[1].mma_K4(A1, B[1]);
|
| 1172 |
+
}
|
| 1173 |
|
| 1174 |
#pragma unroll
|
| 1175 |
+
for (int n = 0; n < ntx; ++n) {
|
| 1176 |
+
mma_C Cd[2];
|
| 1177 |
|
| 1178 |
+
Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
|
| 1179 |
+
Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
|
| 1180 |
|
| 1181 |
#pragma unroll
|
| 1182 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1183 |
+
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
| 1184 |
+
if (k01 >= WARP_SIZE * 3/4) {
|
| 1185 |
+
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
| 1186 |
+
}
|
| 1187 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
| 1188 |
+
}
|
| 1189 |
+
}
|
| 1190 |
+
}
|
| 1191 |
+
|
| 1192 |
+
#pragma unroll
|
| 1193 |
+
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
|
| 1194 |
+
float2 sB[mma_C::ne/2];
|
| 1195 |
+
|
| 1196 |
+
#pragma unroll
|
| 1197 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1198 |
+
const int j = j0 + mma_C::get_j(l);
|
| 1199 |
+
|
| 1200 |
+
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
#pragma unroll
|
| 1204 |
+
for (int n = 0; n < ntx; ++n) {
|
| 1205 |
+
#pragma unroll
|
| 1206 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1207 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
| 1208 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
| 1209 |
+
}
|
| 1210 |
}
|
| 1211 |
}
|
| 1212 |
}
|
|
|
|
| 1222 |
#ifdef INT8_MMA_AVAILABLE
|
| 1223 |
int * x_qs = (int *) x_tile;
|
| 1224 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1225 |
+
int * x_sc = (int *) (x_df + 1);
|
| 1226 |
#else
|
| 1227 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
| 1228 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1230 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1231 |
#endif // INT8_MMA_AVAILABLE
|
| 1232 |
|
|
|
|
| 1233 |
const int kqsx = threadIdx.x % QI3_K;
|
| 1234 |
|
| 1235 |
#pragma unroll
|
| 1236 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
|
| 1237 |
+
int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
|
| 1238 |
|
| 1239 |
if (need_check) {
|
| 1240 |
i = min(i, i_max);
|
| 1241 |
}
|
| 1242 |
|
| 1243 |
+
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
|
| 1244 |
|
| 1245 |
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
| 1246 |
const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
|
| 1247 |
|
| 1248 |
#pragma unroll
|
| 1249 |
for (int l = 0; l < QR3_K; ++l) {
|
| 1250 |
+
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
|
| 1251 |
|
| 1252 |
const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
| 1253 |
const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
|
| 1254 |
|
| 1255 |
+
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1256 |
|
| 1257 |
#ifdef INT8_MMA_AVAILABLE
|
| 1258 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
| 1259 |
#else
|
| 1260 |
+
x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
|
| 1261 |
#endif // INT8_MMA_AVAILABLE
|
| 1262 |
}
|
| 1263 |
}
|
| 1264 |
|
|
|
|
|
|
|
|
|
|
| 1265 |
#pragma unroll
|
| 1266 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
|
| 1267 |
+
int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
|
| 1268 |
|
| 1269 |
if (need_check) {
|
| 1270 |
i = min(i, i_max);
|
| 1271 |
}
|
| 1272 |
|
| 1273 |
+
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
|
| 1274 |
|
| 1275 |
#ifdef INT8_MMA_AVAILABLE
|
| 1276 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d;
|
| 1277 |
#else
|
| 1278 |
+
x_df[i] = bxi->d;
|
| 1279 |
#endif // INT8_MMA_AVAILABLE
|
| 1280 |
}
|
| 1281 |
|
| 1282 |
#pragma unroll
|
| 1283 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
|
| 1284 |
+
int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
|
| 1285 |
|
| 1286 |
if (need_check) {
|
| 1287 |
i = min(i, i_max);
|
| 1288 |
}
|
| 1289 |
|
| 1290 |
+
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
|
| 1291 |
|
| 1292 |
+
const int ksc = threadIdx.x % (WARP_SIZE/8);
|
| 1293 |
|
| 1294 |
const int ksc_low = ksc % (QI3_K/8);
|
| 1295 |
const int shift_low = 4 * (ksc / (QI3_K/8));
|
|
|
|
| 1302 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1303 |
|
| 1304 |
#ifdef INT8_MMA_AVAILABLE
|
| 1305 |
+
x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc;
|
| 1306 |
#else
|
| 1307 |
+
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
|
| 1308 |
#endif // INT8_MMA_AVAILABLE
|
| 1309 |
}
|
| 1310 |
}
|
| 1311 |
|
| 1312 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1313 |
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
| 1314 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1315 |
|
| 1316 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
| 1317 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1320 |
const int * y_qs = (const int *) y + 4;
|
| 1321 |
const float * y_df = (const float *) y;
|
| 1322 |
|
| 1323 |
+
// #pragma unroll
|
| 1324 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
| 1325 |
+
const int k0 = k00 + k01;
|
| 1326 |
|
| 1327 |
#pragma unroll
|
| 1328 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 1329 |
+
const int j = j0 + threadIdx.y;
|
| 1330 |
|
| 1331 |
+
#pragma unroll
|
| 1332 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 1333 |
+
const int i = i0 + threadIdx.x;
|
| 1334 |
|
| 1335 |
+
const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
|
| 1336 |
|
| 1337 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
|
| 1338 |
+
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
|
| 1339 |
+
x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 1340 |
+
}
|
| 1341 |
}
|
| 1342 |
}
|
| 1343 |
}
|
| 1344 |
|
| 1345 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1346 |
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
|
| 1347 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1348 |
#ifdef INT8_MMA_AVAILABLE
|
| 1349 |
|
| 1350 |
typedef mma_int_A_I16K4 mma_A;
|
| 1351 |
+
typedef mma_int_A_I16K8 mma_A_K8;
|
| 1352 |
typedef mma_int_B_J8K4 mma_B;
|
| 1353 |
typedef mma_int_C_I16J8 mma_C;
|
| 1354 |
|
|
|
|
| 1360 |
|
| 1361 |
const int * x_qs = (const int *) x;
|
| 1362 |
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
| 1363 |
+
const int * x_sc = (const int *) x_df + 1;
|
| 1364 |
const int * y_qs = (const int *) y + 4;
|
| 1365 |
const float * y_df = (const float *) y;
|
| 1366 |
|
| 1367 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1368 |
|
| 1369 |
+
mma_A A[ntx][8];
|
| 1370 |
+
int scA[ntx][mma_C::ne/2][8];
|
| 1371 |
float dA[ntx][mma_C::ne/2];
|
| 1372 |
|
| 1373 |
#pragma unroll
|
| 1374 |
for (int n = 0; n < ntx; ++n) {
|
| 1375 |
#pragma unroll
|
| 1376 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1377 |
+
const int k0 = k00 + k01;
|
|
|
|
| 1378 |
|
| 1379 |
+
((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
|
|
|
|
|
|
|
|
| 1380 |
}
|
| 1381 |
|
| 1382 |
#pragma unroll
|
| 1383 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1384 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1385 |
|
| 1386 |
+
#pragma unroll
|
| 1387 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
|
| 1388 |
+
const int k0 = k00 + k01;
|
| 1389 |
|
| 1390 |
+
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16];
|
| 1391 |
+
const int8_t * sc = (const int8_t *) &sc_packed;
|
|
|
|
| 1392 |
|
| 1393 |
#pragma unroll
|
| 1394 |
+
for (int ksc = 0; ksc < sizeof(int); ++ksc) {
|
| 1395 |
+
scA[n][l][k01/4 + ksc] = sc[ksc];
|
| 1396 |
+
}
|
| 1397 |
+
}
|
| 1398 |
|
| 1399 |
+
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K];
|
| 1400 |
}
|
| 1401 |
}
|
| 1402 |
|
| 1403 |
#pragma unroll
|
| 1404 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
| 1405 |
+
#pragma unroll
|
| 1406 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
| 1407 |
+
mma_B B[2];
|
| 1408 |
+
float dB[mma_C::ne/2];
|
| 1409 |
|
| 1410 |
+
B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
| 1411 |
+
B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
| 1412 |
|
| 1413 |
#pragma unroll
|
| 1414 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1415 |
+
const int j = j0 + mma_C::get_j(l);
|
| 1416 |
|
| 1417 |
+
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 1418 |
+
}
|
| 1419 |
|
| 1420 |
#pragma unroll
|
| 1421 |
+
for (int n = 0; n < ntx; ++n) {
|
| 1422 |
+
mma_C C[2];
|
| 1423 |
+
C[0].mma_K4(A[n][k01/4 + 0], B[0]);
|
| 1424 |
+
C[1].mma_K4(A[n][k01/4 + 1], B[1]);
|
| 1425 |
|
| 1426 |
#pragma unroll
|
| 1427 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1428 |
+
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*
|
| 1429 |
+
(C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]);
|
| 1430 |
+
}
|
| 1431 |
}
|
| 1432 |
}
|
| 1433 |
}
|
|
|
|
| 1519 |
|
| 1520 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1521 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
| 1522 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1523 |
|
| 1524 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
| 1525 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1528 |
const int * y_qs = (const int *) y + 4;
|
| 1529 |
const half2 * y_ds = (const half2 *) y;
|
| 1530 |
|
| 1531 |
+
// #pragma unroll
|
| 1532 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
|
| 1533 |
+
const int k0 = k00 + k01;
|
| 1534 |
+
|
| 1535 |
#pragma unroll
|
| 1536 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 1537 |
+
const int j = j0 + threadIdx.y;
|
| 1538 |
|
| 1539 |
#pragma unroll
|
| 1540 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 1541 |
+
const int i = i0 + threadIdx.x;
|
| 1542 |
|
| 1543 |
+
const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
|
| 1544 |
|
| 1545 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
|
| 1546 |
+
&x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
| 1547 |
+
x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 1548 |
+
}
|
| 1549 |
}
|
| 1550 |
}
|
| 1551 |
}
|
| 1552 |
|
| 1553 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1554 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
| 1555 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1556 |
#ifdef INT8_MMA_AVAILABLE
|
| 1557 |
|
| 1558 |
typedef mma_int_A_I16K8 mma_A;
|
|
|
|
| 1573 |
|
| 1574 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1575 |
|
| 1576 |
+
mma_A A[ntx][4];
|
| 1577 |
+
int scA[ntx][mma_C::ne/2][4];
|
| 1578 |
+
int mA[ntx][mma_C::ne/2][4];
|
| 1579 |
half2 dmA[ntx][mma_C::ne/2];
|
| 1580 |
|
| 1581 |
#pragma unroll
|
| 1582 |
for (int n = 0; n < ntx; ++n) {
|
| 1583 |
#pragma unroll
|
| 1584 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
|
| 1585 |
+
const int k0 = k00 + k01;
|
| 1586 |
+
|
| 1587 |
+
A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K);
|
| 1588 |
|
| 1589 |
#pragma unroll
|
| 1590 |
for (int l = 0; l < mma_A::ne; ++l) {
|
| 1591 |
+
A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F;
|
| 1592 |
+
A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F;
|
| 1593 |
}
|
| 1594 |
}
|
| 1595 |
|
| 1596 |
#pragma unroll
|
| 1597 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1598 |
+
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1599 |
+
|
| 1600 |
+
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)];
|
| 1601 |
+
const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)];
|
| 1602 |
|
| 1603 |
+
const uint8_t * sc = (const uint8_t *) &sc_packed;
|
| 1604 |
+
const uint8_t * m = (const uint8_t *) &m_packed;
|
| 1605 |
|
| 1606 |
+
#pragma unroll
|
| 1607 |
+
for (int ksc = 0; ksc < sizeof(int); ++ksc) {
|
| 1608 |
+
scA[n][l][ksc] = sc[ksc];
|
| 1609 |
+
mA[n][l][ksc] = m[ksc];
|
| 1610 |
}
|
| 1611 |
}
|
| 1612 |
|
|
|
|
| 1614 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1615 |
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
| 1616 |
|
| 1617 |
+
dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K];
|
| 1618 |
}
|
| 1619 |
}
|
| 1620 |
|
|
|
|
| 1624 |
float tmpm[ntx][mma_C::ne] = {{0.0f}};
|
| 1625 |
|
| 1626 |
#pragma unroll
|
| 1627 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1628 |
mma_B B;
|
| 1629 |
half2 dsB[mma_C::ne/2];
|
| 1630 |
|
| 1631 |
+
B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
| 1632 |
|
| 1633 |
#pragma unroll
|
| 1634 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1635 |
const int j = j0 + mma_C::get_j(l);
|
| 1636 |
|
| 1637 |
+
dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 1638 |
}
|
| 1639 |
|
| 1640 |
#pragma unroll
|
| 1641 |
for (int n = 0; n < ntx; ++n) {
|
| 1642 |
mma_C C;
|
| 1643 |
+
C.mma_K8(A[n][k01/8], B);
|
| 1644 |
|
| 1645 |
#pragma unroll
|
| 1646 |
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1647 |
+
tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
|
| 1648 |
+
tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
|
| 1649 |
}
|
| 1650 |
}
|
| 1651 |
}
|
|
|
|
| 1760 |
|
| 1761 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1762 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
| 1763 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1764 |
|
| 1765 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
| 1766 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1769 |
const int * y_qs = (const int *) y + 4;
|
| 1770 |
const half2 * y_ds = (const half2 *) y;
|
| 1771 |
|
| 1772 |
+
// #pragma unroll
|
| 1773 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
|
| 1774 |
+
const int k0 = k00 + k01;
|
| 1775 |
+
|
| 1776 |
#pragma unroll
|
| 1777 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 1778 |
+
const int j = j0 + threadIdx.y;
|
| 1779 |
|
| 1780 |
#pragma unroll
|
| 1781 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 1782 |
+
const int i = i0 + threadIdx.x;
|
| 1783 |
|
| 1784 |
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
|
| 1785 |
|
| 1786 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
|
| 1787 |
+
&x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
| 1788 |
+
x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 1789 |
+
}
|
| 1790 |
}
|
| 1791 |
}
|
| 1792 |
}
|
| 1793 |
|
| 1794 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1795 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
| 1796 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1797 |
#ifdef INT8_MMA_AVAILABLE
|
| 1798 |
|
| 1799 |
typedef mma_int_A_I16K8 mma_A;
|
|
|
|
| 1814 |
|
| 1815 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 1816 |
|
| 1817 |
+
mma_A A[ntx][4];
|
| 1818 |
+
int scA[ntx][mma_C::ne/2][4];
|
| 1819 |
+
int mA[ntx][mma_C::ne/2][4];
|
| 1820 |
half2 dmA[ntx][mma_C::ne/2];
|
| 1821 |
|
| 1822 |
#pragma unroll
|
| 1823 |
for (int n = 0; n < ntx; ++n) {
|
| 1824 |
#pragma unroll
|
| 1825 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1826 |
+
const int k0 = k00 + k01;
|
| 1827 |
+
|
| 1828 |
+
A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K);
|
| 1829 |
+
}
|
| 1830 |
|
| 1831 |
#pragma unroll
|
| 1832 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1833 |
+
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1834 |
+
|
| 1835 |
+
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)];
|
| 1836 |
+
const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)];
|
| 1837 |
|
| 1838 |
+
const uint8_t * sc = (const uint8_t *) &sc_packed;
|
| 1839 |
+
const uint8_t * m = (const uint8_t *) &m_packed;
|
| 1840 |
|
| 1841 |
+
#pragma unroll
|
| 1842 |
+
for (int ksc = 0; ksc < sizeof(int); ++ksc) {
|
| 1843 |
+
scA[n][l][ksc] = sc[ksc];
|
| 1844 |
+
mA[n][l][ksc] = m[ksc];
|
| 1845 |
}
|
| 1846 |
}
|
| 1847 |
|
|
|
|
| 1849 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1850 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 1851 |
|
| 1852 |
+
dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K];
|
| 1853 |
}
|
| 1854 |
}
|
| 1855 |
|
|
|
|
| 1859 |
float tmpm[ntx][mma_C::ne] = {{0.0f}};
|
| 1860 |
|
| 1861 |
#pragma unroll
|
| 1862 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1863 |
+
const int k0 = k00 + k01;
|
| 1864 |
+
|
| 1865 |
mma_B B;
|
| 1866 |
half2 dsB[mma_C::ne/2];
|
| 1867 |
|
| 1868 |
+
B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
|
| 1869 |
|
| 1870 |
#pragma unroll
|
| 1871 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1872 |
const int j = j0 + mma_C::get_j(l);
|
| 1873 |
|
| 1874 |
+
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 1875 |
}
|
| 1876 |
|
| 1877 |
#pragma unroll
|
| 1878 |
for (int n = 0; n < ntx; ++n) {
|
| 1879 |
mma_C C;
|
| 1880 |
+
C.mma_K8(A[n][k01/8], B);
|
| 1881 |
|
| 1882 |
#pragma unroll
|
| 1883 |
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1884 |
+
tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
|
| 1885 |
+
tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
|
| 1886 |
}
|
| 1887 |
}
|
| 1888 |
}
|
|
|
|
| 1989 |
|
| 1990 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1991 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
| 1992 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1993 |
|
| 1994 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
| 1995 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1998 |
const int * y_qs = (const int *) y + 4;
|
| 1999 |
const float * y_df = (const float *) y;
|
| 2000 |
|
| 2001 |
+
// #pragma unroll
|
| 2002 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
|
| 2003 |
+
const int k0 = k00 + k01;
|
| 2004 |
+
|
| 2005 |
#pragma unroll
|
| 2006 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 2007 |
+
const int j = j0 + threadIdx.y;
|
| 2008 |
|
| 2009 |
#pragma unroll
|
| 2010 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 2011 |
+
const int i = i0 + threadIdx.x;
|
| 2012 |
|
| 2013 |
+
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
|
| 2014 |
|
| 2015 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
|
| 2016 |
+
&x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
|
| 2017 |
+
x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 2018 |
+
}
|
| 2019 |
}
|
| 2020 |
}
|
| 2021 |
}
|
| 2022 |
|
| 2023 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 2024 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 2025 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 2026 |
#ifdef INT8_MMA_AVAILABLE
|
| 2027 |
|
| 2028 |
typedef mma_int_A_I16K4 mma_A;
|
|
|
|
| 2043 |
|
| 2044 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
| 2045 |
|
| 2046 |
+
mma_A A[ntx][8];
|
| 2047 |
+
int scA[ntx][mma_C::ne/2][8];
|
| 2048 |
float dA[ntx][mma_C::ne/2];
|
| 2049 |
|
| 2050 |
#pragma unroll
|
| 2051 |
for (int n = 0; n < ntx; ++n) {
|
| 2052 |
#pragma unroll
|
| 2053 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 2054 |
+
const int k0 = k00 + k01;
|
| 2055 |
+
|
| 2056 |
+
A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
| 2057 |
+
A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
|
| 2058 |
+
}
|
| 2059 |
+
|
| 2060 |
+
#pragma unroll
|
| 2061 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
|
| 2062 |
+
const int k0 = k00 + k01;
|
| 2063 |
|
| 2064 |
#pragma unroll
|
| 2065 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 2066 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 2067 |
|
| 2068 |
+
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
|
| 2069 |
+
const int8_t * sc = (const int8_t *) &sc_packed;
|
| 2070 |
|
| 2071 |
+
#pragma unroll
|
| 2072 |
+
for (int ksc = 0; ksc < sizeof(int); ++ksc) {
|
| 2073 |
+
scA[n][l][k01/4 + ksc] = sc[ksc];
|
| 2074 |
+
}
|
| 2075 |
}
|
| 2076 |
}
|
| 2077 |
|
|
|
|
| 2079 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 2080 |
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
| 2081 |
|
| 2082 |
+
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
|
| 2083 |
}
|
| 2084 |
}
|
| 2085 |
|
|
|
|
| 2088 |
float tmp[ntx][mma_C::ne] = {{0.0f}};
|
| 2089 |
|
| 2090 |
#pragma unroll
|
| 2091 |
+
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 2092 |
mma_B B[2];
|
| 2093 |
float dB[mma_C::ne/2];
|
| 2094 |
|
| 2095 |
+
B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
| 2096 |
+
B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
|
|
|
|
| 2097 |
|
| 2098 |
#pragma unroll
|
| 2099 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 2100 |
const int j = j0 + mma_C::get_j(l);
|
| 2101 |
|
| 2102 |
+
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 2103 |
}
|
| 2104 |
|
| 2105 |
#pragma unroll
|
| 2106 |
for (int n = 0; n < ntx; ++n) {
|
| 2107 |
mma_C C[2];
|
| 2108 |
+
C[0].mma_K4(A[n][k01/4 + 0], B[0]);
|
| 2109 |
+
C[1].mma_K4(A[n][k01/4 + 1], B[1]);
|
| 2110 |
|
| 2111 |
#pragma unroll
|
| 2112 |
for (int l = 0; l < mma_C::ne; ++l) {
|
| 2113 |
+
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
|
| 2114 |
}
|
| 2115 |
}
|
| 2116 |
}
|
|
|
|
| 2158 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 2159 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2160 |
#ifdef INT8_MMA_AVAILABLE
|
| 2161 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2162 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 2163 |
#else
|
| 2164 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2165 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
|
|
|
| 2180 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
| 2181 |
|
| 2182 |
#ifdef INT8_MMA_AVAILABLE
|
| 2183 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
| 2184 |
#else
|
| 2185 |
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
|
| 2186 |
#endif // INT8_MMA_AVAILABLE
|
|
|
|
| 2216 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 2217 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2218 |
#ifdef INT8_MMA_AVAILABLE
|
| 2219 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2220 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 2221 |
#else
|
| 2222 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2223 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
|
|
|
| 2240 |
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
| 2241 |
|
| 2242 |
#ifdef INT8_MMA_AVAILABLE
|
| 2243 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
| 2244 |
#else
|
| 2245 |
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
| 2246 |
#endif // INT8_MMA_AVAILABLE
|
|
|
|
| 2336 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
| 2337 |
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
| 2338 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
| 2339 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 2340 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2341 |
};
|
| 2342 |
|
| 2343 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2344 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
| 2345 |
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
| 2346 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
| 2347 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 2348 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2349 |
};
|
| 2350 |
|
| 2351 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
|
|
| 2400 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
|
| 2401 |
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
| 2402 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
|
| 2403 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 2404 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2405 |
};
|
| 2406 |
|
| 2407 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2408 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
|
| 2409 |
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
| 2410 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
|
| 2411 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 2412 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2413 |
};
|
| 2414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2415 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
|
| 2416 |
static __device__ void mul_mat_q_process_tile(
|
| 2417 |
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
|
|
|
| 2419 |
const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
|
| 2420 |
|
| 2421 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
|
|
|
|
|
| 2422 |
constexpr int mmq_y = get_mmq_y_device();
|
|
|
|
| 2423 |
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
| 2424 |
|
| 2425 |
extern __shared__ char data_mul_mat_q[];
|
|
|
|
| 2434 |
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 2435 |
#endif // INT8_MMA_AVAILABLE
|
| 2436 |
|
| 2437 |
+
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2438 |
|
| 2439 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2440 |
|
|
|
|
| 2443 |
|
| 2444 |
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2445 |
|
| 2446 |
+
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
|
|
|
| 2447 |
load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
|
| 2448 |
|
| 2449 |
+
{
|
| 2450 |
+
const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
|
|
| 2451 |
#pragma unroll
|
| 2452 |
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
| 2453 |
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
| 2454 |
|
| 2455 |
tile_y[l] = by0[l];
|
| 2456 |
}
|
| 2457 |
+
}
|
| 2458 |
|
| 2459 |
+
__syncthreads();
|
| 2460 |
|
| 2461 |
+
vec_dot(tile_x, tile_y, sum, 0);
|
| 2462 |
+
|
| 2463 |
+
__syncthreads();
|
| 2464 |
+
|
| 2465 |
+
{
|
| 2466 |
+
const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2467 |
+
#pragma unroll
|
| 2468 |
+
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
| 2469 |
+
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
| 2470 |
|
| 2471 |
+
tile_y[l] = by0[l];
|
| 2472 |
+
}
|
| 2473 |
}
|
| 2474 |
+
|
| 2475 |
+
__syncthreads();
|
| 2476 |
+
|
| 2477 |
+
vec_dot(tile_x, tile_y, sum, WARP_SIZE);
|
| 2478 |
+
|
| 2479 |
+
__syncthreads();
|
| 2480 |
}
|
| 2481 |
|
| 2482 |
if (fixup) {
|
|
|
|
| 2512 |
}
|
| 2513 |
|
| 2514 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
|
|
| 2515 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2516 |
|
| 2517 |
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
|
|
|
| 2526 |
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
| 2527 |
|
| 2528 |
const int64_t blocks_per_ne00 = ne00 / qk;
|
| 2529 |
+
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2530 |
|
| 2531 |
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
|
| 2532 |
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
|
|
|
|
| 2535 |
int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
|
| 2536 |
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
|
| 2537 |
|
| 2538 |
+
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
| 2539 |
+
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
|
| 2540 |
|
| 2541 |
// kb0 == k index when doing the matrix multiplication for an output tile.
|
| 2542 |
int kb0_start = kbc % blocks_per_ne00;
|
|
|
|
| 2577 |
|
| 2578 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2579 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2580 |
+
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
|
|
| 2581 |
const int64_t blocks_per_ne00 = ne00 / qk;
|
| 2582 |
|
| 2583 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
|
|
|
| 2587 |
|
| 2588 |
bool any_fixup = false;
|
| 2589 |
|
| 2590 |
+
const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x);
|
| 2591 |
+
const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
|
| 2592 |
+
|
| 2593 |
+
int64_t kbc_0;
|
| 2594 |
+
int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
|
| 2595 |
|
| 2596 |
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
|
| 2597 |
+
kbc_0 = kbc_stop_0;
|
| 2598 |
+
kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
|
| 2599 |
|
| 2600 |
+
const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter;
|
| 2601 |
+
const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
|
| 2602 |
|
| 2603 |
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
|
| 2604 |
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
|
ggml/src/ggml-cuda/quantize.cu
CHANGED
|
@@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|
| 37 |
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
| 38 |
}
|
| 39 |
|
| 40 |
-
template <
|
| 41 |
static __global__ void quantize_mmq_q8_1(
|
| 42 |
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
if (ix0 >= kx0_padded) {
|
| 47 |
return;
|
| 48 |
}
|
| 49 |
|
|
|
|
|
|
|
| 50 |
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
| 51 |
|
| 52 |
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
| 53 |
|
| 54 |
-
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/
|
| 55 |
-
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;
|
| 56 |
-
const int64_t iqs = ix0 % (4*QK8_1);
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
amax =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
float sum;
|
| 64 |
-
if (
|
| 65 |
-
sum =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
}
|
| 67 |
|
| 68 |
-
const float
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
if (iqs % QK8_1 != 0) {
|
| 74 |
return;
|
| 75 |
}
|
| 76 |
|
| 77 |
-
if (
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
} else {
|
| 80 |
-
|
| 81 |
}
|
| 82 |
}
|
| 83 |
|
|
@@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
|
|
| 101 |
|
| 102 |
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
| 103 |
|
| 104 |
-
const int64_t block_num_x = (kx0_padded +
|
| 105 |
const dim3 num_blocks(block_num_x, kx1, channels);
|
| 106 |
-
const dim3 block_size(
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
}
|
| 112 |
}
|
|
|
|
| 37 |
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
| 38 |
}
|
| 39 |
|
| 40 |
+
template <mmq_q8_1_ds_layout ds_layout>
|
| 41 |
static __global__ void quantize_mmq_q8_1(
|
| 42 |
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
| 43 |
|
| 44 |
+
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
| 45 |
+
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
| 46 |
+
|
| 47 |
+
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
| 48 |
|
| 49 |
if (ix0 >= kx0_padded) {
|
| 50 |
return;
|
| 51 |
}
|
| 52 |
|
| 53 |
+
const float4 * x4 = (const float4 *) x;
|
| 54 |
+
|
| 55 |
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
| 56 |
|
| 57 |
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
| 58 |
|
| 59 |
+
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
| 60 |
+
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
| 61 |
+
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
| 62 |
+
|
| 63 |
+
// Load 4 floats per thread and calculate max. abs. value between them:
|
| 64 |
+
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
|
| 65 |
+
float amax = fabsf(xi.x);
|
| 66 |
+
amax = fmaxf(amax, fabsf(xi.y));
|
| 67 |
+
amax = fmaxf(amax, fabsf(xi.z));
|
| 68 |
+
amax = fmaxf(amax, fabsf(xi.w));
|
| 69 |
+
|
| 70 |
+
// Exchange max. abs. value between vals_per_scale/4 threads.
|
| 71 |
+
#pragma unroll
|
| 72 |
+
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
|
| 73 |
+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
| 74 |
+
}
|
| 75 |
|
| 76 |
float sum;
|
| 77 |
+
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
| 78 |
+
sum = xi.x + xi.y + xi.z + xi.w;
|
| 79 |
+
|
| 80 |
+
// Exchange calculate sum across vals_per_sum/4 threads.
|
| 81 |
+
#pragma unroll
|
| 82 |
+
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
|
| 83 |
+
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
|
| 84 |
+
}
|
| 85 |
}
|
| 86 |
|
| 87 |
+
const float d_inv = 127.0f / amax;
|
| 88 |
+
char4 q;
|
| 89 |
+
q.x = roundf(xi.x*d_inv);
|
| 90 |
+
q.y = roundf(xi.y*d_inv);
|
| 91 |
+
q.z = roundf(xi.z*d_inv);
|
| 92 |
+
q.w = roundf(xi.w*d_inv);
|
| 93 |
|
| 94 |
+
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
|
| 95 |
+
char4 * yqs4 = (char4 *) y[ib].qs;
|
| 96 |
+
yqs4[iqs/4] = q;
|
| 97 |
+
|
| 98 |
+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
|
| 99 |
+
if (iqs % 16 != 0 || iqs >= 96) {
|
| 100 |
+
return;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
y[ib].d2s6[2 + iqs/16] = sum;
|
| 104 |
+
|
| 105 |
+
if (iqs % 64 != 0) {
|
| 106 |
+
return;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
const float d = 1.0f / d_inv;
|
| 110 |
+
|
| 111 |
+
y[ib].d2s6[iqs/64] = d;
|
| 112 |
|
|
|
|
| 113 |
return;
|
| 114 |
}
|
| 115 |
|
| 116 |
+
if (iqs % 32 != 0) {
|
| 117 |
+
return;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
const float d = 1.0f / d_inv;
|
| 121 |
+
|
| 122 |
+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
|
| 123 |
+
y[ib].ds4[iqs/32] = make_half2(d, sum);
|
| 124 |
} else {
|
| 125 |
+
y[ib].d4[iqs/32] = d;
|
| 126 |
}
|
| 127 |
}
|
| 128 |
|
|
|
|
| 146 |
|
| 147 |
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
| 148 |
|
| 149 |
+
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
| 150 |
const dim3 num_blocks(block_num_x, kx1, channels);
|
| 151 |
+
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
| 152 |
+
switch (mmq_get_q8_1_ds_layout(type_x)) {
|
| 153 |
+
case MMQ_Q8_1_DS_LAYOUT_D4:
|
| 154 |
+
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
|
| 155 |
+
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
| 156 |
+
break;
|
| 157 |
+
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
| 158 |
+
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
|
| 159 |
+
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
| 160 |
+
break;
|
| 161 |
+
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
| 162 |
+
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
| 163 |
+
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
| 164 |
+
break;
|
| 165 |
+
default:
|
| 166 |
+
GGML_ASSERT(false);
|
| 167 |
+
break;
|
| 168 |
}
|
| 169 |
}
|
ggml/src/ggml-cuda/quantize.cuh
CHANGED
|
@@ -5,7 +5,11 @@
|
|
| 5 |
|
| 6 |
#include <cstdint>
|
| 7 |
|
| 8 |
-
#define CUDA_QUANTIZE_BLOCK_SIZE
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
typedef void (*quantize_cuda_t)(
|
| 11 |
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
|
|
|
| 5 |
|
| 6 |
#include <cstdint>
|
| 7 |
|
| 8 |
+
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
| 9 |
+
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
|
| 10 |
+
|
| 11 |
+
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
|
| 12 |
+
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
|
| 13 |
|
| 14 |
typedef void (*quantize_cuda_t)(
|
| 15 |
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
ggml/src/ggml-cuda/vecdotq.cuh
CHANGED
|
@@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
|
| 189 |
}
|
| 190 |
|
| 191 |
#define VDR_Q2_K_Q8_1_MMVQ 1
|
| 192 |
-
#define VDR_Q2_K_Q8_1_MMQ
|
| 193 |
|
| 194 |
// contiguous v/x values
|
| 195 |
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
|
@@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
|
| 219 |
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
| 220 |
}
|
| 221 |
|
| 222 |
-
// contiguous u/y values
|
|
|
|
| 223 |
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
| 224 |
-
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
|
| 225 |
|
| 226 |
-
float
|
| 227 |
-
float
|
| 228 |
|
| 229 |
#pragma unroll
|
| 230 |
-
for (int i0 = 0; i0 <
|
| 231 |
-
const float2
|
| 232 |
-
int
|
| 233 |
-
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
const int vi0 = v[i0/(QI8_1/2)];
|
| 236 |
#pragma unroll
|
| 237 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
}
|
|
|
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
}
|
| 246 |
|
| 247 |
-
return d8*
|
| 248 |
}
|
| 249 |
|
| 250 |
#define VDR_Q3_K_Q8_1_MMVQ 1
|
|
@@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
|
|
| 283 |
return d3 * sumf;
|
| 284 |
}
|
| 285 |
|
| 286 |
-
// contiguous u/y values
|
| 287 |
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
| 288 |
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
|
| 289 |
const float & d3, const float & d8) {
|
|
@@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
|
| 296 |
|
| 297 |
#pragma unroll
|
| 298 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 299 |
-
|
| 300 |
-
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
|
| 301 |
}
|
| 302 |
|
| 303 |
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
|
@@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
|
|
| 334 |
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
| 335 |
}
|
| 336 |
|
| 337 |
-
// contiguous u/y values
|
| 338 |
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
| 339 |
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
| 340 |
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
|
@@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
|
|
| 397 |
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
| 398 |
}
|
| 399 |
|
| 400 |
-
// contiguous u/y values
|
| 401 |
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
| 402 |
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
| 403 |
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
|
@@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
|
|
| 451 |
return d*sumf;
|
| 452 |
}
|
| 453 |
|
| 454 |
-
// contiguous u/y values
|
| 455 |
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
| 456 |
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
|
| 457 |
const float & d6, const float * __restrict__ d8) {
|
| 458 |
|
| 459 |
float sumf_d = 0.0f;
|
| 460 |
|
|
|
|
|
|
|
|
|
|
| 461 |
#pragma unroll
|
| 462 |
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
| 463 |
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
|
@@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
|
| 471 |
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
|
| 472 |
}
|
| 473 |
|
| 474 |
-
sumf_d += d8[i0/4] * (
|
| 475 |
}
|
| 476 |
|
| 477 |
return d6 * sumf_d;
|
|
|
|
| 189 |
}
|
| 190 |
|
| 191 |
#define VDR_Q2_K_Q8_1_MMVQ 1
|
| 192 |
+
#define VDR_Q2_K_Q8_1_MMQ 4
|
| 193 |
|
| 194 |
// contiguous v/x values
|
| 195 |
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
|
|
|
| 219 |
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
| 220 |
}
|
| 221 |
|
| 222 |
+
// contiguous v/x + u/y values
|
| 223 |
+
template <int ns8>
|
| 224 |
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
| 225 |
+
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
|
| 226 |
|
| 227 |
+
float sumf = 0.0f;
|
| 228 |
+
float sumf_d8 = 0.0f;
|
| 229 |
|
| 230 |
#pragma unroll
|
| 231 |
+
for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
|
| 232 |
+
const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
|
| 233 |
+
int sumi_d0 = 0;
|
| 234 |
+
|
| 235 |
+
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
|
| 236 |
+
int sumi_d1 = 0;
|
| 237 |
|
|
|
|
| 238 |
#pragma unroll
|
| 239 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 240 |
+
sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
|
| 241 |
+
}
|
| 242 |
+
sumf_d8 += dm2f0.x * sumi_d0;
|
| 243 |
+
|
| 244 |
+
#pragma unroll
|
| 245 |
+
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
| 246 |
+
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
|
| 247 |
}
|
| 248 |
+
sumf_d8 += dm2f1.x * sumi_d1;
|
| 249 |
|
| 250 |
+
if (i0/QI8_1 < ns8) {
|
| 251 |
+
const float2 s8f = __half22float2(s8[i0/QI8_1]);
|
| 252 |
+
sumf -= dm2f0.y*s8f.x;
|
| 253 |
+
sumf -= dm2f1.y*s8f.y;
|
| 254 |
+
} else {
|
| 255 |
+
int sumi_m0 = 0;
|
| 256 |
+
#pragma unroll
|
| 257 |
+
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 258 |
+
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
|
| 259 |
+
}
|
| 260 |
+
sumf_d8 -= dm2f0.y * sumi_m0;
|
| 261 |
+
|
| 262 |
+
int sumi_m1 = 0;
|
| 263 |
+
#pragma unroll
|
| 264 |
+
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
| 265 |
+
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
|
| 266 |
+
}
|
| 267 |
+
sumf_d8 -= dm2f1.y * sumi_m1;
|
| 268 |
+
}
|
| 269 |
}
|
| 270 |
|
| 271 |
+
return sumf + d8*sumf_d8;
|
| 272 |
}
|
| 273 |
|
| 274 |
#define VDR_Q3_K_Q8_1_MMVQ 1
|
|
|
|
| 307 |
return d3 * sumf;
|
| 308 |
}
|
| 309 |
|
| 310 |
+
// contiguous v/x + u/y values
|
| 311 |
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
| 312 |
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
|
| 313 |
const float & d3, const float & d8) {
|
|
|
|
| 320 |
|
| 321 |
#pragma unroll
|
| 322 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 323 |
+
sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
|
|
|
|
| 324 |
}
|
| 325 |
|
| 326 |
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
|
|
|
| 357 |
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
| 358 |
}
|
| 359 |
|
| 360 |
+
// contiguous v/x + u/y values
|
| 361 |
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
| 362 |
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
| 363 |
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
|
|
|
| 420 |
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
| 421 |
}
|
| 422 |
|
| 423 |
+
// contiguous v/x + u/y values
|
| 424 |
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
| 425 |
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
| 426 |
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
|
|
|
| 474 |
return d*sumf;
|
| 475 |
}
|
| 476 |
|
| 477 |
+
// contiguous v/x + u/y values
|
| 478 |
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
| 479 |
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
|
| 480 |
const float & d6, const float * __restrict__ d8) {
|
| 481 |
|
| 482 |
float sumf_d = 0.0f;
|
| 483 |
|
| 484 |
+
const int sc_packed = get_int_b4(sc, 0);
|
| 485 |
+
const int8_t * sc_reg = (const int8_t *) &sc_packed;
|
| 486 |
+
|
| 487 |
#pragma unroll
|
| 488 |
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
| 489 |
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
|
|
|
| 497 |
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
|
| 498 |
}
|
| 499 |
|
| 500 |
+
sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
|
| 501 |
}
|
| 502 |
|
| 503 |
return d6 * sumf_d;
|