ggml-webgpu: support for SSM_SCAN and disable set_rows error checking (#22327)

* Implement ssm_scan

* Remove blocking in graph_compute and check for set rows

* Fix bindings

* Update op support
This commit is contained in:
Reese Levine
2026-04-24 23:18:15 -07:00
committed by GitHub
parent 0adede866d
commit dd2914dc81
5 changed files with 4168 additions and 1669 deletions
+88 -2
View File
@@ -1115,6 +1115,80 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * src3,
ggml_tensor * src4,
ggml_tensor * src5,
ggml_tensor * src6,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
(uint32_t) src3->ne[0],
(uint32_t) (src3->nb[1] / ggml_type_size(src3->type)),
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
(uint32_t) (src5->nb[1] / ggml_type_size(src5->type)),
(uint32_t) (src5->nb[2] / ggml_type_size(src5->type)),
(uint32_t) (src5->nb[3] / ggml_type_size(src5->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
(uint32_t) src4->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
(uint32_t) ggml_nelements(src1),
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst),
};
const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]);
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
uint32_t wg_x;
uint32_t wg_y;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
@@ -2764,6 +2838,9 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
case GGML_OP_SSM_CONV:
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
case GGML_OP_SSM_SCAN:
return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6],
node);
case GGML_OP_GATED_DELTA_NET:
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
case GGML_OP_PAD:
@@ -2822,7 +2899,10 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context &
}
#endif
// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support
// models that would require it right now.
static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) {
#ifdef GGML_WEBGPU_CHECK_SET_ROWS
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
ctx->set_rows_host_error_buf.GetSize());
@@ -2835,6 +2915,10 @@ static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t &
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
ctx->set_rows_host_error_buf.Unmap();
#else
GGML_UNUSED(ctx);
GGML_UNUSED(num_inflight_batches);
#endif
}
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
@@ -2920,8 +3004,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches);
}
ggml_backend_webgpu_wait_queue(ctx->global_ctx);
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
return GGML_STATUS_SUCCESS;
}
@@ -3941,6 +4023,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;
case GGML_OP_SSM_SCAN:
supports_op = op->type == GGML_TYPE_F32 &&
src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
break;
case GGML_OP_GATED_DELTA_NET:
{
const uint32_t s_v = (uint32_t) src2->ne[0];