Add plateau warmup schedule for gradual quantization

This commit is contained in:
2026-04-24 04:45:13 +02:00
parent c10212735a
commit 322316fc5f
3 changed files with 2016 additions and 501 deletions
Binary file not shown.
+2000 -499
View File
File diff suppressed because it is too large Load Diff
+15 -1
View File
@@ -45,7 +45,10 @@ DEFAULTS = dict(
quant_warmup_steps=2000, # lambda warmup over N steps quant_warmup_steps=2000, # lambda warmup over N steps
activation_bits=16, # 16 = no activation quant (use 8 for INT8) activation_bits=16, # 16 = no activation quant (use 8 for INT8)
threshold=0.5, # deadzone threshold (Bonsai: 0.5) threshold=0.5, # deadzone threshold (Bonsai: 0.5)
soft_quant=True, # use tanh proxy for smooth gradients soft_quant=False, # use tanh proxy for smooth gradients
warmup_schedule="plateau", # linear or plateau
plateau_steps=500, # steps per plateau level
plateau_max=0.8, # max lambda for plateau warmup
# Data # Data
train_dataset="roneneldan/TinyStories", train_dataset="roneneldan/TinyStories",
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"), eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
@@ -577,6 +580,14 @@ def train(args):
step_start = time.time() step_start = time.time()
# ---- Lambda warmup ---- # ---- Lambda warmup ----
if args.warmup_schedule == "plateau":
# Plateau warmup: hold lambda at each level for plateau_steps
levels = int(args.plateau_max / 0.05) # 0.05 increments
plateau_size = args.plateau_steps
level = min(step // plateau_size, levels)
lambda_ = min(level * 0.05, args.plateau_max)
else:
# Linear warmup
lambda_ = min(step / args.quant_warmup_steps, 1.0) lambda_ = min(step / args.quant_warmup_steps, 1.0)
set_lambda(model, lambda_) set_lambda(model, lambda_)
@@ -702,6 +713,9 @@ def main():
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold") parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant") parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant") parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
parser.add_argument("--warmup-schedule", default=DEFAULTS["warmup_schedule"], dest="warmup_schedule")
parser.add_argument("--plateau-steps", type=int, default=DEFAULTS["plateau_steps"], dest="plateau_steps")
parser.add_argument("--plateau-max", type=float, default=DEFAULTS["plateau_max"], dest="plateau_max")
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset") parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path") parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
parser.add_argument("--log-file", default=DEFAULTS["log_file"]) parser.add_argument("--log-file", default=DEFAULTS["log_file"])