diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index efc5b8c97..449eae808 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -197,11 +197,12 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { /** RMS_NORM + MUL **/ struct ggml_webgpu_rms_norm_mul_pipeline_key { - bool inplace; - bool src_overlap; + bool inplace; // rn_src == dst + bool overlap; // mul_src == dst + bool src_overlap; // rn_src == mul_src bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { - return inplace == other.inplace && src_overlap == other.src_overlap; + return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; } }; @@ -209,6 +210,7 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } @@ -556,7 +558,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -1878,6 +1880,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rms_norm_mul_pipeline_key key = {}; key.inplace = context.inplace; + key.overlap = context.overlap; key.src_overlap = context.src_overlap; auto it = rms_norm_mul_pipelines.find(key); @@ -1892,6 +1895,9 @@ class ggml_webgpu_shader_lib { if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; } else if (key.src_overlap) { defines.push_back("SRC_OVERLAP"); variant += "_src_overlap"; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index bcca2bd46..acc486cfd 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2071,8 +2071,9 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); uint32_t offset_merged_rn_src = 0; @@ -2116,7 +2117,7 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context std::vector entries; - if (inplace) { + if (inplace || overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); } else if (src_overlap) { @@ -2136,6 +2137,7 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.inplace = inplace; + shader_lib_ctx.overlap = overlap; shader_lib_ctx.src_overlap = src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl index 71f063b51..74aaa2753 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -1,4 +1,4 @@ -#ifdef INPLACE +#ifdef OVERLAP @group(0) @binding(0) var rn_src: array; @@ -13,6 +13,21 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; } +#elif INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + #elif SRC_OVERLAP @group(0) @binding(0)