ggml-webgpu: Add supports for GGML_OP_REPEAT (#20230)
* Add GGML_OP_REPEAT to webgpu backend. * Add i16 support for GGML_OP_REPEAT.
This commit is contained in:
committed by
GitHub
parent
d28961d81e
commit
f2ab047f27
@@ -1567,6 +1567,48 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = { ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
|
||||
ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->ne[0]),
|
||||
(uint32_t) (src0->ne[1]),
|
||||
(uint32_t) (src0->ne[2]),
|
||||
(uint32_t) (src0->ne[3]),
|
||||
(uint32_t) (dst->ne[0]),
|
||||
(uint32_t) (dst->ne[1]),
|
||||
(uint32_t) (dst->ne[2]) };
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
@@ -2158,6 +2200,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node);
|
||||
case GGML_OP_CONCAT:
|
||||
return ggml_webgpu_concat(ctx, src0, src1, node);
|
||||
case GGML_OP_REPEAT:
|
||||
return ggml_webgpu_repeat(ctx, src0, node);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
case GGML_OP_ROPE:
|
||||
@@ -2919,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
||||
/* .alloc_buffer = */
|
||||
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
|
||||
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
|
||||
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
|
||||
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
|
||||
},
|
||||
/* .device = */
|
||||
dev,
|
||||
@@ -3000,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_CONCAT:
|
||||
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
|
||||
break;
|
||||
case GGML_OP_REPEAT:
|
||||
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
|
||||
Reference in New Issue
Block a user