vulkan: Flash Attention DP4A shader for quantized KV cache (#20797)

* use integer dot product for quantized KV flash attention

* small improvements

* fix SHMEM_STAGING indexing

* add missing KV type quants

* fixes

* add supported quants to FA tests

* readd fast paths for <8bit quants

* fix mmq gate and shmem checks
This commit is contained in:
Ruben Ortlam
2026-04-13 14:21:31 +02:00
committed by GitHub
parent aa00911d12
commit 75f3bc94e6
8 changed files with 431 additions and 28 deletions
+1 -1
View File
@@ -8580,7 +8580,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 75, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL}) {
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));