Add ternary QAT training pipeline: prepare.py (data/eval), train.py (quantization/training), program.md (agent instructions), autoresearch.sh (loop)

This commit is contained in:
2026-04-24 01:36:44 +02:00
parent f4601547d2
commit 7378d4ef8f
5 changed files with 1189 additions and 0 deletions
+122
View File
@@ -0,0 +1,122 @@
#!/bin/bash
# autoresearch.sh — Autonomous experiment loop
#
# Usage: ./autoresearch.sh [model] [time_budget_seconds]
#
# This script runs the autoresearch loop:
# 1. Agent proposes a change to train.py
# 2. Git commit the change
# 3. Run training for fixed time budget
# 4. Extract val_ppl from results.tsv
# 5. If improved → keep; if worse → git reset
# 6. Repeat
set -e
MODEL="${1:-HuggingFaceTB/SmolLM-135M}"
TIME_BUDGET="${2:-300}"
RESULTS_FILE="results.tsv"
# Ensure git is initialized
if [ ! -d ".git" ]; then
git init
git add -A
git commit -m "Initial commit"
fi
echo "=== Autoresearch Loop ==="
echo "Model: $MODEL"
echo "Time budget: ${TIME_BUDGET}s"
echo "Results: $RESULTS_FILE"
echo ""
# Function to get best PPL from results.tsv
get_best_ppl() {
if [ ! -f "$RESULTS_FILE" ]; then
echo "999999"
return
fi
# Get the best_ppl from the last successful run (column 7)
tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]'
}
# Function to get last status
get_last_status() {
if [ ! -f "$RESULTS_FILE" ]; then
echo "none"
return
fi
tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]'
}
# Initial commit if not committed
git add -A
git commit -m "Initial setup" --allow-empty 2>/dev/null || true
BEST_PPL=$(get_best_ppl)
RUN_NUM=0
while true; do
RUN_NUM=$((RUN_NUM + 1))
echo ""
echo "========================================"
echo "RUN #$RUN_NUM"
echo "Current best PPL: $BEST_PPL"
echo "========================================"
# Save current state
PREV_COMMIT=$(git rev-parse HEAD)
# Prompt the agent to make a change
# In production, this would call the LLM agent
# For now, we just run with current config
echo "Running training..."
# Run training
START_TIME=$(date +%s)
python3 train.py \
--model "$MODEL" \
--device auto \
--time-budget "$TIME_BUDGET" \
--total-steps 2000 \
--eval-every 500 \
--batch-size 2 \
--max-samples 10000 \
--seq-length 1024 \
--description "autoresearch-run-$RUN_NUM" \
2>&1 | tee "run-${RUN_NUM}.log" || true
END_TIME=$(date +%s)
ELAPSED=$((END_TIME - START_TIME))
# Check results
STATUS=$(get_last_status)
NEW_PPL=$(get_best_ppl)
echo ""
echo "Run #$RUN_NUM completed in ${ELAPSED}s"
echo "Status: $STATUS"
echo "Best PPL: $NEW_PPL"
if [ "$STATUS" = "success" ]; then
# Compare with previous best
if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then
echo "IMPROVED! Keeping changes."
BEST_PPL=$NEW_PPL
git add results.tsv
git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL"
else
echo "No improvement. Reverting."
git reset --hard $PREV_COMMIT 2>/dev/null || true
git checkout -- results.tsv 2>/dev/null || true
fi
else
echo "FAILED. Reverting."
git reset --hard $PREV_COMMIT 2>/dev/null || true
git checkout -- results.tsv 2>/dev/null || true
fi
echo ""
echo "Continuing... (Ctrl+C to stop)"
done
+324
View File
@@ -0,0 +1,324 @@
#!/usr/bin/env python3
"""
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
This file is NOT modified by the autoresearch agent. All mutable training logic
lives in train.py. This file provides:
- Dataset loading and tokenization
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
- Utility functions for the training loop
"""
import os
import math
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
# Alternative tiny models for ultra-fast iteration:
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
# Default eval dataset
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
# Default training dataset
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
# Sequence length for training
DEFAULT_SEQ_LENGTH = 2048
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
"""Load tokenizer for the given model."""
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
# ---------------------------------------------------------------------------
# Dataset Wrappers
# ---------------------------------------------------------------------------
class TokenizedDataset(Dataset):
"""Simple dataset that yields batches of tokenized text."""
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
self.tokenizer = tokenizer
self.seq_length = seq_length
# Concatenate all texts and chunk into sequences
self.samples = self._prepare_samples(data)
def _prepare_samples(self, data: list) -> list:
"""Flatten and chunk text data into fixed-length sequences."""
all_tokens = []
for item in data:
text = item.get("text", item.get("content", str(item)))
tokens = self.tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
all_tokens.append(self.tokenizer.eos_token_id)
# Chunk into sequences of seq_length
samples = []
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
samples.append(all_tokens[i : i + self.seq_length])
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x = torch.tensor(self.samples[idx], dtype=torch.long)
return {"input_ids": x[:-1], "labels": x[1:]}
def load_eval_dataset(
dataset_name: str = DEFAULT_EVAL_DATASET,
config: str = DEFAULT_EVAL_CONFIG,
split: str = "validation",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
):
"""Load and tokenize the evaluation dataset."""
if dataset_name == "wikitext":
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
elif dataset_name == "ptb":
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
else:
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
if tokenizer is None:
tokenizer = load_tokenizer()
# For eval, we want contiguous text without chunking artifacts
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
all_tokens = []
for text in texts:
tokens = tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
# Remove empty tokens
all_tokens = [t for t in all_tokens if t is not None]
# Create sequences
samples = []
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
samples.append(all_tokens[i : i + seq_length])
return samples, tokenizer
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_ppl(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
"""
Evaluate perplexity on the given samples.
Returns perplexity (lower is better).
"""
model.eval()
total_loss = 0.0
total_tokens = 0
for i in range(0, len(eval_samples), batch_size):
batch_samples = eval_samples[i : i + batch_size]
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
labels = inputs.clone()
# Shift for causal LM
input_ids = inputs[:, :-1]
label_ids = labels[:, 1:]
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
if device_type != "cpu":
with torch.amp.autocast(device_type=device_type):
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
else:
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
loss = outputs.loss
total_loss += loss.item() * label_ids.size(0)
total_tokens += label_ids.size(0)
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
model.train()
return perplexity
@torch.no_grad()
def evaluate_bpb(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
"""
Evaluate bits-per-byte (bpb) on the given samples.
Returns bpb (lower is better).
"""
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
bpb = math.log2(ppl)
return bpb
# ---------------------------------------------------------------------------
# Training Dataset Loader
# ---------------------------------------------------------------------------
def load_train_dataset(
dataset_name: str = DEFAULT_TRAIN_DATASET,
split: str = "train",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
max_samples: int = None,
streaming: bool = True,
):
"""
Load and prepare the training dataset.
Returns a TokenizedDataset.
"""
if tokenizer is None:
tokenizer = load_tokenizer()
if dataset_name == "HuggingFaceFW/fineweb-edu":
# Use the edu-only subset for quality
ds = load_dataset(
dataset_name,
name="sample-100BT", # 100B token subset for faster iteration
split=split,
streaming=streaming,
trust_remote_code=True,
)
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
else:
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
# Collect texts (handle streaming by taking first N samples)
texts = []
for idx, item in enumerate(ds):
if max_samples and idx >= max_samples:
break
text = item.get("text", item.get("content", str(item)))
texts.append(text)
if not texts:
raise ValueError(f"No texts loaded from dataset {dataset_name}")
dataset = TokenizedDataset(texts, tokenizer, seq_length)
return dataset
# ---------------------------------------------------------------------------
# Model Loading
# ---------------------------------------------------------------------------
def load_model(
model_name: str = DEFAULT_MODEL,
device: str = "cpu",
dtype: torch.dtype = torch.float16,
attn_implementation: str = "sdpa",
) -> AutoModelForCausalLM:
"""
Load a pretrained model from HuggingFace.
"""
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
attn_implementation=attn_implementation,
trust_remote_code=True,
device_map=device if device != "cpu" else None,
)
if device == "cpu":
model = model.to(device)
return model
# ---------------------------------------------------------------------------
# CLI Entry Point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
parser.add_argument("--device", default="cpu", help="Device to run on")
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
args = parser.parse_args()
if args.dry_run:
print(f"Model: {args.model}")
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
print(f"Seq length: {args.seq_length}")
print(f"Batch size: {args.batch_size}")
print(f"Device: {args.device}")
print(f"Train samples: {args.train_samples}")
exit(0)
print(f"Loading tokenizer for {args.model}...")
tokenizer = load_tokenizer(args.model)
print(f"Loading eval dataset {args.eval_dataset}...")
eval_samples, _ = load_eval_dataset(
dataset_name=args.eval_dataset,
config=args.eval_config,
tokenizer=tokenizer,
seq_length=args.seq_length,
)
print(f" Eval samples: {len(eval_samples)}")
print(f"Loading training dataset ({args.train_samples} samples)...")
train_dataset = load_train_dataset(
tokenizer=tokenizer,
seq_length=args.seq_length,
max_samples=args.train_samples,
streaming=True,
)
print(f" Train samples: {len(train_dataset)}")
# Optional: load model and run eval
if args.device != "skip-model":
print(f"Loading model {args.model}...")
model = load_model(args.model, device=args.device)
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
print(f"Evaluating perplexity...")
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
print(f" Perplexity: {ppl:.2f}")
bpb = math.log2(ppl)
print(f" BPB: {bpb:.4f}")
+78
View File
@@ -0,0 +1,78 @@
# program.md — Instructions for the Autoresearch Agent
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
## Your Goal
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2.
## File Boundaries
- **MUTABLE**: `train.py` — You may modify this file. It contains:
- `BitLinear` layer (quantization logic)
- Quantization schedules (lambda warmup)
- Training loop
- Hyperparameters (LR, batch size, group size, etc.)
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains:
- Dataset loading and tokenization
- Evaluation harness (WikiText PPL)
- Model loading utilities
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run.
## Experiment Protocol
1. Read `results.tsv` to understand what has been tried
2. Read `train.py` to understand current implementation
3. Propose ONE focused change to `train.py`
4. The change will be committed and a training run will execute (~5 minutes)
5. Results are logged to `results.tsv`
6. If improved (lower val_ppl) → change is kept
7. If equal or worse → git reset to previous commit
## What to Explore
### Priority 1: Quantization Schedule
- Lambda warmup shape: linear, cosine, exponential
- Warmup step counts: 200, 500, 1000, 2000, 5000
- Two-phase warmup (fast initial + slow final)
### Priority 2: Learning Rate
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
- LR schedule: constant, linear decay, cosine decay
### Priority 3: Quantization Details
- Group size: 32, 64, 128, 256, per-tensor
- Scale initialization: mean-based vs absmax-based
- Ternary threshold adjustments
### Priority 4: Deadzone Recovery
- Tequila-style reactivation (learnable lambda for deadzone weights)
- Bias injection for zero-valued weights
- Gradient scaling for deadzone weights
### Priority 5: Distillation
- OFF loss (cosine similarity between FP and ternary features)
- Logits distillation weight
- Feature distillation weight
## Constraints
- Keep experiments focused — ONE change per iteration
- Always maintain working code — syntax errors waste time
- Use SmolLM-135M or Llama-3.2-1B for fast iteration
- Target metric: val_ppl (lower is better)
- Time budget: 5 minutes per experiment
## Important Notes
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization
- Warmup quantization prevents catastrophic accuracy loss at the start of training
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
- Per-group quantization scales are essential for handling outlier weights
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
## NEVER STOP
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision.
+1
View File
@@ -0,0 +1 @@
timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
1 timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
+664
View File
@@ -0,0 +1,664 @@
#!/usr/bin/env python3
"""
train.py — Ternary quantization training loop (MUTABLE in autoresearch loop).
This is the single mutable file for the autoresearch agent. The agent may modify:
- Quantization schedule (lambda warmup shape, length)
- Learning rate and schedule
- BitLinear implementation details (group size, deadzone recovery)
- Loss function (distillation weights, OFF loss)
- Training hyperparameters
DO NOT modify prepare.py — all data loading and evaluation lives there.
"""
import os
import sys
import math
import time
import json
import csv
import git
import random
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
# Import from prepare.py (read-only)
from prepare import (
load_tokenizer,
load_train_dataset,
load_eval_dataset,
load_model,
evaluate_ppl,
evaluate_bpb,
DEFAULT_MODEL,
DEFAULT_SEQ_LENGTH,
DEFAULT_TRAIN_DATASET,
DEFAULT_EVAL_DATASET,
DEFAULT_EVAL_CONFIG,
)
# ---------------------------------------------------------------------------
# Quantization Utilities
# ---------------------------------------------------------------------------
def ternary_quantize_weight(w: torch.Tensor, group_size: int = 128) -> tuple:
"""
Ternary quantize weights to {-1, 0, +1} with per-group scales.
Quantization formula (from HF blog):
scale = 1.0 / w.abs().mean()
w_scaled = w * scale
w_clamped = clamp(w_scaled, -1, 1)
w_rounded = round(w_clamped) # -> {-1, 0, +1}
Returns:
w_quant: ternary quantized weights
scales: per-group scales for dequantization
"""
original_shape = w.shape
# Flatten for group-wise processing
w_flat = w.flatten().float()
num_groups = (len(w_flat) + group_size - 1) // group_size
scales = []
quant_parts = []
for i in range(num_groups):
start = i * group_size
end = min(start + group_size, len(w_flat))
group = w_flat[start:end]
# Compute scale: 1 / mean(|w|)
scale = 1.0 / (group.abs().mean() + 1e-6)
scales.append(scale)
# Scale, clamp, round
w_scaled = group * scale
w_clamped = w_scaled.clamp(-1.0, 1.0)
w_rounded = w_clamped.round().sign() # -> {-1, 0, +1}
quant_parts.append(w_rounded)
w_quant = torch.cat(quant_parts, dim=0).reshape(original_shape)
scales = torch.tensor(scales, device=w.device, dtype=w.dtype)
return w_quant, scales
def ternary_dequantize(w_quant: torch.Tensor, scales: torch.Tensor, group_size: int = 128) -> torch.Tensor:
"""Dequantize ternary weights back to original shape."""
original_shape = w_quant.shape
w_flat = w_quant.flatten().float()
num_groups = (len(w_flat) + group_size - 1) // group_size
dequant_parts = []
for i in range(num_groups):
start = i * group_size
end = min(start + group_size, len(w_flat))
group = w_flat[start:end]
scale = scales[i]
dequant_parts.append(group / (scale + 1e-6))
return torch.cat(dequant_parts, dim=0).reshape(original_shape).to(w_quant.dtype)
# ---------------------------------------------------------------------------
# BitLinear Layer
# ---------------------------------------------------------------------------
class BitLinear(nn.Module):
"""
Linear layer with ternary weight quantization using STE (Straight-Through Estimator).
Forward pass uses quantized weights; backward pass uses STE to pass gradients
through the quantization operation.
Supports warmup quantization: gradually introduce quantization during training
using lambda scheduling to avoid catastrophic accuracy loss.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
group_size: int = 128,
init_scale: float = 1.0,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
self.init_scale = init_scale
# FP16 weights (the "learned" weights)
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.bias = None
# Quantization state
self.register_buffer("w_quant", torch.zeros(out_features, in_features))
self.register_buffer("scales", torch.zeros((out_features * in_features + group_size - 1) // group_size))
self.register_buffer("lambda", torch.tensor(0.0)) # warmup lambda
def _quantize(self, w: torch.Tensor) -> tuple:
"""Quantize weights to ternary."""
return ternary_quantize_weight(w, self.group_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with warmup quantization.
Lambda is set externally via self.lambda_val before each forward pass.
Quantization warmup formula:
w_eff = w + lambda * (w_dequant - w).detach()
When lambda=0: output uses full-precision weights
When lambda=1: output uses quantized weights (with STE gradients)
"""
# Get lambda (default 1.0 if not set)
lambda_ = getattr(self, 'lambda_val', 1.0)
# Quantize weights
w_quant, scales = self._quantize(self.weight)
# Dequantize for computation
w_dequant = ternary_dequantize(w_quant, scales, self.group_size)
# STE with warmup: interpolate between FP and quantized
# w_eff = w + lambda * (w_dequant - w)
# Gradient flows through w (the FP weights), not through quantization
if lambda_ > 0:
w_eff = self.weight + lambda_ * (w_dequant - self.weight).detach()
else:
w_eff = self.weight
# Linear forward
x = x.to(w_eff.dtype)
out = F.linear(x, w_eff, self.bias)
return out
def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, "
f"out_features={self.out_features}, "
f"group_size={self.group_size}, "
f"bias={self.bias is not None}"
)
# ---------------------------------------------------------------------------
# Model Conversion
# ---------------------------------------------------------------------------
def replace_linears_with_bitlinear(
model: AutoModelForCausalLM,
group_size: int = 128,
target_modules: list = None,
) -> AutoModelForCausalLM:
"""
Replace all nn.Linear layers in the model with BitLinear layers.
Args:
model: pretrained model
group_size: quantization group size
target_modules: list of module names to replace. If None, replace all.
Returns:
Model with BitLinear layers
"""
if target_modules is None:
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
for name, module in model.named_modules():
parent_name = ".".join(name.split(".")[:-1])
module_name = name.split(".")[-1]
if isinstance(module, nn.Linear) and module_name in target_modules:
bit_linear = BitLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
group_size=group_size,
)
# Initialize from pretrained weights
bit_linear.weight.data = module.weight.data.clone()
if module.bias is not None:
bit_linear.bias.data = module.bias.data.clone()
# Replace in parent
parent = model
for attr in parent_name.split("."):
parent = getattr(parent, attr)
setattr(parent, module_name, bit_linear)
return model
def get_bitlinear_modules(model: AutoModelForCausalLM) -> list:
"""Get all BitLinear modules in the model."""
return [m for m in model.modules() if isinstance(m, BitLinear)]
# ---------------------------------------------------------------------------
# Lambda Warmup Schedules
# ---------------------------------------------------------------------------
def linear_warmup_lambda(step: int, warmup_steps: int) -> float:
"""Linear warmup: lambda goes from 0 to 1 over warmup_steps."""
return min(step / warmup_steps, 1.0)
def cosine_warmup_lambda(step: int, warmup_steps: int) -> float:
"""Cosine warmup: lambda follows cosine schedule over warmup_steps."""
if step >= warmup_steps:
return 1.0
return 0.5 * (1 - math.cos(math.pi * step / warmup_steps))
def exponential_warmup_lambda(step: int, warmup_steps: int) -> float:
"""Exponential warmup: lambda grows exponentially over warmup_steps."""
if step >= warmup_steps:
return 1.0
# Exponential from 0 to 1
return 1.0 - math.exp(-3.0 * step / warmup_steps)
def get_lambda(step: int, warmup_steps: int, schedule: str = "linear") -> float:
"""Get lambda value for the given step and schedule."""
if schedule == "linear":
return linear_warmup_lambda(step, warmup_steps)
elif schedule == "cosine":
return cosine_warmup_lambda(step, warmup_steps)
elif schedule == "exponential":
return exponential_warmup_lambda(step, warmup_steps)
else:
raise ValueError(f"Unknown lambda schedule: {schedule}")
# ---------------------------------------------------------------------------
# Training Loop
# ---------------------------------------------------------------------------
def train(
model_name: str = DEFAULT_MODEL,
train_dataset_name: str = DEFAULT_TRAIN_DATASET,
eval_dataset_name: str = DEFAULT_EVAL_DATASET,
eval_config: str = DEFAULT_EVAL_CONFIG,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
learning_rate: float = 1e-4,
warmup_steps: int = 1000,
total_steps: int = 5000,
eval_every: int = 500,
lambda_schedule: str = "linear",
group_size: int = 128,
max_train_samples: int = 50000,
device: str = "cpu",
dtype: torch.dtype = torch.float16,
time_budget: float = 300.0, # seconds (5 minutes default)
seed: int = 42,
output_dir: str = "./results",
):
"""
Main training loop for ternary quantization-aware training.
Args:
model_name: HuggingFace model name or path
train_dataset_name: Training dataset name
eval_dataset_name: Evaluation dataset name
eval_config: Evaluation dataset config
seq_length: Sequence length for training
batch_size: Training batch size
learning_rate: Base learning rate
warmup_steps: Number of steps for lambda warmup
total_steps: Total training steps
eval_every: Evaluate every N steps
lambda_schedule: Lambda warmup schedule ("linear", "cosine", "exponential")
group_size: Quantization group size
max_train_samples: Max training samples to load
device: Device to train on
dtype: Data type
time_budget: Max training time in seconds
seed: Random seed
output_dir: Output directory for results
Returns:
dict with training metrics
"""
# Auto-detect device if not specified
if device == "auto":
if torch.backends.mps.is_available():
device = "mps"
print("Using MPS (Apple Silicon)")
elif torch.cuda.is_available():
device = "cuda"
print("Using CUDA")
else:
device = "cpu"
print("Using CPU")
# Set seeds
random.seed(seed)
torch.manual_seed(seed)
if device == "cuda":
torch.cuda.manual_seed(seed)
# Track time
start_time = time.time()
print(f"=== Ternary QAT Training ===")
print(f"Model: {model_name}")
print(f"Device: {device}")
print(f"LR: {learning_rate}, Warmup: {warmup_steps} steps")
print(f"Lambda schedule: {lambda_schedule}")
print(f"Group size: {group_size}")
print(f"Time budget: {time_budget}s")
print(f"Total steps: {total_steps}")
print()
# Load tokenizer
print("Loading tokenizer...")
tokenizer = load_tokenizer(model_name)
# Load model
print(f"Loading model {model_name}...")
model = load_model(model_name, device=device, dtype=dtype, attn_implementation="sdpa")
model = model.to(dtype)
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f" Linear layers replaced: ", end="")
# Replace Linear with BitLinear
model = replace_linears_with_bitlinear(model, group_size=group_size)
bitlinears = get_bitlinear_modules(model)
print(f"{len(bitlinears)}")
# Load datasets
print(f"Loading training dataset ({max_train_samples} samples)...")
train_dataset = load_train_dataset(
dataset_name=train_dataset_name,
tokenizer=tokenizer,
seq_length=seq_length,
max_samples=max_train_samples,
streaming=True,
)
print(f" Train samples: {len(train_dataset)}")
print(f"Loading eval dataset...")
eval_samples, _ = load_eval_dataset(
dataset_name=eval_dataset_name,
config=eval_config,
tokenizer=tokenizer,
seq_length=seq_length,
)
print(f" Eval samples: {len(eval_samples)}")
# Initial evaluation
print("\n--- Initial Evaluation ---")
initial_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
initial_bpb = math.log2(initial_ppl)
print(f" Initial PPL: {initial_ppl:.2f}")
print(f" Initial BPB: {initial_bpb:.4f}")
# Optimizer
# Only optimize BitLinear weights and biases (not the quantization buffers)
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=learning_rate,
weight_decay=0.01,
)
# Learning rate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(total_steps * 0.1),
num_training_steps=total_steps,
)
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader_iter = iter(train_loader)
# Training loop
results = {
"initial_ppl": initial_ppl,
"initial_bpb": initial_bpb,
"final_ppl": None,
"final_bpb": None,
"best_ppl": initial_ppl,
"best_step": 0,
"steps_completed": 0,
"training_time": 0,
}
step = 0
train_loss = 0.0
steps_since_last_eval = 0
print(f"\n--- Training ---")
while step < total_steps:
# Check time budget
elapsed = time.time() - start_time
if elapsed >= time_budget - 30: # Reserve 30s for final eval
print(f"\n Time budget nearly reached ({elapsed:.0f}/{time_budget}s). Stopping.")
break
# Get lambda for this step
lambda_ = get_lambda(step, warmup_steps, lambda_schedule)
# Get batch
try:
batch = next(loader_iter)
except StopIteration:
loader_iter = iter(train_loader)
batch = next(loader_iter)
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
# Set lambda on all BitLinear layers for this step
for bl in get_bitlinear_modules(model):
bl.lambda_val = lambda_
# Forward pass
outputs = model(input_ids=input_ids, labels=labels, use_cache=False)
loss = outputs.loss
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
train_loss += loss.item()
step += 1
steps_since_last_eval += 1
# Periodic logging
if step % 100 == 0:
avg_loss = train_loss / 100
avg_ppl = math.exp(avg_loss)
elapsed = time.time() - start_time
print(f" Step {step}/{total_steps} | Loss: {avg_loss:.4f} | PPL: {avg_ppl:.2f} | λ: {lambda_:.3f} | LR: {lr_scheduler.get_last_lr()[0]:.2e} | Time: {elapsed:.0f}s")
train_loss = 0.0
# Evaluation
if steps_since_last_eval >= eval_every or step == total_steps:
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
bpb = math.log2(ppl)
elapsed = time.time() - start_time
print(f"\n [Eval] Step {step} | PPL: {ppl:.2f} | BPB: {bpb:.4f} | λ: {lambda_:.3f} | Time: {elapsed:.0f}s")
if ppl < results["best_ppl"]:
results["best_ppl"] = ppl
results["best_step"] = step
print(f" *** New best PPL! ***")
steps_since_last_eval = 0
# Final evaluation
print(f"\n--- Final Evaluation ---")
final_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
final_bpb = math.log2(final_ppl)
results["final_ppl"] = final_ppl
results["final_bpb"] = final_bpb
results["steps_completed"] = step
results["training_time"] = time.time() - start_time
print(f" Final PPL: {final_ppl:.2f}")
print(f" Final BPB: {final_bpb:.4f}")
print(f" Best PPL: {results['best_ppl']:.2f} (step {results['best_step']})")
print(f" Steps completed: {step}")
print(f" Training time: {results['training_time']:.0f}s")
return results
# ---------------------------------------------------------------------------
# Results Logging
# ---------------------------------------------------------------------------
def log_result(results: dict, config: dict, results_file: str = "results.tsv"):
"""Log results to TSV file."""
file_exists = os.path.exists(results_file)
with open(results_file, "a") as f:
if not file_exists:
# Write header
headers = ["timestamp", "commit", "status", "val_ppl", "val_bpb",
"initial_ppl", "best_ppl", "best_step", "steps", "time_s",
"model", "lr", "warmup_steps", "total_steps", "lambda_schedule",
"group_size", "batch_size", "seq_length", "description"]
f.write("\t".join(headers) + "\n")
row = [
time.strftime("%Y-%m-%d %H:%M:%S"),
config.get("commit", ""),
config.get("status", "success"),
f"{results.get('final_ppl', 0):.4f}",
f"{results.get('final_bpb', 0):.4f}",
f"{results.get('initial_ppl', 0):.4f}",
f"{results.get('best_ppl', 0):.4f}",
str(results.get('best_step', 0)),
str(results.get('steps_completed', 0)),
f"{results.get('training_time', 0):.1f}",
config.get("model", ""),
str(config.get("lr", "")),
str(config.get("warmup_steps", "")),
str(config.get("total_steps", "")),
config.get("lambda_schedule", ""),
str(config.get("group_size", "")),
str(config.get("batch_size", "")),
str(config.get("seq_length", "")),
config.get("description", ""),
]
f.write("\t".join(row) + "\n")
print(f"\nResult logged to {results_file}")
def get_git_commit() -> str:
"""Get current git commit hash."""
try:
repo = git.Repo(search_parent_directories=True)
return repo.head.object.hexsha[:8]
except:
return "nogit"
# ---------------------------------------------------------------------------
# CLI Entry Point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Ternary QAT Training")
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name")
parser.add_argument("--train-dataset", default=DEFAULT_TRAIN_DATASET, help="Training dataset")
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Eval dataset")
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Eval dataset config")
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--warmup-steps", type=int, default=1000, help="Lambda warmup steps")
parser.add_argument("--total-steps", type=int, default=5000, help="Total training steps")
parser.add_argument("--eval-every", type=int, default=500, help="Eval every N steps")
parser.add_argument("--lambda-schedule", default="linear", choices=["linear", "cosine", "exponential"], help="Lambda schedule")
parser.add_argument("--group-size", type=int, default=128, help="Quantization group size")
parser.add_argument("--max-samples", type=int, default=50000, help="Max training samples")
parser.add_argument("--device", default="auto", help="Device (auto=cuda/mps/cpu)")
parser.add_argument("--time-budget", type=float, default=300.0, help="Time budget in seconds")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--description", default="", help="Experiment description")
parser.add_argument("--dry-run", action="store_true", help="Print config without training")
args = parser.parse_args()
config = {
"commit": get_git_commit(),
"model": args.model,
"lr": args.lr,
"warmup_steps": args.warmup_steps,
"total_steps": args.total_steps,
"lambda_schedule": args.lambda_schedule,
"group_size": args.group_size,
"batch_size": args.batch_size,
"seq_length": args.seq_length,
"description": args.description or f"λ={args.lambda_schedule}-{args.warmup_steps} lr={args.lr} gs={args.group_size}",
}
if args.dry_run:
print("=== Configuration (dry run) ===")
for k, v in config.items():
print(f" {k}: {v}")
print(f" time_budget: {args.time_budget}s")
print(f" device: {args.device}")
exit(0)
try:
results = train(
model_name=args.model,
train_dataset_name=args.train_dataset,
eval_dataset_name=args.eval_dataset,
eval_config=args.eval_config,
seq_length=args.seq_length,
batch_size=args.batch_size,
learning_rate=args.lr,
warmup_steps=args.warmup_steps,
total_steps=args.total_steps,
eval_every=args.eval_every,
lambda_schedule=args.lambda_schedule,
group_size=args.group_size,
max_train_samples=args.max_samples,
device=args.device,
time_budget=args.time_budget,
seed=args.seed,
)
config["status"] = "success"
log_result(results, config)
except Exception as e:
print(f"\n!!! Training failed: {e}")
import traceback
traceback.print_exc()
config["status"] = "failed"
config["error"] = str(e)
log_result({}, config)
sys.exit(1)