model : refactor QKV into common build_qkv and create_tensor_qkv helpers (#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
This commit is contained in:
PikaPikachu
2026-04-16 23:41:34 +08:00
committed by GitHub
parent f772f6e434
commit 9db77a020c
88 changed files with 351 additions and 1764 deletions
+18
View File
@@ -17,6 +17,7 @@ struct ggml_context;
struct ggml_tensor;
struct llama_cparams;
struct llama_layer;
struct llama_memory_context_i;
@@ -707,6 +708,12 @@ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
struct llm_graph_qkv {
ggml_tensor * q; // [n_embd_head, n_head, n_tokens]
ggml_tensor * k; // [n_embd_head, n_head_kv, n_tokens]
ggml_tensor * v; // [n_embd_head, n_head_kv, n_tokens]
};
struct llm_graph_context {
const llm_arch arch;
@@ -793,6 +800,17 @@ struct llm_graph_context {
llm_norm_type type,
int il) const;
// compute Q, K, V projections with optional bias and reshape
// supports both fused wqkv and separate wq/wk/wv paths
llm_graph_qkv build_qkv(
const llama_layer & layer,
ggml_tensor * cur,
int64_t n_embd_head,
int64_t n_head,
int64_t n_head_kv,
int il) const;
ggml_tensor * build_ffn(
ggml_tensor * cur,
ggml_tensor * up,