ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix (#22199)

* ggml-webgpu: add tile flash attention fallback

* ggml-webgpu: add new fields and discard usage of mnk for tile version

* ggml-webgpu: modify the vec path to discard the mnk parameter

* ggml-webgpu: enable flash attention vec and tile version for broswer

* ggml-webgpu: stagging KV for flash attention tile version

* formatting

* turn on subgroup uniformity check

* remove Q_TILE as it is always 1 for vec path

* make row_max and exp_sum to local register

* make different bindings with same underlying buffer to have the same usage flags

* move path selection into the shader library and have the host consume a single flash-attn decision object.

* turn off skip_validation and address buffer overlapping when nwg==1

* formatting

* merge binding when kv overlap
This commit is contained in:
Zheyuan Chen
2026-04-24 10:39:09 -07:00
committed by GitHub
parent f65bc34c68
commit 13d36cf891
6 changed files with 817 additions and 400 deletions
+122 -77
View File
@@ -389,23 +389,6 @@ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_t
return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
}
static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx,
const ggml_tensor * Q,
const ggml_tensor * K,
const ggml_tensor * V) {
const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const uint32_t k_offset_elems =
(uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
const uint32_t v_offset_elems =
(uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
}
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
size_t offset = ggml_webgpu_tensor_offset(t);
return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
@@ -1567,7 +1550,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
return ggml_backend_webgpu_build_multi(ctx, dispatches);
}
#ifndef __EMSCRIPTEN__
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
@@ -1585,13 +1567,29 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int has_mask = (mask != nullptr);
const int has_sinks = (sinks != nullptr);
const int has_mask = (mask != nullptr);
const int has_sinks = (sinks != nullptr);
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type;
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
size_t kv_bind_offset = 0;
size_t kv_bind_size = 0;
if (kv_overlap) {
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K);
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V);
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K);
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V);
kv_bind_offset = std::min(k_bind_offset, v_bind_offset);
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
}
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
offset_k,
offset_v,
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1619,10 +1617,15 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V),
};
uint32_t binding_index = 3;
if (kv_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
}
uint32_t binding_index = kv_overlap ? 2u : 3u;
if (has_mask) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
}
@@ -1638,25 +1641,25 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
shader_lib_ctx.src3 = mask;
shader_lib_ctx.src4 = sinks;
shader_lib_ctx.dst = dst;
shader_lib_ctx.src_overlap = kv_overlap;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V);
webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) :
ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
if (!use_vec) {
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get());
wgpu::Buffer blk_buf = {};
uint64_t blk_size_bytes = 0;
uint32_t blk_nblk0 = 0;
@@ -1695,10 +1698,12 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
tmp_bind_size = tmp_size_bytes;
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
} else {
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
// nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region.
tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT;
tmp_buf = ggml_webgpu_tensor_buf(dst);
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
tmp_bind_offset = scratch_offset;
tmp_bind_size = tmp_size_bytes;
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
}
webgpu_pipeline blk_pipeline;
@@ -1713,7 +1718,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx;
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile);
blk_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
@@ -1745,12 +1750,19 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
std::vector<wgpu::BindGroupEntry> split_entries = {
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q),
ggml_webgpu_tensor_binding_size(ctx, Q)),
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K),
ggml_webgpu_tensor_binding_size(ctx, K)),
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V),
ggml_webgpu_tensor_binding_size(ctx, V)),
};
uint32_t split_binding_index = 3;
if (kv_overlap) {
split_entries.push_back(
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
} else {
split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K),
ggml_webgpu_tensor_align_offset(ctx, K),
ggml_webgpu_tensor_binding_size(ctx, K)));
split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V),
ggml_webgpu_tensor_align_offset(ctx, V),
ggml_webgpu_tensor_binding_size(ctx, V)));
}
uint32_t split_binding_index = kv_overlap ? 2u : 3u;
if (has_mask) {
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask),
ggml_webgpu_tensor_align_offset(ctx, mask),
@@ -1820,7 +1832,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
return ggml_backend_webgpu_build_multi(ctx, dispatches);
}
#endif // __EMSCRIPTEN__
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY;
@@ -2710,11 +2721,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
case GGML_OP_MUL_MAT_ID:
return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node);
case GGML_OP_FLASH_ATTN_EXT:
#ifndef __EMSCRIPTEN__
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
#else
return std::nullopt;
#endif
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -3257,13 +3264,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix =
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) {
const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx);
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
const uint32_t kv_tile = decisions.kv_tile;
const uint32_t vec_nwg_cap = std::max(
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
@@ -3283,6 +3296,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
const size_t tmp_size_bytes = ROUNDUP_POW2(
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += tmp_size_bytes + align;
} else {
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
}
if (mask != nullptr) {
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
@@ -3431,12 +3446,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
bool valid_subgroup_matrix_config = false;
#ifndef __EMSCRIPTEN__
// Accept f16 subgroup matrix configurations (square or non-square).
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
// The shaders are already parameterized to handle any M/N/K dimensions.
bool valid_subgroup_matrix_config = false;
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
@@ -3450,8 +3465,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
}
}
}
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
#endif
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
@@ -3499,12 +3514,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
// Enable Dawn-specific toggles to increase native performance
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
// only for native performance?
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
"disable_polyfills_on_integer_div_and_mod" };
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init",
"disable_polyfills_on_integer_div_and_mod" };
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
deviceTogglesDesc.enabledToggleCount = 4;
deviceTogglesDesc.enabledToggleCount = 3;
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
deviceTogglesDesc.disabledToggleCount = 1;
@@ -3782,33 +3797,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
case GGML_OP_FLASH_ATTN_EXT:
{
#ifndef __EMSCRIPTEN__
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
break;
}
// Head dimensions must be divisible by subgroup matrix dimensions
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {
break;
}
supports_op = src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
src2->type == src1->type && op->type == GGML_TYPE_F32;
#endif
if (!supports_op) {
break;
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.src2 = src2;
shader_lib_ctx.src3 = op->src[3];
shader_lib_ctx.src4 = op->src[4];
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.wg_mem_limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
supports_op = false;
break;
}
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}
case GGML_OP_RMS_NORM: