ggml : avoid UB in gemm ukernel (#19642)
This commit is contained in:
@@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
// Computes C[M x N] += A[M x K] * B[K x N]
|
// Computes C[M x N] += A[M x K] * B[K x N]
|
||||||
|
|
||||||
#include "ggml-cpu-impl.h"
|
|
||||||
#include "vec.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "simd-mappings.h"
|
#include "simd-mappings.h"
|
||||||
|
|
||||||
// TODO: add support for sizeless vector types
|
// TODO: add support for sizeless vector types
|
||||||
@@ -23,44 +20,38 @@
|
|||||||
static constexpr int GEMM_RN = 2;
|
static constexpr int GEMM_RN = 2;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__GNUC__) && !defined(__clang__)
|
|
||||||
#pragma GCC diagnostic push
|
|
||||||
#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <int RM, int RN>
|
template <int RM, int RN>
|
||||||
static inline void simd_gemm_ukernel(
|
static inline void simd_gemm_ukernel(
|
||||||
float * GGML_RESTRICT C,
|
float * GGML_RESTRICT C,
|
||||||
const float * GGML_RESTRICT A,
|
const float * GGML_RESTRICT A,
|
||||||
const float * GGML_RESTRICT B,
|
const float * GGML_RESTRICT B,
|
||||||
int64_t K, int64_t N,
|
int K, int N)
|
||||||
int64_t ii, int64_t jj)
|
|
||||||
{
|
{
|
||||||
static constexpr int KN = GGML_F32_EPR;
|
static constexpr int KN = GGML_F32_EPR;
|
||||||
|
|
||||||
GGML_F32_VEC acc[RM][RN];
|
GGML_F32_VEC acc[RM][RN];
|
||||||
for (int i = 0; i < RM; i++) {
|
for (int64_t i = 0; i < RM; i++) {
|
||||||
for (int r = 0; r < RN; r++) {
|
for (int r = 0; r < RN; r++) {
|
||||||
acc[i][r] = GGML_F32_VEC_LOAD(C + (ii + i) * N + jj + r * KN);
|
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t kk = 0; kk < K; kk++) {
|
for (int64_t kk = 0; kk < K; kk++) {
|
||||||
GGML_F32_VEC Bv[RN];
|
GGML_F32_VEC Bv[RN];
|
||||||
for (int r = 0; r < RN; r++) {
|
for (int r = 0; r < RN; r++) {
|
||||||
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN);
|
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
|
||||||
}
|
}
|
||||||
for (int i = 0; i < RM; i++) {
|
for (int64_t i = 0; i < RM; i++) {
|
||||||
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]);
|
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
|
||||||
for (int r = 0; r < RN; r++) {
|
for (int r = 0; r < RN; r++) {
|
||||||
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
|
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < RM; i++) {
|
for (int64_t i = 0; i < RM; i++) {
|
||||||
for (int r = 0; r < RN; r++) {
|
for (int r = 0; r < RN; r++) {
|
||||||
GGML_F32_VEC_STORE(C + (ii + i) * N + jj + r * KN, acc[i][r]);
|
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -70,7 +61,7 @@ static void simd_gemm(
|
|||||||
float * GGML_RESTRICT C,
|
float * GGML_RESTRICT C,
|
||||||
const float * GGML_RESTRICT A,
|
const float * GGML_RESTRICT A,
|
||||||
const float * GGML_RESTRICT B,
|
const float * GGML_RESTRICT B,
|
||||||
int64_t M, int64_t K, int64_t N)
|
int M, int K, int N)
|
||||||
{
|
{
|
||||||
static constexpr int KN = GGML_F32_EPR;
|
static constexpr int KN = GGML_F32_EPR;
|
||||||
|
|
||||||
@@ -78,38 +69,44 @@ static void simd_gemm(
|
|||||||
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
||||||
int64_t jj = 0;
|
int64_t jj = 0;
|
||||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||||
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C, A, B, K, N, ii, jj);
|
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||||
}
|
}
|
||||||
for (; jj + KN <= N; jj += KN) {
|
for (; jj + KN <= N; jj += KN) {
|
||||||
simd_gemm_ukernel<GEMM_RM, 1>(C, A, B, K, N, ii, jj);
|
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
|
||||||
}
|
}
|
||||||
for (; jj < N; jj++) {
|
for (; jj < N; jj++) {
|
||||||
for (int i = 0; i < GEMM_RM; i++) {
|
for (int64_t i = 0; i < GEMM_RM; i++) {
|
||||||
float a = C[(ii + i) * N + jj];
|
float a = C[i * N + jj];
|
||||||
for (int64_t kk = 0; kk < K; kk++) {
|
for (int64_t kk = 0; kk < K; kk++) {
|
||||||
a += A[(ii + i) * K + kk] * B[kk * N + jj];
|
a += A[i + kk] * B[kk * N + jj];
|
||||||
}
|
}
|
||||||
C[(ii + i) * N + jj] = a;
|
C[i * N + jj] = a;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
A += GEMM_RM * K;
|
||||||
|
C += GEMM_RM * N;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tail rows: one at a time
|
// Tail rows: one at a time
|
||||||
for (; ii < M; ii++) {
|
for (; ii < M; ii++) {
|
||||||
int64_t jj = 0;
|
int64_t jj = 0;
|
||||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||||
simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj);
|
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||||
}
|
}
|
||||||
for (; jj + KN <= N; jj += KN) {
|
for (; jj + KN <= N; jj += KN) {
|
||||||
simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj);
|
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
|
||||||
}
|
}
|
||||||
for (; jj < N; jj++) {
|
for (; jj < N; jj++) {
|
||||||
float a = C[ii * N + jj];
|
float a = C[jj];
|
||||||
for (int64_t kk = 0; kk < K; kk++) {
|
for (int64_t kk = 0; kk < K; kk++) {
|
||||||
a += A[ii * K + kk] * B[kk * N + jj];
|
a += A[kk] * B[kk * N + jj];
|
||||||
}
|
}
|
||||||
C[ii * N + jj] = a;
|
C[jj] = a;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
A += K;
|
||||||
|
C += N;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,7 +120,7 @@ static void simd_gemm(
|
|||||||
float * GGML_RESTRICT C,
|
float * GGML_RESTRICT C,
|
||||||
const float * GGML_RESTRICT A,
|
const float * GGML_RESTRICT A,
|
||||||
const float * GGML_RESTRICT B,
|
const float * GGML_RESTRICT B,
|
||||||
int64_t M, int64_t K, int64_t N)
|
int M, int K, int N)
|
||||||
{
|
{
|
||||||
for (int64_t i = 0; i < M; i++) {
|
for (int64_t i = 0; i < M; i++) {
|
||||||
for (int64_t j = 0; j < N; j++) {
|
for (int64_t j = 0; j < N; j++) {
|
||||||
|
|||||||
@@ -8301,7 +8301,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
|
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
|
||||||
for (int kv : { 113, 512, 1024, }) {
|
for (int kv : { 113, 512, 1024, }) {
|
||||||
if (nr2 != 1 && kv != 512) continue;
|
if (nr2 != 1 && kv != 512) continue;
|
||||||
for (int nb : { 1, 3, 32, 35, }) {
|
for (int nb : { 1, 3, 32, 75, }) {
|
||||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||||
|
|||||||
Reference in New Issue
Block a user