Spaces:
Running
Running
| // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]> | |
| // SPDX-License-Identifier: MIT | |
| // | |
| // KleidiAI micro-kernels | |
| static const size_t INT4_PER_BYTE = 2; | |
| static const size_t INT4_BITS = 4; | |
| static const int Q4_0_ZERO_POINT = 8; | |
| const size_t INT4_PER_UINT16 = 4; | |
| static void dequantize_row_qsi4c32pscalef16( | |
| const void *packed_data, | |
| int32_t row_idx, | |
| int64_t nc, | |
| float *out, | |
| size_t nr_pack, | |
| size_t packed_row_stride, | |
| size_t kr, | |
| size_t bl, | |
| size_t num_bytes_multiplier | |
| ) { | |
| size_t group_idx = row_idx / nr_pack; | |
| size_t row_in_group = row_idx % nr_pack; | |
| const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; | |
| size_t num_blocks = nc / bl; | |
| const uint8_t *block_ptr = packed_group; | |
| for (size_t b = 0; b < num_blocks; ++b) { | |
| uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier)); | |
| float scale = GGML_CPU_FP16_TO_FP32(scale_f16); | |
| const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier; | |
| size_t num_segments = bl / kr; | |
| size_t num_bytes_per_segment = kr / INT4_PER_BYTE; | |
| for (size_t s = 0; s < num_segments; ++s) { | |
| const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment; | |
| const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment; | |
| for (size_t k = 0; k < num_bytes_per_segment; ++k) { | |
| uint8_t byte = qbytes[k] ^ 0x88; | |
| int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT; | |
| int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT; | |
| out[b * bl + s * num_bytes_per_segment + k] = x0 * scale; | |
| out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale; | |
| } | |
| } | |
| block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment; | |
| } | |
| } | |
| static void dequantize_row_qsi4c32ps1s0scalef16( | |
| const void *packed_data, | |
| int32_t row_idx, | |
| int64_t k, | |
| float *out, | |
| size_t nr, | |
| size_t packed_row_stride, | |
| size_t kr, | |
| size_t bl, | |
| size_t num_bytes_multiplier | |
| ) { | |
| const size_t num_blocks = k / bl; | |
| const size_t bl4 = bl / INT4_PER_UINT16; | |
| size_t group_idx = row_idx / nr; | |
| size_t row_in_group = row_idx % nr; | |
| const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; | |
| const uint16_t *qdata = (const uint16_t *)packed_group; | |
| const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier)); | |
| for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) { | |
| uint16_t scale_f16 = scales[row_in_group + block_idx * nr]; | |
| float scale = GGML_CPU_FP16_TO_FP32(scale_f16); | |
| for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) { | |
| uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group]; | |
| for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) { | |
| int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT; | |
| out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale; | |
| } | |
| } | |
| } | |
| GGML_UNUSED(kr); | |
| } | |
| static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | |
| { | |
| /* SME GEMM */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | |
| }, | |
| /* SME GEMV */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | |
| }, | |
| /* .lhs_info = */ { | |
| /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, | |
| /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, | |
| /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, | |
| /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, | |
| }, | |
| /* .rhs_info = */ { | |
| /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | |
| /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | |
| /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | |
| /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, | |
| }, | |
| /* .required_cpu = */ CPU_FEATURE_SME, | |
| /* .lhs_type = */ GGML_TYPE_F32, | |
| /* .rhs_type = */ GGML_TYPE_Q4_0, | |
| /* .op_type = */ GGML_TYPE_F32, | |
| }, | |
| { | |
| /* SME GEMM */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| }, | |
| /* SME GEMV */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | |
| }, | |
| /* .lhs_info = */ { | |
| /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, | |
| /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, | |
| /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, | |
| /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, | |
| }, | |
| /* .rhs_info = */ { | |
| /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, | |
| /* .packed_stride = */ NULL, | |
| /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, | |
| /* .to_float = */ NULL, | |
| }, | |
| /* .required_cpu = */ CPU_FEATURE_SME, | |
| /* .lhs_type = */ GGML_TYPE_F32, | |
| /* .rhs_type = */ GGML_TYPE_F16, | |
| /* .op_type = */ GGML_TYPE_F32, | |
| }, | |
| { | |
| /* DOTPROD GEMM */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| }, | |
| /* DOTPROD GEMV */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| }, | |
| /* .lhs_info = */ { | |
| /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | |
| /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | |
| }, | |
| /* .rhs_info = */ { | |
| /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .to_float = */ dequantize_row_qsi4c32pscalef16, | |
| }, | |
| /* .required_cpu = */ CPU_FEATURE_DOTPROD, | |
| /* .lhs_type = */ GGML_TYPE_F32, | |
| /* .rhs_type = */ GGML_TYPE_Q4_0, | |
| /* .op_type = */ GGML_TYPE_F32, | |
| }, | |
| { | |
| /* i8mm GEMM */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| }, | |
| /* i8mm GEMV */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| }, | |
| /* .lhs_info = */ { | |
| /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | |
| /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | |
| }, | |
| /* .rhs_info = */ { | |
| /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .to_float = */ dequantize_row_qsi4c32pscalef16, | |
| }, | |
| /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, | |
| /* .lhs_type = */ GGML_TYPE_F32, | |
| /* .rhs_type = */ GGML_TYPE_Q4_0, | |
| /* .op_type = */ GGML_TYPE_F32, | |
| }, | |
| { | |
| /* i8mm GEMM */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | |
| }, | |
| /* i8mm GEMV */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | |
| }, | |
| /* .lhs_info = */ { | |
| /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | |
| /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | |
| }, | |
| /* .rhs_info = */ { | |
| /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .to_float = */ dequantize_row_qsi4c32pscalef16, | |
| }, | |
| /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, | |
| /* .lhs_type = */ GGML_TYPE_F32, | |
| /* .rhs_type = */ GGML_TYPE_Q4_0, | |
| /* .op_type = */ GGML_TYPE_F32, | |
| }, | |
| { | |
| /* DOTPROD GEMM */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | |
| }, | |
| /* DOTPROD GEMV */ | |
| /* .kern_info = */ { | |
| /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | |
| }, | |
| /* .lhs_info = */ { | |
| /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | |
| /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | |
| /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | |
| }, | |
| /* .rhs_info = */ { | |
| /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | |
| /* .to_float = */ dequantize_row_qsi4c32pscalef16, | |
| }, | |
| /* .required_cpu = */ CPU_FEATURE_DOTPROD, | |
| /* .lhs_type = */ GGML_TYPE_F32, | |
| /* .rhs_type = */ GGML_TYPE_Q4_0, | |
| /* .op_type = */ GGML_TYPE_F32, | |
| }, | |
| }; | |
| ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { | |
| ggml_kleidiai_kernels * kernel = nullptr; | |
| if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { | |
| for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { | |
| if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && | |
| gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && | |
| gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && | |
| gemm_gemv_kernels[i].op_type == tensor->type) { | |
| kernel = &gemm_gemv_kernels[i]; | |
| break; | |
| } | |
| } | |
| } | |
| return kernel; | |
| } | |
| ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { | |
| ggml_kleidiai_kernels * kernels = nullptr; | |
| for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { | |
| if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { | |
| kernels = &gemm_gemv_kernels[i]; | |
| break; | |
| } | |
| } | |
| return kernels; | |
| } | |