ggml-webgpu: JIT compile binary operators and handle binding overlaps (#19310)

* ggml webgpu: port binary operators to use pre-wgsl

* Add binary.wgsl: unified shader with conditionals for all 4 ops

* Add gen_binary_shaders.cpp: build tool for using pre_wgsl preprocessor

* Remove bin_op.tmpl.wgsl and binary.wgsl (Python template)

* Update CMake to generate binary operator shaders at build time

* ggml-webgpu: migrate binary ops to JIT compilation with overlap handling

* port binary operators from AOT to pre-wgsl JIT compilation

* add src1=dst overlap handling for binary ops

* use compile-time workgroup size defines instead of runtime overrides

* ggml-webgpu: complete overlap handling for binary ops

* add support for inplace & overlap case in binding setup

* restructure conditional logic to handle all overlap cases

* ensure all buffer bindings are correctly assigned for edge cases

* ggml-webgpu: remove unused binary overlap cases

Remove src0==src1 binary overlap case that never occurs in practice.

* keep INPLACE (src0==dst), OVERLAP (src1==dst), DEFAULT

* remove unused src0==src1 and all-same variant

* refactor wgsl to eliminate duplication
This commit is contained in:
Abhijit Ramesh
2026-02-06 10:33:30 -08:00
committed by GitHub
parent 537eadb1b9
commit 7fbd36c50c
5 changed files with 257 additions and 330 deletions
+81 -97
View File
@@ -348,13 +348,12 @@ struct webgpu_context_struct {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines;
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
@@ -823,6 +822,28 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
}
struct binary_overlap_flags {
bool inplace; // src0 == dst
bool overlap; // src1 == dst
};
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
return flags;
}
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
@@ -1375,14 +1396,42 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
webgpu_pipeline & pipeline,
bool inplace) {
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
ggml_webgpu_binary_pipeline_key pipeline_key = {
.type = dst->type,
.op = dst->op,
.inplace = flags.inplace,
.overlap = flags.overlap,
};
ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline;
auto it = ctx->binary_pipelines.find(pipeline_key);
if (it != ctx->binary_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->binary_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = {
(uint32_t) ggml_nelements(dst),
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1399,24 +1448,30 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
};
if (!inplace) {
std::vector<wgpu::BindGroupEntry> entries;
entries.push_back({
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
});
entries.push_back({
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
});
if (!flags.inplace && !flags.overlap) {
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -2038,25 +2093,10 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return std::nullopt;
#endif
case GGML_OP_ADD:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
}
case GGML_OP_SUB:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
}
case GGML_OP_MUL:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
}
case GGML_OP_DIV:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
}
return ggml_webgpu_binary_op(ctx, src0, src1, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
@@ -2665,58 +2705,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants);
webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants);
webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants);
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants);
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants);
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants);
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants);
webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants);
webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
@@ -3018,10 +3006,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_add_pipeline(webgpu_ctx);
ggml_webgpu_init_sub_pipeline(webgpu_ctx);
ggml_webgpu_init_mul_pipeline(webgpu_ctx);
ggml_webgpu_init_div_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);