Kimi-Linear support (backend agnostic + MLA KV cache) (#18755)
* kimi linear model implementation * kimi linear convert_hf_to_gguf * kimi linear constants.py tensor_mapping.py * Kimi Linear ggml.h * kimi linear ggml-cpu * Kimi Linear ggml-cuda * Kimi Linear ggml.c * kimi linear src/llama * remove "const int64_t n_seq_tokens = q->ne[2];" to get rid of unused variable warning * remove type mismatch warning * read MoE params * removed some hard coded code * removed all hard code * use DeepseekV2 tokenizer * removed unnecessary internal methods called by the old set_vocab of KimiLinear * rewrite get_vocab for KimiLinear. Removed all kda_scan code * removed all traces of kda_scan * reduce OP count by 1 due to removal of kda_scan * Move KIMI_LINEAR to llm_arch_is_hybrid to enable KV cache * set n_embd_head_k/v to ensure kv cache works * don't quantize conv1d of Kimi Linear * Kimi Linear backend agnostic * removed LOG_INFO * naive chunking form implemented * fixed some comments * add Kimi-K2 specific tokens to be recognized as EOG * build_kda_autoregressive is implemented to replace build_kda_recurrent for faster inference. sync'd to b7682 * replaced Akk and Aqk with mul_mat and clamp * no clamp version * Moved Aqk computation out of the loop * fixed typo and split wkv_b into wk_b and wv_b * MLA KV cache support * fix trailing spaces * moved const llama_model & model; around to follow qwen3next format and see if it cna pass the -Wunused-private-field error * fix trailing whitespace * removed traling whitespaces in empty line + make sure indentation is multiple of 4 * try to make lint happy * remove blank lines to make lint happy * removed at least blank line containing white space * fixed flake8 complaints locally * return ggml_tensor * pair in kda_autoregressive and kda_chunking as in ngxson's Qwen3Next improvement * removed Kimi-Linear specific change that causes failure at server-windows * removed private: from kimi_linear to make build checks happy * removed unnecessary ggml_cont before ggml_reshape * created static function causal_conv1d to abtract similar code for q/k/v * merged dt_bias to SSM_DT. Do -exp(log_A) in convert_hf_to_gguf.py. * reverted to original * fixed find_hparam calls. Fixed e_score_correction_bias to use bias instead of weight. Removed all ssm_conv bias terms. * remove DT_B from constants.py. remove one comment line in llama-model.cpp * new class llm_graph_input_mem_hybrid_k to get around the new MLA change. switch the concat order of ggml_concat calls in kimi-linear.cpp to accommodate MLA changes. Removed support for exp_probs_b.weight * remove ssm_o_norm_b * remove ssm_o_norm_b * changed hparams.kda_head_dim to hparams.n_embd_head_kda. added TODO comment for class llama_graph_mem_hybrid_k * removed all ggml_cont b4 ggml_reshape_4d * Whitespace * replaced all hparams.get with find_hparams * added new names for n_experts, n_experts_used and score_func in TextModel and removed their code in KimiLinear in convert_hf_to_gguf.py. Removed unnecessary ggml_cont and GGML_ASSERT in kimi-linear.cpp * use is_mla to switch between different mem_hybrid types * fixed logical errors in convert_hf_to_gguf.py pointed out by CISC * removed if else for required parameters kv_lora_rank and qk_rope_head_dim * add back ggml_cont for Vcur * minor changes * removed extra line in llama-vocab.cpp. Added back the comment in llama-graph.cpp * f16 gguf cannot run without context length * made a mistake of adding back n_ctx parsing --------- Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
This commit is contained in:
@@ -207,6 +207,9 @@ class Keys:
|
||||
GROUP_COUNT = "{arch}.ssm.group_count"
|
||||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
||||
|
||||
class KDA:
|
||||
HEAD_DIM = "{arch}.kda.head_dim"
|
||||
|
||||
class WKV:
|
||||
HEAD_SIZE = "{arch}.wkv.head_size"
|
||||
|
||||
@@ -461,6 +464,7 @@ class MODEL_ARCH(IntEnum):
|
||||
MIMO2 = auto()
|
||||
LLAMA_EMBED = auto()
|
||||
MAINCODER = auto()
|
||||
KIMI_LINEAR = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -551,6 +555,14 @@ class MODEL_TENSOR(IntEnum):
|
||||
SSM_NORM = auto()
|
||||
SSM_OUT = auto()
|
||||
SSM_BETA_ALPHA = auto() # qwen3next
|
||||
SSM_CONV1D_Q = auto() # Kimi Linear
|
||||
SSM_CONV1D_K = auto() # Kimi Linear
|
||||
SSM_CONV1D_V = auto() # Kimi Linear
|
||||
SSM_F_A = auto() # Kimi Linear
|
||||
SSM_F_B = auto() # Kimi Linear
|
||||
SSM_BETA = auto() # Kimi Linear
|
||||
SSM_G_A = auto() # Kimi Linear
|
||||
SSM_G_B = auto() # Kimi Linear
|
||||
TIME_MIX_W0 = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
TIME_MIX_W2 = auto()
|
||||
@@ -882,6 +894,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.MIMO2: "mimo2",
|
||||
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
|
||||
MODEL_ARCH.MAINCODER: "maincoder",
|
||||
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -969,6 +982,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
|
||||
MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear
|
||||
MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear
|
||||
MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear
|
||||
MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear
|
||||
MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear
|
||||
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear
|
||||
MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear
|
||||
MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear
|
||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||
@@ -3379,6 +3400,47 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.KIMI_LINEAR: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_K_B,
|
||||
MODEL_TENSOR.ATTN_V_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.SSM_CONV1D_Q,
|
||||
MODEL_TENSOR.SSM_CONV1D_K,
|
||||
MODEL_TENSOR.SSM_CONV1D_V,
|
||||
MODEL_TENSOR.SSM_F_A,
|
||||
MODEL_TENSOR.SSM_F_B,
|
||||
MODEL_TENSOR.SSM_BETA,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_G_A,
|
||||
MODEL_TENSOR.SSM_G_B,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
@@ -3706,6 +3768,9 @@ KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
|
||||
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
|
||||
|
||||
# KDA
|
||||
KEY_KDA_HEAD_DIM = Keys.KDA.HEAD_DIM
|
||||
|
||||
# tokenization
|
||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||
KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
|
||||
|
||||
@@ -980,6 +980,9 @@ class GGUFWriter:
|
||||
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
||||
|
||||
def add_kda_head_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.KDA.HEAD_DIM.format(arch=self.arch), value)
|
||||
|
||||
def add_tokenizer_model(self, model: str) -> None:
|
||||
self.add_string(Keys.Tokenizer.MODEL, model)
|
||||
|
||||
|
||||
@@ -438,6 +438,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||
"backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe
|
||||
"model.layers.{bid}.mlp.e_score_correction", # exaone-moe
|
||||
"model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
@@ -502,6 +503,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w3", # mistral-large
|
||||
"backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe
|
||||
"model.layers.{bid}.block_sparse_moe.shared_experts.up_proj", # kimi
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_CHEXP: (
|
||||
@@ -549,6 +551,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w1", # mistral-large
|
||||
"model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_CHEXP: (
|
||||
@@ -613,6 +616,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w2", # mistral-large
|
||||
"backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe
|
||||
"model.layers.{bid}.block_sparse_moe.shared_experts.down_proj", # kimi
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_CHEXP: (
|
||||
@@ -759,6 +763,7 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
|
||||
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
|
||||
"backbone.layers.{bid}.mixer.dt", # nemotron-h-moe
|
||||
"model.layers.{bid}.self_attn.dt_proj", # kimi
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_DT_NORM: (
|
||||
@@ -772,6 +777,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.A_log", # plamo2
|
||||
"model.layers.{bid}.linear_attn.A_log", # qwen3next
|
||||
"model.layers.{bid}.self_attn.A_log", # kimi
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_B_NORM: (
|
||||
@@ -797,6 +803,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
||||
"model.layers.{bid}.linear_attn.norm", # qwen3next
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
"model.layers.{bid}.self_attn.o_norm", # kimi
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
@@ -811,6 +818,31 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
|
||||
),
|
||||
|
||||
# Kimi Linear KDA (using SSM_ prefix for consistency)
|
||||
MODEL_TENSOR.SSM_CONV1D_Q: (
|
||||
"model.layers.{bid}.self_attn.q_conv1d",
|
||||
),
|
||||
MODEL_TENSOR.SSM_CONV1D_K: (
|
||||
"model.layers.{bid}.self_attn.k_conv1d",
|
||||
),
|
||||
MODEL_TENSOR.SSM_CONV1D_V: (
|
||||
"model.layers.{bid}.self_attn.v_conv1d",
|
||||
),
|
||||
MODEL_TENSOR.SSM_F_A: (
|
||||
"model.layers.{bid}.self_attn.f_a_proj",
|
||||
),
|
||||
MODEL_TENSOR.SSM_F_B: (
|
||||
"model.layers.{bid}.self_attn.f_b_proj",
|
||||
),
|
||||
MODEL_TENSOR.SSM_BETA: (
|
||||
"model.layers.{bid}.self_attn.b_proj",
|
||||
),
|
||||
MODEL_TENSOR.SSM_G_A: (
|
||||
"model.layers.{bid}.self_attn.g_a_proj",
|
||||
),
|
||||
MODEL_TENSOR.SSM_G_B: (
|
||||
"model.layers.{bid}.self_attn.g_b_proj",
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
"model.layers.{bid}.attention.w0", # rwkv7
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user