From 7378d4ef8f8211f9aba23d397ceb5aa78097bd53 Mon Sep 17 00:00:00 2001 From: Kaloyan Nikolov Date: Fri, 24 Apr 2026 01:36:44 +0200 Subject: [PATCH] Add ternary QAT training pipeline: prepare.py (data/eval), train.py (quantization/training), program.md (agent instructions), autoresearch.sh (loop) --- autoresearch.sh | 122 +++++++++ prepare.py | 324 +++++++++++++++++++++++ program.md | 78 ++++++ results.tsv | 1 + train.py | 664 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1189 insertions(+) create mode 100755 autoresearch.sh create mode 100644 prepare.py create mode 100644 program.md create mode 100644 results.tsv create mode 100644 train.py diff --git a/autoresearch.sh b/autoresearch.sh new file mode 100755 index 0000000..7bf5880 --- /dev/null +++ b/autoresearch.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# autoresearch.sh — Autonomous experiment loop +# +# Usage: ./autoresearch.sh [model] [time_budget_seconds] +# +# This script runs the autoresearch loop: +# 1. Agent proposes a change to train.py +# 2. Git commit the change +# 3. Run training for fixed time budget +# 4. Extract val_ppl from results.tsv +# 5. If improved → keep; if worse → git reset +# 6. Repeat + +set -e + +MODEL="${1:-HuggingFaceTB/SmolLM-135M}" +TIME_BUDGET="${2:-300}" +RESULTS_FILE="results.tsv" + +# Ensure git is initialized +if [ ! -d ".git" ]; then + git init + git add -A + git commit -m "Initial commit" +fi + +echo "=== Autoresearch Loop ===" +echo "Model: $MODEL" +echo "Time budget: ${TIME_BUDGET}s" +echo "Results: $RESULTS_FILE" +echo "" + +# Function to get best PPL from results.tsv +get_best_ppl() { + if [ ! -f "$RESULTS_FILE" ]; then + echo "999999" + return + fi + # Get the best_ppl from the last successful run (column 7) + tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]' +} + +# Function to get last status +get_last_status() { + if [ ! -f "$RESULTS_FILE" ]; then + echo "none" + return + fi + tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]' +} + +# Initial commit if not committed +git add -A +git commit -m "Initial setup" --allow-empty 2>/dev/null || true + +BEST_PPL=$(get_best_ppl) +RUN_NUM=0 + +while true; do + RUN_NUM=$((RUN_NUM + 1)) + echo "" + echo "========================================" + echo "RUN #$RUN_NUM" + echo "Current best PPL: $BEST_PPL" + echo "========================================" + + # Save current state + PREV_COMMIT=$(git rev-parse HEAD) + + # Prompt the agent to make a change + # In production, this would call the LLM agent + # For now, we just run with current config + echo "Running training..." + + # Run training + START_TIME=$(date +%s) + + python3 train.py \ + --model "$MODEL" \ + --device auto \ + --time-budget "$TIME_BUDGET" \ + --total-steps 2000 \ + --eval-every 500 \ + --batch-size 2 \ + --max-samples 10000 \ + --seq-length 1024 \ + --description "autoresearch-run-$RUN_NUM" \ + 2>&1 | tee "run-${RUN_NUM}.log" || true + + END_TIME=$(date +%s) + ELAPSED=$((END_TIME - START_TIME)) + + # Check results + STATUS=$(get_last_status) + NEW_PPL=$(get_best_ppl) + + echo "" + echo "Run #$RUN_NUM completed in ${ELAPSED}s" + echo "Status: $STATUS" + echo "Best PPL: $NEW_PPL" + + if [ "$STATUS" = "success" ]; then + # Compare with previous best + if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then + echo "IMPROVED! Keeping changes." + BEST_PPL=$NEW_PPL + git add results.tsv + git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL" + else + echo "No improvement. Reverting." + git reset --hard $PREV_COMMIT 2>/dev/null || true + git checkout -- results.tsv 2>/dev/null || true + fi + else + echo "FAILED. Reverting." + git reset --hard $PREV_COMMIT 2>/dev/null || true + git checkout -- results.tsv 2>/dev/null || true + fi + + echo "" + echo "Continuing... (Ctrl+C to stop)" +done diff --git a/prepare.py b/prepare.py new file mode 100644 index 0000000..6efb76a --- /dev/null +++ b/prepare.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +""" +prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop). + +This file is NOT modified by the autoresearch agent. All mutable training logic +lives in train.py. This file provides: + - Dataset loading and tokenization + - Evaluation harness (WikiText PPL, optional zero-shot tasks) + - Utility functions for the training loop +""" + +import os +import math +import torch +from torch.utils.data import Dataset +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +# Default model — small enough for fast iteration on M4 Pro / consumer GPU +DEFAULT_MODEL = "meta-llama/Llama-3.2-1B" +# Alternative tiny models for ultra-fast iteration: +# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M" + +# Default eval dataset +DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb" +DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1" + +# Default training dataset +DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu +# Alternative: "monology/patrickstar-core-1000" for smaller local testing + +# Sequence length for training +DEFAULT_SEQ_LENGTH = 2048 + +# --------------------------------------------------------------------------- +# Tokenizer +# --------------------------------------------------------------------------- + +def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer: + """Load tokenizer for the given model.""" + tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True, + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +# --------------------------------------------------------------------------- +# Dataset Wrappers +# --------------------------------------------------------------------------- + +class TokenizedDataset(Dataset): + """Simple dataset that yields batches of tokenized text.""" + + def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH): + self.tokenizer = tokenizer + self.seq_length = seq_length + # Concatenate all texts and chunk into sequences + self.samples = self._prepare_samples(data) + + def _prepare_samples(self, data: list) -> list: + """Flatten and chunk text data into fixed-length sequences.""" + all_tokens = [] + for item in data: + text = item.get("text", item.get("content", str(item))) + tokens = self.tokenizer.encode(text, add_special_tokens=False) + all_tokens.extend(tokens) + all_tokens.append(self.tokenizer.eos_token_id) + + # Chunk into sequences of seq_length + samples = [] + for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length): + samples.append(all_tokens[i : i + self.seq_length]) + return samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + x = torch.tensor(self.samples[idx], dtype=torch.long) + return {"input_ids": x[:-1], "labels": x[1:]} + + +def load_eval_dataset( + dataset_name: str = DEFAULT_EVAL_DATASET, + config: str = DEFAULT_EVAL_CONFIG, + split: str = "validation", + tokenizer: AutoTokenizer = None, + seq_length: int = DEFAULT_SEQ_LENGTH, +): + """Load and tokenize the evaluation dataset.""" + if dataset_name == "wikitext": + ds = load_dataset(dataset_name, config, trust_remote_code=True)[split] + elif dataset_name == "ptb": + ds = load_dataset(dataset_name, trust_remote_code=True)[split] + else: + ds = load_dataset(dataset_name, split=split, trust_remote_code=True) + + if tokenizer is None: + tokenizer = load_tokenizer() + + # For eval, we want contiguous text without chunking artifacts + texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds] + all_tokens = [] + for text in texts: + tokens = tokenizer.encode(text, add_special_tokens=False) + all_tokens.extend(tokens) + + # Remove empty tokens + all_tokens = [t for t in all_tokens if t is not None] + + # Create sequences + samples = [] + for i in range(0, len(all_tokens) - seq_length + 1, seq_length): + samples.append(all_tokens[i : i + seq_length]) + + return samples, tokenizer + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- + +@torch.no_grad() +def evaluate_ppl( + model: AutoModelForCausalLM, + eval_samples: list, + tokenizer: AutoTokenizer, + seq_length: int = DEFAULT_SEQ_LENGTH, + batch_size: int = 4, + device: str = "cpu", +) -> float: + """ + Evaluate perplexity on the given samples. + Returns perplexity (lower is better). + """ + model.eval() + total_loss = 0.0 + total_tokens = 0 + + for i in range(0, len(eval_samples), batch_size): + batch_samples = eval_samples[i : i + batch_size] + inputs = torch.tensor(batch_samples, dtype=torch.long, device=device) + labels = inputs.clone() + + # Shift for causal LM + input_ids = inputs[:, :-1] + label_ids = labels[:, 1:] + + device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu") + if device_type != "cpu": + with torch.amp.autocast(device_type=device_type): + outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False) + else: + outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False) + + loss = outputs.loss + total_loss += loss.item() * label_ids.size(0) + total_tokens += label_ids.size(0) + + avg_loss = total_loss / total_tokens + perplexity = math.exp(avg_loss) + + model.train() + return perplexity + + +@torch.no_grad() +def evaluate_bpb( + model: AutoModelForCausalLM, + eval_samples: list, + tokenizer: AutoTokenizer, + seq_length: int = DEFAULT_SEQ_LENGTH, + batch_size: int = 4, + device: str = "cpu", +) -> float: + """ + Evaluate bits-per-byte (bpb) on the given samples. + Returns bpb (lower is better). + """ + ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device) + bpb = math.log2(ppl) + return bpb + + +# --------------------------------------------------------------------------- +# Training Dataset Loader +# --------------------------------------------------------------------------- + +def load_train_dataset( + dataset_name: str = DEFAULT_TRAIN_DATASET, + split: str = "train", + tokenizer: AutoTokenizer = None, + seq_length: int = DEFAULT_SEQ_LENGTH, + max_samples: int = None, + streaming: bool = True, +): + """ + Load and prepare the training dataset. + Returns a TokenizedDataset. + """ + if tokenizer is None: + tokenizer = load_tokenizer() + + if dataset_name == "HuggingFaceFW/fineweb-edu": + # Use the edu-only subset for quality + ds = load_dataset( + dataset_name, + name="sample-100BT", # 100B token subset for faster iteration + split=split, + streaming=streaming, + trust_remote_code=True, + ) + elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"): + ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming) + else: + ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True) + + # Collect texts (handle streaming by taking first N samples) + texts = [] + for idx, item in enumerate(ds): + if max_samples and idx >= max_samples: + break + text = item.get("text", item.get("content", str(item))) + texts.append(text) + + if not texts: + raise ValueError(f"No texts loaded from dataset {dataset_name}") + + dataset = TokenizedDataset(texts, tokenizer, seq_length) + return dataset + + +# --------------------------------------------------------------------------- +# Model Loading +# --------------------------------------------------------------------------- + +def load_model( + model_name: str = DEFAULT_MODEL, + device: str = "cpu", + dtype: torch.dtype = torch.float16, + attn_implementation: str = "sdpa", +) -> AutoModelForCausalLM: + """ + Load a pretrained model from HuggingFace. + """ + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=dtype, + attn_implementation=attn_implementation, + trust_remote_code=True, + device_map=device if device != "cpu" else None, + ) + if device == "cpu": + model = model.to(device) + return model + + +# --------------------------------------------------------------------------- +# CLI Entry Point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets") + parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path") + parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset") + parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config") + parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length") + parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size") + parser.add_argument("--device", default="cpu", help="Device to run on") + parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load") + parser.add_argument("--dry-run", action="store_true", help="Just print config without loading") + + args = parser.parse_args() + + if args.dry_run: + print(f"Model: {args.model}") + print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})") + print(f"Seq length: {args.seq_length}") + print(f"Batch size: {args.batch_size}") + print(f"Device: {args.device}") + print(f"Train samples: {args.train_samples}") + exit(0) + + print(f"Loading tokenizer for {args.model}...") + tokenizer = load_tokenizer(args.model) + + print(f"Loading eval dataset {args.eval_dataset}...") + eval_samples, _ = load_eval_dataset( + dataset_name=args.eval_dataset, + config=args.eval_config, + tokenizer=tokenizer, + seq_length=args.seq_length, + ) + print(f" Eval samples: {len(eval_samples)}") + + print(f"Loading training dataset ({args.train_samples} samples)...") + train_dataset = load_train_dataset( + tokenizer=tokenizer, + seq_length=args.seq_length, + max_samples=args.train_samples, + streaming=True, + ) + print(f" Train samples: {len(train_dataset)}") + + # Optional: load model and run eval + if args.device != "skip-model": + print(f"Loading model {args.model}...") + model = load_model(args.model, device=args.device) + print(f" Model params: {sum(p.numel() for p in model.parameters()):,}") + + print(f"Evaluating perplexity...") + ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device) + print(f" Perplexity: {ppl:.2f}") + bpb = math.log2(ppl) + print(f" BPB: {bpb:.4f}") diff --git a/program.md b/program.md new file mode 100644 index 0000000..e83de5c --- /dev/null +++ b/program.md @@ -0,0 +1,78 @@ +# program.md — Instructions for the Autoresearch Agent + +You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs. + +## Your Goal + +Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2. + +## File Boundaries + +- **MUTABLE**: `train.py` — You may modify this file. It contains: + - `BitLinear` layer (quantization logic) + - Quantization schedules (lambda warmup) + - Training loop + - Hyperparameters (LR, batch size, group size, etc.) + +- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains: + - Dataset loading and tokenization + - Evaluation harness (WikiText PPL) + - Model loading utilities + +- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run. + +## Experiment Protocol + +1. Read `results.tsv` to understand what has been tried +2. Read `train.py` to understand current implementation +3. Propose ONE focused change to `train.py` +4. The change will be committed and a training run will execute (~5 minutes) +5. Results are logged to `results.tsv` +6. If improved (lower val_ppl) → change is kept +7. If equal or worse → git reset to previous commit + +## What to Explore + +### Priority 1: Quantization Schedule +- Lambda warmup shape: linear, cosine, exponential +- Warmup step counts: 200, 500, 1000, 2000, 5000 +- Two-phase warmup (fast initial + slow final) + +### Priority 2: Learning Rate +- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4 +- LR schedule: constant, linear decay, cosine decay + +### Priority 3: Quantization Details +- Group size: 32, 64, 128, 256, per-tensor +- Scale initialization: mean-based vs absmax-based +- Ternary threshold adjustments + +### Priority 4: Deadzone Recovery +- Tequila-style reactivation (learnable lambda for deadzone weights) +- Bias injection for zero-valued weights +- Gradient scaling for deadzone weights + +### Priority 5: Distillation +- OFF loss (cosine similarity between FP and ternary features) +- Logits distillation weight +- Feature distillation weight + +## Constraints + +- Keep experiments focused — ONE change per iteration +- Always maintain working code — syntax errors waste time +- Use SmolLM-135M or Llama-3.2-1B for fast iteration +- Target metric: val_ppl (lower is better) +- Time budget: 5 minutes per experiment + +## Important Notes + +- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization +- Warmup quantization prevents catastrophic accuracy loss at the start of training +- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions +- Per-group quantization scales are essential for handling outlier weights +- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))` + +## NEVER STOP + +Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision. diff --git a/results.tsv b/results.tsv new file mode 100644 index 0000000..507550a --- /dev/null +++ b/results.tsv @@ -0,0 +1 @@ +timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description diff --git a/train.py b/train.py new file mode 100644 index 0000000..9ce0328 --- /dev/null +++ b/train.py @@ -0,0 +1,664 @@ +#!/usr/bin/env python3 +""" +train.py — Ternary quantization training loop (MUTABLE in autoresearch loop). + +This is the single mutable file for the autoresearch agent. The agent may modify: + - Quantization schedule (lambda warmup shape, length) + - Learning rate and schedule + - BitLinear implementation details (group size, deadzone recovery) + - Loss function (distillation weights, OFF loss) + - Training hyperparameters + +DO NOT modify prepare.py — all data loading and evaluation lives there. +""" + +import os +import sys +import math +import time +import json +import csv +import git +import random +import argparse +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup + +# Import from prepare.py (read-only) +from prepare import ( + load_tokenizer, + load_train_dataset, + load_eval_dataset, + load_model, + evaluate_ppl, + evaluate_bpb, + DEFAULT_MODEL, + DEFAULT_SEQ_LENGTH, + DEFAULT_TRAIN_DATASET, + DEFAULT_EVAL_DATASET, + DEFAULT_EVAL_CONFIG, +) + +# --------------------------------------------------------------------------- +# Quantization Utilities +# --------------------------------------------------------------------------- + +def ternary_quantize_weight(w: torch.Tensor, group_size: int = 128) -> tuple: + """ + Ternary quantize weights to {-1, 0, +1} with per-group scales. + + Quantization formula (from HF blog): + scale = 1.0 / w.abs().mean() + w_scaled = w * scale + w_clamped = clamp(w_scaled, -1, 1) + w_rounded = round(w_clamped) # -> {-1, 0, +1} + + Returns: + w_quant: ternary quantized weights + scales: per-group scales for dequantization + """ + original_shape = w.shape + # Flatten for group-wise processing + w_flat = w.flatten().float() + + num_groups = (len(w_flat) + group_size - 1) // group_size + scales = [] + quant_parts = [] + + for i in range(num_groups): + start = i * group_size + end = min(start + group_size, len(w_flat)) + group = w_flat[start:end] + + # Compute scale: 1 / mean(|w|) + scale = 1.0 / (group.abs().mean() + 1e-6) + scales.append(scale) + + # Scale, clamp, round + w_scaled = group * scale + w_clamped = w_scaled.clamp(-1.0, 1.0) + w_rounded = w_clamped.round().sign() # -> {-1, 0, +1} + + quant_parts.append(w_rounded) + + w_quant = torch.cat(quant_parts, dim=0).reshape(original_shape) + scales = torch.tensor(scales, device=w.device, dtype=w.dtype) + + return w_quant, scales + + +def ternary_dequantize(w_quant: torch.Tensor, scales: torch.Tensor, group_size: int = 128) -> torch.Tensor: + """Dequantize ternary weights back to original shape.""" + original_shape = w_quant.shape + w_flat = w_quant.flatten().float() + + num_groups = (len(w_flat) + group_size - 1) // group_size + dequant_parts = [] + + for i in range(num_groups): + start = i * group_size + end = min(start + group_size, len(w_flat)) + group = w_flat[start:end] + scale = scales[i] + dequant_parts.append(group / (scale + 1e-6)) + + return torch.cat(dequant_parts, dim=0).reshape(original_shape).to(w_quant.dtype) + + +# --------------------------------------------------------------------------- +# BitLinear Layer +# --------------------------------------------------------------------------- + +class BitLinear(nn.Module): + """ + Linear layer with ternary weight quantization using STE (Straight-Through Estimator). + + Forward pass uses quantized weights; backward pass uses STE to pass gradients + through the quantization operation. + + Supports warmup quantization: gradually introduce quantization during training + using lambda scheduling to avoid catastrophic accuracy loss. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + group_size: int = 128, + init_scale: float = 1.0, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + self.init_scale = init_scale + + # FP16 weights (the "learned" weights) + self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.bias = None + + # Quantization state + self.register_buffer("w_quant", torch.zeros(out_features, in_features)) + self.register_buffer("scales", torch.zeros((out_features * in_features + group_size - 1) // group_size)) + self.register_buffer("lambda", torch.tensor(0.0)) # warmup lambda + + def _quantize(self, w: torch.Tensor) -> tuple: + """Quantize weights to ternary.""" + return ternary_quantize_weight(w, self.group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with warmup quantization. + + Lambda is set externally via self.lambda_val before each forward pass. + + Quantization warmup formula: + w_eff = w + lambda * (w_dequant - w).detach() + + When lambda=0: output uses full-precision weights + When lambda=1: output uses quantized weights (with STE gradients) + """ + # Get lambda (default 1.0 if not set) + lambda_ = getattr(self, 'lambda_val', 1.0) + + # Quantize weights + w_quant, scales = self._quantize(self.weight) + + # Dequantize for computation + w_dequant = ternary_dequantize(w_quant, scales, self.group_size) + + # STE with warmup: interpolate between FP and quantized + # w_eff = w + lambda * (w_dequant - w) + # Gradient flows through w (the FP weights), not through quantization + if lambda_ > 0: + w_eff = self.weight + lambda_ * (w_dequant - self.weight).detach() + else: + w_eff = self.weight + + # Linear forward + x = x.to(w_eff.dtype) + out = F.linear(x, w_eff, self.bias) + return out + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, " + f"out_features={self.out_features}, " + f"group_size={self.group_size}, " + f"bias={self.bias is not None}" + ) + + +# --------------------------------------------------------------------------- +# Model Conversion +# --------------------------------------------------------------------------- + +def replace_linears_with_bitlinear( + model: AutoModelForCausalLM, + group_size: int = 128, + target_modules: list = None, +) -> AutoModelForCausalLM: + """ + Replace all nn.Linear layers in the model with BitLinear layers. + + Args: + model: pretrained model + group_size: quantization group size + target_modules: list of module names to replace. If None, replace all. + + Returns: + Model with BitLinear layers + """ + if target_modules is None: + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + + for name, module in model.named_modules(): + parent_name = ".".join(name.split(".")[:-1]) + module_name = name.split(".")[-1] + + if isinstance(module, nn.Linear) and module_name in target_modules: + bit_linear = BitLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + group_size=group_size, + ) + # Initialize from pretrained weights + bit_linear.weight.data = module.weight.data.clone() + if module.bias is not None: + bit_linear.bias.data = module.bias.data.clone() + + # Replace in parent + parent = model + for attr in parent_name.split("."): + parent = getattr(parent, attr) + setattr(parent, module_name, bit_linear) + + return model + + +def get_bitlinear_modules(model: AutoModelForCausalLM) -> list: + """Get all BitLinear modules in the model.""" + return [m for m in model.modules() if isinstance(m, BitLinear)] + + +# --------------------------------------------------------------------------- +# Lambda Warmup Schedules +# --------------------------------------------------------------------------- + +def linear_warmup_lambda(step: int, warmup_steps: int) -> float: + """Linear warmup: lambda goes from 0 to 1 over warmup_steps.""" + return min(step / warmup_steps, 1.0) + + +def cosine_warmup_lambda(step: int, warmup_steps: int) -> float: + """Cosine warmup: lambda follows cosine schedule over warmup_steps.""" + if step >= warmup_steps: + return 1.0 + return 0.5 * (1 - math.cos(math.pi * step / warmup_steps)) + + +def exponential_warmup_lambda(step: int, warmup_steps: int) -> float: + """Exponential warmup: lambda grows exponentially over warmup_steps.""" + if step >= warmup_steps: + return 1.0 + # Exponential from 0 to 1 + return 1.0 - math.exp(-3.0 * step / warmup_steps) + + +def get_lambda(step: int, warmup_steps: int, schedule: str = "linear") -> float: + """Get lambda value for the given step and schedule.""" + if schedule == "linear": + return linear_warmup_lambda(step, warmup_steps) + elif schedule == "cosine": + return cosine_warmup_lambda(step, warmup_steps) + elif schedule == "exponential": + return exponential_warmup_lambda(step, warmup_steps) + else: + raise ValueError(f"Unknown lambda schedule: {schedule}") + + +# --------------------------------------------------------------------------- +# Training Loop +# --------------------------------------------------------------------------- + +def train( + model_name: str = DEFAULT_MODEL, + train_dataset_name: str = DEFAULT_TRAIN_DATASET, + eval_dataset_name: str = DEFAULT_EVAL_DATASET, + eval_config: str = DEFAULT_EVAL_CONFIG, + seq_length: int = DEFAULT_SEQ_LENGTH, + batch_size: int = 4, + learning_rate: float = 1e-4, + warmup_steps: int = 1000, + total_steps: int = 5000, + eval_every: int = 500, + lambda_schedule: str = "linear", + group_size: int = 128, + max_train_samples: int = 50000, + device: str = "cpu", + dtype: torch.dtype = torch.float16, + time_budget: float = 300.0, # seconds (5 minutes default) + seed: int = 42, + output_dir: str = "./results", +): + """ + Main training loop for ternary quantization-aware training. + + Args: + model_name: HuggingFace model name or path + train_dataset_name: Training dataset name + eval_dataset_name: Evaluation dataset name + eval_config: Evaluation dataset config + seq_length: Sequence length for training + batch_size: Training batch size + learning_rate: Base learning rate + warmup_steps: Number of steps for lambda warmup + total_steps: Total training steps + eval_every: Evaluate every N steps + lambda_schedule: Lambda warmup schedule ("linear", "cosine", "exponential") + group_size: Quantization group size + max_train_samples: Max training samples to load + device: Device to train on + dtype: Data type + time_budget: Max training time in seconds + seed: Random seed + output_dir: Output directory for results + + Returns: + dict with training metrics + """ + # Auto-detect device if not specified + if device == "auto": + if torch.backends.mps.is_available(): + device = "mps" + print("Using MPS (Apple Silicon)") + elif torch.cuda.is_available(): + device = "cuda" + print("Using CUDA") + else: + device = "cpu" + print("Using CPU") + + # Set seeds + random.seed(seed) + torch.manual_seed(seed) + if device == "cuda": + torch.cuda.manual_seed(seed) + + # Track time + start_time = time.time() + + print(f"=== Ternary QAT Training ===") + print(f"Model: {model_name}") + print(f"Device: {device}") + print(f"LR: {learning_rate}, Warmup: {warmup_steps} steps") + print(f"Lambda schedule: {lambda_schedule}") + print(f"Group size: {group_size}") + print(f"Time budget: {time_budget}s") + print(f"Total steps: {total_steps}") + print() + + # Load tokenizer + print("Loading tokenizer...") + tokenizer = load_tokenizer(model_name) + + # Load model + print(f"Loading model {model_name}...") + model = load_model(model_name, device=device, dtype=dtype, attn_implementation="sdpa") + model = model.to(dtype) + print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") + print(f" Linear layers replaced: ", end="") + + # Replace Linear with BitLinear + model = replace_linears_with_bitlinear(model, group_size=group_size) + bitlinears = get_bitlinear_modules(model) + print(f"{len(bitlinears)}") + + # Load datasets + print(f"Loading training dataset ({max_train_samples} samples)...") + train_dataset = load_train_dataset( + dataset_name=train_dataset_name, + tokenizer=tokenizer, + seq_length=seq_length, + max_samples=max_train_samples, + streaming=True, + ) + print(f" Train samples: {len(train_dataset)}") + + print(f"Loading eval dataset...") + eval_samples, _ = load_eval_dataset( + dataset_name=eval_dataset_name, + config=eval_config, + tokenizer=tokenizer, + seq_length=seq_length, + ) + print(f" Eval samples: {len(eval_samples)}") + + # Initial evaluation + print("\n--- Initial Evaluation ---") + initial_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device) + initial_bpb = math.log2(initial_ppl) + print(f" Initial PPL: {initial_ppl:.2f}") + print(f" Initial BPB: {initial_bpb:.4f}") + + # Optimizer + # Only optimize BitLinear weights and biases (not the quantization buffers) + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], + lr=learning_rate, + weight_decay=0.01, + ) + + # Learning rate scheduler + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=int(total_steps * 0.1), + num_training_steps=total_steps, + ) + + # DataLoader + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + loader_iter = iter(train_loader) + + # Training loop + results = { + "initial_ppl": initial_ppl, + "initial_bpb": initial_bpb, + "final_ppl": None, + "final_bpb": None, + "best_ppl": initial_ppl, + "best_step": 0, + "steps_completed": 0, + "training_time": 0, + } + + step = 0 + train_loss = 0.0 + steps_since_last_eval = 0 + + print(f"\n--- Training ---") + + while step < total_steps: + # Check time budget + elapsed = time.time() - start_time + if elapsed >= time_budget - 30: # Reserve 30s for final eval + print(f"\n Time budget nearly reached ({elapsed:.0f}/{time_budget}s). Stopping.") + break + + # Get lambda for this step + lambda_ = get_lambda(step, warmup_steps, lambda_schedule) + + # Get batch + try: + batch = next(loader_iter) + except StopIteration: + loader_iter = iter(train_loader) + batch = next(loader_iter) + + input_ids = batch["input_ids"].to(device) + labels = batch["labels"].to(device) + + # Set lambda on all BitLinear layers for this step + for bl in get_bitlinear_modules(model): + bl.lambda_val = lambda_ + + # Forward pass + outputs = model(input_ids=input_ids, labels=labels, use_cache=False) + + loss = outputs.loss + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + train_loss += loss.item() + step += 1 + steps_since_last_eval += 1 + + # Periodic logging + if step % 100 == 0: + avg_loss = train_loss / 100 + avg_ppl = math.exp(avg_loss) + elapsed = time.time() - start_time + print(f" Step {step}/{total_steps} | Loss: {avg_loss:.4f} | PPL: {avg_ppl:.2f} | λ: {lambda_:.3f} | LR: {lr_scheduler.get_last_lr()[0]:.2e} | Time: {elapsed:.0f}s") + train_loss = 0.0 + + # Evaluation + if steps_since_last_eval >= eval_every or step == total_steps: + ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device) + bpb = math.log2(ppl) + elapsed = time.time() - start_time + + print(f"\n [Eval] Step {step} | PPL: {ppl:.2f} | BPB: {bpb:.4f} | λ: {lambda_:.3f} | Time: {elapsed:.0f}s") + + if ppl < results["best_ppl"]: + results["best_ppl"] = ppl + results["best_step"] = step + print(f" *** New best PPL! ***") + + steps_since_last_eval = 0 + + # Final evaluation + print(f"\n--- Final Evaluation ---") + final_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device) + final_bpb = math.log2(final_ppl) + results["final_ppl"] = final_ppl + results["final_bpb"] = final_bpb + results["steps_completed"] = step + results["training_time"] = time.time() - start_time + + print(f" Final PPL: {final_ppl:.2f}") + print(f" Final BPB: {final_bpb:.4f}") + print(f" Best PPL: {results['best_ppl']:.2f} (step {results['best_step']})") + print(f" Steps completed: {step}") + print(f" Training time: {results['training_time']:.0f}s") + + return results + + +# --------------------------------------------------------------------------- +# Results Logging +# --------------------------------------------------------------------------- + +def log_result(results: dict, config: dict, results_file: str = "results.tsv"): + """Log results to TSV file.""" + file_exists = os.path.exists(results_file) + + with open(results_file, "a") as f: + if not file_exists: + # Write header + headers = ["timestamp", "commit", "status", "val_ppl", "val_bpb", + "initial_ppl", "best_ppl", "best_step", "steps", "time_s", + "model", "lr", "warmup_steps", "total_steps", "lambda_schedule", + "group_size", "batch_size", "seq_length", "description"] + f.write("\t".join(headers) + "\n") + + row = [ + time.strftime("%Y-%m-%d %H:%M:%S"), + config.get("commit", ""), + config.get("status", "success"), + f"{results.get('final_ppl', 0):.4f}", + f"{results.get('final_bpb', 0):.4f}", + f"{results.get('initial_ppl', 0):.4f}", + f"{results.get('best_ppl', 0):.4f}", + str(results.get('best_step', 0)), + str(results.get('steps_completed', 0)), + f"{results.get('training_time', 0):.1f}", + config.get("model", ""), + str(config.get("lr", "")), + str(config.get("warmup_steps", "")), + str(config.get("total_steps", "")), + config.get("lambda_schedule", ""), + str(config.get("group_size", "")), + str(config.get("batch_size", "")), + str(config.get("seq_length", "")), + config.get("description", ""), + ] + f.write("\t".join(row) + "\n") + + print(f"\nResult logged to {results_file}") + + +def get_git_commit() -> str: + """Get current git commit hash.""" + try: + repo = git.Repo(search_parent_directories=True) + return repo.head.object.hexsha[:8] + except: + return "nogit" + + +# --------------------------------------------------------------------------- +# CLI Entry Point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Ternary QAT Training") + parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name") + parser.add_argument("--train-dataset", default=DEFAULT_TRAIN_DATASET, help="Training dataset") + parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Eval dataset") + parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Eval dataset config") + parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length") + parser.add_argument("--batch-size", type=int, default=4, help="Batch size") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--warmup-steps", type=int, default=1000, help="Lambda warmup steps") + parser.add_argument("--total-steps", type=int, default=5000, help="Total training steps") + parser.add_argument("--eval-every", type=int, default=500, help="Eval every N steps") + parser.add_argument("--lambda-schedule", default="linear", choices=["linear", "cosine", "exponential"], help="Lambda schedule") + parser.add_argument("--group-size", type=int, default=128, help="Quantization group size") + parser.add_argument("--max-samples", type=int, default=50000, help="Max training samples") + parser.add_argument("--device", default="auto", help="Device (auto=cuda/mps/cpu)") + parser.add_argument("--time-budget", type=float, default=300.0, help="Time budget in seconds") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--description", default="", help="Experiment description") + parser.add_argument("--dry-run", action="store_true", help="Print config without training") + + args = parser.parse_args() + + config = { + "commit": get_git_commit(), + "model": args.model, + "lr": args.lr, + "warmup_steps": args.warmup_steps, + "total_steps": args.total_steps, + "lambda_schedule": args.lambda_schedule, + "group_size": args.group_size, + "batch_size": args.batch_size, + "seq_length": args.seq_length, + "description": args.description or f"λ={args.lambda_schedule}-{args.warmup_steps} lr={args.lr} gs={args.group_size}", + } + + if args.dry_run: + print("=== Configuration (dry run) ===") + for k, v in config.items(): + print(f" {k}: {v}") + print(f" time_budget: {args.time_budget}s") + print(f" device: {args.device}") + exit(0) + + try: + results = train( + model_name=args.model, + train_dataset_name=args.train_dataset, + eval_dataset_name=args.eval_dataset, + eval_config=args.eval_config, + seq_length=args.seq_length, + batch_size=args.batch_size, + learning_rate=args.lr, + warmup_steps=args.warmup_steps, + total_steps=args.total_steps, + eval_every=args.eval_every, + lambda_schedule=args.lambda_schedule, + group_size=args.group_size, + max_train_samples=args.max_samples, + device=args.device, + time_budget=args.time_budget, + seed=args.seed, + ) + + config["status"] = "success" + log_result(results, config) + + except Exception as e: + print(f"\n!!! Training failed: {e}") + import traceback + traceback.print_exc() + + config["status"] = "failed" + config["error"] = str(e) + log_result({}, config) + sys.exit(1)