mtmd: build_attn modified, flash_attn on/off via ctx_params (#19729)
This commit is contained in:
+1
-4
@@ -628,9 +628,6 @@ ggml_tensor * clip_graph::build_attn(
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
// F32 may not needed for vision encoders?
|
||||
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
@@ -639,7 +636,7 @@ ggml_tensor * clip_graph::build_attn(
|
||||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
|
||||
}
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
Reference in New Issue
Block a user