Add plateau warmup schedule for gradual quantization
This commit is contained in:
Binary file not shown.
+2000
-499
File diff suppressed because it is too large
Load Diff
@@ -45,7 +45,10 @@ DEFAULTS = dict(
|
|||||||
quant_warmup_steps=2000, # lambda warmup over N steps
|
quant_warmup_steps=2000, # lambda warmup over N steps
|
||||||
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
|
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
|
||||||
threshold=0.5, # deadzone threshold (Bonsai: 0.5)
|
threshold=0.5, # deadzone threshold (Bonsai: 0.5)
|
||||||
soft_quant=True, # use tanh proxy for smooth gradients
|
soft_quant=False, # use tanh proxy for smooth gradients
|
||||||
|
warmup_schedule="plateau", # linear or plateau
|
||||||
|
plateau_steps=500, # steps per plateau level
|
||||||
|
plateau_max=0.8, # max lambda for plateau warmup
|
||||||
# Data
|
# Data
|
||||||
train_dataset="roneneldan/TinyStories",
|
train_dataset="roneneldan/TinyStories",
|
||||||
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
|
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
|
||||||
@@ -577,6 +580,14 @@ def train(args):
|
|||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
|
|
||||||
# ---- Lambda warmup ----
|
# ---- Lambda warmup ----
|
||||||
|
if args.warmup_schedule == "plateau":
|
||||||
|
# Plateau warmup: hold lambda at each level for plateau_steps
|
||||||
|
levels = int(args.plateau_max / 0.05) # 0.05 increments
|
||||||
|
plateau_size = args.plateau_steps
|
||||||
|
level = min(step // plateau_size, levels)
|
||||||
|
lambda_ = min(level * 0.05, args.plateau_max)
|
||||||
|
else:
|
||||||
|
# Linear warmup
|
||||||
lambda_ = min(step / args.quant_warmup_steps, 1.0)
|
lambda_ = min(step / args.quant_warmup_steps, 1.0)
|
||||||
set_lambda(model, lambda_)
|
set_lambda(model, lambda_)
|
||||||
|
|
||||||
@@ -702,6 +713,9 @@ def main():
|
|||||||
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
|
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
|
||||||
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
|
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
|
||||||
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
|
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
|
||||||
|
parser.add_argument("--warmup-schedule", default=DEFAULTS["warmup_schedule"], dest="warmup_schedule")
|
||||||
|
parser.add_argument("--plateau-steps", type=int, default=DEFAULTS["plateau_steps"], dest="plateau_steps")
|
||||||
|
parser.add_argument("--plateau-max", type=float, default=DEFAULTS["plateau_max"], dest="plateau_max")
|
||||||
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
|
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
|
||||||
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
|
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
|
||||||
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
|
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
|
||||||
|
|||||||
Reference in New Issue
Block a user