ggml : add sve tuned code for gemm_q8_0_4x8_q8_0() kernel (#21916)

* Added sve tuned code for gemm_q8_0_4x8_q8_0() kernel

* Change arrays to static const in repack.cpp

---------

Co-authored-by: Vithulep <prashant.vithule@fujitsu.com>
This commit is contained in:
hrushitfujitsu
2026-04-29 13:27:37 +05:30
committed by GitHub
parent 739393beeb
commit bdc9c743a5
+65
View File
@@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntb() * 8 == 256) {
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
svuint32_t idx = svld1(svptrue_b32(), idx_arr);
static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0};
svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1);
static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3};
svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2);
for (int y = 0; y < nr; y += 4) {
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
for (int x = 0; x < nc; x += ncols_interleaved) {
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
const block_q8_0x4 * a_ptr = a_ptr_base;
svfloat32_t acc_f32_01 = svdup_f32(0);
svfloat32_t acc_f32_23 = svdup_f32(0);
for (int b = 0; b < nb; b++) {
svint32_t acc_01 = svdup_s32(0);
svint32_t acc_23 = svdup_s32(0);
// Process 4 chunks of 8 positions each
for (int chunk = 0; chunk < 4; chunk++) {
svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32);
svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16);
svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32);
acc_01 = svmmla_s32(acc_01, s_a01, s_b0123);
acc_23 = svmmla_s32(acc_23, s_a23, s_b0123);
}
// Reorder outputs from 2×2 tiles to row-major
// acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3]
// acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3]
svint32_t row01 = svtbl_s32(acc_01, idx);
svint32_t row23 = svtbl_s32(acc_23, idx);
svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d);
svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d);
svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1);
svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2);
acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0));
acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2));
a_ptr++;
b_ptr++;
}
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01);
svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4));
svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23);
svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4));
}
}
return;
}
#endif // SVE compile-time end
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;