ggml-webgpu: Add the support of MUL_MAT_ID (#21147)
* Add mul_mat_id support to WebGPU * Apply suggestion from @reeselevine --------- Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2e1f0a889e
commit
d0a6dfeb28
@@ -1376,6 +1376,163 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
wgpu::CommandEncoder & encoder,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.src2 = src2,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline gather_pipeline, main_pipeline;
|
||||
|
||||
std::vector<webgpu_pipeline> pipelines;
|
||||
std::vector<std::vector<uint32_t>> params_list;
|
||||
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
|
||||
|
||||
gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx);
|
||||
main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx);
|
||||
|
||||
const uint32_t param_n_expert = (uint32_t) src0->ne[2];
|
||||
const uint32_t param_n_expert_used = (uint32_t) dst->ne[1];
|
||||
const uint32_t param_n_tokens = (uint32_t) dst->ne[2];
|
||||
|
||||
// params for mul_mat_id_gather.wgsl
|
||||
std::vector<uint32_t> gather_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
|
||||
param_n_expert,
|
||||
param_n_expert_used,
|
||||
param_n_tokens,
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
};
|
||||
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t);
|
||||
|
||||
const size_t gathered_expert_used_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t gathered_tokens_align_offset =
|
||||
ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes,
|
||||
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t gathered_count_ids_align_offset =
|
||||
ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes,
|
||||
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
|
||||
const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t gathered_count_ids_binding_size =
|
||||
ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
// bind group entries for mul_mat_id_gather.wgsl
|
||||
std::vector<wgpu::BindGroupEntry> gather_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_expert_used_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_tokens_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 3,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_count_ids_align_offset,
|
||||
.size = gathered_count_ids_binding_size },
|
||||
};
|
||||
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
const uint32_t gather_total_wg = param_n_expert;
|
||||
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
|
||||
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
|
||||
|
||||
pipelines.push_back(gather_pipeline);
|
||||
params_list.push_back(std::move(gather_params));
|
||||
entries_list.push_back(std::move(gather_entries));
|
||||
workgroups_list.push_back({ gather_wg_x, gather_wg_y });
|
||||
|
||||
// params for mul_mat_id.wgsl
|
||||
std::vector<uint32_t> main_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
param_n_expert,
|
||||
param_n_expert_used,
|
||||
param_n_tokens,
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
};
|
||||
|
||||
// bind group entries for mul_mat_id.wgsl
|
||||
std::vector<wgpu::BindGroupEntry> main_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
{ .binding = 3,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_expert_used_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 4,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_tokens_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 5,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_count_ids_align_offset,
|
||||
.size = gathered_count_ids_binding_size },
|
||||
};
|
||||
|
||||
// Calculate workgroup dimensions
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
|
||||
auto * main_decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(main_pipeline.context.get());
|
||||
|
||||
uint32_t wg_m;
|
||||
|
||||
uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m;
|
||||
uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n;
|
||||
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
||||
uint32_t total_gathered = dst->ne[1] * dst->ne[2];
|
||||
uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered);
|
||||
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
|
||||
uint32_t total_wg = wg_m * max_wg_n;
|
||||
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
pipelines.push_back(main_pipeline);
|
||||
params_list.push_back(std::move(main_params));
|
||||
entries_list.push_back(std::move(main_entries));
|
||||
workgroups_list.push_back({ wg_x, wg_y });
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list,
|
||||
entries_list, workgroups_list);
|
||||
}
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
wgpu::CommandEncoder & encoder,
|
||||
@@ -2638,6 +2795,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context
|
||||
return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node);
|
||||
case GGML_OP_MUL_MAT:
|
||||
return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node);
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node);
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#ifndef __EMSCRIPTEN__
|
||||
return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node);
|
||||
@@ -3082,6 +3241,20 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const ggml_tensor * src0 = tensor->src[0];
|
||||
const ggml_tensor * src1 = tensor->src[1];
|
||||
if (src0 && src1) {
|
||||
const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2];
|
||||
const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2];
|
||||
res = ROUNDUP_POW2(
|
||||
res + gathered_size * 2 + gathered_count_ids_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -3503,6 +3676,35 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
supports_op |= (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
#ifndef __EMSCRIPTEN__
|
||||
|
||||
Reference in New Issue
Block a user