diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 9d88f9805..f84dfee9d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -240,6 +240,27 @@ struct ggml_webgpu_ssm_conv_pipeline_key { } }; +/** CONV 2D */ +struct ggml_webgpu_conv2d_pipeline_key { + ggml_type weight_type; + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const { + return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_conv2d_pipeline_key_hash { + size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.weight_type); + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -789,6 +810,8 @@ class ggml_webgpu_shader_lib { rope_pipelines; std::unordered_map soft_max_pipelines; + std::unordered_map + conv2d_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -2382,6 +2405,46 @@ class ggml_webgpu_shader_lib { return soft_max_pipelines[key]; } + webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_conv2d_pipeline_key key = {}; + key.weight_type = context.src0->type; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = conv2d_pipelines.find(key); + if (it != conv2d_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "conv_2d"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for CONV_2D shader"); + } + }; + + push_type_defines("WEIGHT", key.weight_type); + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_conv2d, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + conv2d_pipelines[key] = pipeline; + return conv2d_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a29231452..551586751 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" +#include "ggml.h" #ifdef __EMSCRIPTEN__ # include @@ -921,6 +922,87 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + uint32_t max_wg_size = + std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX); + uint32_t wg_size = + std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = wg_size; + + webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t n_out = ggml_nelements(dst); + uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); + uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x = std::min(total_wg, max_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2477,6 +2559,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_SUM: case GGML_OP_SUM_ROWS: return ggml_webgpu_sum_rows(ctx, src0, node); + case GGML_OP_CONV_2D: + return ggml_webgpu_conv_2d(ctx, src0, src1, node); default: return std::nullopt; } @@ -3495,6 +3579,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SOLVE_TRI: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; break; + case GGML_OP_CONV_2D: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl new file mode 100644 index 000000000..9eb131dc2 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl @@ -0,0 +1,165 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(WEIGHT_F32) +var weights: array; +#elif defined(WEIGHT_F16) +var weights: array; +#endif + +@group(0) @binding(1) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(2) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_w: u32, + offset_i: u32, + offset_o: u32, + + // element strides + sw0: u32, sw1: u32, sw2: u32, sw3: u32, + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + // kernel dimensions + KW: u32, KH: u32, IC: u32, + // input dimensions + IW: u32, IH: u32, + // output dimensions + OW: u32, OH: u32, OC_out: u32, N_out: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +}; + +@group(0) @binding(3) +var params: Params; + +fn load_weight(idx: u32) -> f32 { + #if defined(WEIGHT_F32) + return weights[idx]; + #elif defined(WEIGHT_F16) + return f32(weights[idx]); + #endif +} + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +fn ceil_div_u32(x: u32, y: u32) -> u32 { + return (x + y - 1) / y; +} + +// returns the first valid kernel index k such that base + k * step >= 0 +fn first_valid_k(base: i32, step: u32) -> u32 { + if (base >= 0) { + return 0; + } + + return ceil_div_u32(u32(-base), step); +} + +// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k) +fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 { + let remaining = i32(limit) - base; + if (remaining <= 0) { + return 0; + } + + return min(k_max, ceil_div_u32(u32(remaining), step)); +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let n_out = params.OW * params.OH * params.OC_out * params.N_out; + + var sum: f32 = 0.0; + if (i_out >= n_out) { + return; + } + + // Kernel layout: [KW, KH, IC, ..] + // Input layout: [IW, IH, .., ..] + // Output layout: [OW, OH, OC, N] + + var i = i_out; + let n = i / (params.OC_out * params.OH * params.OW); + i = i % (params.OC_out * params.OH * params.OW); + let oc = i / (params.OH * params.OW); + i = i % (params.OH * params.OW); + let oh = i / params.OW; + let ow = i % params.OW; + + let ow_base = i32(ow * params.s0) - i32(params.p0); + let oh_base = i32(oh * params.s1) - i32(params.p1); + + // clip the valid kernel window once + let kw_begin = first_valid_k(ow_base, params.d0); + let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW); + let kh_begin = first_valid_k(oh_base, params.d1); + let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH); + + // entire receptive field is out of bounds + if (kw_begin >= kw_end || kh_begin >= kh_end) { + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, 0.0); + return; + } + + let weight_oc_base = params.offset_w + oc * params.sw3; + let input_n_base = params.offset_i + n * params.si3; + + for (var ic: u32 = 0; ic < params.IC; ic += 1) { + let w_base_ic = ic * params.sw2 + weight_oc_base; + let in_base = ic * params.si2 + input_n_base; + + for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) { + let ih = u32(oh_base + i32(kh * params.d1)); + let w_row_base = w_base_ic + kh * params.sw1; + let in_row_base = in_base + ih * params.si1; + for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) { + let iw = u32(ow_base + i32(kw * params.d0)); + let w_idx = w_row_base + kw * params.sw0; + let in_idx = in_row_base + iw * params.si0; + sum += load_weight(w_idx) * load_input(in_idx); + } + } + } + + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, sum); +}