vulkan: support solve_tri with larger N/K values (#17781)
Split N into chunks to fit into shared memory. If K > 128, use a larger workgroup with enough invocations. Add perf tests matching qwen3next.
This commit is contained in:
@@ -4033,10 +4033,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
for (auto &s : device->pipeline_solve_tri_f32) {
|
||||
const vk_solve_tri_pipeline_state &state = s.first;
|
||||
|
||||
// Max number of rows to load at a time, limited by shared memory
|
||||
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
|
||||
// Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
|
||||
const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
|
||||
|
||||
ggml_vk_create_pipeline(
|
||||
device, s.second, "solve_tri_f32",
|
||||
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
|
||||
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
|
||||
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
|
||||
}
|
||||
|
||||
#define IM2COL(bda) \
|
||||
@@ -14025,10 +14031,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
const uint32_t N = op->src[0]->ne[0];
|
||||
const uint32_t K = op->src[1]->ne[0];
|
||||
// K dimension limited to workgroup size
|
||||
if (K > 128) {
|
||||
if (K > 1u << device->max_workgroup_size_log2) {
|
||||
return false;
|
||||
}
|
||||
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
|
||||
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
|
||||
|
||||
if (batch_N == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
Reference in New Issue
Block a user