From 6da7168312a01da0f1ea50c8eb909340bb192c4e Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 23 Apr 2026 02:51:40 +0900 Subject: [PATCH] ggml-webgpu: Add fused RMS_NORM + MUL (#21983) * fused rms_norm_mul + mul * Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion. * Decouple num_fused_ops from webgpu_context; misc cleanup * Fix eps handling and remove disable_fusion. * Fix not to use c++20 initializers. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 71 +++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 157 ++++++++++++++++-- .../wgsl-shaders/rms_norm_mul.wgsl | 139 ++++++++++++++++ 3 files changed, 349 insertions(+), 18 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f84dfee9d..6593a9fe1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -194,6 +194,26 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { } }; +/** RMS_NORM + MUL **/ + +struct ggml_webgpu_rms_norm_mul_pipeline_key { + bool inplace; + bool src_overlap; + + bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { + return inplace == other.inplace && src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -517,7 +537,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -755,16 +775,17 @@ class ggml_webgpu_shader_lib { std::unordered_map cumsum_pipelines; // key is fixed, no variants yet std::unordered_map row_norm_pipelines; // op/inplace + std::unordered_map - get_rows_pipelines; // src_type, vectorized + get_rows_pipelines; // src_type, vectorized std::unordered_map - unary_pipelines; // type/op/inplace + unary_pipelines; // type/op/inplace std::unordered_map - scale_pipelines; // inplace + scale_pipelines; // inplace std::unordered_map - solve_tri_pipelines; // type + solve_tri_pipelines; // type std::unordered_map - ssm_conv_pipelines; // type/vectorized + ssm_conv_pipelines; // type/vectorized std::unordered_map @@ -813,6 +834,11 @@ class ggml_webgpu_shader_lib { std::unordered_map conv2d_pipelines; + std::unordered_map + rms_norm_mul_pipelines; + public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -1828,6 +1854,39 @@ class ggml_webgpu_shader_lib { return unary_pipelines[key]; } + webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rms_norm_mul_pipeline_key key = {}; + key.inplace = context.inplace; + key.src_overlap = context.src_overlap; + + auto it = rms_norm_mul_pipelines.find(key); + if (it != rms_norm_mul_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = "RMS_NORM_MUL"; + std::string variant = op_name; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rms_norm_mul_pipelines[key] = pipeline; + return rms_norm_mul_pipelines[key]; + } + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_binary_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 551586751..5d3169904 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static std::optional ggml_webgpu_rms_norm_mul(webgpu_context & ctx, + ggml_tensor * rn_src, + ggml_tensor * rn_dst, + ggml_tensor * mul_src0, + ggml_tensor * mul_src1, + ggml_tensor * dst) { + ggml_tensor * mul_src; + + if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) { + mul_src = mul_src1; + } else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) { + mul_src = mul_src0; + } else { + GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); + } + + bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); + + uint32_t offset_merged_rn_src = 0; + uint32_t offset_merged_mul_src = 0; + size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src); + size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src); + + if (src_overlap) { + size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + offset_merged_rn_src = + (uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type)); + offset_merged_mul_src = + (uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type)); + } + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), + offset_merged_rn_src, + offset_merged_mul_src, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)), + (uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) mul_src->ne[0], + (uint32_t) mul_src->ne[1], + (uint32_t) mul_src->ne[2], + (uint32_t) mul_src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader + }; + + std::vector entries; + + if (inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + } else if (src_overlap) { + size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + size_t merged_end = + std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src), + mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src)); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; + shader_lib_ctx.src_overlap = src_overlap; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); +} + static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + return false; + } + + // additional constraints specific to this fusion + const ggml_tensor * rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + + return true; +} + // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode(webgpu_context ctx, + ggml_cgraph * cgraph, + int node_idx, + int & num_encoded_ops) { + ggml_tensor ** nodes = cgraph->nodes; + ggml_tensor * node = nodes[node_idx]; + if (ggml_is_empty(node)) { return std::nullopt; } if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { return std::nullopt; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; @@ -2519,6 +2640,13 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: + if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) { + num_encoded_ops = 2; + ggml_tensor * mul_node = nodes[node_idx + 1]; + return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node); + } else { + return ggml_webgpu_row_norm(ctx, src0, node); + } case GGML_OP_L2_NORM: return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str uint32_t num_inflight_batches = 0; bool contains_set_rows = false; bool batch_compute_passes = true; + int num_encoded_ops = 1; + int node_idx = 0; #ifdef GGML_WEBGPU_GPU_PROFILE ctx->profile_timestamp_query_count = 0; @@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + while (node_idx < cgraph->n_nodes) { + if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->param_arena.reset(); commands.clear(); } + + node_idx += num_encoded_ops; + num_encoded_ops = 1; } if (ctx->active_compute_pass) { @@ -3237,7 +3370,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; - webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_arena.init( webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, @@ -3487,12 +3620,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl new file mode 100644 index 000000000..71f063b51 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -0,0 +1,139 @@ +#ifdef INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif SRC_OVERLAP + +@group(0) @binding(0) +var merged_src: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset]; +} + +#else + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#endif + +struct Params { + offset_rn_src: u32, + offset_mul_src: u32, + offset_merged_rn_src: u32, + offset_merged_mul_src: u32, + offset_dst: u32, + + stride_rn_src1: u32, + stride_rn_src2: u32, + stride_rn_src3: u32, + + stride_mul_src1: u32, + stride_mul_src2: u32, + stride_mul_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + mul_src_ne0: u32, + mul_src_ne1: u32, + mul_src_ne2: u32, + mul_src_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } +#ifdef SRC_OVERLAP + sum += pow(merged_src[i_rn_src_row + col], 2.0); +#else + sum += pow(rn_src[i_rn_src_row + col], 2.0); +#endif + col += WG_SIZE; + } + + scratch[lid.x] = sum; + + workgroupBarrier(); + + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); + + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0); + col += WG_SIZE; + } +}