llama: Add option to merge gate and exp weights (#19139)
* llama: Add option to merge gate and exp weights * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * update constants.py * add gate_up for the all MoE models * convert: simplify merge tensor condition * update constants.py * reduce number of models, add create_tensor_gate_up helper --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
+37
-4
@@ -116,7 +116,8 @@ class ModelBase:
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
disable_mistral_community_chat_template: bool = False,
|
||||
sentence_transformers_dense_modules: bool = False):
|
||||
sentence_transformers_dense_modules: bool = False,
|
||||
fuse_gate_up_exps: bool = False):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is MmprojModel:
|
||||
@@ -135,6 +136,9 @@ class ModelBase:
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
self.fuse_gate_up_exps = fuse_gate_up_exps
|
||||
self._gate_exp_buffer: dict[int, Tensor] = {}
|
||||
self._up_exp_buffer: dict[int, Tensor] = {}
|
||||
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
|
||||
self.metadata_override = metadata_override
|
||||
@@ -512,8 +516,31 @@ class ModelBase:
|
||||
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# Handle gate/up expert tensor fusion if enabled
|
||||
if self.fuse_gate_up_exps and bid is not None:
|
||||
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
|
||||
self._gate_exp_buffer[bid] = data_torch
|
||||
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
|
||||
self._up_exp_buffer[bid] = data_torch
|
||||
|
||||
# Check if both gate and up are buffered for this layer
|
||||
if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
|
||||
gate_data = self._gate_exp_buffer.pop(bid)
|
||||
up_data = self._up_exp_buffer.pop(bid)
|
||||
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
|
||||
fused_data = torch.cat([gate_data, up_data], dim=1)
|
||||
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
|
||||
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
|
||||
return [(fused_name, fused_data)]
|
||||
|
||||
# If we buffered a gate/up tensor, wait for the other
|
||||
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
|
||||
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
|
||||
return []
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
@@ -11942,6 +11969,11 @@ def parse_args() -> argparse.Namespace:
|
||||
"Default these modules are not included.")
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fuse-gate-up-exps", action="store_true",
|
||||
help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.print_supported_models and args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
@@ -12079,7 +12111,8 @@ def main() -> None:
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
|
||||
fuse_gate_up_exps=args.fuse_gate_up_exps
|
||||
)
|
||||
|
||||
if args.vocab_only:
|
||||
|
||||
Reference in New Issue
Block a user