diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 16ebc32cb..503171ee1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -26,20 +26,23 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 #define WEBGPU_MUL_MAT_WG_SIZE_M 8 #define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 +#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 +#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 // Subgroup matrix parameters // The number of subgroups in the M dimension #define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 // The number of subgroup matrices each subgroup accumulates over #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 @@ -1734,13 +1737,24 @@ class ggml_webgpu_shader_lib { // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + const bool is_quant = ggml_is_quantized(context.src0->type); + + uint32_t tile_k; + if (key.use_subgroup_matrix) { + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT + : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + } else { + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT + : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + } + // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); // Subgroup matrix specifics if (key.use_subgroup_matrix) { + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); @@ -1760,12 +1774,13 @@ class ggml_webgpu_shader_lib { if (!key.use_subgroup_matrix) { defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); } auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); - decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_k = tile_k; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->use_subgroup_matrix = key.use_subgroup_matrix; @@ -1962,10 +1977,15 @@ class ggml_webgpu_shader_lib { defines.push_back("SCALAR"); + // mul_mat_id is register-tile only. + const uint32_t tile_k = ggml_is_quantized(context.src0->type) + ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT + : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); @@ -1976,7 +1996,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); auto decisions = std::make_shared(); - decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_k = tile_k; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;