vulkan: Implement top-k (#17418)

* vulkan: Implement top-k

Each pass launches workgroups that each sort 2^N elements (where N is usually 7-10)
and discards all but the top K. Repeat until only K are left. And there's a fast
path when K==1 to just find the max value rather than sorting.

* fix pipeline selection

* vulkan: Add N-ary search algorithm for topk

* microoptimizations
This commit is contained in:
Jeff Bolz
2025-11-26 09:45:43 -06:00
committed by GitHub
parent 6ab4e50d9c
commit 879d673759
5 changed files with 480 additions and 1 deletions
@@ -913,6 +913,9 @@ void process_shaders() {
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));