ggml : implemented simd_gemm kernel for riscv vector extension (#20627)
Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
This commit is contained in:
committed by
GitHub
parent
90fb96a7b3
commit
5637536517
@@ -109,6 +109,96 @@ static void simd_gemm(
|
||||
C += N;
|
||||
}
|
||||
}
|
||||
#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic)
|
||||
// RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7
|
||||
// Microkernel: C[RM x vl] += A[RM x K] * B[K x N]
|
||||
template <int RM>
|
||||
static inline void rvv_simd_gemm_ukernel(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int K, int N, size_t vl)
|
||||
{
|
||||
static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4");
|
||||
|
||||
vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl);
|
||||
vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6;
|
||||
if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl);
|
||||
if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl);
|
||||
if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl);
|
||||
if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl);
|
||||
if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl);
|
||||
if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl);
|
||||
|
||||
for (int kk = 0; kk < K; kk++) {
|
||||
vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl);
|
||||
|
||||
acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl);
|
||||
if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl);
|
||||
if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl);
|
||||
if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl);
|
||||
if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl);
|
||||
if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl);
|
||||
if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl);
|
||||
}
|
||||
|
||||
__riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl);
|
||||
if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl);
|
||||
if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl);
|
||||
if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl);
|
||||
if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl);
|
||||
if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl);
|
||||
if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl);
|
||||
}
|
||||
|
||||
template <int RM>
|
||||
static inline void rvv_simd_gemm_dispatch_tail(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int K, int N, int KN, int remaining_rows)
|
||||
{
|
||||
if constexpr (RM > 0) {
|
||||
if (remaining_rows == RM) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, KN);
|
||||
}
|
||||
if (jj < N) {
|
||||
rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, N - jj);
|
||||
}
|
||||
} else {
|
||||
rvv_simd_gemm_dispatch_tail<RM - 1>(C, A, B, K, N, KN, remaining_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int GEMM_RM = 7;
|
||||
|
||||
// C[M x N] += A[M x K] * B[K x N]
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
const int KN = (int)__riscv_vlenb();
|
||||
int64_t ii = 0;
|
||||
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, KN);
|
||||
}
|
||||
if (jj < N) {
|
||||
rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, N - jj);
|
||||
}
|
||||
A += GEMM_RM * K;
|
||||
C += GEMM_RM * N;
|
||||
}
|
||||
|
||||
int remaining_rows = M - ii;
|
||||
rvv_simd_gemm_dispatch_tail<GEMM_RM - 1>(C, A, B, K, N, KN, remaining_rows);
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
Reference in New Issue
Block a user