llama: Add option to merge gate and exp weights (#19139)

* llama: Add option to merge gate and exp weights

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* update constants.py

* add gate_up for the all MoE models

* convert: simplify merge tensor condition

* update constants.py

* reduce number of models, add create_tensor_gate_up helper

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Aman Gupta
2026-02-26 21:01:08 +08:00
committed by GitHub
parent ffaafde16f
commit b68d75165a
12 changed files with 135 additions and 44 deletions
+37 -4
View File
@@ -116,7 +116,8 @@ class ModelBase:
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False,
sentence_transformers_dense_modules: bool = False):
sentence_transformers_dense_modules: bool = False,
fuse_gate_up_exps: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
@@ -135,6 +136,9 @@ class ModelBase:
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
self.fuse_gate_up_exps = fuse_gate_up_exps
self._gate_exp_buffer: dict[int, Tensor] = {}
self._up_exp_buffer: dict[int, Tensor] = {}
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
@@ -512,8 +516,31 @@ class ModelBase:
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
return [(self.map_tensor_name(name), data_torch)]
new_name = self.map_tensor_name(name)
# Handle gate/up expert tensor fusion if enabled
if self.fuse_gate_up_exps and bid is not None:
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
self._gate_exp_buffer[bid] = data_torch
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
self._up_exp_buffer[bid] = data_torch
# Check if both gate and up are buffered for this layer
if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
gate_data = self._gate_exp_buffer.pop(bid)
up_data = self._up_exp_buffer.pop(bid)
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
fused_data = torch.cat([gate_data, up_data], dim=1)
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
return [(fused_name, fused_data)]
# If we buffered a gate/up tensor, wait for the other
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
return []
return [(new_name, data_torch)]
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused
@@ -11942,6 +11969,11 @@ def parse_args() -> argparse.Namespace:
"Default these modules are not included.")
)
parser.add_argument(
"--fuse-gate-up-exps", action="store_true",
help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
)
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
parser.error("the following arguments are required: model")
@@ -12079,7 +12111,8 @@ def main() -> None:
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
fuse_gate_up_exps=args.fuse_gate_up_exps
)
if args.vocab_only: