ggml-webgpu: add support for im2col (#22259)
* shader(im2col): implement the im2col shader * shader(im2col): clean the formatting issues * shader(im2col): clean the editorconfig checker warning * fix(shader): address the workgroup issues of im2col and conv2d
This commit is contained in:
@@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
|
||||
};
|
||||
|
||||
uint32_t max_wg_size =
|
||||
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
|
||||
uint32_t wg_size =
|
||||
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = wg_size;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t n_out = ggml_nelements(dst);
|
||||
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
|
||||
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
uint32_t wg_x = std::min(total_wg, max_wg);
|
||||
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
||||
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
||||
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
|
||||
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
|
||||
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
|
||||
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
|
||||
const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1;
|
||||
|
||||
const uint32_t KW = src0->ne[0];
|
||||
const uint32_t KH = is_2D ? src0->ne[1] : 1;
|
||||
const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1];
|
||||
|
||||
const uint32_t IW = src1->ne[0];
|
||||
const uint32_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2];
|
||||
|
||||
const uint32_t OW = dst->ne[1];
|
||||
const uint32_t OH = is_2D ? dst->ne[2] : 1;
|
||||
|
||||
const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type));
|
||||
const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0;
|
||||
const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type));
|
||||
const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type));
|
||||
|
||||
const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type));
|
||||
const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type));
|
||||
const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0;
|
||||
const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) :
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type));
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
si0,
|
||||
si1,
|
||||
si2,
|
||||
si3,
|
||||
so0,
|
||||
so1,
|
||||
so2,
|
||||
so3,
|
||||
|
||||
KW,
|
||||
KH,
|
||||
IC,
|
||||
|
||||
IW,
|
||||
IH,
|
||||
N,
|
||||
|
||||
OW,
|
||||
OH,
|
||||
|
||||
(uint32_t) s0,
|
||||
(uint32_t) s1,
|
||||
(uint32_t) p0,
|
||||
(uint32_t) p1,
|
||||
(uint32_t) d0,
|
||||
(uint32_t) d1,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
|
||||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
@@ -1988,8 +2071,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||
}
|
||||
|
||||
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||
|
||||
uint32_t offset_merged_rn_src = 0;
|
||||
@@ -2689,6 +2772,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||
return ggml_webgpu_sum_rows(ctx, src0, node);
|
||||
case GGML_OP_CONV_2D:
|
||||
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
|
||||
case GGML_OP_IM2COL:
|
||||
return ggml_webgpu_im2col(ctx, src0, src1, node);
|
||||
default:
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -3455,7 +3540,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
||||
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
|
||||
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
||||
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
|
||||
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
||||
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
||||
webgpu_ctx->param_arena.init(
|
||||
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
|
||||
@@ -3705,12 +3790,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
||||
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
||||
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
||||
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
@@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user