CUDA: refactor mma data loading for AMD (#22051)
* CUDA: refactor mma data loading for AMD * fix CDNA MMQ occupancy * fix CDNA3 mma * fix RDNA3 compile
This commit is contained in:
@@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) {
|
|||||||
#define FLASH_ATTN_AVAILABLE
|
#define FLASH_ATTN_AVAILABLE
|
||||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
|
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
|
||||||
|
|
||||||
#if defined(TURING_MMA_AVAILABLE)
|
|
||||||
#define LDMATRIX_TRANS_AVAILABLE
|
|
||||||
#endif // defined(TURING_MMA_AVAILABLE)
|
|
||||||
|
|
||||||
static bool fp16_available(const int cc) {
|
static bool fp16_available(const int cc) {
|
||||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
|
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
|
||||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
||||||
|
|||||||
@@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|||||||
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
||||||
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
// The minimum granularity is 16 bytes.
|
||||||
if constexpr (use_cp_async) {
|
|
||||||
static_assert(!oob_check, "OOB check not compatible with cp_async");
|
|
||||||
constexpr int preload = 64;
|
|
||||||
constexpr int h2_per_chunk = 16/sizeof(half2);
|
constexpr int h2_per_chunk = 16/sizeof(half2);
|
||||||
const int chunks_per_row = D2 / h2_per_chunk;
|
const int chunks_per_row = D2 / h2_per_chunk;
|
||||||
|
if constexpr (use_cp_async) {
|
||||||
|
static_assert(warp_size == 32, "bad warp_size");
|
||||||
|
static_assert(!oob_check, "OOB check not compatible with cp_async");
|
||||||
|
constexpr int preload = 64;
|
||||||
|
|
||||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||||
|
|
||||||
@@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|||||||
// 6: max 1*16= 16 bytes, 8 half
|
// 6: max 1*16= 16 bytes, 8 half
|
||||||
ggml_cuda_unroll<6>{}(load);
|
ggml_cuda_unroll<6>{}(load);
|
||||||
} else {
|
} else {
|
||||||
// TODO use ggml_cuda_memcpy_1
|
const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}};
|
||||||
auto load = [&] __device__ (const int n) {
|
auto load = [&] __device__ (const int n) {
|
||||||
const int stride_k = warp_size >> n;
|
const int stride_k = 32 >> n;
|
||||||
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
|
const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||||
const int k0_stop = D2 - D2 % (1*stride_k);
|
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||||
const int stride_i = warp_size / stride_k;
|
const int stride_i = warp_size / stride_k;
|
||||||
|
|
||||||
if (k0_start == k0_stop) {
|
if (k0_start == k0_stop) {
|
||||||
@@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|||||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||||
|
|
||||||
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
|
ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4,
|
||||||
|
!oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// 1: max 32* 4=128 bytes, 64 half
|
// 1: max 32*16=512 bytes, 256 half
|
||||||
// 2: max 16* 4= 64 bytes, 32 half
|
// 2: max 16*16=256 bytes, 128 half
|
||||||
// 3: max 8* 4= 32 bytes, 16 half
|
// 3: max 8*16=128 bytes, 64 half
|
||||||
// 4: max 4* 4= 16 bytes, 8 half
|
// 4: max 4*16= 64 bytes, 32 half
|
||||||
ggml_cuda_unroll<4>{}(load);
|
// 5: max 2*16= 32 bytes, 16 half
|
||||||
|
// 6: max 1*16= 16 bytes, 8 half
|
||||||
|
ggml_cuda_unroll<6>{}(load);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
||||||
T_A_VKQ A_identity;
|
|
||||||
make_identity_mat(A_identity);
|
|
||||||
#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
||||||
|
|
||||||
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
||||||
@@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
|
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
|
||||||
|
|
||||||
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
||||||
#if defined(LDMATRIX_TRANS_AVAILABLE)
|
|
||||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||||
#elif defined(AMD_MFMA_AVAILABLE)
|
|
||||||
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
|
|
||||||
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
|
|
||||||
// Load with transposed addressing: 4 strided half loads.
|
|
||||||
{
|
|
||||||
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
|
|
||||||
const half * xs0_h = (const half *) xs0;
|
|
||||||
const int stride_h = stride_tile_V * 2; // stride in half units
|
|
||||||
half * A_h = (half *) A.x;
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < 4; ++l) {
|
|
||||||
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
// TODO: Try to transpose tile_V when loading gmem to smem.
|
|
||||||
// Use mma to transpose T_A_VKQ for RDNA.
|
|
||||||
T_A_VKQ A_trans;
|
|
||||||
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
||||||
mma(A, A_trans, A_identity);
|
|
||||||
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
|
|
||||||
if constexpr (T_B_KQ::I == 8) {
|
if constexpr (T_B_KQ::I == 8) {
|
||||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+79
-166
@@ -86,17 +86,12 @@ namespace ggml_cuda_mma {
|
|||||||
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
|
|
||||||
static constexpr bool is_i_major(const data_layout dl) {
|
|
||||||
return dl == DATA_LAYOUT_I_MAJOR ||
|
|
||||||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr __device__ data_layout get_input_data_layout() {
|
static constexpr __device__ data_layout get_input_data_layout() {
|
||||||
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE)
|
||||||
return DATA_LAYOUT_I_MAJOR_MIRRORED;
|
return DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
#else
|
#else
|
||||||
return DATA_LAYOUT_I_MAJOR;
|
return DATA_LAYOUT_I_MAJOR;
|
||||||
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||||
@@ -113,7 +108,6 @@ namespace ggml_cuda_mma {
|
|||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
||||||
static constexpr __device__ bool supported() {
|
static constexpr __device__ bool supported() {
|
||||||
if (I == 64 && J == 2) return true;
|
|
||||||
if (I == 16 && J == 8) return true;
|
if (I == 16 && J == 8) return true;
|
||||||
if (I == 32 && J == 4) return true;
|
if (I == 32 && J == 4) return true;
|
||||||
if (I == 16 && J == 16) return true;
|
if (I == 16 && J == 16) return true;
|
||||||
@@ -122,7 +116,7 @@ namespace ggml_cuda_mma {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
if constexpr (I == 16 && J == 4) {
|
||||||
return threadIdx.x % 16;
|
return threadIdx.x % 16;
|
||||||
} else if constexpr (I == 16 && J == 8) {
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
return threadIdx.x % 16;
|
return threadIdx.x % 16;
|
||||||
@@ -139,8 +133,8 @@ namespace ggml_cuda_mma {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
if constexpr (I == 16 && J == 4) {
|
||||||
return (2 * ((threadIdx.x / 16) % 2) + l);
|
return threadIdx.x / 16;
|
||||||
} else if constexpr (I == 16 && J == 8) {
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
return 2 * (threadIdx.x / 16) + l;
|
return 2 * (threadIdx.x / 16) + l;
|
||||||
} else if constexpr (I == 32 && J == 4) {
|
} else if constexpr (I == 32 && J == 4) {
|
||||||
@@ -154,7 +148,7 @@ namespace ggml_cuda_mma {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#elif defined(VOLTA_MMA_AVAILABLE)
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
||||||
@@ -283,7 +277,7 @@ namespace ggml_cuda_mma {
|
|||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||||
|
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if defined(VOLTA_MMA_AVAILABLE)
|
||||||
static constexpr int ne = I * J / WARP_SIZE;
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
half2 x[ne] = {{0.0f, 0.0f}};
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
@@ -407,7 +401,7 @@ namespace ggml_cuda_mma {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#endif // defined(VOLTA_MMA_AVAILABLE)
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I_, int J_>
|
template <int I_, int J_>
|
||||||
@@ -701,57 +695,12 @@ namespace ggml_cuda_mma {
|
|||||||
}
|
}
|
||||||
#endif // defined(TURING_MMA_AVAILABLE)
|
#endif // defined(TURING_MMA_AVAILABLE)
|
||||||
|
|
||||||
static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
|
|
||||||
#if defined(RDNA4)
|
|
||||||
const int row = t.get_i(0);
|
|
||||||
const int left_right = t.get_j(0) / 4;
|
|
||||||
const int up_down = row / 8;
|
|
||||||
const int idx = row % 8;
|
|
||||||
reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
|
|
||||||
#else
|
|
||||||
GGML_UNUSED_VARS(t);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // defined(RDNA4)
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int I, int J, typename T, data_layout dl>
|
template <int I, int J, typename T, data_layout dl>
|
||||||
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
|
||||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
||||||
}
|
|
||||||
#elif defined(AMD_WMMA_AVAILABLE)
|
|
||||||
// All wmma layout has contiguous data when i-major.
|
|
||||||
if constexpr (is_i_major(dl)) {
|
|
||||||
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
|
|
||||||
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
|
|
||||||
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
|
|
||||||
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
|
|
||||||
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < aligned_copy_count; ++i) {
|
|
||||||
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
|
||||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
|
||||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
||||||
}
|
|
||||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -764,26 +713,37 @@ namespace ggml_cuda_mma {
|
|||||||
: "=r"(xi[0]), "=r"(xi[1])
|
: "=r"(xi[0]), "=r"(xi[1])
|
||||||
: "l"(xs));
|
: "l"(xs));
|
||||||
#else
|
#else
|
||||||
load_generic(t, xs0, stride);
|
GGML_UNUSED_VARS(t, xs0, stride);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, data_layout dl>
|
||||||
static __device__ __forceinline__ void load_ldmatrix(
|
static __device__ __forceinline__ void load_ldmatrix(
|
||||||
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
#ifdef TURING_MMA_AVAILABLE
|
#ifdef TURING_MMA_AVAILABLE
|
||||||
int * xi = (int *) t.x;
|
int * xi = (int *) t.x;
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||||
: "=r"(xi[0]), "=r"(xi[1])
|
: "=r"(xi[0]), "=r"(xi[1])
|
||||||
: "l"(xs));
|
: "l"(xs));
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
#ifdef RDNA3
|
||||||
|
static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout");
|
||||||
|
static_assert(sizeof(t.x) == 16, "bad ne");
|
||||||
|
ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
||||||
|
ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2);
|
||||||
|
#else
|
||||||
|
static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout");
|
||||||
|
static_assert(sizeof(t.x) == 8, "bad ne");
|
||||||
|
ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
|
||||||
|
#endif // RDNA3
|
||||||
|
#elif defined(AMD_MFMA_AVAILABLE)
|
||||||
|
static_assert(sizeof(t.x) == 4, "bad ne");
|
||||||
|
ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
|
||||||
#else
|
#else
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
||||||
GGML_UNUSED_VARS(t, xs0, stride);
|
GGML_UNUSED_VARS(t, xs0, stride);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#else
|
|
||||||
load_generic(t, xs0, stride);
|
|
||||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -796,19 +756,26 @@ namespace ggml_cuda_mma {
|
|||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
||||||
: "l"(xs));
|
: "l"(xs));
|
||||||
#else
|
#elif defined(VOLTA_MMA_AVAILABLE)
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
||||||
#if 1
|
|
||||||
// TODO: more generic handling
|
|
||||||
static_assert(sizeof(T) == 4, "bad type size");
|
|
||||||
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
||||||
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
|
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
#ifdef RDNA3
|
||||||
|
static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout");
|
||||||
|
static_assert(sizeof(t.x) == 32, "bad ne");
|
||||||
|
ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
||||||
|
ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4);
|
||||||
#else
|
#else
|
||||||
load_generic(t, xs0, stride);
|
static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout");
|
||||||
#endif // 1
|
static_assert(sizeof(t.x) == 16, "bad ne");
|
||||||
|
ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
|
||||||
|
#endif // RDNA3
|
||||||
|
#elif defined(AMD_MFMA_AVAILABLE)
|
||||||
|
static_assert(sizeof(t.x) == 8, "bad ne");
|
||||||
|
ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0));
|
||||||
#else
|
#else
|
||||||
load_generic(t, xs0, stride);
|
GGML_UNUSED_VARS(t, xs0, stride);
|
||||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
NO_DEVICE_CODE;
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -827,23 +794,30 @@ namespace ggml_cuda_mma {
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_ldmatrix(
|
static __device__ __forceinline__ void load_ldmatrix(
|
||||||
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if defined(VOLTA_MMA_AVAILABLE)
|
||||||
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(t, xs0, stride);
|
GGML_UNUSED_VARS(t, xs0, stride);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#endif // defined(VOLTA_MMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __device__ __forceinline__ void load_ldmatrix_trans(
|
static __device__ __forceinline__ void load_ldmatrix_trans(
|
||||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
#ifdef TURING_MMA_AVAILABLE
|
#ifdef TURING_MMA_AVAILABLE
|
||||||
int * xi = (int * ) t.x;
|
int * xi = (int *) t.x;
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
||||||
: "l"(xs));
|
: "l"(xs));
|
||||||
|
#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
|
half * xh = (half *) t.x;
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)];
|
||||||
|
xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)];
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(t, xs0, stride);
|
GGML_UNUSED_VARS(t, xs0, stride);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
@@ -1218,73 +1192,27 @@ namespace ggml_cuda_mma {
|
|||||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||||
int32x4_t * acc = (int32x4_t *) D.x;
|
int32x4_t * acc = (int32x4_t *) D.x;
|
||||||
#if defined(CDNA4) || defined(CDNA3)
|
#if defined(CDNA4) || defined(CDNA3)
|
||||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0);
|
||||||
((int64_t *) B.x)[0],
|
|
||||||
acc[0],
|
|
||||||
0, 0, 0);
|
|
||||||
#elif defined(CDNA2) || defined(CDNA1)
|
#elif defined(CDNA2) || defined(CDNA1)
|
||||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0);
|
||||||
B.x[0],
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0);
|
||||||
acc[0],
|
|
||||||
0, 0, 0);
|
|
||||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
|
|
||||||
B.x[1],
|
|
||||||
acc[0],
|
|
||||||
0, 0, 0);
|
|
||||||
#endif // defined(CDNA4) || defined(CDNA3)
|
#endif // defined(CDNA4) || defined(CDNA3)
|
||||||
|
|
||||||
#elif defined(AMD_WMMA_AVAILABLE)
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
|
||||||
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||||
int32x8_t * acc = (int32x8_t *) D.x;
|
int32x8_t * acc = (int32x8_t *) D.x;
|
||||||
|
|
||||||
#if defined(RDNA4)
|
#if defined(RDNA4)
|
||||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||||
int32x2_t * a_vec = (int32x2_t *) A.x;
|
int32x2_t * a_vec = (int32x2_t *) A.x;
|
||||||
int32x2_t * b_vec = (int32x2_t *) B.x;
|
int32x2_t * b_vec = (int32x2_t *) B.x;
|
||||||
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true);
|
||||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true);
|
||||||
true,
|
|
||||||
a_vec[0],
|
|
||||||
true,
|
|
||||||
b_vec[0],
|
|
||||||
acc[0],
|
|
||||||
true
|
|
||||||
);
|
|
||||||
|
|
||||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
||||||
true,
|
|
||||||
a_vec[1],
|
|
||||||
true,
|
|
||||||
b_vec[1],
|
|
||||||
acc[0],
|
|
||||||
true
|
|
||||||
);
|
|
||||||
|
|
||||||
#elif defined(RDNA3)
|
#elif defined(RDNA3)
|
||||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||||
int32x4_t * a_vec = (int32x4_t *) A.x;
|
int32x4_t * a_vec = (int32x4_t *) A.x;
|
||||||
int32x4_t * b_vec = (int32x4_t *) B.x;
|
int32x4_t * b_vec = (int32x4_t *) B.x;
|
||||||
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true);
|
||||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true);
|
||||||
true,
|
|
||||||
a_vec[0],
|
|
||||||
true,
|
|
||||||
b_vec[0],
|
|
||||||
acc[0],
|
|
||||||
true
|
|
||||||
);
|
|
||||||
|
|
||||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
||||||
true,
|
|
||||||
a_vec[1],
|
|
||||||
true,
|
|
||||||
b_vec[1],
|
|
||||||
acc[0],
|
|
||||||
true
|
|
||||||
);
|
|
||||||
#endif // RDNA4
|
#endif // RDNA4
|
||||||
|
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(D, A, B);
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
@@ -1297,19 +1225,10 @@ namespace ggml_cuda_mma {
|
|||||||
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
|
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
|
||||||
int32x16_t * acc = (int32x16_t *) D.x;
|
int32x16_t * acc = (int32x16_t *) D.x;
|
||||||
#if defined(CDNA4) || defined(CDNA3)
|
#if defined(CDNA4) || defined(CDNA3)
|
||||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
|
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0);
|
||||||
((int64_t *) B.x)[0],
|
|
||||||
acc[0],
|
|
||||||
0, 0, 0);
|
|
||||||
#elif defined(CDNA2) || defined(CDNA1)
|
#elif defined(CDNA2) || defined(CDNA1)
|
||||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
|
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0);
|
||||||
B.x[0],
|
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0);
|
||||||
acc[0],
|
|
||||||
0, 0, 0);
|
|
||||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
|
|
||||||
B.x[1],
|
|
||||||
acc[0],
|
|
||||||
0, 0, 0);
|
|
||||||
#endif // defined(CDNA4) || defined(CDNA3)
|
#endif // defined(CDNA4) || defined(CDNA3)
|
||||||
|
|
||||||
#else
|
#else
|
||||||
@@ -1329,7 +1248,7 @@ namespace ggml_cuda_mma {
|
|||||||
|
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if defined(VOLTA_MMA_AVAILABLE)
|
||||||
const int * Axi = (const int *) A.x;
|
const int * Axi = (const int *) A.x;
|
||||||
const int * Bxi = (const int *) B.x;
|
const int * Bxi = (const int *) B.x;
|
||||||
int * Dxi = (int *) D.x;
|
int * Dxi = (int *) D.x;
|
||||||
@@ -1344,12 +1263,12 @@ namespace ggml_cuda_mma {
|
|||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(D, A, B);
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
#endif // defined(VOLTA_MMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if defined(VOLTA_MMA_AVAILABLE)
|
||||||
const int * Axi = (const int *) A.x;
|
const int * Axi = (const int *) A.x;
|
||||||
const int * Bxi = (const int *) B.x;
|
const int * Bxi = (const int *) B.x;
|
||||||
int * Dxi = (int *) D.x;
|
int * Dxi = (int *) D.x;
|
||||||
@@ -1364,41 +1283,35 @@ namespace ggml_cuda_mma {
|
|||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(D, A, B);
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
#endif // defined(VOLTA_MMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <data_layout dl_d, data_layout dl_ab>
|
template <data_layout dl_d, data_layout dl_ab>
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
||||||
#if defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||||
|
int32x4_t * acc = (int32x4_t *) D.x;
|
||||||
|
#if defined(CDNA4) || defined(CDNA3)
|
||||||
|
const int64_t xA = uint32_t(A.x[0]);
|
||||||
|
const int64_t xB = uint32_t(B.x[0]);
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0);
|
||||||
|
#elif defined(CDNA2) || defined(CDNA1)
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0);
|
||||||
|
#endif // defined(CDNA4) || defined(CDNA3)
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||||
int32x8_t * acc = (int32x8_t *) D.x;
|
int32x8_t * acc = (int32x8_t *) D.x;
|
||||||
#if defined(RDNA4)
|
#if defined(RDNA4)
|
||||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||||
int32x2_t * a_vec = (int32x2_t *) A.x;
|
int32x2_t * a_vec = (int32x2_t *) A.x;
|
||||||
int32x2_t * b_vec = (int32x2_t *) B.x;
|
int32x2_t * b_vec = (int32x2_t *) B.x;
|
||||||
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false);
|
||||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
||||||
true,
|
|
||||||
a_vec[0],
|
|
||||||
true,
|
|
||||||
b_vec[0],
|
|
||||||
acc[0],
|
|
||||||
false
|
|
||||||
);
|
|
||||||
#elif defined(RDNA3)
|
#elif defined(RDNA3)
|
||||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||||
int32x4_t * a_vec = (int32x4_t *) A.x;
|
int32x4_t * a_vec = (int32x4_t *) A.x;
|
||||||
int32x4_t * b_vec = (int32x4_t *) B.x;
|
int32x4_t * b_vec = (int32x4_t *) B.x;
|
||||||
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false);
|
||||||
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
||||||
true,
|
|
||||||
a_vec[0],
|
|
||||||
true,
|
|
||||||
b_vec[0],
|
|
||||||
acc[0],
|
|
||||||
false
|
|
||||||
);
|
|
||||||
#endif // RDNA4
|
#endif // RDNA4
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(D);
|
GGML_UNUSED(D);
|
||||||
|
|||||||
+16
-185
@@ -104,7 +104,7 @@ struct tile_x_sizes {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static int get_mmq_x_max_host(const int cc) {
|
static int get_mmq_x_max_host(const int cc) {
|
||||||
return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
|
return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
|
||||||
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
||||||
#ifdef GGML_CUDA_FORCE_MMQ
|
#ifdef GGML_CUDA_FORCE_MMQ
|
||||||
128 : 64;
|
128 : 64;
|
||||||
@@ -114,9 +114,9 @@ static int get_mmq_x_max_host(const int cc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ int get_mmq_x_max_device() {
|
static constexpr __device__ int get_mmq_x_max_device() {
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
return 128;
|
return 128;
|
||||||
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(GGML_USE_HIP)
|
||||||
return 64;
|
return 64;
|
||||||
@@ -1054,13 +1054,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|||||||
tile_A A[ntx];
|
tile_A A[ntx];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
tile_B B;
|
tile_B B;
|
||||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
||||||
|
|
||||||
float dB;
|
float dB;
|
||||||
const int j = j0 + tile_C::get_j(0);
|
const int j = j0 + tile_C::get_j(0);
|
||||||
@@ -1295,13 +1295,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|||||||
tile_A A[ntx];
|
tile_A A[ntx];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
tile_B B;
|
tile_B B;
|
||||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int j = j0 + tile_C::get_j(0);
|
const int j = j0 + tile_C::get_j(0);
|
||||||
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||||
@@ -1435,57 +1435,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|||||||
template <int mmq_x, int mmq_y>
|
template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
constexpr data_layout input_layout = get_input_data_layout();
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_A;
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_B;
|
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
||||||
typedef tile<64, 2, int, input_layout> tile_load;
|
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
||||||
constexpr int rows_per_warp = granularity;
|
|
||||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
|
||||||
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
||||||
const int * y_qs = (const int *) y + 4;
|
|
||||||
const float * y_df = (const float *) y;
|
|
||||||
|
|
||||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
||||||
|
|
||||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
||||||
const int k0 = k00 + k01;
|
|
||||||
|
|
||||||
tile_A A[ntx];
|
|
||||||
#pragma unroll
|
|
||||||
for (int n = 0; n < ntx; ++n) {
|
|
||||||
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
||||||
tile_B B[1];
|
|
||||||
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
||||||
|
|
||||||
const int j = j0 + tile_C::get_j(0);
|
|
||||||
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int n = 0; n < ntx; ++n) {
|
|
||||||
tile_C C;
|
|
||||||
mma(C, A[n], B[0]);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < tile_C::ne; ++l) {
|
|
||||||
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
||||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
||||||
constexpr data_layout input_layout = get_input_data_layout();
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 4, int, input_layout> tile_A;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 4, int, input_layout> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
@@ -1510,13 +1460,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|||||||
tile_A A[ntx];
|
tile_A A[ntx];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
tile_B B;
|
tile_B B;
|
||||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int j = j0 + tile_C::get_j(0);
|
const int j = j0 + tile_C::get_j(0);
|
||||||
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||||
@@ -1742,74 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|||||||
template <int mmq_x, int mmq_y>
|
template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
constexpr data_layout input_layout = get_input_data_layout();
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_A;
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_B;
|
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
||||||
typedef tile<64, 2, int, input_layout> tile_load;
|
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
||||||
constexpr int rows_per_warp = granularity;
|
|
||||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
|
||||||
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
||||||
const int * y_qs = (const int *) y + 4;
|
|
||||||
const half2 * y_ds = (const half2 *) y;
|
|
||||||
|
|
||||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
||||||
|
|
||||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
||||||
const int k0 = k00 + k01;
|
|
||||||
|
|
||||||
tile_A A[ntx];
|
|
||||||
#pragma unroll
|
|
||||||
for (int n = 0; n < ntx; ++n) {
|
|
||||||
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
||||||
tile_B B[1];
|
|
||||||
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
||||||
|
|
||||||
const int j = j0 + tile_C::get_j(0);
|
|
||||||
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
|
|
||||||
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
|
||||||
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
|
||||||
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
|
||||||
|
|
||||||
tile_C Cm;
|
|
||||||
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
||||||
tile_A A1;
|
|
||||||
A1.x[0] = 0x01010101;
|
|
||||||
A1.x[1] = 0x01010101;
|
|
||||||
mma(Cm, A1, B[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int n = 0; n < ntx; ++n) {
|
|
||||||
tile_C Cd;
|
|
||||||
mma(Cd, A[n], B[0]);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < tile_C::ne; ++l) {
|
|
||||||
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
||||||
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
|
||||||
float tmp = Cd.x[l]*dm.x;
|
|
||||||
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
||||||
tmp -= Cm.x[l]*dm.y;
|
|
||||||
}
|
|
||||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
|
||||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
||||||
constexpr data_layout input_layout = get_input_data_layout();
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 4, int, input_layout> tile_A;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 4, int, input_layout> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
@@ -1834,13 +1717,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||||||
tile_A A[ntx];
|
tile_A A[ntx];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
tile_B B;
|
tile_B B;
|
||||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int j = j0 + tile_C::get_j(0);
|
const int j = j0 + tile_C::get_j(0);
|
||||||
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
|
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
|
||||||
@@ -2573,59 +2456,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|||||||
template <int mmq_x, int mmq_y>
|
template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
constexpr data_layout input_layout = get_input_data_layout();
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_A;
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_B;
|
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
||||||
typedef tile<64, 2, int, input_layout> tile_load;
|
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
||||||
constexpr int rows_per_warp = granularity;
|
|
||||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
||||||
|
|
||||||
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
||||||
|
|
||||||
const int * x_qs = (const int *) x;
|
|
||||||
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
||||||
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
||||||
const int * y_qs = (const int *) y + 4;
|
|
||||||
const float * y_df = (const float *) y;
|
|
||||||
|
|
||||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
||||||
|
|
||||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
||||||
const int k0 = k00 + k01;
|
|
||||||
|
|
||||||
tile_A A[ntx];
|
|
||||||
#pragma unroll
|
|
||||||
for (int n = 0; n < ntx; ++n) {
|
|
||||||
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
||||||
tile_B B[1];
|
|
||||||
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
||||||
|
|
||||||
const int j = j0 + tile_C::get_j(0);
|
|
||||||
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int n = 0; n < ntx; ++n) {
|
|
||||||
tile_C C;
|
|
||||||
mma(C, A[n], B[0]);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < tile_C::ne; ++l) {
|
|
||||||
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
||||||
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
|
||||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
||||||
constexpr data_layout input_layout = get_input_data_layout();
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 4, int, input_layout> tile_A;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 4, int, input_layout> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
@@ -2651,13 +2482,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
tile_A A[ntx];
|
tile_A A[ntx];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||||
tile_B B;
|
tile_B B;
|
||||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
||||||
|
|
||||||
const int j = j0 + tile_C::get_j(0);
|
const int j = j0 + tile_C::get_j(0);
|
||||||
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||||
|
|||||||
Reference in New Issue
Block a user