ggml-webgpu: fix compiler warnings and refactor FlashAttention encoding (#21052)

* Update workflows to remove dependence on llvmpipe

* Try setting Dawn_DIR

* remove c++20 initializers

* Move to proper guid

* Try avoiding segfaults on vulkan backend process exit

* Remove compiler warnings on parameter casting

* Fix soft_max and update reg_tile accumulation to f32 for better precision

* Refactor flash_attn a bit

* remove c++20 initializers and format

* Increase div precision for NVIDIA

* revert div precision and comment out ggml-ci node for now

* Formatting

* Try debugging on a failing CI node

* Revert "Try debugging on a failing CI node"

This reverts commit 1971e33cba919915e12bcfd5828abfbd54ca942e.
This commit is contained in:
Reese Levine
2026-04-17 09:17:11 -07:00
committed by GitHub
parent b94050e896
commit 45cac7ca70
4 changed files with 947 additions and 1176 deletions
+307 -276
View File
@@ -390,12 +390,11 @@ struct ggml_webgpu_flash_attn_pipeline_key {
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
bool use_vec;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
uses_logit_softcap == other.uses_logit_softcap;
}
};
@@ -409,47 +408,37 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
ggml_webgpu_hash_combine(seed, key.use_vec);
return seed;
}
};
struct ggml_webgpu_flash_attn_shader_lib_context {
ggml_webgpu_flash_attn_pipeline_key key;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
size_t wg_mem_limit_bytes;
uint32_t max_subgroup_size;
};
struct ggml_webgpu_flash_attn_shader_decisions {
struct ggml_webgpu_flash_attn_decisions {
uint32_t q_tile = 0;
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
};
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
// Keep conservative defaults unless this is the f16 vec-split shape family.
if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
return 1u;
}
struct ggml_webgpu_flash_attn_vec_decisions {
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
};
// Head-dim specializations used by the tuned vec f16 path.
switch (key.head_dim_qk) {
case 64:
return 2u;
case 96:
return 4u;
case 128:
return 1u;
case 192:
return 2u;
case 576:
return 2u;
default:
return 1u;
}
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
const ggml_webgpu_shader_lib_context & context) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;
const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
ggml_webgpu_flash_attn_pipeline_key key = {};
key.kv_type = context.src1->type;
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = kv_direct;
key.has_mask = has_mask;
key.has_sinks = has_sinks;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
return key;
}
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
@@ -471,79 +460,20 @@ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lh
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
}
struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_reduce";
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
variant += std::string("_wg") + std::to_string(context.max_wg_size);
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}
struct ggml_webgpu_flash_attn_blk_pipeline_key {
uint32_t q_tile;
uint32_t kv_tile;
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
return q_tile == other.q_tile && kv_tile == other.kv_tile;
}
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; }
};
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.q_tile);
ggml_webgpu_hash_combine(seed, key.kv_tile);
return seed;
}
};
struct ggml_webgpu_flash_attn_blk_shader_lib_context {
ggml_webgpu_flash_attn_blk_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_blk";
defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
variant += std::string("_qt") + std::to_string(context.key.q_tile);
defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
uint32_t wg_size = 1;
while ((wg_size << 1) <= context.max_wg_size) {
wg_size <<= 1;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
variant += std::string("_wg") + std::to_string(wg_size);
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}
// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
@@ -568,6 +498,41 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
}
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
const ggml_webgpu_flash_attn_pipeline_key & key) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!key.kv_direct) {
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
}
if (key.has_mask) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
}
inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) {
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile));
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
if (key.kv_direct) {
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
}
return kv_tile;
}
/** Matrix Multiplication **/
struct ggml_webgpu_legacy_mul_mat_pipeline_key {
@@ -802,6 +767,8 @@ class ggml_webgpu_shader_lib {
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_vec_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
webgpu_pipeline,
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
@@ -849,10 +816,9 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {
.op = context.dst->op,
.inplace = context.inplace,
};
ggml_webgpu_row_norm_pipeline_key key = {};
key.op = context.dst->op;
key.inplace = context.inplace;
auto it = row_norm_pipelines.find(key);
if (it != row_norm_pipelines.end()) {
@@ -908,9 +874,10 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
.vec4 = context.src0->ne[0] % 4 == 0,
.i64_idx = context.src1->type == GGML_TYPE_I64 };
ggml_webgpu_set_rows_pipeline_key key = {};
key.dst_type = context.dst->type;
key.vec4 = context.src0->ne[0] % 4 == 0;
key.i64_idx = context.src1->type == GGML_TYPE_I64;
auto it = set_rows_pipelines.find(key);
if (it != set_rows_pipelines.end()) {
@@ -955,7 +922,9 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
ggml_webgpu_set_pipeline_key key = {};
key.type = context.dst->type;
key.inplace = context.inplace;
auto it = set_pipelines.find(key);
if (it != set_pipelines.end()) {
@@ -1062,10 +1031,9 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
ggml_webgpu_get_rows_pipeline_key key = {
.src_type = context.src0->type,
.vectorized = (int) vectorized,
};
ggml_webgpu_get_rows_pipeline_key key = {};
key.src_type = context.src0->type;
key.vectorized = (int) vectorized;
auto it = get_rows_pipelines.find(key);
if (it != get_rows_pipelines.end()) {
@@ -1115,8 +1083,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (key.src_type)
{
switch (key.src_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@@ -1136,9 +1103,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
}
defines.push_back("BYTE_HELPERS");
@@ -1181,7 +1148,8 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
ggml_webgpu_scale_pipeline_key key = {};
key.inplace = context.inplace;
auto it = scale_pipelines.find(key);
if (it != scale_pipelines.end()) {
@@ -1208,11 +1176,10 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_solve_tri_pipeline_key key = {
.type = context.dst->type,
.n = (int) context.src0->ne[0],
.k = (int) context.src1->ne[0],
};
ggml_webgpu_solve_tri_pipeline_key key = {};
key.type = context.dst->type;
key.n = (int) context.src0->ne[0];
key.k = (int) context.src1->ne[0];
auto it = solve_tri_pipelines.find(key);
if (it != solve_tri_pipelines.end()) {
@@ -1250,10 +1217,9 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_ssm_conv_pipeline_key key = {
.type = context.dst->type,
.vectorized = context.src1->ne[0] == 4,
};
ggml_webgpu_ssm_conv_pipeline_key key = {};
key.type = context.dst->type;
key.vectorized = context.src1->ne[0] == 4;
auto it = ssm_conv_pipelines.find(key);
if (it != ssm_conv_pipelines.end()) {
@@ -1293,11 +1259,10 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_gated_delta_net_pipeline_key key = {
.type = context.dst->type,
.s_v = (int) context.src2->ne[0],
.kda = context.src3->ne[0] == context.src2->ne[0],
};
ggml_webgpu_gated_delta_net_pipeline_key key = {};
key.type = context.dst->type;
key.s_v = (int) context.src2->ne[0];
key.kda = context.src3->ne[0] == context.src2->ne[0];
auto it = gated_delta_net_pipelines.find(key);
if (it != gated_delta_net_pipelines.end()) {
@@ -1330,7 +1295,8 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
ggml_webgpu_pad_pipeline_key key = {};
key.circular = ggml_get_op_params_i32(context.dst, 8) != 0;
auto it = pad_pipelines.find(key);
if (it != pad_pipelines.end()) {
@@ -1357,15 +1323,13 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_vec_pipeline_key key = {
.src0_type = context.src0->type,
.src1_type = context.src1->type,
// Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float
.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0,
};
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
auto it = mul_mat_vec_pipelines.find(key);
if (it != mul_mat_vec_pipelines.end()) {
@@ -1451,15 +1415,14 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_pipeline_key key = {
.src0_type = context.src0->type,
.src1_type = context.src1->type,
.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0,
.use_subgroup_matrix = context.supports_subgroup_matrix
};
ggml_webgpu_mul_mat_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_subgroup_matrix = context.supports_subgroup_matrix;
auto it = mul_mat_fast_pipelines.find(key);
if (it != mul_mat_fast_pipelines.end()) {
@@ -1578,8 +1541,9 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,
.src1_type = context.src1->type };
ggml_webgpu_legacy_mul_mat_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
auto it = mul_mat_legacy_pipelines.find(key);
if (it != mul_mat_legacy_pipelines.end()) {
@@ -1621,8 +1585,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (context.src0->type)
{
switch (context.src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@@ -1642,9 +1605,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
}
defines.push_back("BYTE_HELPERS");
@@ -1689,10 +1652,9 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_id_pipeline_key key = {
.src0_type = context.src0->type,
.src1_type = context.src1->type,
};
ggml_webgpu_mul_mat_id_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
auto it = mul_mat_id_pipelines.find(key);
if (it != mul_mat_id_pipelines.end()) {
@@ -1782,13 +1744,12 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool is_unary = context.dst->op == GGML_OP_UNARY;
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
ggml_webgpu_unary_pipeline_key key = {
.type = context.dst->type,
.op = op,
.is_unary = is_unary,
.inplace = context.inplace,
.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
};
ggml_webgpu_unary_pipeline_key key = {};
key.type = context.dst->type;
key.op = op;
key.is_unary = is_unary;
key.inplace = context.inplace;
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
auto it = unary_pipelines.find(key);
if (it != unary_pipelines.end()) {
@@ -1853,13 +1814,12 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.src_overlap = context.src_overlap,
};
ggml_webgpu_binary_pipeline_key key = {};
key.type = context.dst->type;
key.op = context.dst->op;
key.inplace = context.inplace;
key.overlap = context.overlap;
key.src_overlap = context.src_overlap;
auto it = binary_pipelines.find(key);
if (it != binary_pipelines.end()) {
@@ -1908,9 +1868,8 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_concat_pipeline_key key = {
.type = context.dst->type,
};
ggml_webgpu_concat_pipeline_key key = {};
key.type = context.dst->type;
auto it = concat_pipelines.find(key);
if (it != concat_pipelines.end()) {
@@ -1945,9 +1904,8 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_repeat_pipeline_key key = {
.type = context.dst->type,
};
ggml_webgpu_repeat_pipeline_key key = {};
key.type = context.dst->type;
auto it = repeat_pipelines.find(key);
if (it != repeat_pipelines.end()) {
@@ -1985,16 +1943,16 @@ class ggml_webgpu_shader_lib {
return repeat_pipelines[key];
}
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
auto it = flash_attn_pipelines.find(context.key);
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
auto it = flash_attn_pipelines.find(key);
if (it != flash_attn_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "flash_attn";
switch (context.key.kv_type) {
switch (key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
@@ -2010,111 +1968,206 @@ class ggml_webgpu_shader_lib {
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(context.key.kv_type);
variant += std::string("_") + ggml_type_name(key.kv_type);
if (context.key.has_mask) {
if (key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (context.key.has_sinks) {
if (key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (context.key.uses_logit_softcap) {
if (key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (context.key.kv_direct) {
if (key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
if (context.key.has_mask && context.key.use_vec) {
defines.push_back("BLK");
variant += "_blk";
}
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (context.key.use_vec) {
q_tile = 1;
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
}
if (context.key.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
auto decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>();
decisions->q_tile = context.sg_mat_m;
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (key.kv_direct) {
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
}
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
decisions->kv_tile = kv_tile;
decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
uint32_t wg_size = 0;
if (context.key.use_vec) {
wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
} else {
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
pipeline.context = decisions;
flash_attn_pipelines[context.key] = pipeline;
return flash_attn_pipelines[context.key];
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant);
pipeline.context = decisions;
flash_attn_pipelines[key] = pipeline;
return flash_attn_pipelines[key];
}
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
auto it = flash_attn_blk_pipelines.find(context.key);
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
auto it = flash_attn_vec_pipelines.find(key);
if (it != flash_attn_vec_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "flash_attn_vec";
switch (key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
case GGML_TYPE_F16:
defines.push_back("KV_F16");
break;
case GGML_TYPE_Q4_0:
defines.push_back("KV_Q4_0");
break;
case GGML_TYPE_Q8_0:
defines.push_back("KV_Q8_0");
break;
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(key.kv_type);
if (key.has_mask) {
defines.push_back("MASK");
defines.push_back("BLK");
variant += "_mask_blk";
}
if (key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
defines.push_back("Q_TILE=1");
auto decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>();
decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
decisions->wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
uint32_t vec_ne = 1u;
// Keep conservative defaults unless this is the f16 vec-split shape family.
if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) {
switch (key.head_dim_qk) {
case 64:
case 192:
case 576:
vec_ne = 2u;
break;
case 96:
vec_ne = 4u;
break;
default:
break;
}
}
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
pipeline.context = decisions;
flash_attn_vec_pipelines[key] = pipeline;
return flash_attn_vec_pipelines[key];
}
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
auto it = flash_attn_blk_pipelines.find(key);
if (it != flash_attn_blk_pipelines.end()) {
return it->second;
}
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
flash_attn_blk_pipelines[context.key] = pipeline;
return flash_attn_blk_pipelines[context.key];
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_blk";
defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile));
variant += std::string("_kvt") + std::to_string(key.kv_tile);
uint32_t wg_size = 1;
while ((wg_size << 1) <= context.max_wg_size) {
wg_size <<= 1;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
variant += std::string("_wg") + std::to_string(wg_size);
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant);
flash_attn_blk_pipelines[key] = pipeline;
return flash_attn_blk_pipelines[key];
}
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
auto it = flash_attn_vec_reduce_pipelines.find(context.key);
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.wg_size = context.max_wg_size;
auto it = flash_attn_vec_reduce_pipelines.find(key);
if (it != flash_attn_vec_reduce_pipelines.end()) {
return it->second;
}
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
flash_attn_vec_reduce_pipelines[context.key] = pipeline;
return flash_attn_vec_reduce_pipelines[context.key];
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_reduce";
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
variant += std::string("_wg") + std::to_string(context.max_wg_size);
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant);
flash_attn_vec_reduce_pipelines[key] = pipeline;
return flash_attn_vec_reduce_pipelines[key];
}
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_cpy_pipeline_key key = {
.src_type = context.src0->type,
.dst_type = context.dst->type,
};
ggml_webgpu_cpy_pipeline_key key = {};
key.src_type = context.src0->type;
key.dst_type = context.dst->type;
auto it = cpy_pipelines.find(key);
if (it != cpy_pipelines.end()) {
@@ -2166,11 +2219,10 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_glu_pipeline_key key = {
.glu_op = ggml_get_glu_op(context.dst),
.type = context.dst->type,
.split = (context.src1 != nullptr),
};
ggml_webgpu_glu_pipeline_key key = {};
key.glu_op = ggml_get_glu_op(context.dst);
key.type = context.dst->type;
key.split = (context.src1 != nullptr);
auto it = glu_pipelines.find(key);
if (it != glu_pipelines.end()) {
@@ -2239,11 +2291,10 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rope_pipeline_key key = {
.type = context.dst->type,
.inplace = context.inplace,
.has_ff = (context.src2 != nullptr),
};
ggml_webgpu_rope_pipeline_key key = {};
key.type = context.dst->type;
key.inplace = context.inplace;
key.has_ff = (context.src2 != nullptr);
auto it = rope_pipelines.find(key);
if (it != rope_pipelines.end()) {
@@ -2288,12 +2339,11 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_soft_max_pipeline_key key = {
.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
.has_mask = (context.src1 != nullptr),
.has_sink = (context.src2 != nullptr),
.inplace = context.inplace,
};
ggml_webgpu_soft_max_pipeline_key key = {};
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
key.has_mask = (context.src1 != nullptr);
key.has_sink = (context.src2 != nullptr);
key.inplace = context.inplace;
auto it = soft_max_pipelines.find(key);
if (it != soft_max_pipelines.end()) {
@@ -2359,25 +2409,6 @@ class ggml_webgpu_shader_lib {
pipeline_desc.layout = nullptr; // nullptr means auto layout
return { device.CreateComputePipeline(&pipeline_desc), label };
}
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes =
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!context.key.kv_direct) {
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
}
if (context.key.has_mask) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
}
};
#endif // GGML_WEBGPU_SHADER_LIB_HPP
File diff suppressed because it is too large Load Diff
@@ -1,7 +1,6 @@
diagnostic(off, subgroup_uniformity);
enable f16;
#define Q_TILE 1
#define KV_TILE 32
#define WG_SIZE 32
@@ -11,7 +10,7 @@ struct Params {
seq_len_kv: u32,
stride_mask3: u32,
// Number of KV blocks and Q blocks per batch.
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE).
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q.
nblk0: u32,
nblk1: u32,
};
@@ -40,7 +39,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
return;
}
let q_start = q_blk * Q_TILE;
let q_start = q_blk;
let k_start = kv_blk * KV_TILE;
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
@@ -54,11 +53,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
var local_max = -MASK_MAX;
var local_any = 0u;
for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) {
let q_row = q_start + q_rel;
if (q_row >= params.seq_len_q) {
continue;
}
let q_row = q_start;
if (q_row < params.seq_len_q) {
let row_base = mask_batch_base + q_row * params.seq_len_kv;
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
let k_col = k_start + k_rel;