ggml-webgpu(shader): support conv2d kernels. (#21964)
* ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops * shader(conv2d): add conv2d shader kernels and pass f32 and f16 tests * shader(conv2d): fix the out of bounds memory access in the weight indexing * shader(conv2d): clean unused variables and optimize the computation * merge: use the new entries function * clean: address the formatting issues * clean: address the warning issues * clear: clean the shader editorconfig-checker issues * clear: clean the shader editorconfig-checker with utf-8 --------- Co-authored-by: Jeremy J. Hartmann <jeremy@mtion.tv>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-webgpu-shader-lib.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
@@ -921,6 +922,87 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
||||
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
||||
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
|
||||
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
|
||||
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
|
||||
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) src0->ne[2],
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
|
||||
(uint32_t) s0,
|
||||
(uint32_t) s1,
|
||||
(uint32_t) p0,
|
||||
(uint32_t) p1,
|
||||
(uint32_t) d0,
|
||||
(uint32_t) d1,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
|
||||
};
|
||||
|
||||
uint32_t max_wg_size =
|
||||
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
|
||||
uint32_t wg_size =
|
||||
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = wg_size;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t n_out = ggml_nelements(dst);
|
||||
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
|
||||
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
uint32_t wg_x = std::min(total_wg, max_wg);
|
||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
@@ -2477,6 +2559,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
return ggml_webgpu_sum_rows(ctx, src0, node);
|
||||
case GGML_OP_CONV_2D:
|
||||
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
|
||||
default:
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -3495,6 +3579,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_CONV_2D:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user