diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h index 78d663e59..4119d04f8 100644 --- a/ggml/src/ggml-cpu/simd-gemm.h +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -109,6 +109,96 @@ static void simd_gemm( C += N; } } +#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic) +// RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7 +// Microkernel: C[RM x vl] += A[RM x K] * B[K x N] +template +static inline void rvv_simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, size_t vl) +{ + static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4"); + + vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl); + vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6; + if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl); + if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl); + if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl); + if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl); + if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl); + if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl); + + for (int kk = 0; kk < K; kk++) { + vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl); + + acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl); + if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl); + if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl); + if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl); + if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl); + if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl); + if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl); + } + + __riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl); + if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl); + if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl); + if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl); + if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl); + if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl); + if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl); +} + +template +static inline void rvv_simd_gemm_dispatch_tail( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, int KN, int remaining_rows) +{ + if constexpr (RM > 0) { + if (remaining_rows == RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + } else { + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); + } + } +} + +static constexpr int GEMM_RM = 7; + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + const int KN = (int)__riscv_vlenb(); + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + int remaining_rows = M - ii; + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); +} #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop