Spaces:
Running
Running
Commit ·
e2965b0
1
Parent(s): 2d993ad
ggml : Q2k interleaving implementation - x86/x64 SIMD (llama/14373)
Browse files* Initial Q2_K Block Interleaving Implementation
* Addressed review comments and clean up of the code
* Post rebase fixes
* Initial CI/CD fixes
* Update declarations in arch-fallback.h
* Changes for GEMV Q2_K in arch-fallback.h
* Enable repacking only on AVX-512 machines
* Update comments in repack.cpp
* Address q2k comments
---------
Co-authored-by: Manogna-Sree <elisetti.manognasree@multicorewareinc.com>
- ggml/src/ggml-cpu/arch-fallback.h +14 -0
- ggml/src/ggml-cpu/arch/x86/repack.cpp +0 -0
- ggml/src/ggml-cpu/repack.cpp +263 -0
- ggml/src/ggml-cpu/repack.h +11 -0
ggml/src/ggml-cpu/arch-fallback.h
CHANGED
|
@@ -37,17 +37,21 @@
|
|
| 37 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 38 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 39 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
|
|
| 40 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 41 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 42 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 43 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 44 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
|
|
| 45 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 46 |
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
| 47 |
// repack.cpp
|
| 48 |
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
| 49 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
|
|
| 50 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
|
|
| 51 |
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
| 52 |
// repack.cpp
|
| 53 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
|
@@ -72,11 +76,13 @@
|
|
| 72 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 73 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 74 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
|
|
| 75 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 76 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 77 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 78 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 79 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
|
|
| 80 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 81 |
#elif defined(__loongarch64)
|
| 82 |
// quants.c
|
|
@@ -92,11 +98,13 @@
|
|
| 92 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 93 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 94 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
|
|
| 95 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 96 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 97 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 98 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 99 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
|
|
| 100 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 101 |
#elif defined(__riscv)
|
| 102 |
// quants.c
|
|
@@ -119,10 +127,12 @@
|
|
| 119 |
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
| 120 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 121 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
|
|
| 122 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 123 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 124 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 125 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
|
|
| 126 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 127 |
#elif defined(__s390x__)
|
| 128 |
// quants.c
|
|
@@ -147,11 +157,13 @@
|
|
| 147 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 148 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 149 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
|
|
| 150 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 151 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 152 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 153 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 154 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
|
|
| 155 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 156 |
#elif defined(__wasm__)
|
| 157 |
// quants.c
|
|
@@ -175,10 +187,12 @@
|
|
| 175 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 176 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 177 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
|
|
| 178 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 179 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 180 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 181 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 182 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
|
|
| 183 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 184 |
#endif
|
|
|
|
| 37 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 38 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 39 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
| 40 |
+
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
| 41 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 42 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 43 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 44 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 45 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
| 46 |
+
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
| 47 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 48 |
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
| 49 |
// repack.cpp
|
| 50 |
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
| 51 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
| 52 |
+
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
| 53 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
| 54 |
+
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
| 55 |
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
| 56 |
// repack.cpp
|
| 57 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
|
|
|
| 76 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 77 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 78 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
| 79 |
+
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
| 80 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 81 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 82 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 83 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 84 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
| 85 |
+
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
| 86 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 87 |
#elif defined(__loongarch64)
|
| 88 |
// quants.c
|
|
|
|
| 98 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 99 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 100 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
| 101 |
+
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
| 102 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 103 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 104 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 105 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 106 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
| 107 |
+
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
| 108 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 109 |
#elif defined(__riscv)
|
| 110 |
// quants.c
|
|
|
|
| 127 |
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
| 128 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 129 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
| 130 |
+
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
| 131 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 132 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 133 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 134 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
| 135 |
+
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
| 136 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 137 |
#elif defined(__s390x__)
|
| 138 |
// quants.c
|
|
|
|
| 157 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 158 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 159 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
| 160 |
+
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
| 161 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 162 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 163 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 164 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 165 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
| 166 |
+
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
| 167 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 168 |
#elif defined(__wasm__)
|
| 169 |
// quants.c
|
|
|
|
| 187 |
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
| 188 |
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
| 189 |
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
| 190 |
+
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
| 191 |
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
| 192 |
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
| 193 |
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
| 194 |
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
| 195 |
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
| 196 |
+
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
| 197 |
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
| 198 |
#endif
|
ggml/src/ggml-cpu/arch/x86/repack.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml/src/ggml-cpu/repack.cpp
CHANGED
|
@@ -412,6 +412,82 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|
| 412 |
}
|
| 413 |
}
|
| 414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
| 416 |
const int qk = QK8_0;
|
| 417 |
const int nb = n / qk;
|
|
@@ -711,6 +787,97 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|
| 711 |
}
|
| 712 |
}
|
| 713 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
| 715 |
const int qk = QK8_0;
|
| 716 |
const int nb = n / qk;
|
|
@@ -914,6 +1081,50 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
|
|
| 914 |
return out;
|
| 915 |
}
|
| 916 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 917 |
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
| 918 |
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
| 919 |
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
@@ -975,6 +1186,37 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block
|
|
| 975 |
GGML_UNUSED(data_size);
|
| 976 |
}
|
| 977 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 978 |
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
| 979 |
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
| 980 |
GGML_ASSERT(interleave_block == 8);
|
|
@@ -1095,6 +1337,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
|
|
| 1095 |
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
| 1096 |
}
|
| 1097 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1098 |
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
| 1099 |
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
| 1100 |
}
|
|
@@ -1124,6 +1370,10 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|
| 1124 |
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
| 1125 |
}
|
| 1126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1127 |
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
| 1128 |
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
| 1129 |
}
|
|
@@ -1148,6 +1398,10 @@ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|
| 1148 |
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
| 1149 |
}
|
| 1150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
| 1152 |
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
| 1153 |
}
|
|
@@ -1421,6 +1675,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|
| 1421 |
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
| 1422 |
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
| 1423 |
|
|
|
|
|
|
|
|
|
|
| 1424 |
// instance for IQ4
|
| 1425 |
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
| 1426 |
|
|
@@ -1446,6 +1703,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|
| 1446 |
return &q4_K_8x8_q8_K;
|
| 1447 |
}
|
| 1448 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1449 |
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
| 1450 |
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
| 1451 |
if (cur->ne[1] % 4 == 0) {
|
|
|
|
| 412 |
}
|
| 413 |
}
|
| 414 |
|
| 415 |
+
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
| 416 |
+
const int qk = QK_K;
|
| 417 |
+
const int nb = n / qk;
|
| 418 |
+
const int ncols_interleaved = 8;
|
| 419 |
+
const int blocklen = 8;
|
| 420 |
+
|
| 421 |
+
assert (n % qk == 0);
|
| 422 |
+
assert (nc % ncols_interleaved == 0);
|
| 423 |
+
|
| 424 |
+
UNUSED(s);
|
| 425 |
+
UNUSED(bs);
|
| 426 |
+
UNUSED(vx);
|
| 427 |
+
UNUSED(vy);
|
| 428 |
+
UNUSED(nr);
|
| 429 |
+
UNUSED(nc);
|
| 430 |
+
UNUSED(nb);
|
| 431 |
+
UNUSED(ncols_interleaved);
|
| 432 |
+
UNUSED(blocklen);
|
| 433 |
+
|
| 434 |
+
float sumf[8];
|
| 435 |
+
float sum_minf[8];
|
| 436 |
+
int sumi1,sumi2,sumi3,sumi4;
|
| 437 |
+
int sumi;
|
| 438 |
+
|
| 439 |
+
const block_q8_K * a_ptr = (const block_q8_K *)vy;
|
| 440 |
+
for(int x = 0; x < nc / ncols_interleaved; x++) {
|
| 441 |
+
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
|
| 442 |
+
for (int j = 0; j < ncols_interleaved; j++) {
|
| 443 |
+
sumf[j] = 0.0;
|
| 444 |
+
sum_minf[j] = 0.0;
|
| 445 |
+
}
|
| 446 |
+
for (int l = 0; l < nb; l++) {
|
| 447 |
+
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
|
| 448 |
+
const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
|
| 449 |
+
const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
|
| 450 |
+
const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
|
| 451 |
+
const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
|
| 452 |
+
for (int j = 0; j < ncols_interleaved; j++) {
|
| 453 |
+
sumi1 = 0;
|
| 454 |
+
sumi2 = 0;
|
| 455 |
+
sumi3 = 0;
|
| 456 |
+
sumi4 = 0;
|
| 457 |
+
sumi = 0;
|
| 458 |
+
int offset = ((k / 2) % 2) + j * 2;
|
| 459 |
+
for (int i = 0; i < blocklen; ++i){
|
| 460 |
+
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
|
| 461 |
+
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
|
| 462 |
+
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
|
| 463 |
+
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
|
| 464 |
+
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
|
| 465 |
+
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
|
| 466 |
+
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
|
| 467 |
+
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
|
| 468 |
+
|
| 469 |
+
sumi1 = sumi1 * (scales_0[offset] & 0xF);
|
| 470 |
+
sumi2 = sumi2 * (scales_1[offset] & 0xF);
|
| 471 |
+
sumi3 = sumi3 * (scales_2[offset] & 0xF);
|
| 472 |
+
sumi4 = sumi4 * (scales_3[offset] & 0xF);
|
| 473 |
+
sumi += sumi1 + sumi2 + sumi3 + sumi4;
|
| 474 |
+
}
|
| 475 |
+
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
| 476 |
+
}
|
| 477 |
+
}
|
| 478 |
+
for(int sb = 0; sb < 8; sb++) {
|
| 479 |
+
const uint8_t *mins = b_ptr[l].scales + sb * 16;
|
| 480 |
+
for(int j = 0; j < ncols_interleaved; j++){
|
| 481 |
+
sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
for (int j = 0; j < ncols_interleaved; j++) {
|
| 486 |
+
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
| 492 |
const int qk = QK8_0;
|
| 493 |
const int nb = n / qk;
|
|
|
|
| 787 |
}
|
| 788 |
}
|
| 789 |
|
| 790 |
+
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
| 791 |
+
const int qk = QK_K;
|
| 792 |
+
const int nb = n / qk;
|
| 793 |
+
const int ncols_interleaved = 8;
|
| 794 |
+
const int blocklen = 8;
|
| 795 |
+
|
| 796 |
+
assert (n % qk == 0);
|
| 797 |
+
assert (nr % 4 == 0);
|
| 798 |
+
assert (nc % ncols_interleaved == 0);
|
| 799 |
+
|
| 800 |
+
UNUSED(s);
|
| 801 |
+
UNUSED(bs);
|
| 802 |
+
UNUSED(vx);
|
| 803 |
+
UNUSED(vy);
|
| 804 |
+
UNUSED(nr);
|
| 805 |
+
UNUSED(nc);
|
| 806 |
+
UNUSED(nb);
|
| 807 |
+
UNUSED(ncols_interleaved);
|
| 808 |
+
UNUSED(blocklen);
|
| 809 |
+
|
| 810 |
+
float sumf[4][8];
|
| 811 |
+
float sum_minf[4][8];
|
| 812 |
+
int sumi1, sumi2, sumi3, sumi4;
|
| 813 |
+
int sumi;
|
| 814 |
+
|
| 815 |
+
for (int y = 0; y < nr / 4; y++) {
|
| 816 |
+
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
| 817 |
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
| 818 |
+
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
|
| 819 |
+
for (int m = 0; m < 4; m++) {
|
| 820 |
+
for (int j = 0; j < ncols_interleaved; j++) {
|
| 821 |
+
sumf[m][j] = 0.0;
|
| 822 |
+
sum_minf[m][j] = 0.0;
|
| 823 |
+
}
|
| 824 |
+
}
|
| 825 |
+
for (int l = 0; l < nb; l++) {
|
| 826 |
+
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
|
| 827 |
+
|
| 828 |
+
const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
|
| 829 |
+
const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
|
| 830 |
+
const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
|
| 831 |
+
const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
|
| 832 |
+
for (int m = 0; m < 4; m++) {
|
| 833 |
+
for (int j = 0; j < ncols_interleaved; j++) {
|
| 834 |
+
sumi1 = 0;
|
| 835 |
+
sumi2 = 0;
|
| 836 |
+
sumi3 = 0;
|
| 837 |
+
sumi4 = 0;
|
| 838 |
+
sumi = 0;
|
| 839 |
+
int offset = ((k / 2) % 2) + j * 2;
|
| 840 |
+
for (int i = 0; i < blocklen; ++i){
|
| 841 |
+
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
|
| 842 |
+
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
|
| 843 |
+
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
|
| 844 |
+
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
|
| 845 |
+
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
|
| 846 |
+
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
|
| 847 |
+
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
|
| 848 |
+
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
|
| 849 |
+
sumi1 = sumi1 * (scales_0[offset] & 0xF);
|
| 850 |
+
sumi2 = sumi2 * (scales_1[offset] & 0xF);
|
| 851 |
+
sumi3 = sumi3 * (scales_2[offset] & 0xF);
|
| 852 |
+
sumi4 = sumi4 * (scales_3[offset] & 0xF);
|
| 853 |
+
sumi += sumi1 + sumi2 + sumi3 + sumi4;
|
| 854 |
+
}
|
| 855 |
+
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
| 856 |
+
}
|
| 857 |
+
}
|
| 858 |
+
}
|
| 859 |
+
for(int sb = 0; sb < 8; sb++) {
|
| 860 |
+
const uint8_t *mins = b_ptr[l].scales + sb * 16;
|
| 861 |
+
for(int m = 0; m < 4; m++) {
|
| 862 |
+
const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
| 863 |
+
for(int j = 0; j < ncols_interleaved; j++) {
|
| 864 |
+
int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
|
| 865 |
+
sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
| 866 |
+
}
|
| 867 |
+
}
|
| 868 |
+
}
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
for (int m = 0; m < 4; m++) {
|
| 872 |
+
for (int j = 0; j < ncols_interleaved; j++) {
|
| 873 |
+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
| 874 |
+
}
|
| 875 |
+
}
|
| 876 |
+
}
|
| 877 |
+
}
|
| 878 |
+
}
|
| 879 |
+
|
| 880 |
+
|
| 881 |
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
| 882 |
const int qk = QK8_0;
|
| 883 |
const int nb = n / qk;
|
|
|
|
| 1081 |
return out;
|
| 1082 |
}
|
| 1083 |
|
| 1084 |
+
static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {
|
| 1085 |
+
block_q2_Kx8 out;
|
| 1086 |
+
|
| 1087 |
+
// Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure
|
| 1088 |
+
for (int i = 0; i < 8; i++) {
|
| 1089 |
+
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
| 1090 |
+
}
|
| 1091 |
+
|
| 1092 |
+
for (int i = 0; i < 8; i++) {
|
| 1093 |
+
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
const int end = QK_K * 2 / blck_size_interleave;
|
| 1097 |
+
|
| 1098 |
+
// Interleave Q2_K quants by taking 8 bytes at a time
|
| 1099 |
+
for (int i = 0; i < end; ++i) {
|
| 1100 |
+
int src_id = i % 8;
|
| 1101 |
+
int src_offset = (i / 8) * blck_size_interleave;
|
| 1102 |
+
int dst_offset = i * blck_size_interleave;
|
| 1103 |
+
|
| 1104 |
+
uint64_t elems;
|
| 1105 |
+
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
| 1106 |
+
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
| 1107 |
+
}
|
| 1108 |
+
|
| 1109 |
+
// The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K
|
| 1110 |
+
// Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value)
|
| 1111 |
+
// The output Q2_Kx8 structure has 128 bytes for storing scales and mins
|
| 1112 |
+
// Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
|
| 1113 |
+
// For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
|
| 1114 |
+
|
| 1115 |
+
for(int i = 0; i < 128; i++){
|
| 1116 |
+
|
| 1117 |
+
// Index for selecting which q2k super block
|
| 1118 |
+
int src1 = (i % 16) / 2;
|
| 1119 |
+
// Index for selecting scale
|
| 1120 |
+
int src2 = ((i / 16) * 2) + (i % 2);
|
| 1121 |
+
|
| 1122 |
+
out.scales[i] = in[src1].scales[src2];
|
| 1123 |
+
}
|
| 1124 |
+
return out;
|
| 1125 |
+
|
| 1126 |
+
}
|
| 1127 |
+
|
| 1128 |
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
| 1129 |
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
| 1130 |
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
|
|
| 1186 |
GGML_UNUSED(data_size);
|
| 1187 |
}
|
| 1188 |
|
| 1189 |
+
static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
| 1190 |
+
GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
|
| 1191 |
+
GGML_ASSERT(interleave_block == 8);
|
| 1192 |
+
constexpr int nrows_interleaved = 8;
|
| 1193 |
+
|
| 1194 |
+
block_q2_Kx8 * dst = (block_q2_Kx8*)t->data;
|
| 1195 |
+
const block_q2_K * src = (const block_q2_K*) data;
|
| 1196 |
+
block_q2_K dst_tmp[8];
|
| 1197 |
+
int nrow = ggml_nrows(t);
|
| 1198 |
+
int nblocks = t->ne[0] / QK_K;
|
| 1199 |
+
|
| 1200 |
+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));
|
| 1201 |
+
|
| 1202 |
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
| 1203 |
+
return -1;
|
| 1204 |
+
}
|
| 1205 |
+
|
| 1206 |
+
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
| 1207 |
+
for (int64_t x = 0; x < nblocks; x++) {
|
| 1208 |
+
for (int i = 0; i < nrows_interleaved; i++ ) {
|
| 1209 |
+
dst_tmp[i] = src[x + i * nblocks];
|
| 1210 |
+
}
|
| 1211 |
+
*dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
|
| 1212 |
+
}
|
| 1213 |
+
src += nrows_interleaved * nblocks;
|
| 1214 |
+
}
|
| 1215 |
+
return 0;
|
| 1216 |
+
|
| 1217 |
+
GGML_UNUSED(data_size);
|
| 1218 |
+
}
|
| 1219 |
+
|
| 1220 |
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
| 1221 |
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
| 1222 |
GGML_ASSERT(interleave_block == 8);
|
|
|
|
| 1337 |
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
| 1338 |
}
|
| 1339 |
|
| 1340 |
+
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
| 1341 |
+
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
| 1342 |
+
}
|
| 1343 |
+
|
| 1344 |
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
| 1345 |
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
| 1346 |
}
|
|
|
|
| 1370 |
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
| 1371 |
}
|
| 1372 |
|
| 1373 |
+
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
| 1374 |
+
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
| 1375 |
+
}
|
| 1376 |
+
|
| 1377 |
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
| 1378 |
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
| 1379 |
}
|
|
|
|
| 1398 |
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
| 1399 |
}
|
| 1400 |
|
| 1401 |
+
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
| 1402 |
+
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
| 1403 |
+
}
|
| 1404 |
+
|
| 1405 |
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
| 1406 |
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
| 1407 |
}
|
|
|
|
| 1675 |
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
| 1676 |
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
| 1677 |
|
| 1678 |
+
// instance for Q2
|
| 1679 |
+
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
|
| 1680 |
+
|
| 1681 |
// instance for IQ4
|
| 1682 |
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
| 1683 |
|
|
|
|
| 1703 |
return &q4_K_8x8_q8_K;
|
| 1704 |
}
|
| 1705 |
}
|
| 1706 |
+
} else if (cur->type == GGML_TYPE_Q2_K) {
|
| 1707 |
+
if (ggml_cpu_has_avx512()) {
|
| 1708 |
+
if (cur->ne[1] % 8 == 0) {
|
| 1709 |
+
return &q2_K_8x8_q8_K;
|
| 1710 |
+
}
|
| 1711 |
+
}
|
| 1712 |
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
| 1713 |
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
| 1714 |
if (cur->ne[1] % 4 == 0) {
|
ggml/src/ggml-cpu/repack.h
CHANGED
|
@@ -44,7 +44,14 @@ struct block_q4_Kx8 {
|
|
| 44 |
};
|
| 45 |
|
| 46 |
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
|
|
|
| 48 |
struct block_q8_Kx4 {
|
| 49 |
float d[4]; // delta
|
| 50 |
int8_t qs[QK_K * 4]; // quants
|
|
@@ -71,11 +78,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
| 71 |
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 72 |
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 73 |
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
|
|
| 74 |
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 75 |
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 76 |
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 77 |
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 78 |
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
|
|
| 79 |
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 80 |
|
| 81 |
// Native implementations
|
|
@@ -86,11 +95,13 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|
| 86 |
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 87 |
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 88 |
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
|
|
| 89 |
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 90 |
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 91 |
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 92 |
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 93 |
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
|
|
| 94 |
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 95 |
|
| 96 |
#if defined(__cplusplus)
|
|
|
|
| 44 |
};
|
| 45 |
|
| 46 |
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
| 47 |
+
struct block_q2_Kx8 {
|
| 48 |
+
ggml_half d[8]; // super-block scale for quantized scales
|
| 49 |
+
ggml_half dmin[8]; // super-block scale for quantized mins
|
| 50 |
+
uint8_t scales[128]; // scales and mins, quantized with 4 bits
|
| 51 |
+
uint8_t qs[512]; // 2--bit quants
|
| 52 |
+
};
|
| 53 |
|
| 54 |
+
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
| 55 |
struct block_q8_Kx4 {
|
| 56 |
float d[4]; // delta
|
| 57 |
int8_t qs[QK_K * 4]; // quants
|
|
|
|
| 78 |
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 79 |
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 80 |
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 81 |
+
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 82 |
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 83 |
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 84 |
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 85 |
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 86 |
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 87 |
+
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 88 |
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 89 |
|
| 90 |
// Native implementations
|
|
|
|
| 95 |
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 96 |
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 97 |
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 98 |
+
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 99 |
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 100 |
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 101 |
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 102 |
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 103 |
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 104 |
+
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 105 |
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 106 |
|
| 107 |
#if defined(__cplusplus)
|