ggml webgpu: fix workgroup dispatch limit for large batch sizes (#19965)
* ggml-webgpu: fix workgroup dispatch limit for large batch sizes WebGPU limits workgroup sizes to 65535 per dimension. Large MUL_MAT operations with batch sizes exceedeing this limi would fail. * add compute_2d_workgroups() helper to split total workgroup ID across X/Y dimensions * update mul_mat_reg_tile.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat_subgroup_matrix.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat.wgsl to compute global index from 2D workgroup coordinates * refactor all three mul_mat dispatch paths to use the shared helper * ggml-webgpu: add bounds checking for over-dispatched workgroups 2D workgroup dispatch can over-dispatch when total workgroups don't divide evenly into the 65535 per-dimension limit. Extra workgroups would compute invalid batch indices, causing memory corruption. * add batch_idx bound check to mul_mat_reg_tile.wgsl and mul_mat_subgroup_matrix.wgsl to prevent over-dispatched workgroups from accessing invalid memory * fixes test failures with large batch sizes (eg., bs=[128, 1024]) * ggml-webgpu: add back TODO for spliting large sizes into batches * Optimize 2d workgroup provisioning * Set some parameters that increase speed --------- Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
@@ -31,6 +31,13 @@
|
||||
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
|
||||
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
|
||||
|
||||
// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
|
||||
// Assumes that the total number of workgroups does not exceed max_per_dim^2.
|
||||
static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
|
||||
wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
|
||||
wg_x = CEIL_DIV(total_wg, wg_y);
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
||||
# define WEBGPU_DEBUG_BUF_ELEMS 512
|
||||
@@ -69,8 +76,8 @@
|
||||
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_NUM_PARAM_BUFS 16u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
|
||||
#define WEBGPU_NUM_PARAM_BUFS 48u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
||||
// parameter buffer pool
|
||||
@@ -1146,8 +1153,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
};
|
||||
|
||||
// Calculate workgroup dimensions
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
@@ -1155,9 +1163,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
// TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
|
||||
wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
} else if (use_fast) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
@@ -1176,12 +1182,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
||||
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
|
||||
}
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
} else { // legacy
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
wg_y = 1;
|
||||
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
|
||||
Reference in New Issue
Block a user