diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 80ff5ce54..a75344430 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; + + static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + svuint32_t idx = svld1(svptrue_b32(), idx_arr); + static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0}; + svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1); + static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3}; + svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2); + + for (int y = 0; y < nr; y += 4) { + const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb; + + for (int x = 0; x < nc; x += ncols_interleaved) { + const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb; + const block_q8_0x4 * a_ptr = a_ptr_base; + + svfloat32_t acc_f32_01 = svdup_f32(0); + svfloat32_t acc_f32_23 = svdup_f32(0); + + for (int b = 0; b < nb; b++) { + + svint32_t acc_01 = svdup_s32(0); + svint32_t acc_23 = svdup_s32(0); + + // Process 4 chunks of 8 positions each + for (int chunk = 0; chunk < 4; chunk++) { + svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32); + svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16); + svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32); + + acc_01 = svmmla_s32(acc_01, s_a01, s_b0123); + acc_23 = svmmla_s32(acc_23, s_a23, s_b0123); + } + + // Reorder outputs from 2×2 tiles to row-major + // acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3] + // acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3] + + svint32_t row01 = svtbl_s32(acc_01, idx); + svint32_t row23 = svtbl_s32(acc_23, idx); + + svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d); + svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d); + svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1); + svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2); + + acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0)); + acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2)); + a_ptr++; + b_ptr++; + } + + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01); + svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4)); + svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23); + svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4)); + } + } + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;