ggml-webgpu: add support for im2col (#22259)

* shader(im2col): implement the im2col shader

* shader(im2col): clean the formatting issues

* shader(im2col): clean the editorconfig checker warning

* fix(shader): address the workgroup issues of im2col and conv2d
This commit is contained in:
Chen Yuan
2026-04-22 23:17:41 -04:00
committed by GitHub
parent 86db42e97f
commit b76429a69c
3 changed files with 268 additions and 19 deletions
+108 -19
View File
@@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
};
uint32_t max_wg_size =
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
uint32_t wg_size =
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = wg_size;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t n_out = ggml_nelements(dst);
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
uint32_t wg_x = std::min(total_wg, max_wg);
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1;
const uint32_t KW = src0->ne[0];
const uint32_t KH = is_2D ? src0->ne[1] : 1;
const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1];
const uint32_t IW = src1->ne[0];
const uint32_t IH = is_2D ? src1->ne[1] : 1;
const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2];
const uint32_t OW = dst->ne[1];
const uint32_t OH = is_2D ? dst->ne[2] : 1;
const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type));
const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0;
const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type));
const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type));
const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type));
const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type));
const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0;
const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) :
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type));
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
si0,
si1,
si2,
si3,
so0,
so1,
so2,
so3,
KW,
KH,
IC,
IW,
IH,
N,
OW,
OH,
(uint32_t) s0,
(uint32_t) s1,
(uint32_t) p0,
(uint32_t) p1,
(uint32_t) d0,
(uint32_t) d1,
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
@@ -1988,8 +2071,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
}
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
uint32_t offset_merged_rn_src = 0;
@@ -2689,6 +2772,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
return ggml_webgpu_sum_rows(ctx, src0, node);
case GGML_OP_CONV_2D:
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
case GGML_OP_IM2COL:
return ggml_webgpu_im2col(ctx, src0, src1, node);
default:
return std::nullopt;
}
@@ -3455,7 +3540,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->param_arena.init(
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
@@ -3705,12 +3790,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {
@@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
break;
case GGML_OP_IM2COL:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;