ggml-webgpu: add support for im2col (#22259)

* shader(im2col): implement the im2col shader

* shader(im2col): clean the formatting issues

* shader(im2col): clean the editorconfig checker warning

* fix(shader): address the workgroup issues of im2col and conv2d
This commit is contained in:
Chen Yuan
2026-04-22 23:17:41 -04:00
committed by GitHub
parent 86db42e97f
commit b76429a69c
3 changed files with 268 additions and 19 deletions
@@ -281,6 +281,25 @@ struct ggml_webgpu_conv2d_pipeline_key_hash {
} }
}; };
/** Im2Col **/
struct ggml_webgpu_im2col_pipeline_key {
ggml_type input_type;
ggml_type output_type;
bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const {
return input_type == other.input_type && output_type == other.output_type;
}
};
struct ggml_webgpu_im2col_pipeline_key_hash {
size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.input_type);
ggml_webgpu_hash_combine(seed, key.output_type);
return seed;
}
};
/** Gated Delta Net **/ /** Gated Delta Net **/
struct ggml_webgpu_gated_delta_net_pipeline_key { struct ggml_webgpu_gated_delta_net_pipeline_key {
int type; int type;
@@ -833,6 +852,8 @@ class ggml_webgpu_shader_lib {
soft_max_pipelines; soft_max_pipelines;
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash> std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
conv2d_pipelines; conv2d_pipelines;
std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash>
im2col_pipelines;
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key, std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
webgpu_pipeline, webgpu_pipeline,
@@ -2504,6 +2525,44 @@ class ggml_webgpu_shader_lib {
return conv2d_pipelines[key]; return conv2d_pipelines[key];
} }
webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_im2col_pipeline_key key = {};
key.input_type = context.src1->type;
key.output_type = context.dst->type;
auto it = im2col_pipelines.find(key);
if (it != im2col_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "im2col";
auto push_type_defines = [&](const char * prefix, ggml_type type) {
std::string s_prefix = prefix;
if (type == GGML_TYPE_F32) {
defines.push_back(s_prefix + "_F32");
} else if (type == GGML_TYPE_F16) {
defines.push_back(s_prefix + "_F16");
} else {
GGML_ABORT("Unsupported type for IM2COL shader");
}
};
push_type_defines("INPUT", key.input_type);
push_type_defines("OUTPUT", key.output_type);
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_im2col, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
im2col_pipelines[key] = pipeline;
return im2col_pipelines[key];
}
private: private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code, std::string shader_code,
+108 -19
View File
@@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
}; };
uint32_t max_wg_size =
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
uint32_t wg_size =
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
ggml_webgpu_shader_lib_context shader_lib_ctx = {}; ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0; shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1; shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = wg_size; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t n_out = ggml_nelements(dst); uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
uint32_t wg_x = std::min(total_wg, max_wg);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1;
const uint32_t KW = src0->ne[0];
const uint32_t KH = is_2D ? src0->ne[1] : 1;
const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1];
const uint32_t IW = src1->ne[0];
const uint32_t IH = is_2D ? src1->ne[1] : 1;
const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2];
const uint32_t OW = dst->ne[1];
const uint32_t OH = is_2D ? dst->ne[2] : 1;
const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type));
const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0;
const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type));
const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type));
const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type));
const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type));
const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0;
const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) :
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type));
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
si0,
si1,
si2,
si3,
so0,
so1,
so2,
so3,
KW,
KH,
IC,
IW,
IH,
N,
OW,
OH,
(uint32_t) s0,
(uint32_t) s1,
(uint32_t) p0,
(uint32_t) p1,
(uint32_t) d0,
(uint32_t) d1,
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x); uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
@@ -1988,8 +2071,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
} }
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
uint32_t offset_merged_rn_src = 0; uint32_t offset_merged_rn_src = 0;
@@ -2689,6 +2772,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
return ggml_webgpu_sum_rows(ctx, src0, node); return ggml_webgpu_sum_rows(ctx, src0, node);
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
return ggml_webgpu_conv_2d(ctx, src0, src1, node); return ggml_webgpu_conv_2d(ctx, src0, src1, node);
case GGML_OP_IM2COL:
return ggml_webgpu_im2col(ctx, src0, src1, node);
default: default:
return std::nullopt; return std::nullopt;
} }
@@ -3455,7 +3540,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>(); webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->param_arena.init( webgpu_ctx->param_arena.init(
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
@@ -3705,12 +3790,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break; break;
} }
// Head dimensions must fit in workgroup memory with minimum tile sizes // Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr; const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 && const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) { if (min_bytes > limit_bytes) {
@@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
break; break;
case GGML_OP_IM2COL:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32; supports_op = op->type == GGML_TYPE_F32;
break; break;
@@ -0,0 +1,101 @@
#include "common_decls.tmpl"
enable f16;
@group(0) @binding(0)
#if defined(INPUT_F32)
var<storage, read_write> input: array<f32>;
#elif defined(INPUT_F16)
var<storage, read_write> input: array<f16>;
#endif
@group(0) @binding(1)
#if defined(OUTPUT_F32)
var<storage, read_write> output: array<f32>;
#elif defined(OUTPUT_F16)
var<storage, read_write> output: array<f16>;
#endif
struct Params {
offset_i: u32,
offset_o: u32,
// element strides
si0: u32, si1: u32, si2: u32, si3: u32,
so0: u32, so1: u32, so2: u32, so3: u32,
KW: u32, KH: u32, IC: u32,
IW: u32, IH: u32, N: u32,
OW: u32, OH: u32,
// stride
s0: u32, s1: u32,
// padding
p0: u32, p1: u32,
// dilation
d0: u32, d1: u32,
}
@group(0) @binding(2)
var<uniform> params: Params;
fn load_input(idx: u32) -> f32 {
#if defined(INPUT_F32)
return input[idx];
#elif defined(INPUT_F16)
return f32(input[idx]);
#endif
}
fn store_output(idx: u32, val: f32) {
#if defined(OUTPUT_F32)
output[idx] = val;
#elif defined(OUTPUT_F16)
output[idx] = f16(val);
#endif
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let threads_per_group = u32(WG_SIZE);
let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y;
let K = params.KW * params.KH * params.IC;
let M = params.OW * params.OH;
let total = K * M * params.N;
if (i_out >= total) {
return;
}
// decode (k, m, n)
var i = i_out;
let n = i / (K * M);
i = i % (K * M);
let m = i / K;
let k = i % K;
// decode (oh, ow)
let oh = m / params.OW;
let ow = m % params.OW;
// decode (kw, kh, ic)
let kw = k % params.KW;
let tmp = k / params.KW;
let kh = tmp % params.KH;
let ic = tmp / params.KH;
let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0);
let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1);
if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) {
let iw = u32(iw_i32);
let ih = u32(ih_i32);
let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3;
store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx));
} else {
store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0);
}
}