compilade commited on
Commit
d1c244a
·
1 Parent(s): 061ca37

ggml-quants : ternary packing for TriLMs and BitNet b1.58 (llama/8151)

Browse files

* ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b

* ggml-quants : faster 1.625 bpw AVX2 vec_dot

Not using a lookup table anymore makes it match q4_0 speed.

* gguf-py : fix formatting

* llama : remove spaces on empty line

* ggml-quants : subtract 1 when back in epi8

This makes the 1.625 bpw type go faster than q4_0. Still not the fastest.

* ggml-quants : Q2_2 now faster than Q4_K on with AVX2

* ggml-quants : cleanup Q1_3 code formatting

* ggml-quants : ARM NEON vec_dot for q2_2 and q1_3

* ggml-quants : use ceiling division when quantizing q1_3

* convert-hf : simplify BitNet pre-quantization

This still results in the exact same tensor weights and scales,
but it reveals some weirdness in the current algorithm.

* convert-hf : allow converting the weird BitNet 1.3B

Its FFN size is 5460 which is not convenient.
The offending tensors are kept in F16,
which makes the final model 5.01 bpw.

* bitnet : replace 1.58b with b1.58, as in the paper

* ggml-quants : fix build failure on Windows

* ggml-quants : attempt to fix Arm 32-bit support

* ggml : add some informative comments in q1_3 vec_dot

* ggml : add TQ1_0 and TQ2_0 ternary quantization types

* ggml : even faster TQ2_0

* ggml : also faster TQ1_0

Same optimization as for TQ2_0 by offsetting the sum instead of the weights.
This makes TQ1_0 almost as fast as Q8_0 on AVX2.

* ggml : fix build issues in certain environments

* ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0

* ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat

The compiler seems smart enough to use the same instruction
even when using vget_high_s8 instead.

* ggml : remove q1_3 and q2_2

No more 1.625 bpw and 2.000 bpw,
now instead using 1.6875 bpw and 2.0625 bpw
with TQ1_0 and TQ2_0, respectively.

* llama : remove the separate scale tensors of BitNet b1.58

They won't be needed, since the remaining ternary quant types have
built-in scales.

* ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency

* ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot

Not yet tested on hardware which supports it,
might not work or might not even compile. But also it might.
It should make the performance better on recent ARM CPUs.

* ggml-quants : remove comment about possible format change of TQ2_0

Making it slightly more convenient for AVX512
but less convenient for everything else is not worth the trouble.

* gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0

* ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0

This does not change anything for ternary models,
since their values should never end up being in halfway cases anyway.

* convert : allow direct conversion to TQ1_0 and TQ2_0

The token embeddings and output tensors are kept in F16
to allow quantizing them to Q4_K and Q6_K with llama-quantize.

* llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0

Q4_0 is not completely symmetric (so not lossless for ternary models),
but it should be good enough.

* ggml-quants : allow using ARM dot product instructions for TQ1_0

* ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support

* ggml : remove unused ggml_mul special case

It would otherwise conflict with the more general
optimization coming with Mamba-2.

* ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators

* test-backend-ops : add TQ1_0 and TQ2_0 comments for later

Not yet adding uncommented, because some backends like SYCL and Metal
do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT.
(and Metal also doesn't handle it with GGML_OP_GET_ROWS)
Support for TQ1_0 and TQ2_0 for other backends than CPU
will be added in follow-up pull requests.

ggml/include/ggml.h CHANGED
@@ -395,6 +395,8 @@ extern "C" {
395
  GGML_TYPE_Q4_0_4_4 = 31,
396
  GGML_TYPE_Q4_0_4_8 = 32,
397
  GGML_TYPE_Q4_0_8_8 = 33,
 
 
398
  GGML_TYPE_COUNT,
399
  };
400
 
 
395
  GGML_TYPE_Q4_0_4_4 = 31,
396
  GGML_TYPE_Q4_0_4_8 = 32,
397
  GGML_TYPE_Q4_0_8_8 = 33,
398
+ GGML_TYPE_TQ1_0 = 34,
399
+ GGML_TYPE_TQ2_0 = 35,
400
  GGML_TYPE_COUNT,
401
  };
402
 
ggml/src/ggml-common.h CHANGED
@@ -227,6 +227,25 @@ typedef struct {
227
  } block_q8_0x8;
228
  static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  //
231
  // Super-block quantization structures
232
  //
@@ -361,6 +380,7 @@ typedef struct {
361
  } block_iq3_s;
362
  static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
363
 
 
364
  typedef struct {
365
  ggml_half d;
366
  uint8_t qs[QK_K/8];
 
227
  } block_q8_0x8;
228
  static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
229
 
230
+ //
231
+ // Ternary quantization
232
+ //
233
+
234
+ // 1.6875 bpw
235
+ typedef struct {
236
+ uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256)
237
+ uint8_t qh[QK_K/64]; // 4 elements per byte
238
+ ggml_half d;
239
+ } block_tq1_0;
240
+ static_assert(sizeof(block_tq1_0) == sizeof(ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding");
241
+
242
+ // 2.0625 bpw
243
+ typedef struct {
244
+ uint8_t qs[QK_K/4]; // 2 bits per element
245
+ ggml_half d;
246
+ } block_tq2_0;
247
+ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");
248
+
249
  //
250
  // Super-block quantization structures
251
  //
 
380
  } block_iq3_s;
381
  static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
382
 
383
+ // 1.5625 bpw
384
  typedef struct {
385
  ggml_half d;
386
  uint8_t qs[QK_K/8];
ggml/src/ggml-impl.h CHANGED
@@ -175,7 +175,7 @@ typedef __fp16 ggml_fp16_internal_t;
175
 
176
  // 32-bit ARM compatibility
177
 
178
- // vaddvq_s16
179
  // vpaddq_s16
180
  // vpaddq_s32
181
  // vaddvq_s32
@@ -185,12 +185,9 @@ typedef __fp16 ggml_fp16_internal_t;
185
  // vzip1_u8
186
  // vzip2_u8
187
 
188
- inline static int32_t vaddvq_s16(int16x8_t v) {
189
- return
190
- (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
191
- (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
192
- (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
193
- (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
194
  }
195
 
196
  inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
 
175
 
176
  // 32-bit ARM compatibility
177
 
178
+ // vaddlvq_s16
179
  // vpaddq_s16
180
  // vpaddq_s32
181
  // vaddvq_s32
 
185
  // vzip1_u8
186
  // vzip2_u8
187
 
188
+ inline static int32_t vaddlvq_s16(int16x8_t v) {
189
+ int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
190
+ return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
 
 
 
191
  }
192
 
193
  inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
ggml/src/ggml-quants.c CHANGED
@@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
1630
  // ===================== Helper functions
1631
  //
1632
  static inline int nearest_int(float fval) {
1633
- assert(fval <= 4194303.f);
1634
  float val = fval + 12582912.f;
1635
  int i; memcpy(&i, &val, sizeof(int));
1636
  return (i & 0x007fffff) - 0x00400000;
@@ -3306,6 +3306,191 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
3306
  return nrow * row_size;
3307
  }
3308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3309
  // ====================== "True" 2-bit (de)-quantization
3310
 
3311
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
@@ -5470,6 +5655,501 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5470
  *s = sumf;
5471
  }
5472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5473
  void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5474
  assert(nrc == 1);
5475
  UNUSED(nrc);
@@ -14800,6 +15480,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
14800
  }
14801
  }
14802
  } break;
 
 
 
 
 
 
 
 
14803
  case GGML_TYPE_IQ1_S:
14804
  {
14805
  VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
 
1630
  // ===================== Helper functions
1631
  //
1632
  static inline int nearest_int(float fval) {
1633
+ assert(fabsf(fval) <= 4194303.f);
1634
  float val = fval + 12582912.f;
1635
  int i; memcpy(&i, &val, sizeof(int));
1636
  return (i & 0x007fffff) - 0x00400000;
 
3306
  return nrow * row_size;
3307
  }
3308
 
3309
+ // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
3310
+
3311
+ void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
3312
+ assert(k % QK_K == 0);
3313
+ const int64_t nb = k / QK_K;
3314
+
3315
+ for (int64_t i = 0; i < nb; i++) {
3316
+ float amax = 0.0f; // absolute max
3317
+
3318
+ for (int j = 0; j < QK_K; j++) {
3319
+ const float v = x[j];
3320
+ amax = MAX(amax, fabsf(v));
3321
+ }
3322
+
3323
+ const float d = amax;
3324
+ const float id = d ? 1.0f/d : 0.0f;
3325
+
3326
+ y[i].d = GGML_FP32_TO_FP16(d);
3327
+
3328
+ // 5 elements per byte, along 32 bytes
3329
+ for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
3330
+ for (size_t m = 0; m < 32; ++m) {
3331
+ uint8_t q = 0;
3332
+ for (size_t n = 0; n < 5; ++n) {
3333
+ int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
3334
+ q *= 3;
3335
+ q += xi;
3336
+ }
3337
+ // ceiling division (243 == pow(3, 5))
3338
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
3339
+ y[i].qs[j + m] = q;
3340
+ }
3341
+ x += 5*32;
3342
+ }
3343
+ // along 16 bytes
3344
+ for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
3345
+ for (size_t m = 0; m < 16; ++m) {
3346
+ uint8_t q = 0;
3347
+ for (size_t n = 0; n < 5; ++n) {
3348
+ int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
3349
+ q *= 3;
3350
+ q += xi;
3351
+ }
3352
+ // ceiling division (243 == pow(3, 5))
3353
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
3354
+ y[i].qs[j + m] = q;
3355
+ }
3356
+ x += 5*16;
3357
+ }
3358
+ // 4 elements per byte
3359
+ for (size_t j = 0; j < sizeof(y->qh); ++j) {
3360
+ uint8_t q = 0;
3361
+ for (size_t m = 0; m < 4; ++m) {
3362
+ // -1, 0, 1 -> 0, 1, 2
3363
+ int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
3364
+ q *= 3;
3365
+ q += xi;
3366
+ }
3367
+ // shift the first value to the most significant trit
3368
+ q *= 3;
3369
+ // ceiling division (243 == pow(3, 5))
3370
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
3371
+ y[i].qh[j] = q;
3372
+ }
3373
+ x += 4*sizeof(y->qh);
3374
+ }
3375
+ }
3376
+
3377
+ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
3378
+ assert(k % QK_K == 0);
3379
+ const int64_t nb = k / QK_K;
3380
+
3381
+ for (int64_t i = 0; i < nb; i++) {
3382
+ float amax = 0.0f; // absolute max
3383
+
3384
+ for (int j = 0; j < QK_K; j++) {
3385
+ const float v = x[j];
3386
+ amax = MAX(amax, fabsf(v));
3387
+ }
3388
+
3389
+ const float d = amax;
3390
+ const float id = d ? 1.0f/d : 0.0f;
3391
+
3392
+ y[i].d = GGML_FP32_TO_FP16(d);
3393
+
3394
+ for (size_t j = 0; j < sizeof(y->qs); j += 32) {
3395
+ for (size_t m = 0; m < 32; ++m) {
3396
+ uint8_t q = 0;
3397
+ for (size_t n = 0; n < 4; ++n) {
3398
+ // -1, 0, 1 -> 0, 1, 2
3399
+ int xi = lroundf(x[m + n*32] * id) + 1;
3400
+ q += (xi & 3) << (2*n);
3401
+ }
3402
+ y[i].qs[j + m] = q;
3403
+ }
3404
+ x += 4*32;
3405
+ }
3406
+ }
3407
+ }
3408
+
3409
+ void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
3410
+ assert(k % QK_K == 0);
3411
+ block_tq1_0 * restrict y = vy;
3412
+ quantize_row_tq1_0_ref(x, y, k);
3413
+ }
3414
+
3415
+ void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
3416
+ assert(k % QK_K == 0);
3417
+ block_tq2_0 * restrict y = vy;
3418
+ quantize_row_tq2_0_ref(x, y, k);
3419
+ }
3420
+
3421
+ size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3422
+ (void)quant_weights; // not used
3423
+ const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
3424
+ quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row);
3425
+ return nrow * row_size;
3426
+ }
3427
+
3428
+ size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3429
+ (void)quant_weights; // not used
3430
+ const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
3431
+ quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row);
3432
+ return nrow * row_size;
3433
+ }
3434
+
3435
+
3436
+ void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
3437
+ assert(k % QK_K == 0);
3438
+ const int64_t nb = k / QK_K;
3439
+
3440
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
3441
+
3442
+ for (int64_t i = 0; i < nb; ++i) {
3443
+
3444
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3445
+
3446
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
3447
+ for (size_t n = 0; n < 5; ++n) {
3448
+ for (size_t m = 0; m < 32; ++m) {
3449
+ uint8_t q = x[i].qs[j + m] * pow3[n];
3450
+ int16_t xi = ((uint16_t) q * 3) >> 8;
3451
+ *y++ = (float) (xi - 1) * d;
3452
+ }
3453
+ }
3454
+ }
3455
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
3456
+ for (size_t n = 0; n < 5; ++n) {
3457
+ for (size_t m = 0; m < 16; ++m) {
3458
+ uint8_t q = x[i].qs[j + m] * pow3[n];
3459
+ int16_t xi = ((uint16_t) q * 3) >> 8;
3460
+ *y++ = (float) (xi - 1) * d;
3461
+ }
3462
+ }
3463
+ }
3464
+
3465
+ for (size_t n = 0; n < 4; ++n) {
3466
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
3467
+ uint8_t q = x[i].qh[j] * pow3[n];
3468
+ int16_t xi = ((uint16_t) q * 3) >> 8;
3469
+ *y++ = (float) (xi - 1) * d;
3470
+ }
3471
+ }
3472
+ }
3473
+ }
3474
+
3475
+ void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
3476
+ assert(k % QK_K == 0);
3477
+ const int64_t nb = k / QK_K;
3478
+
3479
+ for (int64_t i = 0; i < nb; ++i) {
3480
+
3481
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3482
+
3483
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
3484
+ for (size_t l = 0; l < 4; ++l) {
3485
+ for (size_t m = 0; m < 32; ++m) {
3486
+ int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
3487
+ *y++ = (float) (q - 1) * d;
3488
+ }
3489
+ }
3490
+ }
3491
+ }
3492
+ }
3493
+
3494
  // ====================== "True" 2-bit (de)-quantization
3495
 
3496
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
 
5655
  *s = sumf;
5656
  }
5657
 
5658
+ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5659
+ assert(nrc == 1);
5660
+ UNUSED(nrc);
5661
+ UNUSED(bx);
5662
+ UNUSED(by);
5663
+ UNUSED(bs);
5664
+
5665
+ const block_tq1_0 * restrict x = vx;
5666
+ const block_q8_K * restrict y = vy;
5667
+
5668
+ const int nb = n / QK_K;
5669
+
5670
+ #if defined(__ARM_NEON)
5671
+ float sumf = 0.0f;
5672
+
5673
+ uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
5674
+
5675
+ const uint8x16_t shift = vld1q_u8(k_shift);
5676
+
5677
+ for (int i = 0; i < nb; ++i) {
5678
+ #if defined(__ARM_FEATURE_DOTPROD)
5679
+ int32x4_t sumi0 = vdupq_n_s32(0);
5680
+ int32x4_t sumi1 = vdupq_n_s32(0);
5681
+ #else
5682
+ int16x8_t sumi0 = vdupq_n_s16(0);
5683
+ int16x8_t sumi1 = vdupq_n_s16(0);
5684
+ #endif
5685
+
5686
+ // first 32 bytes of 5 elements
5687
+ {
5688
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
5689
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
5690
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
5691
+ uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
5692
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
5693
+ uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
5694
+ uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
5695
+ uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
5696
+ uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
5697
+ uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
5698
+
5699
+ // multiply by 3 and keep the 2 bits above 8 bits
5700
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5701
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5702
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5703
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5704
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5705
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5706
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
5707
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
5708
+ int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
5709
+ int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
5710
+
5711
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
5712
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
5713
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
5714
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
5715
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
5716
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
5717
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
5718
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
5719
+ const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
5720
+ const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
5721
+
5722
+ #if defined(__ARM_FEATURE_DOTPROD)
5723
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5724
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5725
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5726
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5727
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5728
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5729
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
5730
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
5731
+ sumi0 = vdotq_s32(sumi0, sqx8, qy8);
5732
+ sumi1 = vdotq_s32(sumi1, sqx9, qy9);
5733
+ #else
5734
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
5735
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
5736
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
5737
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
5738
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
5739
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
5740
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
5741
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
5742
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
5743
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
5744
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
5745
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
5746
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
5747
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
5748
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
5749
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
5750
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
5751
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
5752
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
5753
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
5754
+ #endif
5755
+ }
5756
+
5757
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
5758
+ {
5759
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
5760
+ uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
5761
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
5762
+ uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
5763
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
5764
+ uint32_t qh;
5765
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
5766
+ uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
5767
+ qx5 = vmulq_u8(qx5, shift);
5768
+
5769
+ // multiply by 3 and keep the 2 bits above 8 bits
5770
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5771
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5772
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5773
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5774
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5775
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5776
+
5777
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
5778
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
5779
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
5780
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
5781
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
5782
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
5783
+
5784
+ #if defined(__ARM_FEATURE_DOTPROD)
5785
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5786
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5787
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5788
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5789
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5790
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5791
+ #else
5792
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
5793
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
5794
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
5795
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
5796
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
5797
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
5798
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
5799
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
5800
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
5801
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
5802
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
5803
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
5804
+ #endif
5805
+ }
5806
+
5807
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
5808
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
5809
+
5810
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5811
+
5812
+ #if defined(__ARM_FEATURE_DOTPROD)
5813
+ sumi0 = vaddq_s32(sumi0, sumi1);
5814
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
5815
+
5816
+ sumf += d * (float) vaddvq_s32(sumi0);
5817
+ #else
5818
+ sumi0 = vaddq_s16(sumi0, sumi1);
5819
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
5820
+
5821
+ sumf += d * (float) vaddlvq_s16(sumi0);
5822
+ #endif
5823
+ }
5824
+
5825
+ *s = sumf;
5826
+
5827
+ #elif defined(__AVX2__)
5828
+ __m256 sumf = _mm256_setzero_ps();
5829
+
5830
+ for (int i = 0; i < nb; ++i) {
5831
+ // 16-bit sums
5832
+ __m256i sumi0 = _mm256_setzero_si256();
5833
+ __m256i sumi1 = _mm256_setzero_si256();
5834
+ __m256i sumi2 = _mm256_setzero_si256();
5835
+
5836
+ // first 32 bytes of 5 elements
5837
+ {
5838
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
5839
+ // 8-bit multiplies with shifts, masks and adds
5840
+ __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
5841
+ __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
5842
+ __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
5843
+ __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
5844
+
5845
+ // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
5846
+
5847
+ // Cancel the +1 from avg so that it behaves like a halving add
5848
+ qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
5849
+ qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
5850
+ qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
5851
+ qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
5852
+ qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
5853
+ // Multiply by 3 and get the top 2 bits
5854
+ qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
5855
+ qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
5856
+ qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
5857
+ qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
5858
+ qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
5859
+ qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
5860
+ qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
5861
+ qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
5862
+ qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
5863
+ qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
5864
+
5865
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
5866
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
5867
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
5868
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
5869
+ const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
5870
+
5871
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
5872
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
5873
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
5874
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
5875
+ qx4 = _mm256_maddubs_epi16(qx4, qy4);
5876
+
5877
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
5878
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
5879
+ sumi2 = _mm256_add_epi16(sumi2, qx4);
5880
+ }
5881
+
5882
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
5883
+ {
5884
+ __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
5885
+ uint32_t qh;
5886
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
5887
+ __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
5888
+ __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
5889
+ __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
5890
+ __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
5891
+ __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
5892
+ __m256i qx01 = MM256_SET_M128I(qx1, qx0);
5893
+ __m256i qx23 = MM256_SET_M128I(qx3, qx2);
5894
+
5895
+ // avx2 does not have 8-bit multiplies, so 16-bit it is.
5896
+ qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
5897
+ qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
5898
+ __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
5899
+
5900
+ __m256i qx45 = MM256_SET_M128I(qx5, qx4);
5901
+
5902
+ // Cancel the +1 from avg so that it behaves like a halving add
5903
+ qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
5904
+ qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
5905
+ qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
5906
+ // Multiply by 3 and get the top 2 bits
5907
+ qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
5908
+ qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
5909
+ qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
5910
+ qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
5911
+ qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
5912
+ qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
5913
+
5914
+ const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
5915
+ const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
5916
+ const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
5917
+
5918
+ qx01 = _mm256_maddubs_epi16(qx01, qy01);
5919
+ qx23 = _mm256_maddubs_epi16(qx23, qy23);
5920
+ qx45 = _mm256_maddubs_epi16(qx45, qy45);
5921
+
5922
+ sumi0 = _mm256_add_epi16(sumi0, qx01);
5923
+ sumi1 = _mm256_add_epi16(sumi1, qx23);
5924
+ sumi2 = _mm256_add_epi16(sumi2, qx45);
5925
+ }
5926
+
5927
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
5928
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
5929
+
5930
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
5931
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
5932
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
5933
+
5934
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
5935
+ }
5936
+
5937
+ *s = hsum_float_8(sumf);
5938
+
5939
+ #else
5940
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
5941
+
5942
+ float sumf = 0.0f;
5943
+
5944
+ for (int i = 0; i < nb; ++i) {
5945
+ int sum = 0;
5946
+
5947
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
5948
+ for (size_t l = 0; l < 5; ++l) {
5949
+ for (size_t m = 0; m < 32; ++m) {
5950
+ uint8_t q = x[i].qs[j + m] * pow3[l];
5951
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
5952
+ sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
5953
+ }
5954
+ }
5955
+ }
5956
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
5957
+ for (size_t l = 0; l < 5; ++l) {
5958
+ for (size_t m = 0; m < 16; ++m) {
5959
+ uint8_t q = x[i].qs[j + m] * pow3[l];
5960
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
5961
+ sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
5962
+ }
5963
+ }
5964
+ }
5965
+
5966
+ for (size_t l = 0; l < 4; ++l) {
5967
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
5968
+ uint8_t q = x[i].qh[j] * pow3[l];
5969
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
5970
+ sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
5971
+ }
5972
+ }
5973
+
5974
+ sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d);
5975
+ }
5976
+
5977
+ *s = sumf;
5978
+ #endif
5979
+ }
5980
+
5981
+ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5982
+ assert(nrc == 1);
5983
+ UNUSED(nrc);
5984
+ UNUSED(bx);
5985
+ UNUSED(by);
5986
+ UNUSED(bs);
5987
+
5988
+ const block_tq2_0 * restrict x = vx;
5989
+ const block_q8_K * restrict y = vy;
5990
+
5991
+ const int nb = n / QK_K;
5992
+
5993
+ #if defined(__ARM_NEON)
5994
+ float sumf = 0.0f;
5995
+
5996
+ const uint8x16_t m3 = vdupq_n_u8(3);
5997
+
5998
+ for (int i = 0; i < nb; ++i) {
5999
+ #if defined(__ARM_FEATURE_DOTPROD)
6000
+ int32x4_t sumi0 = vdupq_n_s32(0);
6001
+ int32x4_t sumi1 = vdupq_n_s32(0);
6002
+ #else
6003
+ int16x8_t sumi0 = vdupq_n_s16(0);
6004
+ int16x8_t sumi1 = vdupq_n_s16(0);
6005
+ #endif
6006
+
6007
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6008
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
6009
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
6010
+ uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
6011
+ uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
6012
+ uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
6013
+ uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
6014
+ uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
6015
+ uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
6016
+
6017
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
6018
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
6019
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
6020
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
6021
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
6022
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
6023
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
6024
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
6025
+
6026
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
6027
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
6028
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
6029
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
6030
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
6031
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
6032
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
6033
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
6034
+
6035
+ #if defined(__ARM_FEATURE_DOTPROD)
6036
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
6037
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
6038
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
6039
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
6040
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
6041
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
6042
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
6043
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
6044
+ #else
6045
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
6046
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
6047
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
6048
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
6049
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
6050
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
6051
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
6052
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
6053
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
6054
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
6055
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
6056
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
6057
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
6058
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
6059
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
6060
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
6061
+ #endif
6062
+ }
6063
+
6064
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6065
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6066
+
6067
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6068
+
6069
+ #if defined(__ARM_FEATURE_DOTPROD)
6070
+ sumi0 = vaddq_s32(sumi0, sumi1);
6071
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6072
+
6073
+ sumf += d * (float) vaddvq_s32(sumi0);
6074
+ #else
6075
+ sumi0 = vaddq_s16(sumi0, sumi1);
6076
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
6077
+
6078
+ sumf += d * (float) vaddlvq_s16(sumi0);
6079
+ #endif
6080
+ }
6081
+
6082
+ *s = sumf;
6083
+
6084
+ #elif defined(__AVX2__)
6085
+ __m256 sumf = _mm256_setzero_ps();
6086
+
6087
+ for (int i = 0; i < nb; ++i) {
6088
+ // 16-bit sums, because 256*127 still fits
6089
+ __m256i sumi0 = _mm256_setzero_si256();
6090
+ __m256i sumi1 = _mm256_setzero_si256();
6091
+
6092
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6093
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
6094
+ __m256i qx1 = _mm256_srli_epi16(qx0, 2);
6095
+ __m256i qx2 = _mm256_srli_epi16(qx0, 4);
6096
+ __m256i qx3 = _mm256_srli_epi16(qx0, 6);
6097
+
6098
+ // 0, 1, 2 (should not be 3)
6099
+ qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
6100
+ qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
6101
+ qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
6102
+ qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
6103
+
6104
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
6105
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
6106
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
6107
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
6108
+
6109
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
6110
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
6111
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
6112
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
6113
+
6114
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
6115
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
6116
+ }
6117
+
6118
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
6119
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
6120
+
6121
+ sumi0 = _mm256_add_epi16(sumi0, sumi1);
6122
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
6123
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
6124
+
6125
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
6126
+ }
6127
+
6128
+ *s = hsum_float_8(sumf);
6129
+
6130
+ #else
6131
+ float sumf = 0.0f;
6132
+
6133
+ for (int i = 0; i < nb; ++i) {
6134
+ int32_t sumi = 0;
6135
+
6136
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6137
+ for (size_t l = 0; l < 4; ++l) {
6138
+ for (size_t k = 0; k < 32; ++k) {
6139
+ sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
6140
+ }
6141
+ }
6142
+ }
6143
+
6144
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6145
+
6146
+ sumf += (float) sumi * d;
6147
+ }
6148
+
6149
+ *s = sumf;
6150
+ #endif
6151
+ }
6152
+
6153
  void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6154
  assert(nrc == 1);
6155
  UNUSED(nrc);
 
15480
  }
15481
  }
15482
  } break;
15483
+ case GGML_TYPE_TQ1_0:
15484
+ {
15485
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
15486
+ } break;
15487
+ case GGML_TYPE_TQ2_0:
15488
+ {
15489
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
15490
+ } break;
15491
  case GGML_TYPE_IQ1_S:
15492
  {
15493
  VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
ggml/src/ggml-quants.h CHANGED
@@ -26,6 +26,9 @@ void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_REST
26
  void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
27
  void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
28
 
 
 
 
29
  void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
30
  void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
31
  void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
@@ -46,6 +49,9 @@ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
46
  void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
47
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
48
 
 
 
 
49
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
50
  void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
51
  void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -67,6 +73,9 @@ void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRI
67
  void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
68
  void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
69
 
 
 
 
70
  void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
71
  void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
72
  void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +99,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
90
  void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
91
  void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
92
 
 
 
 
93
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
94
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
95
  void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -111,6 +123,9 @@ size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
111
  size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
112
  size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
113
 
 
 
 
114
  size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
115
  size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
116
  size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
26
  void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
27
  void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
28
 
29
+ void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k);
30
+ void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k);
31
+
32
  void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
33
  void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
34
  void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
 
49
  void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
50
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
51
 
52
+ void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
53
+ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
54
+
55
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
56
  void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
57
  void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
73
  void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
74
  void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
75
 
76
+ void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
77
+ void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
78
+
79
  void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
80
  void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
81
  void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
99
  void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
100
  void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
101
 
102
+ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
103
+ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
104
+
105
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
106
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
107
  void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
123
  size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
124
  size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
125
 
126
+ size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
127
+ size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
128
+
129
  size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
130
  size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
131
  size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
ggml/src/ggml.c CHANGED
@@ -1054,7 +1054,31 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
1054
  .ncols = 8,
1055
  .gemv = ggml_gemv_q4_0_8x8_q8_0,
1056
  .gemm = ggml_gemm_q4_0_8x8_q8_0,
1057
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1058
  };
1059
 
1060
  // For internal test use
@@ -9888,6 +9912,8 @@ static void ggml_compute_forward_add(
9888
  case GGML_TYPE_Q4_K:
9889
  case GGML_TYPE_Q5_K:
9890
  case GGML_TYPE_Q6_K:
 
 
9891
  case GGML_TYPE_IQ2_XXS:
9892
  case GGML_TYPE_IQ2_XS:
9893
  case GGML_TYPE_IQ3_XXS:
@@ -10266,6 +10292,8 @@ static void ggml_compute_forward_add1(
10266
  case GGML_TYPE_Q4_K:
10267
  case GGML_TYPE_Q5_K:
10268
  case GGML_TYPE_Q6_K:
 
 
10269
  case GGML_TYPE_IQ2_XXS:
10270
  case GGML_TYPE_IQ2_XS:
10271
  case GGML_TYPE_IQ3_XXS:
@@ -10394,6 +10422,8 @@ static void ggml_compute_forward_acc(
10394
  case GGML_TYPE_Q4_K:
10395
  case GGML_TYPE_Q5_K:
10396
  case GGML_TYPE_Q6_K:
 
 
10397
  case GGML_TYPE_IQ2_XXS:
10398
  case GGML_TYPE_IQ2_XS:
10399
  case GGML_TYPE_IQ3_XXS:
@@ -13374,6 +13404,8 @@ static void ggml_compute_forward_out_prod(
13374
  case GGML_TYPE_Q4_K:
13375
  case GGML_TYPE_Q5_K:
13376
  case GGML_TYPE_Q6_K:
 
 
13377
  case GGML_TYPE_IQ2_XXS:
13378
  case GGML_TYPE_IQ2_XS:
13379
  case GGML_TYPE_IQ3_XXS:
@@ -13562,6 +13594,8 @@ static void ggml_compute_forward_set(
13562
  case GGML_TYPE_Q4_K:
13563
  case GGML_TYPE_Q5_K:
13564
  case GGML_TYPE_Q6_K:
 
 
13565
  case GGML_TYPE_IQ2_XXS:
13566
  case GGML_TYPE_IQ2_XS:
13567
  case GGML_TYPE_IQ3_XXS:
@@ -13824,6 +13858,8 @@ static void ggml_compute_forward_get_rows(
13824
  case GGML_TYPE_Q4_K:
13825
  case GGML_TYPE_Q5_K:
13826
  case GGML_TYPE_Q6_K:
 
 
13827
  case GGML_TYPE_IQ2_XXS:
13828
  case GGML_TYPE_IQ2_XS:
13829
  case GGML_TYPE_IQ3_XXS:
@@ -14413,6 +14449,8 @@ static void ggml_compute_forward_clamp(
14413
  case GGML_TYPE_Q4_K:
14414
  case GGML_TYPE_Q5_K:
14415
  case GGML_TYPE_Q6_K:
 
 
14416
  case GGML_TYPE_IQ2_XXS:
14417
  case GGML_TYPE_IQ2_XS:
14418
  case GGML_TYPE_IQ3_XXS:
@@ -21853,6 +21891,8 @@ size_t ggml_quantize_chunk(
21853
  case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21854
  case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21855
  case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
 
 
21856
  case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21857
  case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21858
  case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
 
1054
  .ncols = 8,
1055
  .gemv = ggml_gemv_q4_0_8x8_q8_0,
1056
  .gemm = ggml_gemm_q4_0_8x8_q8_0,
1057
+ },
1058
+ [GGML_TYPE_TQ1_0] = {
1059
+ .type_name = "tq1_0",
1060
+ .blck_size = QK_K,
1061
+ .type_size = sizeof(block_tq1_0),
1062
+ .is_quantized = true,
1063
+ .to_float = (ggml_to_float_t) dequantize_row_tq1_0,
1064
+ .from_float = quantize_row_tq1_0,
1065
+ .from_float_ref = (ggml_from_float_t) quantize_row_tq1_0_ref,
1066
+ .vec_dot = ggml_vec_dot_tq1_0_q8_K,
1067
+ .vec_dot_type = GGML_TYPE_Q8_K,
1068
+ .nrows = 1,
1069
+ },
1070
+ [GGML_TYPE_TQ2_0] = {
1071
+ .type_name = "tq2_0",
1072
+ .blck_size = QK_K,
1073
+ .type_size = sizeof(block_tq2_0),
1074
+ .is_quantized = true,
1075
+ .to_float = (ggml_to_float_t) dequantize_row_tq2_0,
1076
+ .from_float = quantize_row_tq2_0,
1077
+ .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref,
1078
+ .vec_dot = ggml_vec_dot_tq2_0_q8_K,
1079
+ .vec_dot_type = GGML_TYPE_Q8_K,
1080
+ .nrows = 1,
1081
+ },
1082
  };
1083
 
1084
  // For internal test use
 
9912
  case GGML_TYPE_Q4_K:
9913
  case GGML_TYPE_Q5_K:
9914
  case GGML_TYPE_Q6_K:
9915
+ case GGML_TYPE_TQ1_0:
9916
+ case GGML_TYPE_TQ2_0:
9917
  case GGML_TYPE_IQ2_XXS:
9918
  case GGML_TYPE_IQ2_XS:
9919
  case GGML_TYPE_IQ3_XXS:
 
10292
  case GGML_TYPE_Q4_K:
10293
  case GGML_TYPE_Q5_K:
10294
  case GGML_TYPE_Q6_K:
10295
+ case GGML_TYPE_TQ1_0:
10296
+ case GGML_TYPE_TQ2_0:
10297
  case GGML_TYPE_IQ2_XXS:
10298
  case GGML_TYPE_IQ2_XS:
10299
  case GGML_TYPE_IQ3_XXS:
 
10422
  case GGML_TYPE_Q4_K:
10423
  case GGML_TYPE_Q5_K:
10424
  case GGML_TYPE_Q6_K:
10425
+ case GGML_TYPE_TQ1_0:
10426
+ case GGML_TYPE_TQ2_0:
10427
  case GGML_TYPE_IQ2_XXS:
10428
  case GGML_TYPE_IQ2_XS:
10429
  case GGML_TYPE_IQ3_XXS:
 
13404
  case GGML_TYPE_Q4_K:
13405
  case GGML_TYPE_Q5_K:
13406
  case GGML_TYPE_Q6_K:
13407
+ case GGML_TYPE_TQ1_0:
13408
+ case GGML_TYPE_TQ2_0:
13409
  case GGML_TYPE_IQ2_XXS:
13410
  case GGML_TYPE_IQ2_XS:
13411
  case GGML_TYPE_IQ3_XXS:
 
13594
  case GGML_TYPE_Q4_K:
13595
  case GGML_TYPE_Q5_K:
13596
  case GGML_TYPE_Q6_K:
13597
+ case GGML_TYPE_TQ1_0:
13598
+ case GGML_TYPE_TQ2_0:
13599
  case GGML_TYPE_IQ2_XXS:
13600
  case GGML_TYPE_IQ2_XS:
13601
  case GGML_TYPE_IQ3_XXS:
 
13858
  case GGML_TYPE_Q4_K:
13859
  case GGML_TYPE_Q5_K:
13860
  case GGML_TYPE_Q6_K:
13861
+ case GGML_TYPE_TQ1_0:
13862
+ case GGML_TYPE_TQ2_0:
13863
  case GGML_TYPE_IQ2_XXS:
13864
  case GGML_TYPE_IQ2_XS:
13865
  case GGML_TYPE_IQ3_XXS:
 
14449
  case GGML_TYPE_Q4_K:
14450
  case GGML_TYPE_Q5_K:
14451
  case GGML_TYPE_Q6_K:
14452
+ case GGML_TYPE_TQ1_0:
14453
+ case GGML_TYPE_TQ2_0:
14454
  case GGML_TYPE_IQ2_XXS:
14455
  case GGML_TYPE_IQ2_XS:
14456
  case GGML_TYPE_IQ3_XXS:
 
21891
  case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21892
  case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21893
  case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21894
+ case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21895
+ case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21896
  case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21897
  case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21898
  case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;