Files
ternary_tests/train.py
T

729 lines
26 KiB
Python

"""
Ternary quantization training loop — MUTABLE in the autoresearch loop.
This file implements:
- BitLinear: ternary quantized linear layer with STE + lambda warmup
- Training loop on TinyStories with WikiText eval
- Debug output for monitoring in tmux
Usage:
python train.py # default config
python train.py --steps 2000 --lr 1e-4 # custom hyperparams
python train.py --group-size 64 --warmup 500 # quantization params
"""
import argparse
import json
import math
import os
import sys
import time
from pathlib import Path
import torch
import torch.nn as nn
from torch.nn import functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# ---------------------------------------------------------------------------
# Configuration defaults (tune these, or pass via CLI)
# ---------------------------------------------------------------------------
DEFAULTS = dict(
model_name="HuggingFaceTB/SmolLM-135M",
device="mps" if torch.backends.mps.is_available() else "cpu",
dtype=torch.float32,
# Training
steps=1000,
batch_size=4,
seq_len=256,
learning_rate=1e-5,
warmup_steps=100,
eval_every=50,
# Quantization
group_size=128, # Bonsai uses 128
quant_warmup_steps=2000, # lambda warmup over N steps
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
threshold=0.5, # deadzone threshold (Bonsai: 0.5)
soft_quant=False, # use tanh proxy for smooth gradients
warmup_schedule="plateau", # linear or plateau
plateau_steps=500, # steps per plateau level
plateau_max=0.8, # max lambda for plateau warmup
# Data
train_dataset="roneneldan/TinyStories",
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
# Logging
log_file=str(Path(__file__).parent / "results.tsv"),
)
# ---------------------------------------------------------------------------
# Ternary Quantization Primitives
# ---------------------------------------------------------------------------
def ternary_quantize(w, group_size=128, threshold=0.5):
"""Quantize weights to {-1, 0, +1} with per-group abs_mean scale.
Groups are formed by flattening the weight tensor and taking consecutive
chunks of `group_size` elements. The last group may be smaller.
Args:
w: weight tensor of any shape
group_size: number of weights per quantization group
threshold: deadzone threshold (0 < t < 1). Weights with |w_norm| < t are zeroed.
Returns:
w_quant: ternary weights {-1, 0, +1} in original shape
scale: per-group scale factors, shape (num_groups, 1)
"""
original_shape = w.shape
w_flat = w.reshape(-1)
# Pad to multiple of group_size if needed
n = w_flat.numel()
pad = (group_size - n % group_size) % group_size
if pad > 0:
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
w_groups = w_flat.reshape(-1, group_size)
# Scale: mean(|w|) per group (balanced ternary distribution)
abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6)
scale = abs_mean
# Normalize, clamp, round to nearest ternary
w_norm = w_groups / scale
w_clamped = w_norm.clamp(-1.0, 1.0)
if threshold > 0:
w_quant = torch.where(w_clamped.abs() < threshold,
torch.zeros_like(w_clamped),
torch.sign(w_clamped))
else:
w_quant = torch.round(w_clamped)
# Reshape back to original (trim padding)
w_quant = w_quant.reshape(-1)[:n].reshape(original_shape)
return w_quant, scale
def soft_ternary(w, group_size=128, temperature=0.05, threshold=0.5):
"""Soft ternary quantization with differentiable tanh proxy.
Uses tanh(w / temperature) as a smooth approximation to ternary {-1, 0, +1}.
This provides smooth gradients for weight optimization.
Args:
w: weight tensor of any shape
group_size: number of weights per quantization group
temperature: controls sharpness (lower = closer to hard ternary)
threshold: deadzone threshold for soft zero region
Returns:
w_soft: soft-quantized weights in original shape
scale: per-group scale factors, shape (num_groups, 1)
"""
original_shape = w.shape
w_flat = w.reshape(-1)
# Pad to multiple of group_size if needed
n = w_flat.numel()
pad = (group_size - n % group_size) % group_size
if pad > 0:
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
w_groups = w_flat.reshape(-1, group_size)
# Scale: mean(|w|) per group
abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6)
scale = abs_mean
# Normalize, apply tanh for soft ternary
w_norm = w_groups / scale
w_clamped = w_norm.clamp(-1.0, 1.0)
if threshold > 0:
# Soft deadzone: shrink values near zero
w_deadzone = w_clamped * (1.0 - threshold / (w_clamped.abs() + threshold))
else:
w_deadzone = w_clamped
w_soft = torch.tanh(w_deadzone / temperature) # (-1, 1) smooth
# Reshape back to original (trim padding)
w_soft = w_soft.reshape(-1)[:n].reshape(original_shape)
return w_soft, scale
def ternary_quantize_learnable(w, scale, group_size=128, threshold=0.5):
"""Quantize weights to {-1, 0, +1} using a learnable per-group scale.
The scale is a learnable nn.Parameter, allowing the network to optimize
the quantization scale during training.
Uses STE: forward pass uses ternary weights, backward pass gradients
flow through the continuous weights.
Args:
w: weight tensor of any shape
scale: learnable per-group scale, shape (num_groups,)
group_size: number of weights per quantization group
threshold: deadzone threshold (0 < t < 1)
Returns:
w_dequant: dequantized weights in original shape (for forward pass)
"""
original_shape = w.shape
w_flat = w.reshape(-1)
n = w_flat.numel()
pad = (group_size - n % group_size) % group_size
if pad > 0:
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
w_groups = w_flat.reshape(-1, group_size)
# Normalize using learnable scale
scale = scale.to(w.device) # ensure scale is on same device as weights
scale_expanded = scale.unsqueeze(-1).expand(-1, group_size)
w_norm = w_groups / scale_expanded.clamp(min=1e-6)
w_clamped = w_norm.clamp(-1.0, 1.0)
if threshold > 0:
w_quant = torch.where(w_clamped.abs() < threshold,
torch.zeros_like(w_clamped),
torch.sign(w_clamped))
else:
w_quant = torch.round(w_clamped)
# Dequantize: w_dequant = w_quant * scale
w_dequant = w_quant * scale_expanded
# Reshape back to original (trim padding)
w_dequant = w_dequant.reshape(-1)[:n].reshape(original_shape)
# STE: w_dequant = w + (w_dequant - w).detach()
# Gradients flow through w, not through the quantization
w_dequant = w + (w_dequant - w).detach()
return w_dequant
def ternary_dequantize(w_quant, scale, group_size=128):
"""Reconstruct weights from ternary codes and scales.
Groups are formed the same way as in ternary_quantize.
"""
original_shape = w_quant.shape
w_flat = w_quant.reshape(-1)
n = w_flat.numel()
pad = (group_size - n % group_size) % group_size
if pad > 0:
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
w_groups = w_flat.reshape(-1, group_size)
w_dequant = w_groups * scale
return w_dequant.reshape(-1)[:n].reshape(original_shape)
def activation_quantize(x, bits=8):
"""Quantize activations to INT8 with per-token absmax scaling.
Args:
x: activation tensor
bits: number of bits (default 8)
Returns:
x_quant: quantized then dequantized activations (same dtype as input)
"""
if bits == 16:
return x # no quantization
max_val = x.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
max_q = 2 ** (bits - 1) - 1 # 127 for INT8
scale = max_val / max_q
x_q = (x / scale).round().clamp(-max_q, max_q)
return x_q * scale
# ---------------------------------------------------------------------------
# BitLinear — Ternary Quantized Linear Layer
# ---------------------------------------------------------------------------
class BitLinear(nn.Module):
"""Linear layer with ternary quantized weights and INT8 activations.
Uses straight-through estimator (STE) for gradient flow through quantization.
Lambda warmup gradually introduces quantization during training.
The forward pass:
1. Quantize weights to ternary
2. (Optional) Quantize activations to INT8
3. Forward pass with quantized values
4. STE: gradients flow through unquantized weights
5. Lambda warmup: blend between FP and quantized forward
"""
def __init__(self, in_features, out_features, bias=True, group_size=128,
activation_bits=8, threshold=0.5, soft_quant=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
self.activation_bits = activation_bits
self.threshold = threshold
self.soft_quant = soft_quant # use tanh proxy for smooth gradients
# FP32 weights (learned in full precision, quantized at forward time)
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features, dtype=DEFAULTS['dtype']))
else:
self.bias = None
# Lambda for warmup (set externally or via training loop)
self.register_buffer('lambda_', torch.tensor(0.0))
def forward(self, x):
lambda_ = self.lambda_.item()
if lambda_ <= 0.0:
# No quantization — standard linear
out = F.linear(x, self.weight, self.bias)
return out
# Quantize weights (STE: gradients flow through FP weights)
w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size,
threshold=self.threshold)
w_dequant = ternary_dequantize(w_quant, scale, self.group_size)
# Quantize activations (optional)
if self.activation_bits < 16:
x_q = activation_quantize(x, self.activation_bits)
else:
x_q = x
# Quantized forward
out_quant = F.linear(x_q, w_dequant, self.bias)
# Straight-through estimator with lambda warmup:
# out = out_fp + lambda * (out_quant - out_fp).detach()
# When lambda=0: pure FP forward (no quantization)
# When lambda=1: quantized forward, gradients through FP (full STE)
out_fp = F.linear(x, self.weight, self.bias)
out = out_fp + lambda_ * (out_quant - out_fp).detach()
return out
def extra_repr(self):
return (f"in_features={self.in_features}, out_features={self.out_features}, "
f"group_size={self.group_size}, activation_bits={self.activation_bits}")
# ---------------------------------------------------------------------------
# Model Surgery — Replace Linear with BitLinear
# ---------------------------------------------------------------------------
def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
exclude_embeddings=True, threshold=0.5,
soft_quant=False):
"""Replace all nn.Linear layers in model with BitLinear.
Args:
model: HuggingFace model
group_size: quantization group size
activation_bits: activation quantization bits (16 = no quant)
exclude_embeddings: don't replace lm_head/embedding (usually)
threshold: deadzone threshold for ternary quantization (0 < t < 1)
soft_quant: use tanh proxy for smooth gradients (vs hard ternary)
"""
count = 0
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Skip embedding layers if requested
if exclude_embeddings and ('embed_tokens' in name or 'lm_head' in name):
continue
# Create BitLinear with same dimensions and device
bit_linear = BitLinear(
module.in_features,
module.out_features,
bias=module.bias is not None,
group_size=group_size,
activation_bits=activation_bits,
threshold=threshold,
soft_quant=soft_quant,
)
# Initialize from FP weights (critical for warmup to work)
with torch.no_grad():
bit_linear.weight.data = module.weight.clone()
if module.bias is not None:
bit_linear.bias.data = module.bias.clone()
# Replace in parent module
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent = getattr(model, parent_name, None) if parent_name else model
# Need to walk to the actual parent
parts = name.split(".")
obj = model
for p in parts[:-1]:
obj = getattr(obj, p)
setattr(obj, parts[-1], bit_linear)
count += 1
return count
def set_lambda(model, value):
"""Set lambda on all BitLinear layers."""
for module in model.modules():
if isinstance(module, BitLinear):
module.lambda_.fill_(value)
def get_quant_stats(model):
"""Get quantization statistics across all BitLinear layers.
Returns dict with fraction of weights at {-1, 0, +1}.
"""
total = 0
count_neg1 = 0
count_zero = 0
count_pos1 = 0
for module in model.modules():
if isinstance(module, BitLinear):
w_q, _ = ternary_quantize(module.weight, group_size=module.group_size,
threshold=module.threshold)
n = w_q.numel()
total += n
count_neg1 += (w_q == -1).sum().item()
count_zero += (w_q == 0).sum().item()
count_pos1 += (w_q == 1).sum().item()
if total == 0:
return {"-1": 0, "0": 0, "+1": 0}
return {
"-1": count_neg1 / total,
"0": count_zero / total,
"+1": count_pos1 / total,
}
# ---------------------------------------------------------------------------
# Data Loading
# ---------------------------------------------------------------------------
def load_train_data(dataset_name="roneneldan/TinyStories", tokenizer=None,
seq_len=256, batch_size=4, device="cpu"):
"""Create a streaming dataloader for training.
Yields tensors of shape (batch_size, seq_len) from concatenated text.
"""
dataset = load_dataset(dataset_name, split="train", streaming=True)
def tokenize_and_batch():
buffer = []
target_size = seq_len * batch_size
for sample in dataset:
tokens = tokenizer(sample["text"], truncation=False)["input_ids"]
buffer.extend(tokens)
while len(buffer) >= target_size:
chunk = buffer[:target_size]
buffer = buffer[target_size:]
# Reshape to (batch_size, seq_len)
x = torch.tensor(chunk, dtype=torch.long, device=device).reshape(batch_size, seq_len)
yield x
return tokenize_and_batch()
def load_eval_data(eval_path):
"""Load pre-tokenized eval data."""
with open(eval_path, "r") as f:
input_ids = json.load(f)
return torch.tensor(input_ids, dtype=torch.long)
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate(model, eval_ids, tokenizer, seq_len=128, device="cpu"):
"""Compute perplexity on eval data."""
model.eval()
nll = 0.0
n_tokens = 0
# Move eval data to device in chunks
for i in range(0, len(eval_ids) - seq_len, seq_len):
chunk = eval_ids[i : i + seq_len + 1].to(device)
# Shape: (1, seq_len+1) for batch dim
chunk = chunk.unsqueeze(0)
logits = model(chunk[:, :-1]).logits # (1, seq_len, vocab)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
chunk[:, 1:].reshape(-1),
reduction="sum"
)
nll += loss.item()
n_tokens += seq_len
model.train()
avg_nll = nll / n_tokens
ppl = math.exp(avg_nll)
bpb = avg_nll / math.log(2) # bits per byte
return ppl, bpb, avg_nll
# ---------------------------------------------------------------------------
# Training Loop
# ---------------------------------------------------------------------------
def train(args):
"""Main training loop with ternary quantization."""
device = args.device
start_time = time.time()
# ---- Load model and tokenizer ----
print(f"\n{'='*60}")
print(f" TERNARY QUANTIZATION TRAINING")
print(f"{'='*60}")
print(f" Model: {args.model_name}")
print(f" Device: {device}")
print(f" Steps: {args.steps}")
print(f" Batch size: {args.batch_size}")
print(f" Seq length: {args.seq_len}")
print(f" LR: {args.learning_rate}")
print(f" Group size: {args.group_size}")
print(f" Act bits: {args.activation_bits}")
print(f" Warmup: {args.quant_warmup_steps} steps")
print(f"{'='*60}\n")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model... (this may take a moment)")
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=DEFAULTS['dtype'],
low_cpu_mem_usage=True,
)
model.to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {total_params:,} parameters")
# ---- Replace Linear with BitLinear ----
print(f"\nReplacing Linear layers with BitLinear (group_size={args.group_size})...")
n_replaced = replace_linears_with_bitlinear(
model,
group_size=args.group_size,
activation_bits=args.activation_bits,
threshold=args.threshold,
soft_quant=args.soft_quant,
)
print(f"Replaced {n_replaced} Linear layers with BitLinear")
# Count ternary params
ternary_params = sum(
p.numel() for m in model.modules()
if isinstance(m, BitLinear) for p in m.parameters()
)
print(f"Ternary-quantized parameters: {ternary_params:,}")
# ---- Load eval data ----
eval_ids = load_eval_data(args.eval_data_path)
print(f"Eval data: {len(eval_ids):,} tokens")
# ---- Optimizer ----
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
weight_decay=0.01,
)
# ---- Training loop ----
data_iter = load_train_data(
dataset_name=args.train_dataset,
tokenizer=tokenizer,
seq_len=args.seq_len,
batch_size=args.batch_size,
device=device,
)
best_ppl = float("inf")
step_times = []
# Open log file
log_path = Path(args.log_file)
# Check if we need to write header
write_header = not log_path.exists() or log_path.stat().st_size == 0
if write_header:
with open(log_path, "w") as f:
f.write("step\tlambda\ttrain_loss\ttrain_ppl\teval_ppl\teval_bpb\tlr\ttime_s\tbest_ppl\tq_neg1\tq_zero\tq_pos1\n")
print(f"\n{''*60}")
print(f" STEP LAMBDA LOSS PPL EVAL_PPL LR TIME")
print(f"{''*60}")
for step in range(1, args.steps + 1):
step_start = time.time()
# ---- Lambda warmup ----
if args.warmup_schedule == "plateau":
# Plateau warmup: hold lambda at each level for plateau_steps
levels = int(args.plateau_max / 0.05) # 0.05 increments
plateau_size = args.plateau_steps
level = min(step // plateau_size, levels)
lambda_ = min(level * 0.05, args.plateau_max)
else:
# Linear warmup
lambda_ = min(step / args.quant_warmup_steps, 1.0)
set_lambda(model, lambda_)
# ---- LR warmup ----
if step <= args.warmup_steps:
lr_scale = step / args.warmup_steps
else:
lr_scale = 1.0
for param_group in optimizer.param_groups:
param_group["lr"] = args.learning_rate * lr_scale
current_lr = args.learning_rate * lr_scale
# ---- Forward pass ----
model.train()
try:
batch = next(data_iter)
except StopIteration:
print("\n Dataset exhausted, restarting...")
data_iter = load_train_data(
dataset_name=args.train_dataset,
tokenizer=tokenizer,
seq_len=args.seq_len,
batch_size=args.batch_size,
device=device,
)
batch = next(data_iter)
# batch: (batch_size, seq_len)
# Shift for teacher forcing
x = batch[:, :-1]
y = batch[:, 1:]
logits = model(x).logits
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y.reshape(-1),
)
train_ppl = math.exp(loss.item())
# ---- Backward pass ----
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
step_time = time.time() - step_start
step_times.append(step_time)
avg_step_time = sum(step_times[-20:]) / len(step_times[-20:])
# ---- Eval ----
eval_ppl = "-"
eval_bpb = "-"
quant_stats = None
if step % args.eval_every == 0 or step == args.steps:
ppl, bpb, _ = evaluate(model, eval_ids, tokenizer,
seq_len=128, device=device)
eval_ppl = f"{ppl:.1f}"
eval_bpb = f"{bpb:.2f}"
if ppl < best_ppl:
best_ppl = ppl
# Quantization stats
quant_stats = get_quant_stats(model)
elapsed = time.time() - start_time
eta = avg_step_time * (args.steps - step)
q_str = f"[-1:{quant_stats['-1']:.2f} 0:{quant_stats['0']:.2f} +1:{quant_stats['+1']:.2f}]"
print(f" {step:5d} {lambda_:.3f} {loss.item():.4f} {train_ppl:7.1f} "
f"{eval_ppl:>8s} {current_lr:.2e} {elapsed:.0f}s (ETA {eta:.0f}s)")
print(f" eval_bpb={eval_bpb} best_ppl={best_ppl:.1f} "
f"step_time={avg_step_time:.2f}s quant={q_str}")
# Log to TSV
with open(log_path, "a") as f:
q = quant_stats or {"-1": 0, "0": 0, "+1": 0}
f.write(f"{step}\t{lambda_:.4f}\t{loss.item():.6f}\t{train_ppl:.2f}\t"
f"{ppl:.2f}\t{bpb:.4f}\t{current_lr:.2e}\t{elapsed:.1f}\t{best_ppl:.2f}\t"
f"{q['-1']:.4f}\t{q['0']:.4f}\t{q['+1']:.4f}\n")
else:
elapsed = time.time() - start_time
eta = avg_step_time * (args.steps - step)
if step % 25 == 0 or step == 1:
print(f" {step:5d} {lambda_:.3f} {loss.item():.4f} {train_ppl:7.1f} "
f"{'':>8s} {current_lr:.2e} {elapsed:.0f}s (ETA {eta:.0f}s)")
# Log to TSV (no eval)
with open(log_path, "a") as f:
f.write(f"{step}\t{lambda_:.4f}\t{loss.item():.6f}\t{train_ppl:.2f}\t"
f"-\t-\t{current_lr:.2e}\t{elapsed:.1f}\t{best_ppl:.2f}\t-\t-\t-\n")
# ---- Summary ----
total_time = time.time() - start_time
print(f"\n{'='*60}")
print(f" TRAINING COMPLETE")
print(f" Total time: {total_time:.0f}s ({total_time/60:.1f} min)")
print(f" Best eval PPL: {best_ppl:.1f}")
print(f" Final lambda: {lambda_:.3f}")
print(f" Avg step time: {sum(step_times)/len(step_times):.2f}s")
print(f" Results logged to: {log_path}")
print(f"{'='*60}\n")
return best_ppl
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Ternary quantization training")
parser.add_argument("--model-name", default=DEFAULTS["model_name"])
parser.add_argument("--device", default=DEFAULTS["device"])
parser.add_argument("--steps", type=int, default=DEFAULTS["steps"])
parser.add_argument("--batch-size", type=int, default=DEFAULTS["batch_size"])
parser.add_argument("--seq-len", type=int, default=DEFAULTS["seq_len"])
parser.add_argument("--lr", type=float, default=DEFAULTS["learning_rate"], dest="learning_rate")
parser.add_argument("--warmup-steps", type=int, default=DEFAULTS["warmup_steps"], dest="warmup_steps")
parser.add_argument("--eval-every", type=int, default=DEFAULTS["eval_every"], dest="eval_every")
parser.add_argument("--group-size", type=int, default=DEFAULTS["group_size"], dest="group_size")
parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps")
parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits")
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
parser.add_argument("--warmup-schedule", default=DEFAULTS["warmup_schedule"], dest="warmup_schedule")
parser.add_argument("--plateau-steps", type=int, default=DEFAULTS["plateau_steps"], dest="plateau_steps")
parser.add_argument("--plateau-max", type=float, default=DEFAULTS["plateau_max"], dest="plateau_max")
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main()