From 4eac5b45095a4e8a1ff1cce4f6d030e0872fb4ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 19 Apr 2026 18:26:59 +0200 Subject: [PATCH] CUDA: refactor mma data loading for AMD (#22051) * CUDA: refactor mma data loading for AMD * fix CDNA MMQ occupancy * fix CDNA3 mma * fix RDNA3 compile --- ggml/src/ggml-cuda/common.cuh | 4 - ggml/src/ggml-cuda/fattn-mma-f16.cuh | 57 ++----- ggml/src/ggml-cuda/mma.cuh | 245 +++++++++------------------ ggml/src/ggml-cuda/mmq.cuh | 201 ++-------------------- 4 files changed, 112 insertions(+), 395 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ddf50baf4..3aec1742e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) -#if defined(TURING_MMA_AVAILABLE) -#define LDMATRIX_TRANS_AVAILABLE -#endif // defined(TURING_MMA_AVAILABLE) - static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b613ae61f..e185449d4 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + // The minimum granularity is 16 bytes. + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; if constexpr (use_cp_async) { + static_assert(warp_size == 32, "bad warp_size"); static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); @@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // 6: max 1*16= 16 bytes, 8 half ggml_cuda_unroll<6>{}(load); } else { - // TODO use ggml_cuda_memcpy_1 + const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}; auto load = [&] __device__ (const int n) { - const int stride_k = warp_size >> n; - const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_k = 32 >> n; + const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { @@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4, + !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero); } } }; - // 1: max 32* 4=128 bytes, 64 half - // 2: max 16* 4= 64 bytes, 32 half - // 3: max 8* 4= 32 bytes, 16 half - // 4: max 4* 4= 16 bytes, 8 half - ggml_cuda_unroll<4>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } } @@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } -#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - T_A_VKQ A_identity; - make_identity_mat(A_identity); -#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { @@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. -#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); -#elif defined(AMD_MFMA_AVAILABLE) - // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. - // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. - // Load with transposed addressing: 4 strided half loads. - { - const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; - const half * xs0_h = (const half *) xs0; - const int stride_h = stride_tile_V * 2; // stride in half units - half * A_h = (half *) A.x; -#pragma unroll - for (int l = 0; l < 4; ++l) { - A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; - } - } -#else - // TODO: Try to transpose tile_V when loading gmem to smem. - // Use mma to transpose T_A_VKQ for RDNA. - T_A_VKQ A_trans; - load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - mma(A, A_trans, A_identity); -#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c91dd2d9a..b0f674635 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -86,17 +86,12 @@ namespace ggml_cuda_mma { // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - static constexpr bool is_i_major(const data_layout dl) { - return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED; - } - static constexpr __device__ data_layout get_input_data_layout() { -#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) } template @@ -113,7 +108,6 @@ namespace ggml_cuda_mma { T x[ne] = {0}; static constexpr __device__ bool supported() { - if (I == 64 && J == 2) return true; if (I == 16 && J == 8) return true; if (I == 32 && J == 4) return true; if (I == 16 && J == 16) return true; @@ -122,7 +116,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + if constexpr (I == 16 && J == 4) { return threadIdx.x % 16; } else if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; @@ -139,8 +133,8 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> - return (2 * ((threadIdx.x / 16) % 2) + l); + if constexpr (I == 16 && J == 4) { + return threadIdx.x / 16; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 4) { @@ -154,7 +148,7 @@ namespace ggml_cuda_mma { return -1; } } -#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#elif defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -283,7 +277,7 @@ namespace ggml_cuda_mma { static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -407,7 +401,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) }; template @@ -701,57 +695,12 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) - static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { -#if defined(RDNA4) - const int row = t.get_i(0); - const int left_right = t.get_j(0) / 4; - const int up_down = row / 8; - const int idx = row % 8; - reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; -#else - GGML_UNUSED_VARS(t); - NO_DEVICE_CODE; -#endif // defined(RDNA4) - } - template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { -#if defined(AMD_MFMA_AVAILABLE) - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } -#elif defined(AMD_WMMA_AVAILABLE) - // All wmma layout has contiguous data when i-major. - if constexpr (is_i_major(dl)) { - // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() - constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); - if constexpr (sizeof(t.x) > aligned_copy_bytes) { - static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); - constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; -#pragma unroll - for (int i = 0; i < aligned_copy_count; ++i) { - ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } - } else { -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } -#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -764,26 +713,37 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(t, xs0, stride); + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2); +#else + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 4, "bad ne"); + ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -796,19 +756,26 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); -#else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#if 1 - // TODO: more generic handling - static_assert(sizeof(T) == 4, "bad type size"); +#elif defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 32, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4); #else - load_generic(t, xs0, stride); -#endif // 1 + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -827,23 +794,30 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void load_ldmatrix( tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE - int * xi = (int * ) t.x; + int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); +#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; @@ -1218,73 +1192,27 @@ namespace ggml_cuda_mma { using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) - #elif defined(AMD_WMMA_AVAILABLE) - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; - #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); - + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true); #endif // RDNA4 - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1297,19 +1225,10 @@ namespace ggml_cuda_mma { using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) #else @@ -1329,7 +1248,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1344,12 +1263,12 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } static __device__ __forceinline__ void mma( tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1364,41 +1283,35 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA4) || defined(CDNA3) + const int64_t xA = uint32_t(A.x[0]); + const int64_t xB = uint32_t(B.x[0]); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) +#elif defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false); #endif // RDNA4 #else GGML_UNUSED(D); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 28b662df9..b1a319de9 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -104,7 +104,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -114,9 +114,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -1054,13 +1054,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); @@ -1295,13 +1295,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -1435,57 +1435,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1510,13 +1460,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -1742,74 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1834,13 +1717,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; @@ -2573,59 +2456,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -2651,13 +2482,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];