ggml-webgpu: add vectorized flash attention (#20709)

* naive vectorized version

* add vectorized flash attention

* update vec version

* remove unused path and shader

* remove unused helper functions

* add comments

* remove pad path

* ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization

* change back to vec4

* enable multi split

* enable vec path when:
- Q->ne[1] < 20
- Q->ne[0] % 32 == 0
- V->ne[0] % 4 == 0
- K->type == f16

* update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select

* enable vec path for q4 and q8

* flash-attn vec nwg=1 fast path (skip tmp/reduce staging)

* use packed f16 K loads in flash-attn vec split

* use packed f16 K loads in flash-attn vec split on host side

* tune flash-attn vec f16 VEC_NE by head dim

* cleanup

* cleanup

* keep host side clean

* cleanup host side

* change back to original host wait/submit behavior

* formatting

* reverted param-buffer pool r ecfactor

* add helper functions

* ggml-webgpu: move flash-attn vec pipeline caching back into shader lib

* ggml-webgpu: remove duplicate functions

* ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation

* ggml-webgpu: revert unrelated change

* ggml-webgpu: revert deleted comment

* disable uniformity check

* remove unnecessary change

* Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl

* Update ggml/src/ggml-webgpu/ggml-webgpu.cpp

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
Zheyuan Chen
2026-04-02 10:40:42 -07:00
committed by GitHub
parent 5803c8d115
commit a1cfb64530
5 changed files with 1412 additions and 53 deletions
+309 -14
View File
@@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
for (size_t i = 0; i < params_bufs_list.size(); i++) {
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
}
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
@@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
#ifndef __EMSCRIPTEN__
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
@@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = Q,
.src1 = K,
.src2 = V,
.src3 = mask,
.src4 = sinks,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned &&
(Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
const uint32_t vec_nwg_cap =
std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
const bool use_blk = use_vec && has_mask;
ggml_webgpu_flash_attn_pipeline_key key = {
.kv_type = K->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
.has_mask = static_cast<bool>(has_mask),
.has_sinks = static_cast<bool>(has_sinks),
.uses_logit_softcap = logit_softcap != 0.0f,
.use_vec = use_vec,
};
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
.key = key,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
wgpu::Buffer blk_buf = {};
uint64_t blk_size_bytes = 0;
uint32_t blk_nblk0 = 0;
uint32_t blk_nblk1 = 0;
uint32_t blk_batch_count = 0;
if (use_vec) {
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
nwg = std::min(nwg, vec_nwg_cap);
GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size);
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
const bool use_vec_reduce = nwg > 1u;
GGML_ASSERT(nrows <= UINT32_MAX);
uint64_t tmp_stats_base = 0;
uint64_t tmp_size_bytes = 0;
wgpu::Buffer tmp_buf = {};
uint64_t tmp_bind_offset = 0;
uint64_t tmp_bind_size = 0;
const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes);
if (use_vec_reduce) {
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
tmp_stats_base = tmp_data_elems;
tmp_size_bytes =
ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
GGML_ASSERT(tmp_stats_base <= UINT32_MAX);
tmp_buf = ggml_webgpu_tensor_buf(dst);
tmp_bind_offset = scratch_offset;
tmp_bind_size = tmp_size_bytes;
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
} else {
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
tmp_buf = ggml_webgpu_tensor_buf(dst);
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
}
webgpu_pipeline blk_pipeline;
std::vector<uint32_t> blk_params;
std::vector<wgpu::BindGroupEntry> blk_entries;
if (use_blk) {
GGML_ASSERT(has_mask);
blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
blk_buf = ggml_webgpu_tensor_buf(dst);
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = {
.key =
{
.q_tile = decisions->q_tile,
.kv_tile = decisions->kv_tile,
},
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
blk_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
(uint32_t) Q->ne[1], // seq_len_q
(uint32_t) K->ne[1], // seq_len_kv
stride_mask3, // stride_mask3
blk_nblk0, // nblk0
blk_nblk1, // nblk1
};
blk_entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(mask),
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
.size = ggml_webgpu_tensor_binding_size(ctx, mask) },
{ .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes },
};
scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
}
std::vector<uint32_t> split_params = params;
if (use_blk) {
split_params.push_back(0u); // blk_base
split_params.push_back(blk_nblk0); // blk_nblk0
split_params.push_back(blk_nblk1); // blk_nblk1
}
split_params.push_back(0u); // tmp_data_base
split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base
split_params.push_back(nwg); // nwg
std::vector<wgpu::BindGroupEntry> split_entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(Q),
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(K),
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(V),
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
.size = ggml_webgpu_tensor_binding_size(ctx, V) },
};
uint32_t split_binding_index = 3;
if (has_mask) {
split_entries.push_back({ .binding = split_binding_index++,
.buffer = ggml_webgpu_tensor_buf(mask),
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
}
if (has_sinks) {
split_entries.push_back({ .binding = split_binding_index++,
.buffer = ggml_webgpu_tensor_buf(sinks),
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
}
if (use_blk) {
split_entries.push_back(
{ .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes });
}
split_entries.push_back(
{ .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size });
split_entries.push_back({ .binding = split_binding_index++,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
webgpu_pipeline reduce_pipeline;
std::vector<uint32_t> reduce_params;
std::vector<wgpu::BindGroupEntry> reduce_entries;
if (use_vec_reduce) {
const uint32_t reduce_wg_size = std::max(
32u,
std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = {
.key =
{
.head_dim_v = (uint32_t) V->ne[0],
.wg_size = reduce_wg_size,
},
.max_wg_size = reduce_wg_size,
};
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
reduce_params = {
(uint32_t) nrows, // nrows
(uint32_t) Q->ne[1], // seq_len_q
(uint32_t) Q->ne[2], // n_heads
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst
nwg, // nwg
0u, // tmp_data_base
(uint32_t) tmp_stats_base, // tmp_stats_base
};
reduce_entries = {
{ .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
};
}
const uint64_t split_wg_total = (uint64_t) wg_x * nwg;
GGML_ASSERT(split_wg_total <= UINT32_MAX);
std::vector<webgpu_pipeline> pipelines;
std::vector<std::vector<uint32_t>> params_list;
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
if (use_blk) {
pipelines.push_back(blk_pipeline);
params_list.push_back(std::move(blk_params));
entries_list.push_back(std::move(blk_entries));
workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count });
}
pipelines.push_back(pipeline);
params_list.push_back(std::move(split_params));
entries_list.push_back(std::move(split_entries));
workgroups_list.push_back({ (uint32_t) split_wg_total, 1u });
if (use_vec_reduce) {
pipelines.push_back(reduce_pipeline);
params_list.push_back(std::move(reduce_params));
entries_list.push_back(std::move(reduce_entries));
workgroups_list.push_back({ (uint32_t) nrows, 1u });
}
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
entries_list, workgroups_list);
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
#endif
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY;
@@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
std::vector<webgpu_submission> subs;
uint32_t num_batched_kernels = 0;
bool contains_set_rows = false;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
@@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
}
}
break;
case GGML_OP_FLASH_ATTN_EXT:
{
const ggml_tensor * Q = tensor->src[0];
const ggml_tensor * K = tensor->src[1];
const ggml_tensor * V = tensor->src[2];
const ggml_tensor * mask = tensor->src[3];
const ggml_tensor * sinks = tensor->src[4];
if (Q && K && V) {
GGML_UNUSED(sinks);
const bool kv_direct = (K->type == GGML_TYPE_F16) &&
(Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool use_vec =
(Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
(V->type == K->type);
if (use_vec) {
const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
const size_t limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const size_t q_tile = sg_mat_m;
const size_t base_q_bytes =
(Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!kv_direct) {
bytes_per_kv += std::max(Q->ne[0], V->ne[0]);
}
if (mask != nullptr) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
uint32_t kv_tile =
((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile));
kv_tile = (kv_tile / sg_mat_n) * sg_mat_n;
if (kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= sg_mat_n;
}
}
const uint32_t vec_nwg_cap = std::max(
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
nwg = std::min(nwg, vec_nwg_cap);
const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
if (nwg > 1u) {
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
const size_t tmp_size_bytes = ROUNDUP_POW2(
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += tmp_size_bytes + align;
}
if (mask != nullptr) {
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
const uint32_t stride_mask3 =
(uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
const size_t blk_size_bytes =
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += blk_size_bytes + align;
}
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
}
}
}
break;
default:
break;
}