""" Ternary quantization training loop — MUTABLE in the autoresearch loop. This file implements: - BitLinear: ternary quantized linear layer with STE + lambda warmup - Training loop on TinyStories with WikiText eval - Debug output for monitoring in tmux Usage: python train.py # default config python train.py --steps 2000 --lr 1e-4 # custom hyperparams python train.py --group-size 64 --warmup 500 # quantization params """ import argparse import json import math import os import sys import time from pathlib import Path import torch import torch.nn as nn from torch.nn import functional as F from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # --------------------------------------------------------------------------- # Configuration defaults (tune these, or pass via CLI) # --------------------------------------------------------------------------- DEFAULTS = dict( model_name="HuggingFaceTB/SmolLM-135M", device="mps" if torch.backends.mps.is_available() else "cpu", dtype=torch.float32, # Training steps=1000, batch_size=4, seq_len=256, learning_rate=1e-5, warmup_steps=100, eval_every=50, # Quantization group_size=128, # Bonsai uses 128 quant_warmup_steps=2000, # lambda warmup over N steps activation_bits=16, # 16 = no activation quant (use 8 for INT8) threshold=0.5, # deadzone threshold (Bonsai: 0.5) soft_quant=False, # use tanh proxy for smooth gradients warmup_schedule="plateau", # linear or plateau plateau_steps=500, # steps per plateau level plateau_max=0.8, # max lambda for plateau warmup # Data train_dataset="roneneldan/TinyStories", eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"), # Logging log_file=str(Path(__file__).parent / "results.tsv"), ) # --------------------------------------------------------------------------- # Ternary Quantization Primitives # --------------------------------------------------------------------------- def ternary_quantize(w, group_size=128, threshold=0.5): """Quantize weights to {-1, 0, +1} with per-group abs_mean scale. Groups are formed by flattening the weight tensor and taking consecutive chunks of `group_size` elements. The last group may be smaller. Args: w: weight tensor of any shape group_size: number of weights per quantization group threshold: deadzone threshold (0 < t < 1). Weights with |w_norm| < t are zeroed. Returns: w_quant: ternary weights {-1, 0, +1} in original shape scale: per-group scale factors, shape (num_groups, 1) """ original_shape = w.shape w_flat = w.reshape(-1) # Pad to multiple of group_size if needed n = w_flat.numel() pad = (group_size - n % group_size) % group_size if pad > 0: w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)]) w_groups = w_flat.reshape(-1, group_size) # Scale: mean(|w|) per group (balanced ternary distribution) abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6) scale = abs_mean # Normalize, clamp, round to nearest ternary w_norm = w_groups / scale w_clamped = w_norm.clamp(-1.0, 1.0) if threshold > 0: w_quant = torch.where(w_clamped.abs() < threshold, torch.zeros_like(w_clamped), torch.sign(w_clamped)) else: w_quant = torch.round(w_clamped) # Reshape back to original (trim padding) w_quant = w_quant.reshape(-1)[:n].reshape(original_shape) return w_quant, scale def soft_ternary(w, group_size=128, temperature=0.05, threshold=0.5): """Soft ternary quantization with differentiable tanh proxy. Uses tanh(w / temperature) as a smooth approximation to ternary {-1, 0, +1}. This provides smooth gradients for weight optimization. Args: w: weight tensor of any shape group_size: number of weights per quantization group temperature: controls sharpness (lower = closer to hard ternary) threshold: deadzone threshold for soft zero region Returns: w_soft: soft-quantized weights in original shape scale: per-group scale factors, shape (num_groups, 1) """ original_shape = w.shape w_flat = w.reshape(-1) # Pad to multiple of group_size if needed n = w_flat.numel() pad = (group_size - n % group_size) % group_size if pad > 0: w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)]) w_groups = w_flat.reshape(-1, group_size) # Scale: mean(|w|) per group abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6) scale = abs_mean # Normalize, apply tanh for soft ternary w_norm = w_groups / scale w_clamped = w_norm.clamp(-1.0, 1.0) if threshold > 0: # Soft deadzone: shrink values near zero w_deadzone = w_clamped * (1.0 - threshold / (w_clamped.abs() + threshold)) else: w_deadzone = w_clamped w_soft = torch.tanh(w_deadzone / temperature) # (-1, 1) smooth # Reshape back to original (trim padding) w_soft = w_soft.reshape(-1)[:n].reshape(original_shape) return w_soft, scale def ternary_quantize_learnable(w, scale, group_size=128, threshold=0.5): """Quantize weights to {-1, 0, +1} using a learnable per-group scale. The scale is a learnable nn.Parameter, allowing the network to optimize the quantization scale during training. Uses STE: forward pass uses ternary weights, backward pass gradients flow through the continuous weights. Args: w: weight tensor of any shape scale: learnable per-group scale, shape (num_groups,) group_size: number of weights per quantization group threshold: deadzone threshold (0 < t < 1) Returns: w_dequant: dequantized weights in original shape (for forward pass) """ original_shape = w.shape w_flat = w.reshape(-1) n = w_flat.numel() pad = (group_size - n % group_size) % group_size if pad > 0: w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)]) w_groups = w_flat.reshape(-1, group_size) # Normalize using learnable scale scale = scale.to(w.device) # ensure scale is on same device as weights scale_expanded = scale.unsqueeze(-1).expand(-1, group_size) w_norm = w_groups / scale_expanded.clamp(min=1e-6) w_clamped = w_norm.clamp(-1.0, 1.0) if threshold > 0: w_quant = torch.where(w_clamped.abs() < threshold, torch.zeros_like(w_clamped), torch.sign(w_clamped)) else: w_quant = torch.round(w_clamped) # Dequantize: w_dequant = w_quant * scale w_dequant = w_quant * scale_expanded # Reshape back to original (trim padding) w_dequant = w_dequant.reshape(-1)[:n].reshape(original_shape) # STE: w_dequant = w + (w_dequant - w).detach() # Gradients flow through w, not through the quantization w_dequant = w + (w_dequant - w).detach() return w_dequant def ternary_dequantize(w_quant, scale, group_size=128): """Reconstruct weights from ternary codes and scales. Groups are formed the same way as in ternary_quantize. """ original_shape = w_quant.shape w_flat = w_quant.reshape(-1) n = w_flat.numel() pad = (group_size - n % group_size) % group_size if pad > 0: w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)]) w_groups = w_flat.reshape(-1, group_size) w_dequant = w_groups * scale return w_dequant.reshape(-1)[:n].reshape(original_shape) def activation_quantize(x, bits=8): """Quantize activations to INT8 with per-token absmax scaling. Args: x: activation tensor bits: number of bits (default 8) Returns: x_quant: quantized then dequantized activations (same dtype as input) """ if bits == 16: return x # no quantization max_val = x.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6) max_q = 2 ** (bits - 1) - 1 # 127 for INT8 scale = max_val / max_q x_q = (x / scale).round().clamp(-max_q, max_q) return x_q * scale # --------------------------------------------------------------------------- # BitLinear — Ternary Quantized Linear Layer # --------------------------------------------------------------------------- class BitLinear(nn.Module): """Linear layer with ternary quantized weights and INT8 activations. Uses straight-through estimator (STE) for gradient flow through quantization. Lambda warmup gradually introduces quantization during training. The forward pass: 1. Quantize weights to ternary 2. (Optional) Quantize activations to INT8 3. Forward pass with quantized values 4. STE: gradients flow through unquantized weights 5. Lambda warmup: blend between FP and quantized forward """ def __init__(self, in_features, out_features, bias=True, group_size=128, activation_bits=8, threshold=0.5, soft_quant=False): super().__init__() self.in_features = in_features self.out_features = out_features self.group_size = group_size self.activation_bits = activation_bits self.threshold = threshold self.soft_quant = soft_quant # use tanh proxy for smooth gradients # FP32 weights (learned in full precision, quantized at forward time) self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02) if bias: self.bias = nn.Parameter(torch.zeros(out_features, dtype=DEFAULTS['dtype'])) else: self.bias = None # Lambda for warmup (set externally or via training loop) self.register_buffer('lambda_', torch.tensor(0.0)) def forward(self, x): lambda_ = self.lambda_.item() if lambda_ <= 0.0: # No quantization — standard linear out = F.linear(x, self.weight, self.bias) return out # Quantize weights (STE: gradients flow through FP weights) w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size, threshold=self.threshold) w_dequant = ternary_dequantize(w_quant, scale, self.group_size) # Quantize activations (optional) if self.activation_bits < 16: x_q = activation_quantize(x, self.activation_bits) else: x_q = x # Quantized forward out_quant = F.linear(x_q, w_dequant, self.bias) # Straight-through estimator with lambda warmup: # out = out_fp + lambda * (out_quant - out_fp).detach() # When lambda=0: pure FP forward (no quantization) # When lambda=1: quantized forward, gradients through FP (full STE) out_fp = F.linear(x, self.weight, self.bias) out = out_fp + lambda_ * (out_quant - out_fp).detach() return out def extra_repr(self): return (f"in_features={self.in_features}, out_features={self.out_features}, " f"group_size={self.group_size}, activation_bits={self.activation_bits}") # --------------------------------------------------------------------------- # Model Surgery — Replace Linear with BitLinear # --------------------------------------------------------------------------- def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8, exclude_embeddings=True, threshold=0.5, soft_quant=False): """Replace all nn.Linear layers in model with BitLinear. Args: model: HuggingFace model group_size: quantization group size activation_bits: activation quantization bits (16 = no quant) exclude_embeddings: don't replace lm_head/embedding (usually) threshold: deadzone threshold for ternary quantization (0 < t < 1) soft_quant: use tanh proxy for smooth gradients (vs hard ternary) """ count = 0 for name, module in model.named_modules(): if isinstance(module, nn.Linear): # Skip embedding layers if requested if exclude_embeddings and ('embed_tokens' in name or 'lm_head' in name): continue # Create BitLinear with same dimensions and device bit_linear = BitLinear( module.in_features, module.out_features, bias=module.bias is not None, group_size=group_size, activation_bits=activation_bits, threshold=threshold, soft_quant=soft_quant, ) # Initialize from FP weights (critical for warmup to work) with torch.no_grad(): bit_linear.weight.data = module.weight.clone() if module.bias is not None: bit_linear.bias.data = module.bias.clone() # Replace in parent module parent_name = ".".join(name.split(".")[:-1]) child_name = name.split(".")[-1] parent = getattr(model, parent_name, None) if parent_name else model # Need to walk to the actual parent parts = name.split(".") obj = model for p in parts[:-1]: obj = getattr(obj, p) setattr(obj, parts[-1], bit_linear) count += 1 return count def set_lambda(model, value): """Set lambda on all BitLinear layers.""" for module in model.modules(): if isinstance(module, BitLinear): module.lambda_.fill_(value) def get_quant_stats(model): """Get quantization statistics across all BitLinear layers. Returns dict with fraction of weights at {-1, 0, +1}. """ total = 0 count_neg1 = 0 count_zero = 0 count_pos1 = 0 for module in model.modules(): if isinstance(module, BitLinear): w_q, _ = ternary_quantize(module.weight, group_size=module.group_size, threshold=module.threshold) n = w_q.numel() total += n count_neg1 += (w_q == -1).sum().item() count_zero += (w_q == 0).sum().item() count_pos1 += (w_q == 1).sum().item() if total == 0: return {"-1": 0, "0": 0, "+1": 0} return { "-1": count_neg1 / total, "0": count_zero / total, "+1": count_pos1 / total, } # --------------------------------------------------------------------------- # Data Loading # --------------------------------------------------------------------------- def load_train_data(dataset_name="roneneldan/TinyStories", tokenizer=None, seq_len=256, batch_size=4, device="cpu"): """Create a streaming dataloader for training. Yields tensors of shape (batch_size, seq_len) from concatenated text. """ dataset = load_dataset(dataset_name, split="train", streaming=True) def tokenize_and_batch(): buffer = [] target_size = seq_len * batch_size for sample in dataset: tokens = tokenizer(sample["text"], truncation=False)["input_ids"] buffer.extend(tokens) while len(buffer) >= target_size: chunk = buffer[:target_size] buffer = buffer[target_size:] # Reshape to (batch_size, seq_len) x = torch.tensor(chunk, dtype=torch.long, device=device).reshape(batch_size, seq_len) yield x return tokenize_and_batch() def load_eval_data(eval_path): """Load pre-tokenized eval data.""" with open(eval_path, "r") as f: input_ids = json.load(f) return torch.tensor(input_ids, dtype=torch.long) # --------------------------------------------------------------------------- # Evaluation # --------------------------------------------------------------------------- @torch.no_grad() def evaluate(model, eval_ids, tokenizer, seq_len=128, device="cpu"): """Compute perplexity on eval data.""" model.eval() nll = 0.0 n_tokens = 0 # Move eval data to device in chunks for i in range(0, len(eval_ids) - seq_len, seq_len): chunk = eval_ids[i : i + seq_len + 1].to(device) # Shape: (1, seq_len+1) for batch dim chunk = chunk.unsqueeze(0) logits = model(chunk[:, :-1]).logits # (1, seq_len, vocab) loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), chunk[:, 1:].reshape(-1), reduction="sum" ) nll += loss.item() n_tokens += seq_len model.train() avg_nll = nll / n_tokens ppl = math.exp(avg_nll) bpb = avg_nll / math.log(2) # bits per byte return ppl, bpb, avg_nll # --------------------------------------------------------------------------- # Training Loop # --------------------------------------------------------------------------- def train(args): """Main training loop with ternary quantization.""" device = args.device start_time = time.time() # ---- Load model and tokenizer ---- print(f"\n{'='*60}") print(f" TERNARY QUANTIZATION TRAINING") print(f"{'='*60}") print(f" Model: {args.model_name}") print(f" Device: {device}") print(f" Steps: {args.steps}") print(f" Batch size: {args.batch_size}") print(f" Seq length: {args.seq_len}") print(f" LR: {args.learning_rate}") print(f" Group size: {args.group_size}") print(f" Act bits: {args.activation_bits}") print(f" Warmup: {args.quant_warmup_steps} steps") print(f"{'='*60}\n") tokenizer = AutoTokenizer.from_pretrained(args.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading model... (this may take a moment)") model = AutoModelForCausalLM.from_pretrained( args.model_name, torch_dtype=DEFAULTS['dtype'], low_cpu_mem_usage=True, ) model.to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Model loaded: {total_params:,} parameters") # ---- Replace Linear with BitLinear ---- print(f"\nReplacing Linear layers with BitLinear (group_size={args.group_size})...") n_replaced = replace_linears_with_bitlinear( model, group_size=args.group_size, activation_bits=args.activation_bits, threshold=args.threshold, soft_quant=args.soft_quant, ) print(f"Replaced {n_replaced} Linear layers with BitLinear") # Count ternary params ternary_params = sum( p.numel() for m in model.modules() if isinstance(m, BitLinear) for p in m.parameters() ) print(f"Ternary-quantized parameters: {ternary_params:,}") # ---- Load eval data ---- eval_ids = load_eval_data(args.eval_data_path) print(f"Eval data: {len(eval_ids):,} tokens") # ---- Optimizer ---- optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, weight_decay=0.01, ) # ---- Training loop ---- data_iter = load_train_data( dataset_name=args.train_dataset, tokenizer=tokenizer, seq_len=args.seq_len, batch_size=args.batch_size, device=device, ) best_ppl = float("inf") step_times = [] # Open log file log_path = Path(args.log_file) # Check if we need to write header write_header = not log_path.exists() or log_path.stat().st_size == 0 if write_header: with open(log_path, "w") as f: f.write("step\tlambda\ttrain_loss\ttrain_ppl\teval_ppl\teval_bpb\tlr\ttime_s\tbest_ppl\tq_neg1\tq_zero\tq_pos1\n") print(f"\n{'─'*60}") print(f" STEP LAMBDA LOSS PPL EVAL_PPL LR TIME") print(f"{'─'*60}") for step in range(1, args.steps + 1): step_start = time.time() # ---- Lambda warmup ---- if args.warmup_schedule == "plateau": # Plateau warmup: hold lambda at each level for plateau_steps levels = int(args.plateau_max / 0.05) # 0.05 increments plateau_size = args.plateau_steps level = min(step // plateau_size, levels) lambda_ = min(level * 0.05, args.plateau_max) else: # Linear warmup lambda_ = min(step / args.quant_warmup_steps, 1.0) set_lambda(model, lambda_) # ---- LR warmup ---- if step <= args.warmup_steps: lr_scale = step / args.warmup_steps else: lr_scale = 1.0 for param_group in optimizer.param_groups: param_group["lr"] = args.learning_rate * lr_scale current_lr = args.learning_rate * lr_scale # ---- Forward pass ---- model.train() try: batch = next(data_iter) except StopIteration: print("\n Dataset exhausted, restarting...") data_iter = load_train_data( dataset_name=args.train_dataset, tokenizer=tokenizer, seq_len=args.seq_len, batch_size=args.batch_size, device=device, ) batch = next(data_iter) # batch: (batch_size, seq_len) # Shift for teacher forcing x = batch[:, :-1] y = batch[:, 1:] logits = model(x).logits loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), y.reshape(-1), ) train_ppl = math.exp(loss.item()) # ---- Backward pass ---- optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() step_time = time.time() - step_start step_times.append(step_time) avg_step_time = sum(step_times[-20:]) / len(step_times[-20:]) # ---- Eval ---- eval_ppl = "-" eval_bpb = "-" quant_stats = None if step % args.eval_every == 0 or step == args.steps: ppl, bpb, _ = evaluate(model, eval_ids, tokenizer, seq_len=128, device=device) eval_ppl = f"{ppl:.1f}" eval_bpb = f"{bpb:.2f}" if ppl < best_ppl: best_ppl = ppl # Quantization stats quant_stats = get_quant_stats(model) elapsed = time.time() - start_time eta = avg_step_time * (args.steps - step) q_str = f"[-1:{quant_stats['-1']:.2f} 0:{quant_stats['0']:.2f} +1:{quant_stats['+1']:.2f}]" print(f" {step:5d} {lambda_:.3f} {loss.item():.4f} {train_ppl:7.1f} " f"{eval_ppl:>8s} {current_lr:.2e} {elapsed:.0f}s (ETA {eta:.0f}s)") print(f" eval_bpb={eval_bpb} best_ppl={best_ppl:.1f} " f"step_time={avg_step_time:.2f}s quant={q_str}") # Log to TSV with open(log_path, "a") as f: q = quant_stats or {"-1": 0, "0": 0, "+1": 0} f.write(f"{step}\t{lambda_:.4f}\t{loss.item():.6f}\t{train_ppl:.2f}\t" f"{ppl:.2f}\t{bpb:.4f}\t{current_lr:.2e}\t{elapsed:.1f}\t{best_ppl:.2f}\t" f"{q['-1']:.4f}\t{q['0']:.4f}\t{q['+1']:.4f}\n") else: elapsed = time.time() - start_time eta = avg_step_time * (args.steps - step) if step % 25 == 0 or step == 1: print(f" {step:5d} {lambda_:.3f} {loss.item():.4f} {train_ppl:7.1f} " f"{'—':>8s} {current_lr:.2e} {elapsed:.0f}s (ETA {eta:.0f}s)") # Log to TSV (no eval) with open(log_path, "a") as f: f.write(f"{step}\t{lambda_:.4f}\t{loss.item():.6f}\t{train_ppl:.2f}\t" f"-\t-\t{current_lr:.2e}\t{elapsed:.1f}\t{best_ppl:.2f}\t-\t-\t-\n") # ---- Summary ---- total_time = time.time() - start_time print(f"\n{'='*60}") print(f" TRAINING COMPLETE") print(f" Total time: {total_time:.0f}s ({total_time/60:.1f} min)") print(f" Best eval PPL: {best_ppl:.1f}") print(f" Final lambda: {lambda_:.3f}") print(f" Avg step time: {sum(step_times)/len(step_times):.2f}s") print(f" Results logged to: {log_path}") print(f"{'='*60}\n") return best_ppl # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="Ternary quantization training") parser.add_argument("--model-name", default=DEFAULTS["model_name"]) parser.add_argument("--device", default=DEFAULTS["device"]) parser.add_argument("--steps", type=int, default=DEFAULTS["steps"]) parser.add_argument("--batch-size", type=int, default=DEFAULTS["batch_size"]) parser.add_argument("--seq-len", type=int, default=DEFAULTS["seq_len"]) parser.add_argument("--lr", type=float, default=DEFAULTS["learning_rate"], dest="learning_rate") parser.add_argument("--warmup-steps", type=int, default=DEFAULTS["warmup_steps"], dest="warmup_steps") parser.add_argument("--eval-every", type=int, default=DEFAULTS["eval_every"], dest="eval_every") parser.add_argument("--group-size", type=int, default=DEFAULTS["group_size"], dest="group_size") parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps") parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits") parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold") parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant") parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant") parser.add_argument("--warmup-schedule", default=DEFAULTS["warmup_schedule"], dest="warmup_schedule") parser.add_argument("--plateau-steps", type=int, default=DEFAULTS["plateau_steps"], dest="plateau_steps") parser.add_argument("--plateau-max", type=float, default=DEFAULTS["plateau_max"], dest="plateau_max") parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset") parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path") parser.add_argument("--log-file", default=DEFAULTS["log_file"]) args = parser.parse_args() train(args) if __name__ == "__main__": main()