From 13d36cf89178354d9aa6732e5930d89d64caf718 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Fri, 24 Apr 2026 10:39:09 -0700 Subject: [PATCH] ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix (#22199) * ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 326 +++++++++-------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 199 +++++++---- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 47 ++- .../wgsl-shaders/flash_attn_tile.wgsl | 330 ++++++++++++++++++ .../wgsl-shaders/flash_attn_vec_blk.wgsl | 2 +- .../wgsl-shaders/flash_attn_vec_split.wgsl | 313 ++++++++--------- 6 files changed, 817 insertions(+), 400 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 449eae808..e492c2123 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -436,19 +436,27 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ +enum ggml_webgpu_flash_attn_path : uint32_t { + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u, +}; + struct ggml_webgpu_flash_attn_pipeline_key { ggml_type kv_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; + bool kv_overlap; bool has_mask; bool has_sinks; bool uses_logit_softcap; + uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; @@ -459,39 +467,70 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.path); return seed; } }; struct ggml_webgpu_flash_attn_decisions { - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; + bool kv_direct = false; }; -struct ggml_webgpu_flash_attn_vec_decisions { - uint32_t kv_tile = 0; - uint32_t wg_size = 0; -}; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; + +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || + key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + return 2u; + case 96: + return 4u; + default: + return 1u; + } +} inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_shader_lib_context & context, + uint32_t path) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; - const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + bool kv_direct = false; + if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; + if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + kv_direct_align = context.sg_mat_k; + } + kv_direct = (context.src1->type == GGML_TYPE_F16) && + (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + } ggml_webgpu_flash_attn_pipeline_key key = {}; key.kv_type = context.src1->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; + key.kv_overlap = context.src_overlap; key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + key.path = path; return key; } @@ -554,8 +593,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, const ggml_webgpu_flash_attn_pipeline_key & key) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; + const size_t limit_bytes = context.wg_mem_limit_bytes; + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_granularity = context.sg_mat_n; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + kv_granularity = std::max(1u, context.max_subgroup_size); + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + q_tile = 1u; + kv_granularity = 8u; + } const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; @@ -568,23 +615,90 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ bytes_per_kv += q_tile; bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; + return (max_kv_tile / kv_granularity) * kv_granularity; } -inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; +inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( + const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + ggml_webgpu_flash_attn_decisions decisions = {}; + const size_t alignment = std::max(1u, storage_offset_alignment); + const auto * K = context.src1; + const auto * V = context.src2; + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); - if (key.kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; + const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; + }; + + const uint32_t k_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && + (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); + const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && + V->type == GGML_TYPE_F16 && f16_vec4_aligned && + (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; + + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + decisions.kv_direct = key.kv_direct; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + decisions.q_tile = 1u; + decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile)); + decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + if (decisions.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= 8u; + } + } + return decisions; + } + + decisions.q_tile = + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; + decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) : + std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size); + decisions.kv_tile = + std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity); + } + + if (decisions.kv_direct) { + GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::max(1u, context.max_subgroup_size) : + context.sg_mat_n; } } - return kv_tile; + return decisions; } /** Matrix Multiplication **/ @@ -821,8 +935,6 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; - std::unordered_map - flash_attn_vec_pipelines; std::unordered_map @@ -2044,14 +2156,19 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + const ggml_webgpu_flash_attn_decisions decisions = + ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } std::vector defines; - std::string variant = "flash_attn"; + std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : + "flash_attn"; switch (key.kv_type) { case GGML_TYPE_F32: @@ -2073,7 +2190,12 @@ class ggml_webgpu_shader_lib { if (key.has_mask) { defines.push_back("MASK"); - variant += "_mask"; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("BLK"); + variant += "_mask_blk"; + } else { + variant += "_mask"; + } } if (key.has_sinks) { defines.push_back("SINKS"); @@ -2087,6 +2209,10 @@ class ggml_webgpu_shader_lib { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); @@ -2094,129 +2220,37 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - auto decisions = std::make_shared(); - decisions->q_tile = context.sg_mat_m; - - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - - if (key.kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } + const char * shader_src = wgsl_flash_attn; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("KV_GRANULARITY=8"); + defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); + shader_src = wgsl_flash_attn_vec_split; + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + shader_src = wgsl_flash_attn_tile; + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); + variant += "_tile"; + } else { + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - decisions->kv_tile = kv_tile; - decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + auto pipeline_decisions = std::make_shared(decisions); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); - pipeline.context = decisions; + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); + pipeline.context = pipeline_decisions; flash_attn_pipelines[key] = pipeline; return flash_attn_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_vec_pipelines.find(key); - if (it != flash_attn_vec_pipelines.end()) { - return it->second; - } - - std::vector defines; - std::string variant = "flash_attn_vec"; - - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - if (key.has_mask) { - defines.push_back("MASK"); - defines.push_back("BLK"); - variant += "_mask_blk"; - } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); - - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - defines.push_back("Q_TILE=1"); - - auto decisions = std::make_shared(); - decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); - decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - uint32_t vec_ne = 1u; - - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - vec_ne = 2u; - break; - case 96: - vec_ne = 4u; - break; - default: - break; - } - } - - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - - webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); - pipeline.context = decisions; - flash_attn_vec_pipelines[key] = pipeline; - return flash_attn_vec_pipelines[key]; - } - - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { ggml_webgpu_flash_attn_blk_pipeline_key key = {}; - key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + key.kv_tile = kv_tile; auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index acc486cfd..7ed6fdd16 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -389,23 +389,6 @@ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_t return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } -static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, - const ggml_tensor * Q, - const ggml_tensor * K, - const ggml_tensor * V) { - const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const uint32_t k_offset_elems = - (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - - return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); -} - static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); @@ -1567,7 +1550,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -#ifndef __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1585,13 +1567,29 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); + const int has_mask = (mask != nullptr); + const int has_sinks = (sinks != nullptr); + const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type; + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + if (kv_overlap) { + const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K); + const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V); + const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K); + const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V); + kv_bind_offset = std::min(k_bind_offset, v_bind_offset); + kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset; + offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type)); + offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type)); + } std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + offset_k, + offset_v, has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1619,10 +1617,15 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, }; std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; - uint32_t binding_index = 3; + if (kv_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + } + uint32_t binding_index = kv_overlap ? 2u : 3u; if (has_mask) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } @@ -1638,25 +1641,25 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.src3 = mask; shader_lib_ctx.src4 = sinks; shader_lib_ctx.dst = dst; + shader_lib_ctx.src_overlap = kv_overlap; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); - webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : - ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( + shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); - if (!use_vec) { - auto * decisions = static_cast(pipeline.context.get()); + if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } - auto * decisions = static_cast(pipeline.context.get()); - wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; uint32_t blk_nblk0 = 0; @@ -1695,10 +1698,12 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, tmp_bind_size = tmp_size_bytes; scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + // nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region. + tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT; tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); } webgpu_pipeline blk_pipeline; @@ -1713,7 +1718,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); blk_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask @@ -1745,12 +1750,19 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector split_entries = { ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), ggml_webgpu_tensor_binding_size(ctx, Q)), - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), - ggml_webgpu_tensor_binding_size(ctx, K)), - ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), - ggml_webgpu_tensor_binding_size(ctx, V)), }; - uint32_t split_binding_index = 3; + if (kv_overlap) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), + ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K))); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), + ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V))); + } + uint32_t split_binding_index = kv_overlap ? 2u : 3u; if (has_mask) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), ggml_webgpu_tensor_align_offset(ctx, mask), @@ -1820,7 +1832,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -#endif // __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -2710,11 +2721,7 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, case GGML_OP_MUL_MAT_ID: return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: -#ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); -#else - return std::nullopt; -#endif case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -3257,13 +3264,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { - const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const uint32_t kv_tile = decisions.kv_tile; const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3283,6 +3296,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const size_t tmp_size_bytes = ROUNDUP_POW2( (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; } if (mask != nullptr) { const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); @@ -3431,12 +3446,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ // Accept f16 subgroup matrix configurations (square or non-square). // NVIDIA GPUs typically report square configs (e.g. 16x16x16), // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). // The shaders are already parameterized to handle any M/N/K dimensions. - bool valid_subgroup_matrix_config = false; if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; @@ -3450,8 +3465,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } } } - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; #endif + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. @@ -3499,12 +3514,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { // Enable Dawn-specific toggles to increase native performance // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; wgpu::DawnTogglesDescriptor deviceTogglesDesc; deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.enabledToggleCount = 3; deviceTogglesDesc.disabledToggles = deviceDisabledToggles; deviceTogglesDesc.disabledToggleCount = 1; @@ -3782,33 +3797,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { -#ifndef __EMSCRIPTEN__ - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - break; - } - // Head dimensions must be divisible by subgroup matrix dimensions - if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 || - src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) { - break; - } - // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, - (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); - if (min_bytes > limit_bytes) { - break; - } - supports_op = src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && src2->type == src1->type && op->type == GGML_TYPE_F32; -#endif + if (!supports_op) { + break; + } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = op->src[3]; + shader_lib_ctx.src4 = op->src[4]; + shader_lib_ctx.dst = const_cast(op); + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } + break; + } + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } + break; + } + + if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = false; + break; + } + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index aa2d2e54d..6d5d69fb8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -138,25 +138,54 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array; +#define V K +#else @group(0) @binding(1) var K: array; @group(0) @binding(2) var V: array; +#endif #if defined(MASK) && defined(SINKS) -@group(0) @binding(3) var mask: array; -@group(0) @binding(4) var sinks: array; -#define DST_BINDING 5 -#define PARAMS_BINDING 6 -#elif defined(MASK) -@group(0) @binding(3) var mask: array; -#define DST_BINDING 4 -#define PARAMS_BINDING 5 -#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; @group(0) @binding(3) var sinks: array; #define DST_BINDING 4 #define PARAMS_BINDING 5 #else +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; #define DST_BINDING 3 #define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif #endif @group(0) @binding(DST_BINDING) var dst: array>; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl new file mode 100644 index 000000000..37ea23b80 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -0,0 +1,330 @@ +enable f16; +enable subgroups; + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 +#define KV_STAGE_STRIDE 64 +#define Q_TILE 4 +#define KV_TILE 64 +#define WG_SIZE 128 + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + q_per_kv: u32, + + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array>; +#define V K +#else +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; +#endif + +#if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif +#endif + +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; +const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; +const V_CHUNKS: u32 = HEAD_DIM_V / 4u; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; + +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + if (subgroup_size == 0u || num_subgroups < Q_TILE) { + return; + } + + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + let global_q_row = q_row_start + subgroup_id; + let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q; + +#ifdef MASK + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); + + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_tile_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_tile_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col] * params.scale, + head_q_row < params.seq_len_q)); + } + + workgroupBarrier(); + + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + var out_regs: array, OUT_REGS_PER_LANE>; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] = vec4(0.0); + } + + let q_base = subgroup_id * HEAD_DIM_QK; + let subgroup_p_offset = subgroup_id * KV_TILE; + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); + let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size); + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + var local_scores: array; + for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) { + local_scores[slot] = FLOAT_MIN; + } + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = k4.x; + kv_shmem[kv_off + 1u] = k4.y; + kv_shmem[kv_off + 2u] = k4.z; + kv_shmem[kv_off + 3u] = k4.w; + } + + workgroupBarrier(); + + var local_max = FLOAT_MIN; + if (row_active) { + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (kv_local >= kv_count) { + continue; + } + + let global_k_row = kv_tile + kv_local; + var dot_val = 0.0; + for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { + let q_off = q_base + chunk * 4u; + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + dot_val += dot(qv, kv); + } +#ifdef LOGIT_SOFTCAP + dot_val = params.logit_softcap * tanh(dot_val); +#endif +#ifdef MASK + let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row; + dot_val += slope * f32(mask[mask_idx]); +#endif + local_scores[slot] = dot_val; + local_max = max(local_max, dot_val); + } + } + + let tile_max = subgroupMax(local_max); + let new_max = max(row_max, tile_max); + let cur_exp = exp(row_max - new_max); + exp_sum *= cur_exp; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= cur_exp; + } + + var local_sum = 0.0; + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (row_active && kv_local < kv_count) { + let p = exp(local_scores[slot] - new_max); + p_shmem[subgroup_p_offset + kv_local] = p; + local_sum += p; + } + } + + workgroupBarrier(); + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = v4.x; + kv_shmem[kv_off + 1u] = v4.y; + kv_shmem[kv_off + 2u] = v4.z; + kv_shmem[kv_off + 3u] = v4.w; + } + + workgroupBarrier(); + + let tile_sum = subgroupAdd(local_sum); + exp_sum += tile_sum; + row_max = new_max; + + if (row_active) { + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + + var acc = out_regs[reg_idx]; + for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { + let p = p_shmem[subgroup_p_offset + kv_local]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let v4 = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + acc += p * v4; + } + out_regs[reg_idx] = acc; + } + } + + workgroupBarrier(); + } + +#ifdef SINKS + if (row_active) { + let sink_score = sinks[params.offset_sinks + head_idx]; + let sink_max = max(row_max, sink_score); + let sink_scale = exp(row_max - sink_max); + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= sink_scale; + } + exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max); + row_max = sink_max; + } +#endif + + if (row_active) { + let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base = dst_global_offset + subgroup_id * dst2_stride; + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + let dst_vec_index = (row_base + chunk * 4u) >> 2u; + dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 61107c6a9..b4f7c16c3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -15,7 +15,7 @@ struct Params { nblk1: u32, }; -@group(0) @binding(0) var mask: array; +@group(0) @binding(0) var mask: array; @group(0) @binding(1) var blk: array; @group(0) @binding(2) var params: Params; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index a52575871..b1e234784 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -1,8 +1,6 @@ -diagnostic(off, chromium.subgroup_matrix_uniformity); diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32 @@ -13,19 +11,14 @@ enable chromium_experimental_subgroup_matrix; #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 - -#define SG_MAT_M 8 -#define SG_MAT_N 8 -#define SG_MAT_K 8 - -#define Q_TILE SG_MAT_M +#define KV_GRANULARITY 8 #define KV_TILE 16 #define WG_SIZE 64 #ifndef VEC_NE #define VEC_NE 4u #endif -#define KV_BLOCKS (KV_TILE / SG_MAT_N) +#define KV_BLOCKS (KV_TILE / KV_GRANULARITY) #define BLOCK_SIZE 32 #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) @@ -97,6 +90,14 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#define V K +#else #if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(1) var K: array; #else @@ -107,7 +108,22 @@ struct Params { #else @group(0) @binding(2) var V: array>; #endif +#endif #if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; #ifdef BLK @@ -120,7 +136,21 @@ struct Params { #define DST_BINDING 6 #define PARAMS_BINDING 7 #endif +#endif #elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#ifdef BLK +#define BLK_BINDING 3 +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else @group(0) @binding(3) var mask: array; #ifdef BLK #define BLK_BINDING 4 @@ -132,16 +162,30 @@ struct Params { #define DST_BINDING 5 #define PARAMS_BINDING 6 #endif +#endif #elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else @group(0) @binding(3) var sinks: array; #define TMP_BINDING 4 #define DST_BINDING 5 #define PARAMS_BINDING 6 +#endif +#else +#ifdef KV_OVERLAP +#define TMP_BINDING 2 +#define DST_BINDING 3 +#define PARAMS_BINDING 4 #else #define TMP_BINDING 3 #define DST_BINDING 4 #define PARAMS_BINDING 5 #endif +#endif #ifdef BLK @group(0) @binding(BLK_BINDING) var blk: array; @@ -153,7 +197,7 @@ struct Params { // Just a very small float value. const FLOAT_MIN: f32 = -1.0e9; -var q_shmem: array; +var q_shmem: array; #ifndef KV_DIRECT const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); @@ -161,31 +205,27 @@ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var kv_shmem: array; #endif -var o_shmem: array; +var o_shmem: array; #ifdef MASK // storage for mask values -var mask_shmem: array; +var mask_shmem: array; #endif // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; // Storage for row max and exp sum during online softmax -var row_max_shmem: array; -var exp_sum_shmem: array; -var blk_state_wg: u32; - -fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { +fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + f32(inter_shmem[kv_idx]) * params.scale, kv_idx < KV_TILE); #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE); v += select(mask_val, slope * mask_val, has_bias); } #endif @@ -199,19 +239,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(subgroup_size) subgroup_size: u32, @builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { + // Vec path processes exactly one query row per workgroup, so subgroup 0 can + // keep the running softmax state in private storage. + var row_max = FLOAT_MIN; + var exp_sum = 0.0; - // initialize row max for online softmax - for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { - row_max_shmem[i] = FLOAT_MIN; - exp_sum_shmem[i] = 0.0; - } - - for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) { o_shmem[i] = 0.0; } // workgroups per head/batch - let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_head = params.seq_len_q; let wg_per_batch = wg_per_head * params.n_heads; let dst2_stride = HEAD_DIM_V * params.n_heads; @@ -235,9 +273,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; - // starting Q row for this workgroup + // Vec path handles one Q row per workgroup. let wg_in_head = wg_in_batch % wg_per_head; - let q_row_start = wg_in_head * Q_TILE; + let q_row_start = wg_in_head; #ifdef MASK // mask offset @@ -248,21 +286,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let has_bias = params.max_bias > 0.0; let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); - // load q tile into shared memory - for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let q_row = elem_idx / HEAD_DIM_QK; - let q_col = elem_idx % HEAD_DIM_QK; - let head_q_row = q_row_start + q_row; - let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + // load the single Q row into shared memory + for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { + let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; q_shmem[elem_idx] = f16(select( 0.0, - Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + Q[global_q_row_offset + elem_idx], + q_row_start < params.seq_len_q)); } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { #ifdef BLK - let q_blk = q_row_start / Q_TILE; + let q_blk = q_row_start; let kv_blk = kv_tile / KV_TILE; let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; @@ -270,13 +305,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #else let blk_state_local = 1u; #endif - if (local_id.x == 0u) { - blk_state_wg = blk_state_local; - } - workgroupBarrier(); - let blk_state = blk_state_wg; + let blk_state = blk_state_local; let skip_tile = blk_state == 0u; - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { inter_shmem[elem_idx] = f16(0.0); } @@ -360,20 +391,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let num_of_threads = subgroup_size / VEC_NE; let tx = sg_inv_id % num_of_threads; let ty = sg_inv_id / num_of_threads; - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - continue; - } - let local_q_row_offset = q_tile_row * HEAD_DIM_QK; - + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { let kv_idx = kv_base + ty; var partial_sum: f32 = 0.0; let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; if (kv_valid) { for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { - let q_off = local_q_row_offset + i * 4u; + let q_off = i * 4u; let qv = vec4( f32(q_shmem[q_off + 0u]), @@ -410,8 +435,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); if (tx == 0u && kv_valid) { - let dst_idx = q_tile_row * KV_TILE + kv_idx; - inter_shmem[dst_idx] = f16(sum_bcast); + inter_shmem[kv_idx] = f16(sum_bcast); } } } @@ -422,13 +446,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let apply_mask = !skip_tile && (blk_state != 2u); if (apply_mask) { // load mask tile into shared memory for this KV block - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - let mask_row = elem_idx / KV_TILE; - let mask_col = elem_idx % KV_TILE; - let global_q_row = q_row_start + mask_row; - let global_k_col = kv_tile + mask_col; - let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; - let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { + let global_k_col = kv_tile + elem_idx; + let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + global_k_col; mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); } } @@ -439,50 +460,40 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); // online softmax - if (!skip_tile) { - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } + if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } - var prev_max = row_max_shmem[q_tile_row]; - var final_max = prev_max; - // pass 1: compute final max across the full KV tile in chunks - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; - let softmax_term = select(FLOAT_MIN, - calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask), - kv_valid); - final_max = subgroupMax(max(final_max, softmax_term)); + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx] = f16(cur_p); } + } - var total_exp_term: f32 = 0.0; - // pass 2: compute exp sum and write P using final_max - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); - let cur_p = select(0.0, - exp(softmax_term - final_max), - kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); - total_exp_term += subgroupAdd(cur_p); - if (kv_idx < KV_TILE) { - inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); - } - } + let cur_exp = exp(prev_max - final_max); - let cur_exp = exp(prev_max - final_max); + row_max = final_max; + exp_sum = exp_sum * cur_exp + total_exp_term; - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = final_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); - } + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp); } } @@ -562,15 +573,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); if (!skip_tile) { - // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V across the full KV tile let ne_threads : u32 = VEC_NE; let nl_threads = max(1u, subgroup_size / ne_threads); let tx_pv = sg_inv_id % nl_threads; let ty_pv = sg_inv_id / nl_threads; - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { var lo = vec4(0.0, 0.0, 0.0, 0.0); for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { @@ -580,7 +589,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } - let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); + let p = f32(inter_shmem[kv_idx]); #ifdef KV_DIRECT let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; let v4 = vec4(V[v_idx >> 2u]); @@ -621,11 +630,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (ty_pv == 0u) { let elem_base = vec_col * 4u; - let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; - o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x); - o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y); - o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z); - o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w); + o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x); + o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y); + o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z); + o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w); } } } @@ -637,70 +645,46 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef SINKS // Sinks are global terms and must be applied exactly once across split workgroups. - if (iwg == 0u) { - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } + if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; - var prev_max = row_max_shmem[q_tile_row]; + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); - // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum - let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); - let new_max = subgroupMax(max(prev_max, sink_val)); - let max_exp = exp(prev_max - new_max); - let sink_exp = exp(sink_val - new_max); + let sink_exp_sum = subgroupAdd(sink_exp); - let sink_exp_sum = subgroupAdd(sink_exp); + row_max = new_max; + exp_sum = exp_sum * max_exp + sink_exp_sum; - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = new_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp); - } + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp); } - workgroupBarrier(); } + workgroupBarrier(); #endif let rows_per_batch = params.n_heads * params.seq_len_q; - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { break; } - + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { if (params.nwg == 1u) { - let exp_sum = exp_sum_shmem[q_tile_row]; let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); - let row_base: u32 = - params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V; + let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride + + head_idx * HEAD_DIM_V; for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - let v = vec4( - f32(o_shmem[i0]) * scale, - f32(o_shmem[i1]) * scale, - f32(o_shmem[i2]) * scale, - f32(o_shmem[i3]) * scale + f32(o_shmem[elem_base + 0u]) * scale, + f32(o_shmem[elem_base + 1u]) * scale, + f32(o_shmem[elem_base + 2u]) * scale, + f32(o_shmem[elem_base + 3u]) * scale ); let dst_vec_index: u32 = (row_base + elem_base) >> 2u; dst[dst_vec_index] = v; } } else { - let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; @@ -708,21 +692,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - let tbase = tmp_row_data_base + elem_base; - tmp[tbase + 0u] = f32(o_shmem[i0]); - tmp[tbase + 1u] = f32(o_shmem[i1]); - tmp[tbase + 2u] = f32(o_shmem[i2]); - tmp[tbase + 3u] = f32(o_shmem[i3]); + tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]); + tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]); + tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]); + tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]); } if (sg_inv_id == 0u) { - tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; - tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 0u] = exp_sum; + tmp[tmp_row_stats_base + 1u] = row_max; } } }