fix(shader): handle the buffer aliasing for rms fuse (#22266)
This commit is contained in:
@@ -2071,8 +2071,9 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||
}
|
||||
|
||||
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst);
|
||||
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||
|
||||
uint32_t offset_merged_rn_src = 0;
|
||||
@@ -2116,7 +2117,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
if (inplace) {
|
||||
if (inplace || overlap) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
|
||||
} else if (src_overlap) {
|
||||
@@ -2136,6 +2137,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
shader_lib_ctx.overlap = overlap;
|
||||
shader_lib_ctx.src_overlap = src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||
|
||||
Reference in New Issue
Block a user