ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (#21669)
This commit is contained in:
@@ -3461,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Only support square f16 matrices of size 8 or 16 for now
|
||||
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
||||
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
||||
// The shaders are already parameterized to handle any M/N/K dimensions.
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
||||
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
|
||||
@@ -3805,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must be divisible by subgroup matrix dimensions
|
||||
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
|
||||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
Reference in New Issue
Block a user