ggml-webgpu: Add fused RMS_NORM + MUL (#21983)

* fused rms_norm_mul + mul

* Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion.

* Decouple num_fused_ops from webgpu_context; misc cleanup

* Fix eps handling and remove disable_fusion.

* Fix not to use c++20 initializers.
This commit is contained in:
Masashi Yoshimura
2026-04-23 02:51:40 +09:00
committed by GitHub
parent 8bccdbbff9
commit 6da7168312
3 changed files with 349 additions and 18 deletions
+145 -12
View File
@@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor *
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context & ctx,
ggml_tensor * rn_src,
ggml_tensor * rn_dst,
ggml_tensor * mul_src0,
ggml_tensor * mul_src1,
ggml_tensor * dst) {
ggml_tensor * mul_src;
if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) {
mul_src = mul_src1;
} else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) {
mul_src = mul_src0;
} else {
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
}
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
uint32_t offset_merged_rn_src = 0;
uint32_t offset_merged_mul_src = 0;
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
if (src_overlap) {
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
offset_merged_rn_src =
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
offset_merged_mul_src =
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
}
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)),
offset_merged_rn_src,
offset_merged_mul_src,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
(uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)),
(uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)),
(uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)),
(uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) mul_src->ne[0],
(uint32_t) mul_src->ne[1],
(uint32_t) mul_src->ne[2],
(uint32_t) mul_src->ne[3],
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader
};
std::vector<wgpu::BindGroupEntry> entries;
if (inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
} else if (src_overlap) {
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
size_t merged_end =
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
merged_end - merged_offset));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
shader_lib_ctx.src_overlap = src_overlap;
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
}
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
@@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) {
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
return false;
}
// additional constraints specific to this fusion
const ggml_tensor * rms_norm = cgraph->nodes[node_idx];
const ggml_tensor * mul = cgraph->nodes[node_idx + 1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
// rms_norm only supports f32
if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) {
return false;
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false;
}
// rms_norm shader assumes contiguous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
return true;
}
// Returns the encoded command, or std::nullopt if the operation is a no-op
static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
ggml_cgraph * cgraph,
int node_idx,
int & num_encoded_ops) {
ggml_tensor ** nodes = cgraph->nodes;
ggml_tensor * node = nodes[node_idx];
if (ggml_is_empty(node)) {
return std::nullopt;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return std::nullopt;
}
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")");
ggml_tensor * src0 = node->src[0];
ggml_tensor * src1 = node->src[1];
@@ -2519,6 +2640,13 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) {
num_encoded_ops = 2;
ggml_tensor * mul_node = nodes[node_idx + 1];
return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node);
} else {
return ggml_webgpu_row_norm(ctx, src0, node);
}
case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
@@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
uint32_t num_inflight_batches = 0;
bool contains_set_rows = false;
bool batch_compute_passes = true;
int num_encoded_ops = 1;
int node_idx = 0;
#ifdef GGML_WEBGPU_GPU_PROFILE
ctx->profile_timestamp_query_count = 0;
@@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
while (node_idx < cgraph->n_nodes) {
if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
}
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
ctx->param_arena.reset();
commands.clear();
}
node_idx += num_encoded_ops;
num_encoded_ops = 1;
}
if (ctx->active_compute_pass) {
@@ -3237,7 +3370,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->param_arena.init(
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
@@ -3487,12 +3620,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {