Add plateau warmup schedule for gradual quantization
This commit is contained in:
Binary file not shown.
+2000
-499
File diff suppressed because it is too large
Load Diff
@@ -45,7 +45,10 @@ DEFAULTS = dict(
|
||||
quant_warmup_steps=2000, # lambda warmup over N steps
|
||||
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
|
||||
threshold=0.5, # deadzone threshold (Bonsai: 0.5)
|
||||
soft_quant=True, # use tanh proxy for smooth gradients
|
||||
soft_quant=False, # use tanh proxy for smooth gradients
|
||||
warmup_schedule="plateau", # linear or plateau
|
||||
plateau_steps=500, # steps per plateau level
|
||||
plateau_max=0.8, # max lambda for plateau warmup
|
||||
# Data
|
||||
train_dataset="roneneldan/TinyStories",
|
||||
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
|
||||
@@ -577,7 +580,15 @@ def train(args):
|
||||
step_start = time.time()
|
||||
|
||||
# ---- Lambda warmup ----
|
||||
lambda_ = min(step / args.quant_warmup_steps, 1.0)
|
||||
if args.warmup_schedule == "plateau":
|
||||
# Plateau warmup: hold lambda at each level for plateau_steps
|
||||
levels = int(args.plateau_max / 0.05) # 0.05 increments
|
||||
plateau_size = args.plateau_steps
|
||||
level = min(step // plateau_size, levels)
|
||||
lambda_ = min(level * 0.05, args.plateau_max)
|
||||
else:
|
||||
# Linear warmup
|
||||
lambda_ = min(step / args.quant_warmup_steps, 1.0)
|
||||
set_lambda(model, lambda_)
|
||||
|
||||
# ---- LR warmup ----
|
||||
@@ -702,6 +713,9 @@ def main():
|
||||
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
|
||||
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
|
||||
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
|
||||
parser.add_argument("--warmup-schedule", default=DEFAULTS["warmup_schedule"], dest="warmup_schedule")
|
||||
parser.add_argument("--plateau-steps", type=int, default=DEFAULTS["plateau_steps"], dest="plateau_steps")
|
||||
parser.add_argument("--plateau-max", type=float, default=DEFAULTS["plateau_max"], dest="plateau_max")
|
||||
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
|
||||
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
|
||||
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
|
||||
|
||||
Reference in New Issue
Block a user