From d6a5094004d5be6e404220d2e799daa348739a96 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 29 Apr 2026 00:59:00 -0700 Subject: [PATCH] ggml-webgpu: Fix bug in FlashAttention support check (#22492) * Fix flashattention support check for devices that don't support subgroups * set path to none if kv_tile doesn't fit --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 44 ++++++++++++------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 ++ 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 34cbf3694..b7771ac23 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -494,9 +494,10 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ enum ggml_webgpu_flash_attn_path : uint32_t { - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u, - GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u, - GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u, + GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, }; struct ggml_webgpu_flash_attn_pipeline_key { @@ -534,7 +535,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { }; struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; uint32_t q_tile = 0; uint32_t kv_tile = 0; uint32_t wg_size = 0; @@ -709,19 +710,29 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + return decisions; + } const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); decisions.kv_direct = key.kv_direct; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + // invalidate if even the smallest kv_tile doesn't fit in shared memory + if (max_kv_tile == 0) { + decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + return decisions; + } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - decisions.q_tile = 1u; - decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile)); - decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.q_tile = 1u; + decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); + decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { @@ -734,9 +745,8 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) : - std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + std::min(64u, max_kv_tile) : + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); @@ -755,7 +765,6 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( context.sg_mat_n; } } - return decisions; } @@ -1364,7 +1373,7 @@ class ggml_webgpu_shader_lib { if (key.src_type == GGML_TYPE_Q1_0) { defines.push_back("BLOCK_SIZE=128u"); } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || - key.src_type == GGML_TYPE_IQ4_NL) { + key.src_type == GGML_TYPE_IQ4_NL) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { defines.push_back("BLOCK_SIZE=256u"); @@ -2325,6 +2334,7 @@ class ggml_webgpu_shader_lib { size_t storage_offset_alignment) { const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 762d9f8d1..f7fd73ae1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3918,6 +3918,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + supports_op = false; + break; + } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],