snadampal commited on
Commit
0d50a29
·
unverified ·
1 Parent(s): c276f12

ggml : add mmla kernels for quantized GEMM (llama/4966)

Browse files

* ggml: aarch64: implement smmla kernel for q8_0_q8_0 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q8_0_q8_0 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: aarch64: implement smmla kernel for q4_0_q8_0 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q4_0_q8_0 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q4_1_q8_1 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: update unit tests for the new vec_dot interface

* llama.cpp: add MATMUL_INT8 capability to system_info

Files changed (4) hide show
  1. ggml-quants.c +302 -18
  2. ggml-quants.h +14 -14
  3. ggml.c +114 -50
  4. ggml.h +4 -1
ggml-quants.c CHANGED
@@ -49,6 +49,8 @@
49
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
50
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
51
 
 
 
52
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
53
 
54
  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@@ -3677,15 +3679,88 @@ static inline __m128i get_scale_shuffle(int i) {
3677
  }
3678
  #endif
3679
 
3680
- void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3681
  const int qk = QK8_0;
3682
  const int nb = n / qk;
3683
 
3684
  assert(n % qk == 0);
 
 
 
 
 
3685
 
3686
  const block_q4_0 * restrict x = vx;
3687
  const block_q8_0 * restrict y = vy;
3688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3689
  #if defined(__ARM_NEON)
3690
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3691
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -3967,15 +4042,89 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
3967
  #endif
3968
  }
3969
 
3970
- void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3971
  const int qk = QK8_1;
3972
  const int nb = n / qk;
3973
 
3974
  assert(n % qk == 0);
 
 
 
 
 
3975
 
3976
  const block_q4_1 * restrict x = vx;
3977
  const block_q8_1 * restrict y = vy;
3978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3979
  // TODO: add WASM SIMD
3980
  #if defined(__ARM_NEON)
3981
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -4107,12 +4256,17 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
4107
  #endif
4108
  }
4109
 
4110
- void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4111
  const int qk = QK8_0;
4112
  const int nb = n / qk;
4113
 
4114
  assert(n % qk == 0);
4115
  assert(qk == QK5_0);
 
 
 
 
 
4116
 
4117
  const block_q5_0 * restrict x = vx;
4118
  const block_q8_0 * restrict y = vy;
@@ -4393,12 +4547,17 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
4393
  #endif
4394
  }
4395
 
4396
- void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4397
  const int qk = QK8_1;
4398
  const int nb = n / qk;
4399
 
4400
  assert(n % qk == 0);
4401
  assert(qk == QK5_1);
 
 
 
 
 
4402
 
4403
  const block_q5_1 * restrict x = vx;
4404
  const block_q8_1 * restrict y = vy;
@@ -4692,15 +4851,75 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
4692
  #endif
4693
  }
4694
 
4695
- void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4696
  const int qk = QK8_0;
4697
  const int nb = n / qk;
4698
 
4699
  assert(n % qk == 0);
 
 
 
 
 
4700
 
4701
  const block_q8_0 * restrict x = vx;
4702
  const block_q8_0 * restrict y = vy;
4703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4704
  #if defined(__ARM_NEON)
4705
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
4706
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -4795,7 +5014,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
4795
  }
4796
 
4797
  #if QK_K == 256
4798
- void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
 
 
 
 
 
4799
 
4800
  const block_q2_K * restrict x = vx;
4801
  const block_q8_K * restrict y = vy;
@@ -5171,7 +5395,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5171
 
5172
  #else
5173
 
5174
- void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
 
 
 
 
 
5175
 
5176
  const block_q2_K * restrict x = vx;
5177
  const block_q8_K * restrict y = vy;
@@ -5429,8 +5658,13 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5429
  #endif
5430
 
5431
  #if QK_K == 256
5432
- void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5433
  assert(n % QK_K == 0);
 
 
 
 
 
5434
 
5435
  const uint32_t kmask1 = 0x03030303;
5436
  const uint32_t kmask2 = 0x0f0f0f0f;
@@ -5949,8 +6183,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
5949
 
5950
  #else
5951
 
5952
- void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5953
  assert(n % QK_K == 0);
 
 
 
 
 
5954
 
5955
  const block_q3_K * restrict x = vx;
5956
  const block_q8_K * restrict y = vy;
@@ -6292,8 +6531,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
6292
  #endif
6293
 
6294
  #if QK_K == 256
6295
- void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6296
  assert(n % QK_K == 0);
 
 
 
 
 
6297
 
6298
  const block_q4_K * restrict x = vx;
6299
  const block_q8_K * restrict y = vy;
@@ -6648,8 +6892,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
6648
  #endif
6649
  }
6650
  #else
6651
- void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6652
  assert(n % QK_K == 0);
 
 
 
 
 
6653
 
6654
  const block_q4_K * restrict x = vx;
6655
  const block_q8_K * restrict y = vy;
@@ -6891,8 +7140,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
6891
  #endif
6892
 
6893
  #if QK_K == 256
6894
- void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6895
  assert(n % QK_K == 0);
 
 
 
 
 
6896
 
6897
  const block_q5_K * restrict x = vx;
6898
  const block_q8_K * restrict y = vy;
@@ -7311,8 +7565,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7311
 
7312
  #else
7313
 
7314
- void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7315
  assert(n % QK_K == 0);
 
 
 
 
 
7316
 
7317
  const block_q5_K * restrict x = vx;
7318
  const block_q8_K * restrict y = vy;
@@ -7577,8 +7836,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7577
 
7578
 
7579
  #if QK_K == 256
7580
- void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7581
  assert(n % QK_K == 0);
 
 
 
 
 
7582
 
7583
  const block_q6_K * restrict x = vx;
7584
  const block_q8_K * restrict y = vy;
@@ -8009,8 +8273,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
8009
 
8010
  #else
8011
 
8012
- void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8013
  assert(n % QK_K == 0);
 
 
 
 
 
8014
 
8015
  const block_q6_K * restrict x = vx;
8016
  const block_q8_K * restrict y = vy;
@@ -8339,8 +8608,13 @@ static const int8_t keven_signs_q2xs[1024] = {
8339
  1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8340
  };
8341
 
8342
- void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8343
  assert(n % QK_K == 0);
 
 
 
 
 
8344
 
8345
  const block_iq2_xxs * restrict x = vx;
8346
  const block_q8_K * restrict y = vy;
@@ -8462,8 +8736,13 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
8462
  #endif
8463
  }
8464
 
8465
- void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8466
  assert(n % QK_K == 0);
 
 
 
 
 
8467
 
8468
  const block_iq2_xs * restrict x = vx;
8469
  const block_q8_K * restrict y = vy;
@@ -8682,8 +8961,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8682
  }
8683
 
8684
  // TODO
8685
- void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8686
  assert(n % QK_K == 0);
 
 
 
 
 
8687
 
8688
  const block_iq3_xxs * restrict x = vx;
8689
  const block_q8_K * restrict y = vy;
 
49
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
50
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
51
 
52
+ #define UNUSED GGML_UNUSED
53
+
54
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
55
 
56
  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
3679
  }
3680
  #endif
3681
 
3682
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
3683
  const int qk = QK8_0;
3684
  const int nb = n / qk;
3685
 
3686
  assert(n % qk == 0);
3687
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
3688
+ assert((nrc == 2) || (nrc == 1));
3689
+ #else
3690
+ assert(nrc == 1);
3691
+ #endif
3692
 
3693
  const block_q4_0 * restrict x = vx;
3694
  const block_q8_0 * restrict y = vy;
3695
 
3696
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
3697
+ if (nrc == 2) {
3698
+ const block_q4_0 * restrict vx0 = vx;
3699
+ const block_q4_0 * restrict vx1 = vx + bx;
3700
+
3701
+ const block_q8_0 * restrict vy0 = vy;
3702
+ const block_q8_0 * restrict vy1 = vy + by;
3703
+
3704
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3705
+
3706
+ for (int i = 0; i < nb; i++) {
3707
+ const block_q4_0 * restrict b_x0 = &vx0[i];
3708
+ const block_q4_0 * restrict b_x1 = &vx1[i];
3709
+ const block_q8_0 * restrict b_y0 = &vy0[i];
3710
+ const block_q8_0 * restrict b_y1 = &vy1[i];
3711
+
3712
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3713
+ const int8x16_t s8b = vdupq_n_s8(0x8);
3714
+
3715
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
3716
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
3717
+
3718
+ // 4-bit -> 8-bit
3719
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3720
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3721
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3722
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3723
+
3724
+ // sub 8
3725
+ const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
3726
+ const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
3727
+ const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
3728
+ const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
3729
+
3730
+ // load y
3731
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
3732
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
3733
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3734
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3735
+
3736
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3737
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3738
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3739
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3740
+
3741
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3742
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3743
+
3744
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
3745
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
3746
+
3747
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
3748
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
3749
+
3750
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
3751
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
3752
+
3753
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
3754
+ l1, r1)), l2, r2)), l3, r3))), scale);
3755
+ }
3756
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3757
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3758
+
3759
+ vst1_f32(s, vget_low_f32(sumv2));
3760
+ vst1_f32(s + bs, vget_high_f32(sumv2));
3761
+ return;
3762
+ }
3763
+ #endif
3764
  #if defined(__ARM_NEON)
3765
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3766
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
4042
  #endif
4043
  }
4044
 
4045
+ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4046
  const int qk = QK8_1;
4047
  const int nb = n / qk;
4048
 
4049
  assert(n % qk == 0);
4050
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4051
+ assert((nrc == 2) || (nrc == 1));
4052
+ #else
4053
+ assert(nrc == 1);
4054
+ #endif
4055
 
4056
  const block_q4_1 * restrict x = vx;
4057
  const block_q8_1 * restrict y = vy;
4058
 
4059
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4060
+ if (nrc == 2) {
4061
+ const block_q4_1 * restrict vx0 = vx;
4062
+ const block_q4_1 * restrict vx1 = vx + bx;
4063
+ const block_q8_1 * restrict vy0 = vy;
4064
+ const block_q8_1 * restrict vy1 = vy + by;
4065
+
4066
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
4067
+ float32x4_t summs0 = vdupq_n_f32(0.0f);
4068
+
4069
+ for (int i = 0; i < nb; i++) {
4070
+ const block_q4_1 * restrict b_x0 = &vx0[i];
4071
+ const block_q4_1 * restrict b_x1 = &vx1[i];
4072
+ const block_q8_1 * restrict b_y0 = &vy0[i];
4073
+ const block_q8_1 * restrict b_y1 = &vy1[i];
4074
+
4075
+ float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
4076
+ GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
4077
+ GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
4078
+ GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
4079
+ summs0 += summs_t;
4080
+
4081
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
4082
+
4083
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
4084
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
4085
+
4086
+ // 4-bit -> 8-bit
4087
+ const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
4088
+ const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
4089
+ const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
4090
+ const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
4091
+
4092
+ // load y
4093
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4094
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4095
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4096
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4097
+
4098
+ // mmla into int32x4_t
4099
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4100
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4101
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4102
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4103
+
4104
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4105
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4106
+
4107
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4108
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4109
+
4110
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4111
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4112
+
4113
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4114
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4115
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
4116
+ l1, r1)), l2, r2)), l3, r3))), scale);
4117
+ }
4118
+
4119
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
4120
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4121
+ sumv2 = sumv2 + summs0;
4122
+
4123
+ vst1_f32(s, vget_low_f32(sumv2));
4124
+ vst1_f32(s + bs, vget_high_f32(sumv2));
4125
+ return;
4126
+ }
4127
+ #endif
4128
  // TODO: add WASM SIMD
4129
  #if defined(__ARM_NEON)
4130
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
 
4256
  #endif
4257
  }
4258
 
4259
+ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4260
  const int qk = QK8_0;
4261
  const int nb = n / qk;
4262
 
4263
  assert(n % qk == 0);
4264
  assert(qk == QK5_0);
4265
+ assert(nrc == 1);
4266
+ UNUSED(nrc);
4267
+ UNUSED(bx);
4268
+ UNUSED(by);
4269
+ UNUSED(bs);
4270
 
4271
  const block_q5_0 * restrict x = vx;
4272
  const block_q8_0 * restrict y = vy;
 
4547
  #endif
4548
  }
4549
 
4550
+ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4551
  const int qk = QK8_1;
4552
  const int nb = n / qk;
4553
 
4554
  assert(n % qk == 0);
4555
  assert(qk == QK5_1);
4556
+ assert(nrc == 1);
4557
+ UNUSED(nrc);
4558
+ UNUSED(bx);
4559
+ UNUSED(by);
4560
+ UNUSED(bs);
4561
 
4562
  const block_q5_1 * restrict x = vx;
4563
  const block_q8_1 * restrict y = vy;
 
4851
  #endif
4852
  }
4853
 
4854
+ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4855
  const int qk = QK8_0;
4856
  const int nb = n / qk;
4857
 
4858
  assert(n % qk == 0);
4859
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4860
+ assert((nrc == 2) || (nrc == 1));
4861
+ #else
4862
+ assert(nrc == 1);
4863
+ #endif
4864
 
4865
  const block_q8_0 * restrict x = vx;
4866
  const block_q8_0 * restrict y = vy;
4867
 
4868
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4869
+ if (nrc == 2) {
4870
+ const block_q8_0 * restrict vx0 = vx;
4871
+ const block_q8_0 * restrict vx1 = vx + bx;
4872
+ const block_q8_0 * restrict vy0 = vy;
4873
+ const block_q8_0 * restrict vy1 = vy + by;
4874
+
4875
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
4876
+
4877
+ for (int i = 0; i < nb; i++) {
4878
+ const block_q8_0 * restrict b_x0 = &vx0[i];
4879
+ const block_q8_0 * restrict b_y0 = &vy0[i];
4880
+
4881
+ const block_q8_0 * restrict b_x1 = &vx1[i];
4882
+ const block_q8_0 * restrict b_y1 = &vy1[i];
4883
+
4884
+ const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4885
+ const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4886
+ const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4887
+ const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4888
+
4889
+ // load y
4890
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4891
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4892
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4893
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4894
+
4895
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4896
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4897
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4898
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4899
+
4900
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4901
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4902
+
4903
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4904
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4905
+
4906
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4907
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4908
+
4909
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4910
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4911
+
4912
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
4913
+ l1, r1)), l2, r2)), l3, r3))), scale);
4914
+ }
4915
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
4916
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4917
+
4918
+ vst1_f32(s, vget_low_f32(sumv2));
4919
+ vst1_f32(s + bs, vget_high_f32(sumv2));
4920
+ return;
4921
+ }
4922
+ #endif
4923
  #if defined(__ARM_NEON)
4924
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
4925
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
5014
  }
5015
 
5016
  #if QK_K == 256
5017
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5018
+ assert(nrc == 1);
5019
+ UNUSED(nrc);
5020
+ UNUSED(bx);
5021
+ UNUSED(by);
5022
+ UNUSED(bs);
5023
 
5024
  const block_q2_K * restrict x = vx;
5025
  const block_q8_K * restrict y = vy;
 
5395
 
5396
  #else
5397
 
5398
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5399
+ assert(nrc == 1);
5400
+ UNUSED(nrc);
5401
+ UNUSED(bx);
5402
+ UNUSED(by);
5403
+ UNUSED(bs);
5404
 
5405
  const block_q2_K * restrict x = vx;
5406
  const block_q8_K * restrict y = vy;
 
5658
  #endif
5659
 
5660
  #if QK_K == 256
5661
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5662
  assert(n % QK_K == 0);
5663
+ assert(nrc == 1);
5664
+ UNUSED(nrc);
5665
+ UNUSED(bx);
5666
+ UNUSED(by);
5667
+ UNUSED(bs);
5668
 
5669
  const uint32_t kmask1 = 0x03030303;
5670
  const uint32_t kmask2 = 0x0f0f0f0f;
 
6183
 
6184
  #else
6185
 
6186
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6187
  assert(n % QK_K == 0);
6188
+ assert(nrc == 1);
6189
+ UNUSED(nrc);
6190
+ UNUSED(bx);
6191
+ UNUSED(by);
6192
+ UNUSED(bs);
6193
 
6194
  const block_q3_K * restrict x = vx;
6195
  const block_q8_K * restrict y = vy;
 
6531
  #endif
6532
 
6533
  #if QK_K == 256
6534
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6535
  assert(n % QK_K == 0);
6536
+ assert(nrc == 1);
6537
+ UNUSED(nrc);
6538
+ UNUSED(bx);
6539
+ UNUSED(by);
6540
+ UNUSED(bs);
6541
 
6542
  const block_q4_K * restrict x = vx;
6543
  const block_q8_K * restrict y = vy;
 
6892
  #endif
6893
  }
6894
  #else
6895
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6896
  assert(n % QK_K == 0);
6897
+ assert(nrc == 1);
6898
+ UNUSED(nrc);
6899
+ UNUSED(bx);
6900
+ UNUSED(by);
6901
+ UNUSED(bs);
6902
 
6903
  const block_q4_K * restrict x = vx;
6904
  const block_q8_K * restrict y = vy;
 
7140
  #endif
7141
 
7142
  #if QK_K == 256
7143
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
7144
  assert(n % QK_K == 0);
7145
+ assert(nrc == 1);
7146
+ UNUSED(nrc);
7147
+ UNUSED(bx);
7148
+ UNUSED(by);
7149
+ UNUSED(bs);
7150
 
7151
  const block_q5_K * restrict x = vx;
7152
  const block_q8_K * restrict y = vy;
 
7565
 
7566
  #else
7567
 
7568
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
7569
  assert(n % QK_K == 0);
7570
+ assert(nrc == 1);
7571
+ UNUSED(nrc);
7572
+ UNUSED(bx);
7573
+ UNUSED(by);
7574
+ UNUSED(bs);
7575
 
7576
  const block_q5_K * restrict x = vx;
7577
  const block_q8_K * restrict y = vy;
 
7836
 
7837
 
7838
  #if QK_K == 256
7839
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
7840
  assert(n % QK_K == 0);
7841
+ assert(nrc == 1);
7842
+ UNUSED(nrc);
7843
+ UNUSED(bx);
7844
+ UNUSED(by);
7845
+ UNUSED(bs);
7846
 
7847
  const block_q6_K * restrict x = vx;
7848
  const block_q8_K * restrict y = vy;
 
8273
 
8274
  #else
8275
 
8276
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8277
  assert(n % QK_K == 0);
8278
+ assert(nrc == 1);
8279
+ UNUSED(nrc);
8280
+ UNUSED(bx);
8281
+ UNUSED(by);
8282
+ UNUSED(bs);
8283
 
8284
  const block_q6_K * restrict x = vx;
8285
  const block_q8_K * restrict y = vy;
 
8608
  1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8609
  };
8610
 
8611
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8612
  assert(n % QK_K == 0);
8613
+ assert(nrc == 1);
8614
+ UNUSED(nrc);
8615
+ UNUSED(bx);
8616
+ UNUSED(by);
8617
+ UNUSED(bs);
8618
 
8619
  const block_iq2_xxs * restrict x = vx;
8620
  const block_q8_K * restrict y = vy;
 
8736
  #endif
8737
  }
8738
 
8739
+ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8740
  assert(n % QK_K == 0);
8741
+ assert(nrc == 1);
8742
+ UNUSED(nrc);
8743
+ UNUSED(bx);
8744
+ UNUSED(by);
8745
+ UNUSED(bs);
8746
 
8747
  const block_iq2_xs * restrict x = vx;
8748
  const block_q8_K * restrict y = vy;
 
8961
  }
8962
 
8963
  // TODO
8964
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8965
  assert(n % QK_K == 0);
8966
+ assert(nrc == 1);
8967
+ UNUSED(nrc);
8968
+ UNUSED(bx);
8969
+ UNUSED(by);
8970
+ UNUSED(bs);
8971
 
8972
  const block_iq3_xxs * restrict x = vx;
8973
  const block_q8_K * restrict y = vy;
ggml-quants.h CHANGED
@@ -245,20 +245,20 @@ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_
245
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
246
 
247
  // Dot product
248
- void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
249
- void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
250
- void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
251
- void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
252
- void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
253
-
254
- void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
255
- void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
256
- void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
257
- void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
258
- void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
259
- void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
260
- void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
261
- void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
262
 
263
  //
264
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
 
245
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
246
 
247
  // Dot product
248
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
249
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
250
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
251
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
252
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
253
+
254
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
255
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
256
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
257
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
258
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
259
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
260
+ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
261
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
262
 
263
  //
264
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
ggml.c CHANGED
@@ -428,8 +428,8 @@ int64_t ggml_cycles_per_ms(void) {
428
 
429
  static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
430
 
431
- static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
432
- static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
433
 
434
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
435
  [GGML_TYPE_I8] = {
@@ -457,6 +457,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
457
  .is_quantized = false,
458
  .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
459
  .vec_dot_type = GGML_TYPE_F32,
 
460
  },
461
  [GGML_TYPE_F16] = {
462
  .type_name = "f16",
@@ -468,6 +469,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
468
  .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
469
  .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
470
  .vec_dot_type = GGML_TYPE_F16,
 
471
  },
472
  [GGML_TYPE_Q4_0] = {
473
  .type_name = "q4_0",
@@ -479,6 +481,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
479
  .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
480
  .vec_dot = ggml_vec_dot_q4_0_q8_0,
481
  .vec_dot_type = GGML_TYPE_Q8_0,
 
 
 
 
 
482
  },
483
  [GGML_TYPE_Q4_1] = {
484
  .type_name = "q4_1",
@@ -490,6 +497,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
490
  .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
491
  .vec_dot = ggml_vec_dot_q4_1_q8_1,
492
  .vec_dot_type = GGML_TYPE_Q8_1,
 
 
 
 
 
493
  },
494
  [4] = { // GGML_TYPE_Q4_2
495
  .type_name = "DEPRECATED",
@@ -501,6 +513,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
501
  .from_float_reference = NULL,
502
  .vec_dot = NULL,
503
  .vec_dot_type = GGML_TYPE_COUNT,
 
504
  },
505
  [5] = { // GGML_TYPE_Q4_3
506
  .type_name = "DEPRECATED",
@@ -512,6 +525,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
512
  .from_float_reference = NULL,
513
  .vec_dot = NULL,
514
  .vec_dot_type = GGML_TYPE_COUNT,
 
515
  },
516
  [GGML_TYPE_Q5_0] = {
517
  .type_name = "q5_0",
@@ -523,6 +537,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
523
  .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
524
  .vec_dot = ggml_vec_dot_q5_0_q8_0,
525
  .vec_dot_type = GGML_TYPE_Q8_0,
 
526
  },
527
  [GGML_TYPE_Q5_1] = {
528
  .type_name = "q5_1",
@@ -534,6 +549,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
534
  .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
535
  .vec_dot = ggml_vec_dot_q5_1_q8_1,
536
  .vec_dot_type = GGML_TYPE_Q8_1,
 
537
  },
538
  [GGML_TYPE_Q8_0] = {
539
  .type_name = "q8_0",
@@ -545,6 +561,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
545
  .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
546
  .vec_dot = ggml_vec_dot_q8_0_q8_0,
547
  .vec_dot_type = GGML_TYPE_Q8_0,
 
 
 
 
 
548
  },
549
  [GGML_TYPE_Q8_1] = {
550
  .type_name = "q8_1",
@@ -554,6 +575,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
554
  .from_float = quantize_row_q8_1,
555
  .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
556
  .vec_dot_type = GGML_TYPE_Q8_1,
 
557
  },
558
  [GGML_TYPE_Q2_K] = {
559
  .type_name = "q2_K",
@@ -565,6 +587,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
565
  .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
566
  .vec_dot = ggml_vec_dot_q2_K_q8_K,
567
  .vec_dot_type = GGML_TYPE_Q8_K,
 
568
  },
569
  [GGML_TYPE_Q3_K] = {
570
  .type_name = "q3_K",
@@ -576,6 +599,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
576
  .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
577
  .vec_dot = ggml_vec_dot_q3_K_q8_K,
578
  .vec_dot_type = GGML_TYPE_Q8_K,
 
579
  },
580
  [GGML_TYPE_Q4_K] = {
581
  .type_name = "q4_K",
@@ -587,6 +611,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
587
  .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
588
  .vec_dot = ggml_vec_dot_q4_K_q8_K,
589
  .vec_dot_type = GGML_TYPE_Q8_K,
 
590
  },
591
  [GGML_TYPE_Q5_K] = {
592
  .type_name = "q5_K",
@@ -598,6 +623,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
598
  .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
599
  .vec_dot = ggml_vec_dot_q5_K_q8_K,
600
  .vec_dot_type = GGML_TYPE_Q8_K,
 
601
  },
602
  [GGML_TYPE_Q6_K] = {
603
  .type_name = "q6_K",
@@ -609,6 +635,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
609
  .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
610
  .vec_dot = ggml_vec_dot_q6_K_q8_K,
611
  .vec_dot_type = GGML_TYPE_Q8_K,
 
612
  },
613
  [GGML_TYPE_IQ2_XXS] = {
614
  .type_name = "iq2_xxs",
@@ -620,6 +647,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
620
  .from_float_reference = NULL,
621
  .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
622
  .vec_dot_type = GGML_TYPE_Q8_K,
 
623
  },
624
  [GGML_TYPE_IQ2_XS] = {
625
  .type_name = "iq2_xs",
@@ -631,6 +659,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
631
  .from_float_reference = NULL,
632
  .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
633
  .vec_dot_type = GGML_TYPE_Q8_K,
 
634
  },
635
  [GGML_TYPE_IQ3_XXS] = {
636
  .type_name = "iq3_xxs",
@@ -642,6 +671,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
642
  .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
643
  .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
644
  .vec_dot_type = GGML_TYPE_Q8_K,
 
645
  },
646
  [GGML_TYPE_Q8_K] = {
647
  .type_name = "q8_K",
@@ -1212,7 +1242,13 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
1212
  inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
1213
  inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
1214
 
1215
- static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
 
 
 
 
 
 
1216
  #ifdef GGML_SIMD
1217
  float sumf = 0.0f;
1218
  const int np = (n & ~(GGML_F32_STEP - 1));
@@ -1249,7 +1285,13 @@ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * rest
1249
  *s = sumf;
1250
  }
1251
 
1252
- static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
 
 
 
 
 
 
1253
  ggml_float sumf = 0.0;
1254
 
1255
  #if defined(GGML_SIMD)
@@ -1455,7 +1497,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
1455
  #endif
1456
  }
1457
 
1458
- inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
1459
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1460
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
1461
  inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
@@ -9992,6 +10034,7 @@ static void ggml_compute_forward_mul_mat(
9992
  ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
9993
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
9994
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
 
9995
 
9996
  GGML_ASSERT(ne0 == ne01);
9997
  GGML_ASSERT(ne1 == ne11);
@@ -10159,12 +10202,23 @@ static void ggml_compute_forward_mul_mat(
10159
  const int64_t blck_0 = 16;
10160
  const int64_t blck_1 = 16;
10161
 
 
 
 
 
 
 
 
 
 
 
10162
  // attempt to reduce false-sharing (does not seem to make a difference)
10163
- float tmp[16];
 
10164
 
10165
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
10166
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
10167
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
10168
  const int64_t i13 = (ir1/(ne12*ne1));
10169
  const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
10170
  const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
@@ -10187,17 +10241,19 @@ static void ggml_compute_forward_mul_mat(
10187
  (src1_cont || src1->type != vec_dot_type
10188
  ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10189
  : (i11*nb11 + i12*nb12 + i13*nb13));
10190
-
10191
  float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10192
 
10193
  //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10194
  // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
10195
  //}
10196
 
10197
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10198
- vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
 
 
 
 
10199
  }
10200
- memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10201
  }
10202
  }
10203
  }
@@ -10386,7 +10442,7 @@ static void ggml_compute_forward_mul_mat_id(
10386
  //}
10387
 
10388
  for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10389
- vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
10390
  }
10391
  memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10392
  }
@@ -11568,7 +11624,7 @@ static void ggml_compute_forward_soft_max_back_f32(
11568
 
11569
  // linear runtime, no additional memory
11570
  float dot_y_dy = 0;
11571
- ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
11572
  ggml_vec_cpy_f32 (nc, dx, dy);
11573
  ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11574
  ggml_vec_mul_f32 (nc, dx, dx, y);
@@ -12369,9 +12425,9 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
12369
  const int i1n = i10*ne11;
12370
  for (int i00 = 0; i00 < ne00; i00++) {
12371
  float v = 0;
12372
- ggml_vec_dot_f16(ne02, &v,
12373
- (ggml_fp16_t *) wdata_src + i1n,
12374
- (ggml_fp16_t *) wdata_kernel + i00*ne02);
12375
  dst_data[i10*s0 + i00] += v;
12376
  }
12377
  }
@@ -12466,9 +12522,9 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
12466
  const int i1n = i10*ne11;
12467
  for (int i00 = 0; i00 < ne00; i00++) {
12468
  float v = 0;
12469
- ggml_vec_dot_f32(ne02, &v,
12470
- wdata_src + i1n,
12471
- wdata_kernel + i00*ne02);
12472
  dst_data[i10*s0 + i00] += v;
12473
  }
12474
  }
@@ -12783,9 +12839,9 @@ static void ggml_compute_forward_conv_transpose_2d(
12783
  for (int i01 = 0; i01 < ne01; i01++) {
12784
  for (int i00 = 0; i00 < ne00; i00++) {
12785
  float v = 0;
12786
- ggml_vec_dot_f16(ne03, &v,
12787
- wdata_src + i1n,
12788
- wdata_kernel + i01*ne00*ne03 + i00*ne03);
12789
  dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
12790
  }
12791
  }
@@ -13214,9 +13270,9 @@ static void ggml_compute_forward_flash_attn_f32(
13214
  const int i1 = ik1;
13215
 
13216
  ggml_vec_dot_f32(neq0,
13217
- S + i1,
13218
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
13219
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
13220
  }
13221
 
13222
  // scale
@@ -13299,9 +13355,9 @@ static void ggml_compute_forward_flash_attn_f32(
13299
  const int iv3 = iq3;
13300
 
13301
  ggml_vec_dot_f32(masked_begin,
13302
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
13303
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
13304
- S);
13305
  }
13306
  }
13307
  }
@@ -13404,9 +13460,9 @@ static void ggml_compute_forward_flash_attn_f16(
13404
  const int i1 = ik1;
13405
 
13406
  ggml_vec_dot_f16(neq0,
13407
- S + i1,
13408
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
13409
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
13410
  }
13411
  } else {
13412
  for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
@@ -13508,9 +13564,9 @@ static void ggml_compute_forward_flash_attn_f16(
13508
  const int iv3 = iq3;
13509
 
13510
  ggml_vec_dot_f16(nev0,
13511
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
13512
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
13513
- S16);
13514
  }
13515
  } else {
13516
  for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
@@ -13652,9 +13708,9 @@ static void ggml_compute_forward_flash_ff_f16(
13652
  const int i1 = ib01;
13653
 
13654
  ggml_vec_dot_f16(nea0,
13655
- S + i1,
13656
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
13657
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
13658
  }
13659
 
13660
  ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
@@ -13677,9 +13733,9 @@ static void ggml_compute_forward_flash_ff_f16(
13677
  for (int64_t ic = 0; ic < nec01; ++ic) {
13678
 
13679
  ggml_vec_dot_f16(neb01,
13680
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
13681
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),
13682
- S16);
13683
  }
13684
 
13685
  ggml_vec_add_f32(nec01,
@@ -13866,9 +13922,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
13866
  const int i1 = ik1;
13867
 
13868
  ggml_vec_dot_f32(neq0,
13869
- S + i1,
13870
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
13871
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
13872
  }
13873
 
13874
  // scale
@@ -14013,7 +14069,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
14013
 
14014
  // S = SM * (S - dot(SM, S))
14015
  float dot_SM_gradSM = 0;
14016
- ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
14017
  ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14018
  ggml_vec_mul_f32 (masked_begin, S, S, SM);
14019
 
@@ -18382,7 +18438,7 @@ static enum ggml_opt_result linesearch_backtracking(
18382
  }
18383
 
18384
  // compute the initial gradient in the search direction
18385
- ggml_vec_dot_f32(nx, &dginit, g, d);
18386
 
18387
  // make sure that d points to a descent direction
18388
  if (0 < dginit) {
@@ -18432,7 +18488,7 @@ static enum ggml_opt_result linesearch_backtracking(
18432
  return count;
18433
  }
18434
 
18435
- ggml_vec_dot_f32(nx, &dg, g, d);
18436
 
18437
  // check the Wolfe condition
18438
  if (dg < params->lbfgs.wolfe * dginit) {
@@ -18693,8 +18749,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18693
  // ys = y^t \cdot s -> 1 / \rho.
18694
  // yy = y^t \cdot y.
18695
  //
18696
- ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
18697
- ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
18698
 
18699
  lm_ys[end[0]] = ys;
18700
 
@@ -18713,7 +18769,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18713
  for (int i = 0; i < bound; ++i) {
18714
  j[0] = (j[0] + m - 1) % m;
18715
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
18716
- ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
18717
  lm_alpha[j[0]] /= lm_ys[j[0]];
18718
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
18719
  ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
@@ -18723,7 +18779,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18723
 
18724
  for (int i = 0; i < bound; ++i) {
18725
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
18726
- ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
18727
  beta /= lm_ys[j[0]];
18728
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
18729
  ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
@@ -20621,4 +20677,12 @@ int ggml_cpu_has_vsx(void) {
20621
  #endif
20622
  }
20623
 
 
 
 
 
 
 
 
 
20624
  ////////////////////////////////////////////////////////////////////////////////
 
428
 
429
  static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
430
 
431
+ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
432
+ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
433
 
434
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
435
  [GGML_TYPE_I8] = {
 
457
  .is_quantized = false,
458
  .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
459
  .vec_dot_type = GGML_TYPE_F32,
460
+ .nrows = 1,
461
  },
462
  [GGML_TYPE_F16] = {
463
  .type_name = "f16",
 
469
  .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
470
  .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
471
  .vec_dot_type = GGML_TYPE_F16,
472
+ .nrows = 1,
473
  },
474
  [GGML_TYPE_Q4_0] = {
475
  .type_name = "q4_0",
 
481
  .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
482
  .vec_dot = ggml_vec_dot_q4_0_q8_0,
483
  .vec_dot_type = GGML_TYPE_Q8_0,
484
+ #if defined (__ARM_FEATURE_MATMUL_INT8)
485
+ .nrows = 2,
486
+ #else
487
+ .nrows = 1,
488
+ #endif
489
  },
490
  [GGML_TYPE_Q4_1] = {
491
  .type_name = "q4_1",
 
497
  .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
498
  .vec_dot = ggml_vec_dot_q4_1_q8_1,
499
  .vec_dot_type = GGML_TYPE_Q8_1,
500
+ #if defined (__ARM_FEATURE_MATMUL_INT8)
501
+ .nrows = 2,
502
+ #else
503
+ .nrows = 1,
504
+ #endif
505
  },
506
  [4] = { // GGML_TYPE_Q4_2
507
  .type_name = "DEPRECATED",
 
513
  .from_float_reference = NULL,
514
  .vec_dot = NULL,
515
  .vec_dot_type = GGML_TYPE_COUNT,
516
+ .nrows = 1,
517
  },
518
  [5] = { // GGML_TYPE_Q4_3
519
  .type_name = "DEPRECATED",
 
525
  .from_float_reference = NULL,
526
  .vec_dot = NULL,
527
  .vec_dot_type = GGML_TYPE_COUNT,
528
+ .nrows = 1,
529
  },
530
  [GGML_TYPE_Q5_0] = {
531
  .type_name = "q5_0",
 
537
  .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
538
  .vec_dot = ggml_vec_dot_q5_0_q8_0,
539
  .vec_dot_type = GGML_TYPE_Q8_0,
540
+ .nrows = 1,
541
  },
542
  [GGML_TYPE_Q5_1] = {
543
  .type_name = "q5_1",
 
549
  .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
550
  .vec_dot = ggml_vec_dot_q5_1_q8_1,
551
  .vec_dot_type = GGML_TYPE_Q8_1,
552
+ .nrows = 1,
553
  },
554
  [GGML_TYPE_Q8_0] = {
555
  .type_name = "q8_0",
 
561
  .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
562
  .vec_dot = ggml_vec_dot_q8_0_q8_0,
563
  .vec_dot_type = GGML_TYPE_Q8_0,
564
+ #if defined (__ARM_FEATURE_MATMUL_INT8)
565
+ .nrows = 2,
566
+ #else
567
+ .nrows = 1,
568
+ #endif
569
  },
570
  [GGML_TYPE_Q8_1] = {
571
  .type_name = "q8_1",
 
575
  .from_float = quantize_row_q8_1,
576
  .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
577
  .vec_dot_type = GGML_TYPE_Q8_1,
578
+ .nrows = 1,
579
  },
580
  [GGML_TYPE_Q2_K] = {
581
  .type_name = "q2_K",
 
587
  .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
588
  .vec_dot = ggml_vec_dot_q2_K_q8_K,
589
  .vec_dot_type = GGML_TYPE_Q8_K,
590
+ .nrows = 1,
591
  },
592
  [GGML_TYPE_Q3_K] = {
593
  .type_name = "q3_K",
 
599
  .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
600
  .vec_dot = ggml_vec_dot_q3_K_q8_K,
601
  .vec_dot_type = GGML_TYPE_Q8_K,
602
+ .nrows = 1,
603
  },
604
  [GGML_TYPE_Q4_K] = {
605
  .type_name = "q4_K",
 
611
  .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
612
  .vec_dot = ggml_vec_dot_q4_K_q8_K,
613
  .vec_dot_type = GGML_TYPE_Q8_K,
614
+ .nrows = 1,
615
  },
616
  [GGML_TYPE_Q5_K] = {
617
  .type_name = "q5_K",
 
623
  .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
624
  .vec_dot = ggml_vec_dot_q5_K_q8_K,
625
  .vec_dot_type = GGML_TYPE_Q8_K,
626
+ .nrows = 1,
627
  },
628
  [GGML_TYPE_Q6_K] = {
629
  .type_name = "q6_K",
 
635
  .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
636
  .vec_dot = ggml_vec_dot_q6_K_q8_K,
637
  .vec_dot_type = GGML_TYPE_Q8_K,
638
+ .nrows = 1,
639
  },
640
  [GGML_TYPE_IQ2_XXS] = {
641
  .type_name = "iq2_xxs",
 
647
  .from_float_reference = NULL,
648
  .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
649
  .vec_dot_type = GGML_TYPE_Q8_K,
650
+ .nrows = 1,
651
  },
652
  [GGML_TYPE_IQ2_XS] = {
653
  .type_name = "iq2_xs",
 
659
  .from_float_reference = NULL,
660
  .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
661
  .vec_dot_type = GGML_TYPE_Q8_K,
662
+ .nrows = 1,
663
  },
664
  [GGML_TYPE_IQ3_XXS] = {
665
  .type_name = "iq3_xxs",
 
671
  .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
672
  .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
673
  .vec_dot_type = GGML_TYPE_Q8_K,
674
+ .nrows = 1,
675
  },
676
  [GGML_TYPE_Q8_K] = {
677
  .type_name = "q8_K",
 
1242
  inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
1243
  inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
1244
 
1245
+ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
1246
+ assert(nrc == 1);
1247
+ UNUSED(nrc);
1248
+ UNUSED(bx);
1249
+ UNUSED(by);
1250
+ UNUSED(bs);
1251
+
1252
  #ifdef GGML_SIMD
1253
  float sumf = 0.0f;
1254
  const int np = (n & ~(GGML_F32_STEP - 1));
 
1285
  *s = sumf;
1286
  }
1287
 
1288
+ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1289
+ assert(nrc == 1);
1290
+ UNUSED(nrc);
1291
+ UNUSED(bx);
1292
+ UNUSED(by);
1293
+ UNUSED(bs);
1294
+
1295
  ggml_float sumf = 0.0;
1296
 
1297
  #if defined(GGML_SIMD)
 
1497
  #endif
1498
  }
1499
 
1500
+ inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
1501
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1502
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
1503
  inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
 
10034
  ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10035
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10036
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10037
+ int64_t const vec_dot_num_rows = type_traits[type].nrows;
10038
 
10039
  GGML_ASSERT(ne0 == ne01);
10040
  GGML_ASSERT(ne1 == ne11);
 
10202
  const int64_t blck_0 = 16;
10203
  const int64_t blck_1 = 16;
10204
 
10205
+ // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
10206
+ int64_t nrc = vec_dot_num_rows;
10207
+ // TODO: currently the mmla kernels support only even numbered rows/cols.
10208
+ // this check can be removed once they are extended to support odd numbered rows/cols too
10209
+ if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
10210
+ nrc = 1;
10211
+ }
10212
+
10213
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
10214
+
10215
  // attempt to reduce false-sharing (does not seem to make a difference)
10216
+ // 16 * 2, accounting for mmla kernels
10217
+ float tmp[32];
10218
 
10219
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
10220
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
10221
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
10222
  const int64_t i13 = (ir1/(ne12*ne1));
10223
  const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
10224
  const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
 
10241
  (src1_cont || src1->type != vec_dot_type
10242
  ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10243
  : (i11*nb11 + i12*nb12 + i13*nb13));
 
10244
  float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10245
 
10246
  //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10247
  // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
10248
  //}
10249
 
10250
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
10251
+ vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
10252
+ }
10253
+
10254
+ for (int cn = 0; cn < nrc; ++cn) {
10255
+ memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10256
  }
 
10257
  }
10258
  }
10259
  }
 
10442
  //}
10443
 
10444
  for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10445
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
10446
  }
10447
  memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10448
  }
 
11624
 
11625
  // linear runtime, no additional memory
11626
  float dot_y_dy = 0;
11627
+ ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
11628
  ggml_vec_cpy_f32 (nc, dx, dy);
11629
  ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11630
  ggml_vec_mul_f32 (nc, dx, dx, y);
 
12425
  const int i1n = i10*ne11;
12426
  for (int i00 = 0; i00 < ne00; i00++) {
12427
  float v = 0;
12428
+ ggml_vec_dot_f16(ne02, &v, 0,
12429
+ (ggml_fp16_t *) wdata_src + i1n, 0,
12430
+ (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
12431
  dst_data[i10*s0 + i00] += v;
12432
  }
12433
  }
 
12522
  const int i1n = i10*ne11;
12523
  for (int i00 = 0; i00 < ne00; i00++) {
12524
  float v = 0;
12525
+ ggml_vec_dot_f32(ne02, &v, 0,
12526
+ wdata_src + i1n, 0,
12527
+ wdata_kernel + i00*ne02, 0, 1);
12528
  dst_data[i10*s0 + i00] += v;
12529
  }
12530
  }
 
12839
  for (int i01 = 0; i01 < ne01; i01++) {
12840
  for (int i00 = 0; i00 < ne00; i00++) {
12841
  float v = 0;
12842
+ ggml_vec_dot_f16(ne03, &v, 0,
12843
+ wdata_src + i1n, 0,
12844
+ wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
12845
  dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
12846
  }
12847
  }
 
13270
  const int i1 = ik1;
13271
 
13272
  ggml_vec_dot_f32(neq0,
13273
+ S + i1, 0,
13274
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
13275
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
13276
  }
13277
 
13278
  // scale
 
13355
  const int iv3 = iq3;
13356
 
13357
  ggml_vec_dot_f32(masked_begin,
13358
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
13359
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
13360
+ S, 0, 1);
13361
  }
13362
  }
13363
  }
 
13460
  const int i1 = ik1;
13461
 
13462
  ggml_vec_dot_f16(neq0,
13463
+ S + i1, 0,
13464
+ (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
13465
+ (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
13466
  }
13467
  } else {
13468
  for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
 
13564
  const int iv3 = iq3;
13565
 
13566
  ggml_vec_dot_f16(nev0,
13567
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
13568
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
13569
+ S16, 0, 1);
13570
  }
13571
  } else {
13572
  for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
 
13708
  const int i1 = ib01;
13709
 
13710
  ggml_vec_dot_f16(nea0,
13711
+ S + i1, 0,
13712
+ (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
13713
+ (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
13714
  }
13715
 
13716
  ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
 
13733
  for (int64_t ic = 0; ic < nec01; ++ic) {
13734
 
13735
  ggml_vec_dot_f16(neb01,
13736
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
13737
+ (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
13738
+ S16, 0, 1);
13739
  }
13740
 
13741
  ggml_vec_add_f32(nec01,
 
13922
  const int i1 = ik1;
13923
 
13924
  ggml_vec_dot_f32(neq0,
13925
+ S + i1, 0,
13926
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
13927
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
13928
  }
13929
 
13930
  // scale
 
14069
 
14070
  // S = SM * (S - dot(SM, S))
14071
  float dot_SM_gradSM = 0;
14072
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
14073
  ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14074
  ggml_vec_mul_f32 (masked_begin, S, S, SM);
14075
 
 
18438
  }
18439
 
18440
  // compute the initial gradient in the search direction
18441
+ ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1);
18442
 
18443
  // make sure that d points to a descent direction
18444
  if (0 < dginit) {
 
18488
  return count;
18489
  }
18490
 
18491
+ ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1);
18492
 
18493
  // check the Wolfe condition
18494
  if (dg < params->lbfgs.wolfe * dginit) {
 
18749
  // ys = y^t \cdot s -> 1 / \rho.
18750
  // yy = y^t \cdot y.
18751
  //
18752
+ ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1);
18753
+ ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1);
18754
 
18755
  lm_ys[end[0]] = ys;
18756
 
 
18769
  for (int i = 0; i < bound; ++i) {
18770
  j[0] = (j[0] + m - 1) % m;
18771
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
18772
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1);
18773
  lm_alpha[j[0]] /= lm_ys[j[0]];
18774
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
18775
  ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
 
18779
 
18780
  for (int i = 0; i < bound; ++i) {
18781
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
18782
+ ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1);
18783
  beta /= lm_ys[j[0]];
18784
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
18785
  ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
 
20677
  #endif
20678
  }
20679
 
20680
+ int ggml_cpu_has_matmul_int8(void) {
20681
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
20682
+ return 1;
20683
+ #else
20684
+ return 0;
20685
+ #endif
20686
+ }
20687
+
20688
  ////////////////////////////////////////////////////////////////////////////////
ggml.h CHANGED
@@ -2290,6 +2290,7 @@ extern "C" {
2290
  GGML_API int ggml_cpu_has_ssse3 (void);
2291
  GGML_API int ggml_cpu_has_sycl (void);
2292
  GGML_API int ggml_cpu_has_vsx (void);
 
2293
 
2294
  //
2295
  // Internal types and functions exposed for tests and benchmarks
@@ -2303,7 +2304,8 @@ extern "C" {
2303
  #endif
2304
  typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
2305
  typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
2306
- typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
 
2307
 
2308
  typedef struct {
2309
  const char * type_name;
@@ -2315,6 +2317,7 @@ extern "C" {
2315
  ggml_from_float_t from_float_reference;
2316
  ggml_vec_dot_t vec_dot;
2317
  enum ggml_type vec_dot_type;
 
2318
  } ggml_type_traits_t;
2319
 
2320
  GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
 
2290
  GGML_API int ggml_cpu_has_ssse3 (void);
2291
  GGML_API int ggml_cpu_has_sycl (void);
2292
  GGML_API int ggml_cpu_has_vsx (void);
2293
+ GGML_API int ggml_cpu_has_matmul_int8(void);
2294
 
2295
  //
2296
  // Internal types and functions exposed for tests and benchmarks
 
2304
  #endif
2305
  typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
2306
  typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
2307
+ typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
2308
+ const void * GGML_RESTRICT y, size_t by, int nrc);
2309
 
2310
  typedef struct {
2311
  const char * type_name;
 
2317
  ggml_from_float_t from_float_reference;
2318
  ggml_vec_dot_t vec_dot;
2319
  enum ggml_type vec_dot_type;
2320
+ int64_t nrows; // number of rows to process simultaneously;
2321
  } ggml_type_traits_t;
2322
 
2323
  GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);