ggml-webgpu: Add the support of MUL_MAT_ID (#21147)

* Add mul_mat_id support to WebGPU

* Apply suggestion from @reeselevine

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
Masashi Yoshimura
2026-04-07 05:08:46 +09:00
committed by GitHub
parent 2e1f0a889e
commit d0a6dfeb28
7 changed files with 1113 additions and 620 deletions
+202
View File
@@ -1376,6 +1376,163 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
wgpu::CommandEncoder & encoder,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
// Get or create pipeline
webgpu_pipeline gather_pipeline, main_pipeline;
std::vector<webgpu_pipeline> pipelines;
std::vector<std::vector<uint32_t>> params_list;
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx);
main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx);
const uint32_t param_n_expert = (uint32_t) src0->ne[2];
const uint32_t param_n_expert_used = (uint32_t) dst->ne[1];
const uint32_t param_n_tokens = (uint32_t) dst->ne[2];
// params for mul_mat_id_gather.wgsl
std::vector<uint32_t> gather_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
param_n_expert,
param_n_expert_used,
param_n_tokens,
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
};
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t);
const size_t gathered_expert_used_align_offset = ROUNDUP_POW2(
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t gathered_tokens_align_offset =
ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes,
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t gathered_count_ids_align_offset =
ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes,
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
const size_t gathered_count_ids_binding_size =
ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
// bind group entries for mul_mat_id_gather.wgsl
std::vector<wgpu::BindGroupEntry> gather_entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src2),
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = gathered_expert_used_align_offset,
.size = gathered_binding_size },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = gathered_tokens_align_offset,
.size = gathered_binding_size },
{ .binding = 3,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = gathered_count_ids_align_offset,
.size = gathered_count_ids_binding_size },
};
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
const uint32_t gather_total_wg = param_n_expert;
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
pipelines.push_back(gather_pipeline);
params_list.push_back(std::move(gather_params));
entries_list.push_back(std::move(gather_entries));
workgroups_list.push_back({ gather_wg_x, gather_wg_y });
// params for mul_mat_id.wgsl
std::vector<uint32_t> main_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
param_n_expert,
param_n_expert_used,
param_n_tokens,
(uint32_t) src1->ne[1],
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
};
// bind group entries for mul_mat_id.wgsl
std::vector<wgpu::BindGroupEntry> main_entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
{ .binding = 3,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = gathered_expert_used_align_offset,
.size = gathered_binding_size },
{ .binding = 4,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = gathered_tokens_align_offset,
.size = gathered_binding_size },
{ .binding = 5,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = gathered_count_ids_align_offset,
.size = gathered_count_ids_binding_size },
};
// Calculate workgroup dimensions
uint32_t wg_x = 1;
uint32_t wg_y = 1;
auto * main_decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(main_pipeline.context.get());
uint32_t wg_m;
uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m;
uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n;
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
uint32_t total_gathered = dst->ne[1] * dst->ne[2];
uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered);
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
uint32_t total_wg = wg_m * max_wg_n;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
pipelines.push_back(main_pipeline);
params_list.push_back(std::move(main_params));
entries_list.push_back(std::move(main_entries));
workgroups_list.push_back({ wg_x, wg_y });
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list,
entries_list, workgroups_list);
}
#ifndef __EMSCRIPTEN__
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
wgpu::CommandEncoder & encoder,
@@ -2638,6 +2795,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context
return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node);
case GGML_OP_MUL_MAT:
return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node);
case GGML_OP_MUL_MAT_ID:
return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node);
case GGML_OP_FLASH_ATTN_EXT:
#ifndef __EMSCRIPTEN__
return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node);
@@ -3082,6 +3241,20 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
}
}
break;
case GGML_OP_MUL_MAT_ID:
{
const ggml_tensor * src0 = tensor->src[0];
const ggml_tensor * src1 = tensor->src[1];
if (src0 && src1) {
const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2];
const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2];
res = ROUNDUP_POW2(
res + gathered_size * 2 + gathered_count_ids_size +
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3,
WEBGPU_STORAGE_BUF_BINDING_MULT);
}
}
break;
default:
break;
}
@@ -3503,6 +3676,35 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
}
break;
}
case GGML_OP_MUL_MAT_ID:
switch (src1->type) {
case GGML_TYPE_F16:
supports_op |= (src0->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
supports_op = true;
break;
default:
break;
}
break;
default:
break;
}
break;
case GGML_OP_FLASH_ATTN_EXT:
{
#ifndef __EMSCRIPTEN__