diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e185449d4..3f01e858d 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); @@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); + const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build: +extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); + // For GLM 4.7 Flash extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 25b16e83c..d60634cc0 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 320: { + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst); + } break; case 512: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 26721cc4c..928b856f9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) @@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) @@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm } } - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } } if constexpr (ncols2 <= 8) { @@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + if constexpr (DKQ == 320) { // Mistral Small 4 + if (use_gqa_opt && gqa_ratio % 32 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); + } + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); @@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DKQ <= 512) { + if constexpr (DKQ <= 512 && DKQ != 320) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(320, 256); extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea6607cd3..8256591b2 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 320: + // For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build). + GGML_ASSERT(V->ne[0] == 256); + { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 32 == 0); + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst); + } + break; case 512: GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); @@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 320: + if (V->ne[0] != 256 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 32 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 512: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu index 1f554d81e..8fc3b1797 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu index 264751d65..abd2b21ce 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu new file mode 100644 index 000000000..c91f50807 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(320, 256); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 841059c15..5e9a1cb2e 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -62,7 +62,7 @@ for filename in glob("*.cu"): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -84,13 +84,16 @@ for ncols in [8, 16, 32, 64]: continue if head_size_kq == 72: continue - if head_size_kq == 512 and ncols2 not in (4, 8): + # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue - if head_size_kq != 576 and ncols2 in (16, 32): + if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue - if head_size_kq == 576 and ncols2 not in (4, 16, 32): + if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - head_size_v = head_size_kq if head_size_kq != 576 else 512 + if head_size_kq not in (320, 576) and ncols2 in (16, 32): + continue + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: