vulkan: Implement SOLVE_TRI (#17486)

* vulkan: Implement SOLVE_TRI

* load B matrix through shared memory

* use FLOAT_TYPE
This commit is contained in:
Jeff Bolz
2025-11-27 08:48:00 -06:00
committed by GitHub
parent c386114922
commit 4abef75f2c
3 changed files with 167 additions and 0 deletions
@@ -944,6 +944,8 @@ void process_shaders() {
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
for (auto transpose : {false, true}) {
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {