diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6593a9fe1..efc5b8c97 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -281,6 +281,25 @@ struct ggml_webgpu_conv2d_pipeline_key_hash { } }; +/** Im2Col **/ +struct ggml_webgpu_im2col_pipeline_key { + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_im2col_pipeline_key_hash { + size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -833,6 +852,8 @@ class ggml_webgpu_shader_lib { soft_max_pipelines; std::unordered_map conv2d_pipelines; + std::unordered_map + im2col_pipelines; std::unordered_maptype; + key.output_type = context.dst->type; + + auto it = im2col_pipelines.find(key); + if (it != im2col_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "im2col"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for IM2COL shader"); + } + }; + + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_im2col, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + im2col_pipelines[key] = pipeline; + return im2col_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 44e3bf822..bcca2bd46 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - uint32_t max_wg_size = - std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX); - uint32_t wg_size = - std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = wg_size; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - uint32_t n_out = ggml_nelements(dst); - uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); - uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - uint32_t wg_x = std::min(total_wg, max_wg); + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1; + + const uint32_t KW = src0->ne[0]; + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1]; + + const uint32_t IW = src1->ne[0]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2]; + + const uint32_t OW = dst->ne[1]; + const uint32_t OH = is_2D ? dst->ne[2] : 1; + + const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)); + const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0; + const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)); + const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)); + + const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)); + const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)); + const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0; + const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) : + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + si0, + si1, + si2, + si3, + so0, + so1, + so2, + so3, + + KW, + KH, + IC, + + IW, + IH, + N, + + OW, + OH, + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); uint32_t wg_y = CEIL_DIV(total_wg, wg_x); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); @@ -1988,8 +2071,8 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || - (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); uint32_t offset_merged_rn_src = 0; @@ -2689,6 +2772,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_sum_rows(ctx, src0, node); case GGML_OP_CONV_2D: return ggml_webgpu_conv_2d(ctx, src0, src1, node); + case GGML_OP_IM2COL: + return ggml_webgpu_im2col(ctx, src0, src1, node); default: return std::nullopt; } @@ -3455,7 +3540,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; - webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_arena.init( webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, @@ -3705,12 +3790,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { @@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); break; + case GGML_OP_IM2COL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl new file mode 100644 index 000000000..386ebab87 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl @@ -0,0 +1,101 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(1) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + KW: u32, KH: u32, IC: u32, + IW: u32, IH: u32, N: u32, + OW: u32, OH: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +} + +@group(0) @binding(2) +var params: Params; + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let K = params.KW * params.KH * params.IC; + let M = params.OW * params.OH; + let total = K * M * params.N; + + if (i_out >= total) { + return; + } + + // decode (k, m, n) + var i = i_out; + let n = i / (K * M); + i = i % (K * M); + let m = i / K; + let k = i % K; + + // decode (oh, ow) + let oh = m / params.OW; + let ow = m % params.OW; + + // decode (kw, kh, ic) + let kw = k % params.KW; + let tmp = k / params.KW; + let kh = tmp % params.KH; + let ic = tmp / params.KH; + + let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0); + let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1); + + if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) { + let iw = u32(iw_i32); + let ih = u32(ih_i32); + let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3; + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx)); + } else { + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0); + } +}