ggml-webgpu: updated matrix-vector multiplication (#21738)

* merged properly, but slow q3_k and q5_k with u32 indexing

* Start on new mat-vec

* New format float paths working

* Working q4_0

* Work on remaining legacy q-types

* port k-quants to new matvec

* remove old shader

* Remove old constants, format

* remove accidental file

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
neha-ha
2026-04-20 07:37:17 -07:00
committed by GitHub
parent a678916623
commit a6cc43c286
4 changed files with 816 additions and 411 deletions
+17 -11
View File
@@ -181,6 +181,7 @@ struct webgpu_dispatch_desc {
struct webgpu_capabilities {
wgpu::Limits limits;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
@@ -1164,14 +1165,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
use_fast = true;
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
use_fast = !is_vec;
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q2_K:
use_fast = true;
break;
default:
break;
@@ -1182,10 +1180,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
@@ -1287,7 +1287,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
// Get or create pipeline
webgpu_pipeline gather_pipeline, main_pipeline;
webgpu_pipeline gather_pipeline;
webgpu_pipeline main_pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
@@ -3040,6 +3041,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
// we require f16 support
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
#ifndef __EMSCRIPTEN__
// Accept f16 subgroup matrix configurations (square or non-square).
@@ -3072,11 +3075,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
#ifndef __EMSCRIPTEN__
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
required_features.push_back(wgpu::FeatureName::Subgroups);
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
}
#endif
if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) {
required_features.push_back(wgpu::FeatureName::Subgroups);
}
#ifdef GGML_WEBGPU_GPU_PROFILE
required_features.push_back(wgpu::FeatureName::TimestampQuery);
#endif