add performance-portable tuning for register-tile and subgroup matmul (#22241)
This commit is contained in:
@@ -26,20 +26,23 @@
|
|||||||
// Matrix multiplication parameters
|
// Matrix multiplication parameters
|
||||||
|
|
||||||
// Register tiling parameters
|
// Register tiling parameters
|
||||||
#define WEBGPU_MUL_MAT_TILE_M 8
|
#define WEBGPU_MUL_MAT_TILE_M 4
|
||||||
#define WEBGPU_MUL_MAT_TILE_N 8
|
#define WEBGPU_MUL_MAT_TILE_N 4
|
||||||
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
||||||
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
||||||
#define WEBGPU_MUL_MAT_TILE_K 32
|
#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
|
||||||
|
#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
|
||||||
|
|
||||||
// Subgroup matrix parameters
|
// Subgroup matrix parameters
|
||||||
// The number of subgroups in the M dimension
|
// The number of subgroups in the M dimension
|
||||||
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
||||||
// The number of subgroups in the N dimension
|
// The number of subgroups in the N dimension
|
||||||
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
|
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
|
||||||
// The number of subgroup matrices each subgroup accumulates over
|
// The number of subgroup matrices each subgroup accumulates over
|
||||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
||||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||||
|
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
|
||||||
|
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
|
||||||
|
|
||||||
// Matrix-vector multiplication parameters
|
// Matrix-vector multiplication parameters
|
||||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||||
@@ -1734,13 +1737,24 @@ class ggml_webgpu_shader_lib {
|
|||||||
// VEC/SCALAR controls
|
// VEC/SCALAR controls
|
||||||
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
||||||
|
|
||||||
|
const bool is_quant = ggml_is_quantized(context.src0->type);
|
||||||
|
|
||||||
|
uint32_t tile_k;
|
||||||
|
if (key.use_subgroup_matrix) {
|
||||||
|
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT
|
||||||
|
: WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
|
||||||
|
} else {
|
||||||
|
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
|
||||||
|
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
// Tiles
|
// Tiles
|
||||||
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||||
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
||||||
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
|
|
||||||
|
|
||||||
// Subgroup matrix specifics
|
// Subgroup matrix specifics
|
||||||
if (key.use_subgroup_matrix) {
|
if (key.use_subgroup_matrix) {
|
||||||
|
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
||||||
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
||||||
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
|
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
|
||||||
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
|
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
|
||||||
@@ -1760,12 +1774,13 @@ class ggml_webgpu_shader_lib {
|
|||||||
if (!key.use_subgroup_matrix) {
|
if (!key.use_subgroup_matrix) {
|
||||||
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
||||||
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
||||||
|
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||||
|
|
||||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
||||||
decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
|
decisions->tile_k = tile_k;
|
||||||
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
||||||
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
||||||
decisions->use_subgroup_matrix = key.use_subgroup_matrix;
|
decisions->use_subgroup_matrix = key.use_subgroup_matrix;
|
||||||
@@ -1962,10 +1977,15 @@ class ggml_webgpu_shader_lib {
|
|||||||
|
|
||||||
defines.push_back("SCALAR");
|
defines.push_back("SCALAR");
|
||||||
|
|
||||||
|
// mul_mat_id is register-tile only.
|
||||||
|
const uint32_t tile_k = ggml_is_quantized(context.src0->type)
|
||||||
|
? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
|
||||||
|
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||||
|
|
||||||
// Tiles
|
// Tiles
|
||||||
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||||
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
||||||
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
|
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
||||||
|
|
||||||
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
||||||
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
||||||
@@ -1976,7 +1996,7 @@ class ggml_webgpu_shader_lib {
|
|||||||
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
|
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
|
||||||
|
|
||||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
||||||
decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
|
decisions->tile_k = tile_k;
|
||||||
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
||||||
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
||||||
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||||
|
|||||||
Reference in New Issue
Block a user