ggml-webgpu: add support for im2col (#22259)

* shader(im2col): implement the im2col shader

* shader(im2col): clean the formatting issues

* shader(im2col): clean the editorconfig checker warning

* fix(shader): address the workgroup issues of im2col and conv2d
This commit is contained in:
Chen Yuan
2026-04-22 23:17:41 -04:00
committed by GitHub
parent 86db42e97f
commit b76429a69c
3 changed files with 268 additions and 19 deletions
@@ -281,6 +281,25 @@ struct ggml_webgpu_conv2d_pipeline_key_hash {
}
};
/** Im2Col **/
struct ggml_webgpu_im2col_pipeline_key {
ggml_type input_type;
ggml_type output_type;
bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const {
return input_type == other.input_type && output_type == other.output_type;
}
};
struct ggml_webgpu_im2col_pipeline_key_hash {
size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.input_type);
ggml_webgpu_hash_combine(seed, key.output_type);
return seed;
}
};
/** Gated Delta Net **/
struct ggml_webgpu_gated_delta_net_pipeline_key {
int type;
@@ -833,6 +852,8 @@ class ggml_webgpu_shader_lib {
soft_max_pipelines;
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
conv2d_pipelines;
std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash>
im2col_pipelines;
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
webgpu_pipeline,
@@ -2504,6 +2525,44 @@ class ggml_webgpu_shader_lib {
return conv2d_pipelines[key];
}
webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_im2col_pipeline_key key = {};
key.input_type = context.src1->type;
key.output_type = context.dst->type;
auto it = im2col_pipelines.find(key);
if (it != im2col_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "im2col";
auto push_type_defines = [&](const char * prefix, ggml_type type) {
std::string s_prefix = prefix;
if (type == GGML_TYPE_F32) {
defines.push_back(s_prefix + "_F32");
} else if (type == GGML_TYPE_F16) {
defines.push_back(s_prefix + "_F16");
} else {
GGML_ABORT("Unsupported type for IM2COL shader");
}
};
push_type_defines("INPUT", key.input_type);
push_type_defines("OUTPUT", key.output_type);
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_im2col, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
im2col_pipelines[key] = pipeline;
return im2col_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,
+99 -10
View File
@@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
};
uint32_t max_wg_size =
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
uint32_t wg_size =
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = wg_size;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t n_out = ggml_nelements(dst);
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
uint32_t wg_x = std::min(total_wg, max_wg);
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1;
const uint32_t KW = src0->ne[0];
const uint32_t KH = is_2D ? src0->ne[1] : 1;
const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1];
const uint32_t IW = src1->ne[0];
const uint32_t IH = is_2D ? src1->ne[1] : 1;
const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2];
const uint32_t OW = dst->ne[1];
const uint32_t OH = is_2D ? dst->ne[2] : 1;
const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type));
const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0;
const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type));
const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type));
const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type));
const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type));
const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0;
const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) :
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type));
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
si0,
si1,
si2,
si3,
so0,
so1,
so2,
so3,
KW,
KH,
IC,
IW,
IH,
N,
OW,
OH,
(uint32_t) s0,
(uint32_t) s1,
(uint32_t) p0,
(uint32_t) p1,
(uint32_t) d0,
(uint32_t) d1,
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
@@ -2689,6 +2772,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
return ggml_webgpu_sum_rows(ctx, src0, node);
case GGML_OP_CONV_2D:
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
case GGML_OP_IM2COL:
return ggml_webgpu_im2col(ctx, src0, src1, node);
default:
return std::nullopt;
}
@@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
break;
case GGML_OP_IM2COL:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;
@@ -0,0 +1,101 @@
#include "common_decls.tmpl"
enable f16;
@group(0) @binding(0)
#if defined(INPUT_F32)
var<storage, read_write> input: array<f32>;
#elif defined(INPUT_F16)
var<storage, read_write> input: array<f16>;
#endif
@group(0) @binding(1)
#if defined(OUTPUT_F32)
var<storage, read_write> output: array<f32>;
#elif defined(OUTPUT_F16)
var<storage, read_write> output: array<f16>;
#endif
struct Params {
offset_i: u32,
offset_o: u32,
// element strides
si0: u32, si1: u32, si2: u32, si3: u32,
so0: u32, so1: u32, so2: u32, so3: u32,
KW: u32, KH: u32, IC: u32,
IW: u32, IH: u32, N: u32,
OW: u32, OH: u32,
// stride
s0: u32, s1: u32,
// padding
p0: u32, p1: u32,
// dilation
d0: u32, d1: u32,
}
@group(0) @binding(2)
var<uniform> params: Params;
fn load_input(idx: u32) -> f32 {
#if defined(INPUT_F32)
return input[idx];
#elif defined(INPUT_F16)
return f32(input[idx]);
#endif
}
fn store_output(idx: u32, val: f32) {
#if defined(OUTPUT_F32)
output[idx] = val;
#elif defined(OUTPUT_F16)
output[idx] = f16(val);
#endif
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let threads_per_group = u32(WG_SIZE);
let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y;
let K = params.KW * params.KH * params.IC;
let M = params.OW * params.OH;
let total = K * M * params.N;
if (i_out >= total) {
return;
}
// decode (k, m, n)
var i = i_out;
let n = i / (K * M);
i = i % (K * M);
let m = i / K;
let k = i % K;
// decode (oh, ow)
let oh = m / params.OW;
let ow = m % params.OW;
// decode (kw, kh, ic)
let kw = k % params.KW;
let tmp = k / params.KW;
let kh = tmp % params.KH;
let ic = tmp / params.KH;
let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0);
let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1);
if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) {
let iw = u32(iw_i32);
let ih = u32(ih_i32);
let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3;
store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx));
} else {
store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0);
}
}