fix(shader): handle the buffer aliasing for rms fuse (#22266)
This commit is contained in:
@@ -197,11 +197,12 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
|
||||
/** RMS_NORM + MUL **/
|
||||
|
||||
struct ggml_webgpu_rms_norm_mul_pipeline_key {
|
||||
bool inplace;
|
||||
bool src_overlap;
|
||||
bool inplace; // rn_src == dst
|
||||
bool overlap; // mul_src == dst
|
||||
bool src_overlap; // rn_src == mul_src
|
||||
|
||||
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
|
||||
return inplace == other.inplace && src_overlap == other.src_overlap;
|
||||
return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -209,6 +210,7 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.overlap);
|
||||
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
||||
return seed;
|
||||
}
|
||||
@@ -556,7 +558,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!key.kv_direct) {
|
||||
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
|
||||
}
|
||||
@@ -1878,6 +1880,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
|
||||
key.inplace = context.inplace;
|
||||
key.overlap = context.overlap;
|
||||
key.src_overlap = context.src_overlap;
|
||||
|
||||
auto it = rms_norm_mul_pipelines.find(key);
|
||||
@@ -1892,6 +1895,9 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
} else if (key.overlap) {
|
||||
defines.push_back("OVERLAP");
|
||||
variant += "_overlap";
|
||||
} else if (key.src_overlap) {
|
||||
defines.push_back("SRC_OVERLAP");
|
||||
variant += "_src_overlap";
|
||||
|
||||
@@ -2071,8 +2071,9 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||
}
|
||||
|
||||
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst);
|
||||
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||
|
||||
uint32_t offset_merged_rn_src = 0;
|
||||
@@ -2116,7 +2117,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
if (inplace) {
|
||||
if (inplace || overlap) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
|
||||
} else if (src_overlap) {
|
||||
@@ -2136,6 +2137,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
shader_lib_ctx.overlap = overlap;
|
||||
shader_lib_ctx.src_overlap = src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#ifdef INPLACE
|
||||
#ifdef OVERLAP
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> rn_src: array<f32>;
|
||||
@@ -13,6 +13,21 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32)
|
||||
mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
|
||||
}
|
||||
|
||||
#elif INPLACE
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> rn_src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mul_src: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
|
||||
rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
|
||||
}
|
||||
|
||||
#elif SRC_OVERLAP
|
||||
|
||||
@group(0) @binding(0)
|
||||
|
||||
Reference in New Issue
Block a user