CUDA: refactor mma data loading for AMD (#22051)

* CUDA: refactor mma data loading for AMD

* fix CDNA MMQ occupancy

* fix CDNA3 mma

* fix RDNA3 compile
This commit is contained in:
Johannes Gäßler
2026-04-19 18:26:59 +02:00
committed by GitHub
parent d5b780a676
commit 4eac5b4509
4 changed files with 112 additions and 395 deletions
-4
View File
@@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) {
#define FLASH_ATTN_AVAILABLE #define FLASH_ATTN_AVAILABLE
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
#if defined(TURING_MMA_AVAILABLE)
#define LDMATRIX_TRANS_AVAILABLE
#endif // defined(TURING_MMA_AVAILABLE)
static bool fp16_available(const int cc) { static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
+17 -40
View File
@@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
// K/V data is loaded with decreasing granularity for D for better memory bandwidth. // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. // The minimum granularity is 16 bytes.
constexpr int h2_per_chunk = 16/sizeof(half2);
const int chunks_per_row = D2 / h2_per_chunk;
if constexpr (use_cp_async) { if constexpr (use_cp_async) {
static_assert(warp_size == 32, "bad warp_size");
static_assert(!oob_check, "OOB check not compatible with cp_async"); static_assert(!oob_check, "OOB check not compatible with cp_async");
constexpr int preload = 64; constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
const int chunks_per_row = D2 / h2_per_chunk;
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
@@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
// 6: max 1*16= 16 bytes, 8 half // 6: max 1*16= 16 bytes, 8 half
ggml_cuda_unroll<6>{}(load); ggml_cuda_unroll<6>{}(load);
} else { } else {
// TODO use ggml_cuda_memcpy_1 const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}};
auto load = [&] __device__ (const int n) { auto load = [&] __device__ (const int n) {
const int stride_k = warp_size >> n; const int stride_k = 32 >> n;
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = D2 - D2 % (1*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
const int stride_i = warp_size / stride_k; const int stride_i = warp_size / stride_k;
if (k0_start == k0_stop) { if (k0_start == k0_stop) {
@@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4,
!oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero);
} }
} }
}; };
// 1: max 32* 4=128 bytes, 64 half // 1: max 32*16=512 bytes, 256 half
// 2: max 16* 4= 64 bytes, 32 half // 2: max 16*16=256 bytes, 128 half
// 3: max 8* 4= 32 bytes, 16 half // 3: max 8*16=128 bytes, 64 half
// 4: max 4* 4= 16 bytes, 8 half // 4: max 4*16= 64 bytes, 32 half
ggml_cuda_unroll<4>{}(load); // 5: max 2*16= 32 bytes, 16 half
// 6: max 1*16= 16 bytes, 8 half
ggml_cuda_unroll<6>{}(load);
} }
} }
@@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
} }
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
T_A_VKQ A_identity;
make_identity_mat(A_identity);
#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
#pragma unroll #pragma unroll
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
@@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
#if defined(LDMATRIX_TRANS_AVAILABLE)
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
#elif defined(AMD_MFMA_AVAILABLE)
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
// Load with transposed addressing: 4 strided half loads.
{
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
const half * xs0_h = (const half *) xs0;
const int stride_h = stride_tile_V * 2; // stride in half units
half * A_h = (half *) A.x;
#pragma unroll
for (int l = 0; l < 4; ++l) {
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
}
}
#else
// TODO: Try to transpose tile_V when loading gmem to smem.
// Use mma to transpose T_A_VKQ for RDNA.
T_A_VKQ A_trans;
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
mma(A, A_trans, A_identity);
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
if constexpr (T_B_KQ::I == 8) { if constexpr (T_B_KQ::I == 8) {
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
} else { } else {
+79 -166
View File
@@ -86,17 +86,12 @@ namespace ggml_cuda_mma {
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
static constexpr bool is_i_major(const data_layout dl) {
return dl == DATA_LAYOUT_I_MAJOR ||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
}
static constexpr __device__ data_layout get_input_data_layout() { static constexpr __device__ data_layout get_input_data_layout() {
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE)
return DATA_LAYOUT_I_MAJOR_MIRRORED; return DATA_LAYOUT_I_MAJOR_MIRRORED;
#else #else
return DATA_LAYOUT_I_MAJOR; return DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE)
} }
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR> template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
@@ -113,7 +108,6 @@ namespace ggml_cuda_mma {
T x[ne] = {0}; T x[ne] = {0};
static constexpr __device__ bool supported() { static constexpr __device__ bool supported() {
if (I == 64 && J == 2) return true;
if (I == 16 && J == 8) return true; if (I == 16 && J == 8) return true;
if (I == 32 && J == 4) return true; if (I == 32 && J == 4) return true;
if (I == 16 && J == 16) return true; if (I == 16 && J == 16) return true;
@@ -122,7 +116,7 @@ namespace ggml_cuda_mma {
} }
static __device__ __forceinline__ int get_i(const int l) { static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> if constexpr (I == 16 && J == 4) {
return threadIdx.x % 16; return threadIdx.x % 16;
} else if constexpr (I == 16 && J == 8) { } else if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16; return threadIdx.x % 16;
@@ -139,8 +133,8 @@ namespace ggml_cuda_mma {
} }
static __device__ __forceinline__ int get_j(const int l) { static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> if constexpr (I == 16 && J == 4) {
return (2 * ((threadIdx.x / 16) % 2) + l); return threadIdx.x / 16;
} else if constexpr (I == 16 && J == 8) { } else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l; return 2 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 4) { } else if constexpr (I == 32 && J == 4) {
@@ -154,7 +148,7 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #elif defined(VOLTA_MMA_AVAILABLE)
static constexpr int ne = I * J / 32; static constexpr int ne = I * J / 32;
T x[ne] = {0}; T x[ne] = {0};
@@ -283,7 +277,7 @@ namespace ggml_cuda_mma {
static constexpr int J = J_; static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if defined(VOLTA_MMA_AVAILABLE)
static constexpr int ne = I * J / WARP_SIZE; static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}}; half2 x[ne] = {{0.0f, 0.0f}};
@@ -407,7 +401,7 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // defined(VOLTA_MMA_AVAILABLE)
}; };
template <int I_, int J_> template <int I_, int J_>
@@ -701,57 +695,12 @@ namespace ggml_cuda_mma {
} }
#endif // defined(TURING_MMA_AVAILABLE) #endif // defined(TURING_MMA_AVAILABLE)
static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
#if defined(RDNA4)
const int row = t.get_i(0);
const int left_right = t.get_j(0) / 4;
const int up_down = row / 8;
const int idx = row % 8;
reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
#else
GGML_UNUSED_VARS(t);
NO_DEVICE_CODE;
#endif // defined(RDNA4)
}
template <int I, int J, typename T, data_layout dl> template <int I, int J, typename T, data_layout dl>
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) { static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MFMA_AVAILABLE)
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
#elif defined(AMD_WMMA_AVAILABLE)
// All wmma layout has contiguous data when i-major.
if constexpr (is_i_major(dl)) {
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
#pragma unroll
for (int i = 0; i < aligned_copy_count; ++i) {
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
}
} else {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
} else {
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
#else
#pragma unroll #pragma unroll
for (int l = 0; l < t.ne; ++l) { for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
} }
#endif // defined(AMD_MFMA_AVAILABLE)
} }
template <typename T> template <typename T>
@@ -764,26 +713,37 @@ namespace ggml_cuda_mma {
: "=r"(xi[0]), "=r"(xi[1]) : "=r"(xi[0]), "=r"(xi[1])
: "l"(xs)); : "l"(xs));
#else #else
load_generic(t, xs0, stride); GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE;
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
template <typename T> template <typename T, data_layout dl>
static __device__ __forceinline__ void load_ldmatrix( static __device__ __forceinline__ void load_ldmatrix(
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#ifdef TURING_MMA_AVAILABLE #ifdef TURING_MMA_AVAILABLE
int * xi = (int *) t.x; int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "=r"(xi[0]), "=r"(xi[1]) : "=r"(xi[0]), "=r"(xi[1])
: "l"(xs)); : "l"(xs));
#elif defined(AMD_WMMA_AVAILABLE)
#ifdef RDNA3
static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout");
static_assert(sizeof(t.x) == 16, "bad ne");
ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2);
#else
static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout");
static_assert(sizeof(t.x) == 8, "bad ne");
ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
#endif // RDNA3
#elif defined(AMD_MFMA_AVAILABLE)
static_assert(sizeof(t.x) == 4, "bad ne");
ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
#else #else
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
GGML_UNUSED_VARS(t, xs0, stride); GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#else
load_generic(t, xs0, stride);
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
@@ -796,19 +756,26 @@ namespace ggml_cuda_mma {
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
: "l"(xs)); : "l"(xs));
#else #elif defined(VOLTA_MMA_AVAILABLE)
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if 1
// TODO: more generic handling
static_assert(sizeof(T) == 4, "bad type size");
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
#elif defined(AMD_WMMA_AVAILABLE)
#ifdef RDNA3
static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout");
static_assert(sizeof(t.x) == 32, "bad ne");
ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4);
#else #else
load_generic(t, xs0, stride); static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout");
#endif // 1 static_assert(sizeof(t.x) == 16, "bad ne");
ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
#endif // RDNA3
#elif defined(AMD_MFMA_AVAILABLE)
static_assert(sizeof(t.x) == 8, "bad ne");
ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
#else #else
load_generic(t, xs0, stride); GGML_UNUSED_VARS(t, xs0, stride);
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA NO_DEVICE_CODE;
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
@@ -827,23 +794,30 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ void load_ldmatrix( static __device__ __forceinline__ void load_ldmatrix(
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if defined(VOLTA_MMA_AVAILABLE)
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
#else #else
GGML_UNUSED_VARS(t, xs0, stride); GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // defined(VOLTA_MMA_AVAILABLE)
} }
template <typename T> template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans( static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef TURING_MMA_AVAILABLE #ifdef TURING_MMA_AVAILABLE
int * xi = (int * ) t.x; int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
: "l"(xs)); : "l"(xs));
#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
half * xh = (half *) t.x;
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)];
xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)];
}
#else #else
GGML_UNUSED_VARS(t, xs0, stride); GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE; NO_DEVICE_CODE;
@@ -1218,73 +1192,27 @@ namespace ggml_cuda_mma {
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * acc = (int32x4_t *) D.x; int32x4_t * acc = (int32x4_t *) D.x;
#if defined(CDNA4) || defined(CDNA3) #if defined(CDNA4) || defined(CDNA3)
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0);
((int64_t *) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA1) #elif defined(CDNA2) || defined(CDNA1)
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0);
B.x[0], acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0);
acc[0],
0, 0, 0);
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
B.x[1],
acc[0],
0, 0, 0);
#endif // defined(CDNA4) || defined(CDNA3) #endif // defined(CDNA4) || defined(CDNA3)
#elif defined(AMD_WMMA_AVAILABLE) #elif defined(AMD_WMMA_AVAILABLE)
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x; int32x8_t * acc = (int32x8_t *) D.x;
#if defined(RDNA4) #if defined(RDNA4)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x; int32x2_t * b_vec = (int32x2_t *) B.x;
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true);
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#elif defined(RDNA3) #elif defined(RDNA3)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * a_vec = (int32x4_t *) A.x;
int32x4_t * b_vec = (int32x4_t *) B.x; int32x4_t * b_vec = (int32x4_t *) B.x;
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true);
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // RDNA4 #endif // RDNA4
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
@@ -1297,19 +1225,10 @@ namespace ggml_cuda_mma {
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
int32x16_t * acc = (int32x16_t *) D.x; int32x16_t * acc = (int32x16_t *) D.x;
#if defined(CDNA4) || defined(CDNA3) #if defined(CDNA4) || defined(CDNA3)
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0);
((int64_t *) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA1) #elif defined(CDNA2) || defined(CDNA1)
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0);
B.x[0], acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0);
acc[0],
0, 0, 0);
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
B.x[1],
acc[0],
0, 0, 0);
#endif // defined(CDNA4) || defined(CDNA3) #endif // defined(CDNA4) || defined(CDNA3)
#else #else
@@ -1329,7 +1248,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if defined(VOLTA_MMA_AVAILABLE)
const int * Axi = (const int *) A.x; const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x; const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x; int * Dxi = (int *) D.x;
@@ -1344,12 +1263,12 @@ namespace ggml_cuda_mma {
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(VOLTA_MMA_AVAILABLE)
} }
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if defined(VOLTA_MMA_AVAILABLE)
const int * Axi = (const int *) A.x; const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x; const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x; int * Dxi = (int *) D.x;
@@ -1364,41 +1283,35 @@ namespace ggml_cuda_mma {
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(VOLTA_MMA_AVAILABLE)
} }
template <data_layout dl_d, data_layout dl_ab> template <data_layout dl_d, data_layout dl_ab>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * acc = (int32x4_t *) D.x;
#if defined(CDNA4) || defined(CDNA3)
const int64_t xA = uint32_t(A.x[0]);
const int64_t xB = uint32_t(B.x[0]);
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0);
#elif defined(CDNA2) || defined(CDNA1)
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0);
#endif // defined(CDNA4) || defined(CDNA3)
#elif defined(AMD_WMMA_AVAILABLE)
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x; int32x8_t * acc = (int32x8_t *) D.x;
#if defined(RDNA4) #if defined(RDNA4)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x; int32x2_t * b_vec = (int32x2_t *) B.x;
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
false
);
#elif defined(RDNA3) #elif defined(RDNA3)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * a_vec = (int32x4_t *) A.x;
int32x4_t * b_vec = (int32x4_t *) B.x; int32x4_t * b_vec = (int32x4_t *) B.x;
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
false
);
#endif // RDNA4 #endif // RDNA4
#else #else
GGML_UNUSED(D); GGML_UNUSED(D);
+16 -185
View File
@@ -104,7 +104,7 @@ struct tile_x_sizes {
}; };
static int get_mmq_x_max_host(const int cc) { static int get_mmq_x_max_host(const int cc) {
return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
#ifdef GGML_CUDA_FORCE_MMQ #ifdef GGML_CUDA_FORCE_MMQ
128 : 64; 128 : 64;
@@ -114,9 +114,9 @@ static int get_mmq_x_max_host(const int cc) {
} }
static constexpr __device__ int get_mmq_x_max_device() { static constexpr __device__ int get_mmq_x_max_device() {
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
return 128; return 128;
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
return 64; return 64;
@@ -1054,13 +1054,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
tile_A A[ntx]; tile_A A[ntx];
#pragma unroll #pragma unroll
for (int n = 0; n < ntx; ++n) { for (int n = 0; n < ntx; ++n) {
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B; tile_B B;
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
float dB; float dB;
const int j = j0 + tile_C::get_j(0); const int j = j0 + tile_C::get_j(0);
@@ -1295,13 +1295,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
tile_A A[ntx]; tile_A A[ntx];
#pragma unroll #pragma unroll
for (int n = 0; n < ntx; ++n) { for (int n = 0; n < ntx; ++n) {
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B; tile_B B;
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
const int j = j0 + tile_C::get_j(0); const int j = j0 + tile_C::get_j(0);
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -1435,57 +1435,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
template <int mmq_x, int mmq_y> template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
const int k0 = k00 + k01;
tile_A A[ntx];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B[1];
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
const int j = j0 + tile_C::get_j(0);
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C C;
mma(C, A[n], B[0]);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
}
}
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
constexpr data_layout input_layout = get_input_data_layout(); constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 4, int, input_layout> tile_B;
@@ -1510,13 +1460,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
tile_A A[ntx]; tile_A A[ntx];
#pragma unroll #pragma unroll
for (int n = 0; n < ntx; ++n) { for (int n = 0; n < ntx; ++n) {
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B; tile_B B;
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
const int j = j0 + tile_C::get_j(0); const int j = j0 + tile_C::get_j(0);
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -1742,74 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
template <int mmq_x, int mmq_y> template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
const int k0 = k00 + k01;
tile_A A[ntx];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B[1];
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
const int j = j0 + tile_C::get_j(0);
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
tile_C Cm;
if (k01 >= MMQ_TILE_NE_K * 3/4) {
tile_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
mma(Cm, A1, B[0]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C Cd;
mma(Cd, A[n], B[0]);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
float tmp = Cd.x[l]*dm.x;
if (k01 >= MMQ_TILE_NE_K * 3/4) {
tmp -= Cm.x[l]*dm.y;
}
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
}
}
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
constexpr data_layout input_layout = get_input_data_layout(); constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 4, int, input_layout> tile_B;
@@ -1834,13 +1717,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
tile_A A[ntx]; tile_A A[ntx];
#pragma unroll #pragma unroll
for (int n = 0; n < ntx; ++n) { for (int n = 0; n < ntx; ++n) {
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B; tile_B B;
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
const int j = j0 + tile_C::get_j(0); const int j = j0 + tile_C::get_j(0);
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
@@ -2573,59 +2456,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
template <int mmq_x, int mmq_y> template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
const int k0 = k00 + k01;
tile_A A[ntx];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B[1];
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
const int j = j0 + tile_C::get_j(0);
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C C;
mma(C, A[n], B[0]);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
}
}
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
constexpr data_layout input_layout = get_input_data_layout(); constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 4, int, input_layout> tile_B;
@@ -2651,13 +2482,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
tile_A A[ntx]; tile_A A[ntx];
#pragma unroll #pragma unroll
for (int n = 0; n < ntx; ++n) { for (int n = 0; n < ntx; ++n) {
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
tile_B B; tile_B B;
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
const int j = j0 + tile_C::get_j(0); const int j = j0 + tile_C::get_j(0);
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];