ggml webgpu: initial flashattention implementation (#18610)

* FlashAttention (#13)

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

* neg passes backend test

* unary operators pass ggml tests

* rms_norm double declaration bug atoned

* abides by editor-config

* removed vestigial files

* fixed autoconfig

* All operators (inlcluding xielu) working

* removed unnecesarry checking if node->src[1] exists for unary operators

* responded and dealt with PR comments

* implemented REPL_Template support and removed bug in unary operators kernel

* formatted embed wgsl and ggml-webgpu.cpp

* Faster tensors (#8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

* Wasm (#9)

* webgpu : fix build on emscripten

* more debugging stuff

* test-backend-ops: force single thread on wasm

* fix single-thread case for init_tensor_uniform

* use jspi

* add pthread

* test: remember to set n_thread for cpu backend

* Add buffer label and enable dawn-specific toggles to turn off some checks

* Intermediate state

* Fast working f16/f32 vec4

* Working float fast mul mat

* Clean up naming of mul_mat to match logical model, start work on q mul_mat

* Setup for subgroup matrix mat mul

* Basic working subgroup matrix

* Working subgroup matrix tiling

* Handle weirder sg matrix sizes (but still % sg matrix size)

* Working start to gemv

* working f16 accumulation with shared memory staging

* Print out available subgroup matrix configurations

* Vectorize dst stores for sg matrix shader

* Gemv working scalar

* Minor set_rows optimization (#4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Comment on dawn toggles

* Working subgroup matrix code for (semi)generic sizes

* Remove some comments

* Cleanup code

* Update dawn version and move to portable subgroup size

* Try to fix new dawn release

* Update subgroup size comment

* Only check for subgroup matrix configs if they are supported

* Add toggles for subgroup matrix/f16 support on nvidia+vulkan

* Make row/col naming consistent

* Refactor shared memory loading

* Move sg matrix stores to correct file

* Working q4_0

* Formatting

* Work with emscripten builds

* Fix test-backend-ops emscripten for f16/quantized types

* Use emscripten memory64 to support get_memory

* Add build flags and try ci

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>

* Remove extra whitespace

* Move wasm single-thread logic out of test-backend-ops for cpu backend

* Disable multiple threads for emscripten single-thread builds in ggml_graph_plan

* Refactored pipelines and workgroup calculations (#10)

* refactored pipelines

* refactored workgroup calculation

* removed commented out block of prior maps

* Clean up ceiling division pattern

---------

Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on flash attention

* Shader structure set up (many bugs still)

* debugging

* Working first test

* Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32

* Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling

* Start work on integrating pre-wgsl

* Separate structs/initial shader compilation library into separate files

* Work on compilation choices for flashattention

* Work on subgroup matrix/tile size portability

* subgroup size agnostic online softmax

* Cleanups, quantization types

* more cleanup

* fix wasm build

* Refactor flashattention to increase parallelism, use direct loads for KV in somce cases

* Checkpoint

* formatting

* Update to account for default kv cache padding

* formatting shader

* Add workflow for ggml-ci webgpu

* Try passing absolute path to dawn in ggml-ci

* Avoid error on device destruction, add todos for proper cleanup

* Fix unused warning

* Forgot one parameter unused

* Move some flashattn computation to f32 for correctness
This commit is contained in:
Reese Levine
2026-01-08 08:23:39 -08:00
committed by GitHub
parent 2524c26164
commit 15bff84bf5
6 changed files with 1838 additions and 47 deletions
+250 -38
View File
@@ -7,7 +7,9 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "ggml-wgsl-shaders.hpp"
#include "pre_wgsl.hpp"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
@@ -30,7 +32,7 @@
#ifdef GGML_WEBGPU_DEBUG
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
# define WEBGPU_DEBUG_BUF_ELEMS 32
# define WEBGPU_DEBUG_BUF_ELEMS 512
#else
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
#endif // GGML_WEBGPU_DEBUG
@@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool {
struct webgpu_pipeline {
wgpu::ComputePipeline pipeline;
std::string name;
void * context = nullptr;
};
struct webgpu_command {
@@ -263,6 +266,46 @@ struct webgpu_command {
#endif
};
struct flash_attn_pipeline_key {
int q_type;
int kv_type;
int dst_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
bool kv_direct;
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
bool operator==(const flash_attn_pipeline_key & other) const {
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap;
}
};
// Same hash combine function as in boost
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
struct flash_attn_pipeline_key_hash {
size_t operator()(const flash_attn_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.q_type);
ggml_webgpu_hash_combine(seed, key.kv_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.kv_direct);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
return seed;
}
};
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
wgpu::Instance instance;
@@ -271,12 +314,12 @@ struct webgpu_context_struct {
wgpu::Queue queue;
wgpu::Limits limits;
uint32_t subgroup_size;
uint32_t max_subgroup_size;
#ifndef __EMSCRIPTEN__
bool supports_subgroup_matrix = false;
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
#endif
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
std::recursive_mutex mutex;
std::atomic_uint inflight_threads = 0;
@@ -284,20 +327,24 @@ struct webgpu_context_struct {
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
pre_wgsl::Preprocessor p;
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
@@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context {
label(std::move(lbl)) {}
};
/* End struct definitions */
/* WebGPU object initializations */
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
@@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
std::cout << "debug data:";
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
std::cout << " " << i << ": " << debug_data[i];
}
std::cout << "\n";
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
std::cout << "debug[0]: " << debug_data[0] << "\n";
ctx->debug_host_buf.Unmap();
}
#endif
@@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
return ctx->name.c_str();
}
// TODO: implement proper cleanup
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
@@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
return ctx->buffer;
}
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
size_t offset = ggml_webgpu_tensor_offset(t);
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
}
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
size_t offset = ggml_webgpu_tensor_offset(t);
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
}
@@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
#ifndef __EMSCRIPTEN__
if (ctx->supports_subgroup_matrix) {
// The total number of subgroups/workgroups needed per matrix.
uint32_t wg_m_sg_tile =
WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
uint32_t wg_n_sg_tile =
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
} else {
#endif
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
@@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
ggml_tensor * V,
ggml_tensor * mask,
ggml_tensor * sinks,
ggml_tensor * dst) {
float scale = *(float *) dst->op_params;
float max_bias;
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
float logit_softcap;
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int has_mask = (mask != nullptr);
const int has_sinks = (sinks != nullptr);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) Q->ne[2], // number of heads
(uint32_t) Q->ne[1], // sequence length (Q)
(uint32_t) K->ne[1], // sequence length (K/V)
(uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
(uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
(uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
(uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
(uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
(uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
*(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
*(uint32_t *) &max_bias,
*(uint32_t *) &logit_softcap,
*(uint32_t *) &n_head_log2,
*(uint32_t *) &m0,
*(uint32_t *) &m1
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(Q),
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(K),
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(V),
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
.size = ggml_webgpu_tensor_binding_size(ctx, V) }
};
uint32_t binding_index = 3;
if (has_mask) {
entries.push_back({ .binding = binding_index++,
.buffer = ggml_webgpu_tensor_buf(mask),
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
}
if (has_sinks) {
entries.push_back({ .binding = binding_index++,
.buffer = ggml_webgpu_tensor_buf(sinks),
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
}
entries.push_back({ .binding = binding_index++,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
bool kv_direct =
(K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
flash_attn_pipeline_key key = {
.q_type = Q->type,
.kv_type = K->type,
.dst_type = dst->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
.has_mask = static_cast<bool>(has_mask),
.has_sinks = static_cast<bool>(has_sinks),
.uses_logit_softcap = logit_softcap != 0.0f,
};
webgpu_pipeline pipeline;
ggml_webgpu_flash_attn_shader_decisions decisions = {};
auto it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
} else {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
} else {
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
.has_mask = static_cast<bool>(has_mask),
.has_sinks = static_cast<bool>(has_sinks),
.uses_logit_softcap = logit_softcap != 0.0f,
.sg_mat_m = ctx->sg_mat_m,
.sg_mat_n = ctx->sg_mat_n,
.sg_mat_k = ctx->sg_mat_k,
.wg_mem_limit_bytes =
ctx->limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->max_subgroup_size };
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
ctx->flash_attn_pipelines.emplace(key, pipeline);
decisions = processed.decisions;
}
}
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
ggml_unary_op unary_op = ggml_get_unary_op(dst);
@@ -1397,6 +1576,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return ggml_webgpu_get_rows(ctx, src0, src1, node);
case GGML_OP_MUL_MAT:
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
case GGML_OP_FLASH_ATTN_EXT:
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
case GGML_OP_ADD:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
@@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
futures.push_back(new_futures);
}
ggml_backend_webgpu_wait(ctx, futures);
ctx->inflight_threads--;
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
@@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
#ifndef __EMSCRIPTEN__
if (webgpu_ctx->supports_subgroup_matrix) {
std::map<std::string, std::string> sg_matrix_repls;
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
proc_mul_mat_f32_f32_vec =
@@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
}
// TODO: move most initialization logic here
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
}
break;
}
case GGML_OP_FLASH_ATTN_EXT:
{
if (!webgpu_ctx->supports_subgroup_matrix) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
has_mask, kv_direct);
if (min_bytes > limit_bytes) {
break;
}
supports_op = src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
src2->type == src1->type && op->type == GGML_TYPE_F32;
break;
}
case GGML_OP_RMS_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
@@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
}
// TODO: Does this need to be thread safe? Is it only called once?
// TODO: move most logic to device_init function so backend can be freed/initialized properly
// Only one device is supported for now
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
@@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
ctx->subgroup_matrix_config = config;
ctx->sg_mat_m = config.M;
ctx->sg_mat_n = config.N;
ctx->sg_mat_k = config.K;
valid_subgroup_matrix_config = true;
break;
}
@@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
#endif
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
ctx->subgroup_size = info.subgroupMaxSize;
ctx->max_subgroup_size = info.subgroupMaxSize;
// Initialize device
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
@@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
GGML_UNUSED(reason);
GGML_UNUSED(message);
//TODO: uncomment once proper free logic is in place
//GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
//std::string(message).c_str());
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {