model : refactor QKV into common build_qkv and create_tensor_qkv helpers (#21245)
* model : refactor QKV into common build_qkv and create_tensor_qkv helpers * model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
This commit is contained in:
@@ -17,6 +17,7 @@ struct ggml_context;
|
||||
struct ggml_tensor;
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_layer;
|
||||
|
||||
struct llama_memory_context_i;
|
||||
|
||||
@@ -707,6 +708,12 @@ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
|
||||
// used in build_rs to properly order writes and avoid unnecessary copies
|
||||
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
||||
|
||||
struct llm_graph_qkv {
|
||||
ggml_tensor * q; // [n_embd_head, n_head, n_tokens]
|
||||
ggml_tensor * k; // [n_embd_head, n_head_kv, n_tokens]
|
||||
ggml_tensor * v; // [n_embd_head, n_head_kv, n_tokens]
|
||||
};
|
||||
|
||||
struct llm_graph_context {
|
||||
const llm_arch arch;
|
||||
|
||||
@@ -793,6 +800,17 @@ struct llm_graph_context {
|
||||
llm_norm_type type,
|
||||
int il) const;
|
||||
|
||||
|
||||
// compute Q, K, V projections with optional bias and reshape
|
||||
// supports both fused wqkv and separate wq/wk/wv paths
|
||||
llm_graph_qkv build_qkv(
|
||||
const llama_layer & layer,
|
||||
ggml_tensor * cur,
|
||||
int64_t n_embd_head,
|
||||
int64_t n_head,
|
||||
int64_t n_head_kv,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_ffn(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * up,
|
||||
|
||||
Reference in New Issue
Block a user