mla : make the V tensor a view of K (#18986)
* mla : pass V as a view of K to the FA op * cuda : adjust mla logic to new layout * kv-cache : fix rope shift * tests : remove comment * cuda : fix reusable_cutoff Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
|
||||
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
ggml_tensor * v = nullptr;
|
||||
if (hsk_padded == 576 && hsv_padded == 512) {
|
||||
// TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
|
||||
|
||||
// in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
|
||||
// for more info:
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/13435
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18986
|
||||
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
} else {
|
||||
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
}
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
|
||||
Reference in New Issue
Block a user