Spaces:
Running
Running
fj-y-saito
commited on
Commit
·
607a196
1
Parent(s):
55088d3
ggml : add SVE support for q6_K_q8_K (llama/12361)
Browse files
ggml/src/ggml-cpu/ggml-cpu-quants.c
CHANGED
|
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 8158 |
|
| 8159 |
const int nb = n / QK_K;
|
| 8160 |
|
| 8161 |
-
#ifdef
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8162 |
float sum = 0;
|
| 8163 |
|
| 8164 |
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
|
|
|
| 8158 |
|
| 8159 |
const int nb = n / QK_K;
|
| 8160 |
|
| 8161 |
+
#ifdef __ARM_FEATURE_SVE
|
| 8162 |
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
| 8163 |
+
float sum = 0;
|
| 8164 |
+
svuint8_t m4b = svdup_n_u8(0xf);
|
| 8165 |
+
svint32_t vzero = svdup_n_s32(0);
|
| 8166 |
+
svuint8_t mone = svdup_n_u8(0x30);
|
| 8167 |
+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
| 8168 |
+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
| 8169 |
+
|
| 8170 |
+
for (int i = 0; i < nb; ++i) {
|
| 8171 |
+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
| 8172 |
+
|
| 8173 |
+
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
|
| 8174 |
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
| 8175 |
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
| 8176 |
+
|
| 8177 |
+
const int8_t * GGML_RESTRICT scale = x[i].scales;
|
| 8178 |
+
|
| 8179 |
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
| 8180 |
+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
| 8181 |
+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
| 8182 |
+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
| 8183 |
+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
| 8184 |
+
const svint64_t prod = svdup_n_s64(0);
|
| 8185 |
+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
| 8186 |
+
svdot_s64(prod, q8sums_2, q6scales_2)));
|
| 8187 |
+
int32_t isum = 0;
|
| 8188 |
+
|
| 8189 |
+
switch (vector_length) {
|
| 8190 |
+
case 128:
|
| 8191 |
+
{
|
| 8192 |
+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
| 8193 |
+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
| 8194 |
+
svint32_t isum_tmp = svdup_n_s32(0);
|
| 8195 |
+
for (int j = 0; j < QK_K/128; ++j) {
|
| 8196 |
+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
| 8197 |
+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
| 8198 |
+
qh += 32;
|
| 8199 |
+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
| 8200 |
+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
| 8201 |
+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
| 8202 |
+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
| 8203 |
+
q6 += 64;
|
| 8204 |
+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
| 8205 |
+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
| 8206 |
+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
| 8207 |
+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
| 8208 |
+
q8 += 64;
|
| 8209 |
+
|
| 8210 |
+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
| 8211 |
+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
| 8212 |
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
| 8213 |
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
| 8214 |
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
| 8215 |
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
| 8216 |
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
| 8217 |
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
| 8218 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
| 8219 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
| 8220 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
| 8221 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
| 8222 |
+
|
| 8223 |
+
scale += 4;
|
| 8224 |
+
q8bytes_1 = svld1_s8(pg8_16, q8);
|
| 8225 |
+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
| 8226 |
+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
| 8227 |
+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
| 8228 |
+
q8 += 64;
|
| 8229 |
+
|
| 8230 |
+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
| 8231 |
+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
| 8232 |
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
| 8233 |
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
| 8234 |
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
| 8235 |
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
| 8236 |
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
| 8237 |
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
| 8238 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
| 8239 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
| 8240 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
| 8241 |
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
| 8242 |
+
scale += 4;
|
| 8243 |
+
}
|
| 8244 |
+
isum += svaddv_s32(pg32_4, isum_tmp);
|
| 8245 |
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
| 8246 |
+
}
|
| 8247 |
+
break;
|
| 8248 |
+
case 256:
|
| 8249 |
+
case 512:
|
| 8250 |
+
{
|
| 8251 |
+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
| 8252 |
+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
| 8253 |
+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
| 8254 |
+
svint32_t isum_tmp = svdup_n_s32(0);
|
| 8255 |
+
for (int j = 0; j < QK_K/128; j++) {
|
| 8256 |
+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
| 8257 |
+
qh += 32;
|
| 8258 |
+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
| 8259 |
+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
| 8260 |
+
q6 += 64;
|
| 8261 |
+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
| 8262 |
+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
| 8263 |
+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
| 8264 |
+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
| 8265 |
+
q8 += 128;
|
| 8266 |
+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
| 8267 |
+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
| 8268 |
+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
| 8269 |
+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
| 8270 |
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
| 8271 |
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
| 8272 |
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
| 8273 |
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
| 8274 |
+
|
| 8275 |
+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
| 8276 |
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
| 8277 |
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
| 8278 |
+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
| 8279 |
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
| 8280 |
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
| 8281 |
+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
| 8282 |
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
| 8283 |
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
| 8284 |
+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
| 8285 |
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
| 8286 |
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
| 8287 |
+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
| 8288 |
+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
| 8289 |
+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
| 8290 |
+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
| 8291 |
+
|
| 8292 |
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
| 8293 |
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
| 8294 |
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
| 8295 |
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
| 8296 |
+
scale += 8;
|
| 8297 |
+
}
|
| 8298 |
+
isum += svaddv_s32(pg32_8, isum_tmp);
|
| 8299 |
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
| 8300 |
+
}
|
| 8301 |
+
break;
|
| 8302 |
+
default:
|
| 8303 |
+
assert(false && "Unsupported vector length");
|
| 8304 |
+
break;
|
| 8305 |
+
}
|
| 8306 |
+
}
|
| 8307 |
+
|
| 8308 |
+
*s = sum;
|
| 8309 |
+
|
| 8310 |
+
#elif __ARM_NEON
|
| 8311 |
float sum = 0;
|
| 8312 |
|
| 8313 |
const uint8x16_t m4b = vdupq_n_u8(0xF);
|