fix(shader): handle the buffer aliasing for rms fuse (#22266)

This commit is contained in:
Chen Yuan
2026-04-23 19:32:59 -04:00
committed by GitHub
parent fa0b8a70a8
commit e5f070a1dc
3 changed files with 30 additions and 7 deletions
@@ -197,11 +197,12 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
/** RMS_NORM + MUL **/
struct ggml_webgpu_rms_norm_mul_pipeline_key {
bool inplace;
bool src_overlap;
bool inplace; // rn_src == dst
bool overlap; // mul_src == dst
bool src_overlap; // rn_src == mul_src
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
return inplace == other.inplace && src_overlap == other.src_overlap;
return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
}
};
@@ -209,6 +210,7 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.overlap);
ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
@@ -556,7 +558,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
size_t bytes_per_kv = 0;
if (!key.kv_direct) {
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
}
@@ -1878,6 +1880,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
key.inplace = context.inplace;
key.overlap = context.overlap;
key.src_overlap = context.src_overlap;
auto it = rms_norm_mul_pipelines.find(key);
@@ -1892,6 +1895,9 @@ class ggml_webgpu_shader_lib {
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
} else if (key.overlap) {
defines.push_back("OVERLAP");
variant += "_overlap";
} else if (key.src_overlap) {
defines.push_back("SRC_OVERLAP");
variant += "_src_overlap";
+4 -2
View File
@@ -2071,8 +2071,9 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
}
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst);
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
uint32_t offset_merged_rn_src = 0;
@@ -2116,7 +2117,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
std::vector<wgpu::BindGroupEntry> entries;
if (inplace) {
if (inplace || overlap) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
} else if (src_overlap) {
@@ -2136,6 +2137,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
shader_lib_ctx.overlap = overlap;
shader_lib_ctx.src_overlap = src_overlap;
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
@@ -1,4 +1,4 @@
#ifdef INPLACE
#ifdef OVERLAP
@group(0) @binding(0)
var<storage, read_write> rn_src: array<f32>;
@@ -13,6 +13,21 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32)
mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
}
#elif INPLACE
@group(0) @binding(0)
var<storage, read_write> rn_src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> mul_src: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
}
#elif SRC_OVERLAP
@group(0) @binding(0)