CUDA: experimental native mxfp4 support for blackwell (#17906)

* CUDA: experimental native mxfp4 support for blackwell

* optimize load_tiles

* optimize quantize_mxfp4

* cleanup

* first pass review: formatting

* use interleaved layout for mma

* mmq: add assert for size

* use __nv_fp4x4_e2m1

* use iter_k as 512, cleanup

* Use 1200 as blackwell instead of 1000

* address review comments

* mmq: fix stride

* quantize.cu: use reference impl of e8m0 scale

* address review comments

* add 120f-virtual + minor fixes

---------

Co-authored-by: Aman Gupta <aman>
This commit is contained in:
Aman Gupta
2025-12-24 22:28:26 +08:00
committed by GitHub
parent 54132f1b1f
commit c8a2417d7b
8 changed files with 426 additions and 18 deletions
+35
View File
@@ -50,6 +50,10 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
#define GGML_CUDA_CC_RUBIN 1300
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@@ -246,6 +250,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
# define BLACKWELL_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -316,6 +324,11 @@ static bool cp_async_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static bool blackwell_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
return 64;
@@ -701,6 +714,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
const uint8_t sign_bit = (x < 0.0f) << 3;
float ax = fabsf(x) * e;
// Positive LUT
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
int best_i = 0;
float best_err = fabsf(ax - pos_lut[0]);
#pragma unroll
for (int i = 1; i < 8; ++i) {
const float err = fabsf(ax - pos_lut[i]);
if (err < best_err) {
best_err = err;
best_i = i;
}
}
return static_cast<uint8_t>(best_i | sign_bit);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)