ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (#21669)

This commit is contained in:
Rithik Sharma
2026-04-10 10:52:38 -07:00
committed by GitHub
parent e4fed9d08d
commit bfd1f453cb
2 changed files with 27 additions and 20 deletions
+10 -3
View File
@@ -3461,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
#ifndef __EMSCRIPTEN__
// Only support square f16 matrices of size 8 or 16 for now
// Accept f16 subgroup matrix configurations (square or non-square).
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
// The shaders are already parameterized to handle any M/N/K dimensions.
bool valid_subgroup_matrix_config = false;
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
@@ -3805,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
break;
}
// Head dimensions must be divisible by subgroup matrix dimensions
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;