729 lines
26 KiB
Python
729 lines
26 KiB
Python
"""
|
|
Ternary quantization training loop — MUTABLE in the autoresearch loop.
|
|
|
|
This file implements:
|
|
- BitLinear: ternary quantized linear layer with STE + lambda warmup
|
|
- Training loop on TinyStories with WikiText eval
|
|
- Debug output for monitoring in tmux
|
|
|
|
Usage:
|
|
python train.py # default config
|
|
python train.py --steps 2000 --lr 1e-4 # custom hyperparams
|
|
python train.py --group-size 64 --warmup 500 # quantization params
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration defaults (tune these, or pass via CLI)
|
|
# ---------------------------------------------------------------------------
|
|
DEFAULTS = dict(
|
|
model_name="HuggingFaceTB/SmolLM-135M",
|
|
device="mps" if torch.backends.mps.is_available() else "cpu",
|
|
dtype=torch.float32,
|
|
# Training
|
|
steps=1000,
|
|
batch_size=4,
|
|
seq_len=256,
|
|
learning_rate=1e-5,
|
|
warmup_steps=100,
|
|
eval_every=50,
|
|
# Quantization
|
|
group_size=128, # Bonsai uses 128
|
|
quant_warmup_steps=2000, # lambda warmup over N steps
|
|
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
|
|
threshold=0.5, # deadzone threshold (Bonsai: 0.5)
|
|
soft_quant=False, # use tanh proxy for smooth gradients
|
|
warmup_schedule="plateau", # linear or plateau
|
|
plateau_steps=500, # steps per plateau level
|
|
plateau_max=0.8, # max lambda for plateau warmup
|
|
# Data
|
|
train_dataset="roneneldan/TinyStories",
|
|
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
|
|
# Logging
|
|
log_file=str(Path(__file__).parent / "results.tsv"),
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Ternary Quantization Primitives
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def ternary_quantize(w, group_size=128, threshold=0.5):
|
|
"""Quantize weights to {-1, 0, +1} with per-group abs_mean scale.
|
|
|
|
Groups are formed by flattening the weight tensor and taking consecutive
|
|
chunks of `group_size` elements. The last group may be smaller.
|
|
|
|
Args:
|
|
w: weight tensor of any shape
|
|
group_size: number of weights per quantization group
|
|
threshold: deadzone threshold (0 < t < 1). Weights with |w_norm| < t are zeroed.
|
|
|
|
Returns:
|
|
w_quant: ternary weights {-1, 0, +1} in original shape
|
|
scale: per-group scale factors, shape (num_groups, 1)
|
|
"""
|
|
original_shape = w.shape
|
|
w_flat = w.reshape(-1)
|
|
|
|
# Pad to multiple of group_size if needed
|
|
n = w_flat.numel()
|
|
pad = (group_size - n % group_size) % group_size
|
|
if pad > 0:
|
|
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
|
|
|
w_groups = w_flat.reshape(-1, group_size)
|
|
|
|
# Scale: mean(|w|) per group (balanced ternary distribution)
|
|
abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6)
|
|
scale = abs_mean
|
|
|
|
# Normalize, clamp, round to nearest ternary
|
|
w_norm = w_groups / scale
|
|
w_clamped = w_norm.clamp(-1.0, 1.0)
|
|
if threshold > 0:
|
|
w_quant = torch.where(w_clamped.abs() < threshold,
|
|
torch.zeros_like(w_clamped),
|
|
torch.sign(w_clamped))
|
|
else:
|
|
w_quant = torch.round(w_clamped)
|
|
|
|
# Reshape back to original (trim padding)
|
|
w_quant = w_quant.reshape(-1)[:n].reshape(original_shape)
|
|
|
|
return w_quant, scale
|
|
|
|
|
|
def soft_ternary(w, group_size=128, temperature=0.05, threshold=0.5):
|
|
"""Soft ternary quantization with differentiable tanh proxy.
|
|
|
|
Uses tanh(w / temperature) as a smooth approximation to ternary {-1, 0, +1}.
|
|
This provides smooth gradients for weight optimization.
|
|
|
|
Args:
|
|
w: weight tensor of any shape
|
|
group_size: number of weights per quantization group
|
|
temperature: controls sharpness (lower = closer to hard ternary)
|
|
threshold: deadzone threshold for soft zero region
|
|
|
|
Returns:
|
|
w_soft: soft-quantized weights in original shape
|
|
scale: per-group scale factors, shape (num_groups, 1)
|
|
"""
|
|
original_shape = w.shape
|
|
w_flat = w.reshape(-1)
|
|
|
|
# Pad to multiple of group_size if needed
|
|
n = w_flat.numel()
|
|
pad = (group_size - n % group_size) % group_size
|
|
if pad > 0:
|
|
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
|
|
|
w_groups = w_flat.reshape(-1, group_size)
|
|
|
|
# Scale: mean(|w|) per group
|
|
abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6)
|
|
scale = abs_mean
|
|
|
|
# Normalize, apply tanh for soft ternary
|
|
w_norm = w_groups / scale
|
|
w_clamped = w_norm.clamp(-1.0, 1.0)
|
|
if threshold > 0:
|
|
# Soft deadzone: shrink values near zero
|
|
w_deadzone = w_clamped * (1.0 - threshold / (w_clamped.abs() + threshold))
|
|
else:
|
|
w_deadzone = w_clamped
|
|
w_soft = torch.tanh(w_deadzone / temperature) # (-1, 1) smooth
|
|
|
|
# Reshape back to original (trim padding)
|
|
w_soft = w_soft.reshape(-1)[:n].reshape(original_shape)
|
|
|
|
return w_soft, scale
|
|
|
|
|
|
def ternary_quantize_learnable(w, scale, group_size=128, threshold=0.5):
|
|
"""Quantize weights to {-1, 0, +1} using a learnable per-group scale.
|
|
|
|
The scale is a learnable nn.Parameter, allowing the network to optimize
|
|
the quantization scale during training.
|
|
|
|
Uses STE: forward pass uses ternary weights, backward pass gradients
|
|
flow through the continuous weights.
|
|
|
|
Args:
|
|
w: weight tensor of any shape
|
|
scale: learnable per-group scale, shape (num_groups,)
|
|
group_size: number of weights per quantization group
|
|
threshold: deadzone threshold (0 < t < 1)
|
|
|
|
Returns:
|
|
w_dequant: dequantized weights in original shape (for forward pass)
|
|
"""
|
|
original_shape = w.shape
|
|
w_flat = w.reshape(-1)
|
|
|
|
n = w_flat.numel()
|
|
pad = (group_size - n % group_size) % group_size
|
|
if pad > 0:
|
|
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
|
|
|
w_groups = w_flat.reshape(-1, group_size)
|
|
|
|
# Normalize using learnable scale
|
|
scale = scale.to(w.device) # ensure scale is on same device as weights
|
|
scale_expanded = scale.unsqueeze(-1).expand(-1, group_size)
|
|
w_norm = w_groups / scale_expanded.clamp(min=1e-6)
|
|
w_clamped = w_norm.clamp(-1.0, 1.0)
|
|
|
|
if threshold > 0:
|
|
w_quant = torch.where(w_clamped.abs() < threshold,
|
|
torch.zeros_like(w_clamped),
|
|
torch.sign(w_clamped))
|
|
else:
|
|
w_quant = torch.round(w_clamped)
|
|
|
|
# Dequantize: w_dequant = w_quant * scale
|
|
w_dequant = w_quant * scale_expanded
|
|
|
|
# Reshape back to original (trim padding)
|
|
w_dequant = w_dequant.reshape(-1)[:n].reshape(original_shape)
|
|
|
|
# STE: w_dequant = w + (w_dequant - w).detach()
|
|
# Gradients flow through w, not through the quantization
|
|
w_dequant = w + (w_dequant - w).detach()
|
|
|
|
return w_dequant
|
|
|
|
|
|
def ternary_dequantize(w_quant, scale, group_size=128):
|
|
"""Reconstruct weights from ternary codes and scales.
|
|
|
|
Groups are formed the same way as in ternary_quantize.
|
|
"""
|
|
original_shape = w_quant.shape
|
|
w_flat = w_quant.reshape(-1)
|
|
|
|
n = w_flat.numel()
|
|
pad = (group_size - n % group_size) % group_size
|
|
if pad > 0:
|
|
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
|
|
|
w_groups = w_flat.reshape(-1, group_size)
|
|
w_dequant = w_groups * scale
|
|
|
|
return w_dequant.reshape(-1)[:n].reshape(original_shape)
|
|
|
|
|
|
def activation_quantize(x, bits=8):
|
|
"""Quantize activations to INT8 with per-token absmax scaling.
|
|
|
|
Args:
|
|
x: activation tensor
|
|
bits: number of bits (default 8)
|
|
|
|
Returns:
|
|
x_quant: quantized then dequantized activations (same dtype as input)
|
|
"""
|
|
if bits == 16:
|
|
return x # no quantization
|
|
|
|
max_val = x.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
|
|
max_q = 2 ** (bits - 1) - 1 # 127 for INT8
|
|
scale = max_val / max_q
|
|
x_q = (x / scale).round().clamp(-max_q, max_q)
|
|
return x_q * scale
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# BitLinear — Ternary Quantized Linear Layer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class BitLinear(nn.Module):
|
|
"""Linear layer with ternary quantized weights and INT8 activations.
|
|
|
|
Uses straight-through estimator (STE) for gradient flow through quantization.
|
|
Lambda warmup gradually introduces quantization during training.
|
|
|
|
The forward pass:
|
|
1. Quantize weights to ternary
|
|
2. (Optional) Quantize activations to INT8
|
|
3. Forward pass with quantized values
|
|
4. STE: gradients flow through unquantized weights
|
|
5. Lambda warmup: blend between FP and quantized forward
|
|
"""
|
|
|
|
def __init__(self, in_features, out_features, bias=True, group_size=128,
|
|
activation_bits=8, threshold=0.5, soft_quant=False):
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.group_size = group_size
|
|
self.activation_bits = activation_bits
|
|
self.threshold = threshold
|
|
self.soft_quant = soft_quant # use tanh proxy for smooth gradients
|
|
|
|
# FP32 weights (learned in full precision, quantized at forward time)
|
|
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02)
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.zeros(out_features, dtype=DEFAULTS['dtype']))
|
|
else:
|
|
self.bias = None
|
|
|
|
# Lambda for warmup (set externally or via training loop)
|
|
self.register_buffer('lambda_', torch.tensor(0.0))
|
|
|
|
def forward(self, x):
|
|
lambda_ = self.lambda_.item()
|
|
|
|
if lambda_ <= 0.0:
|
|
# No quantization — standard linear
|
|
out = F.linear(x, self.weight, self.bias)
|
|
return out
|
|
|
|
# Quantize weights (STE: gradients flow through FP weights)
|
|
w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size,
|
|
threshold=self.threshold)
|
|
w_dequant = ternary_dequantize(w_quant, scale, self.group_size)
|
|
|
|
# Quantize activations (optional)
|
|
if self.activation_bits < 16:
|
|
x_q = activation_quantize(x, self.activation_bits)
|
|
else:
|
|
x_q = x
|
|
|
|
# Quantized forward
|
|
out_quant = F.linear(x_q, w_dequant, self.bias)
|
|
|
|
# Straight-through estimator with lambda warmup:
|
|
# out = out_fp + lambda * (out_quant - out_fp).detach()
|
|
# When lambda=0: pure FP forward (no quantization)
|
|
# When lambda=1: quantized forward, gradients through FP (full STE)
|
|
out_fp = F.linear(x, self.weight, self.bias)
|
|
out = out_fp + lambda_ * (out_quant - out_fp).detach()
|
|
|
|
return out
|
|
|
|
def extra_repr(self):
|
|
return (f"in_features={self.in_features}, out_features={self.out_features}, "
|
|
f"group_size={self.group_size}, activation_bits={self.activation_bits}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model Surgery — Replace Linear with BitLinear
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
|
exclude_embeddings=True, threshold=0.5,
|
|
soft_quant=False):
|
|
"""Replace all nn.Linear layers in model with BitLinear.
|
|
|
|
Args:
|
|
model: HuggingFace model
|
|
group_size: quantization group size
|
|
activation_bits: activation quantization bits (16 = no quant)
|
|
exclude_embeddings: don't replace lm_head/embedding (usually)
|
|
threshold: deadzone threshold for ternary quantization (0 < t < 1)
|
|
soft_quant: use tanh proxy for smooth gradients (vs hard ternary)
|
|
"""
|
|
count = 0
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, nn.Linear):
|
|
# Skip embedding layers if requested
|
|
if exclude_embeddings and ('embed_tokens' in name or 'lm_head' in name):
|
|
continue
|
|
|
|
# Create BitLinear with same dimensions and device
|
|
bit_linear = BitLinear(
|
|
module.in_features,
|
|
module.out_features,
|
|
bias=module.bias is not None,
|
|
group_size=group_size,
|
|
activation_bits=activation_bits,
|
|
threshold=threshold,
|
|
soft_quant=soft_quant,
|
|
)
|
|
|
|
# Initialize from FP weights (critical for warmup to work)
|
|
with torch.no_grad():
|
|
bit_linear.weight.data = module.weight.clone()
|
|
if module.bias is not None:
|
|
bit_linear.bias.data = module.bias.clone()
|
|
|
|
# Replace in parent module
|
|
parent_name = ".".join(name.split(".")[:-1])
|
|
child_name = name.split(".")[-1]
|
|
parent = getattr(model, parent_name, None) if parent_name else model
|
|
|
|
# Need to walk to the actual parent
|
|
parts = name.split(".")
|
|
obj = model
|
|
for p in parts[:-1]:
|
|
obj = getattr(obj, p)
|
|
setattr(obj, parts[-1], bit_linear)
|
|
count += 1
|
|
|
|
return count
|
|
|
|
|
|
def set_lambda(model, value):
|
|
"""Set lambda on all BitLinear layers."""
|
|
for module in model.modules():
|
|
if isinstance(module, BitLinear):
|
|
module.lambda_.fill_(value)
|
|
|
|
|
|
def get_quant_stats(model):
|
|
"""Get quantization statistics across all BitLinear layers.
|
|
|
|
Returns dict with fraction of weights at {-1, 0, +1}.
|
|
"""
|
|
total = 0
|
|
count_neg1 = 0
|
|
count_zero = 0
|
|
count_pos1 = 0
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, BitLinear):
|
|
w_q, _ = ternary_quantize(module.weight, group_size=module.group_size,
|
|
threshold=module.threshold)
|
|
n = w_q.numel()
|
|
total += n
|
|
count_neg1 += (w_q == -1).sum().item()
|
|
count_zero += (w_q == 0).sum().item()
|
|
count_pos1 += (w_q == 1).sum().item()
|
|
|
|
if total == 0:
|
|
return {"-1": 0, "0": 0, "+1": 0}
|
|
|
|
return {
|
|
"-1": count_neg1 / total,
|
|
"0": count_zero / total,
|
|
"+1": count_pos1 / total,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data Loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_train_data(dataset_name="roneneldan/TinyStories", tokenizer=None,
|
|
seq_len=256, batch_size=4, device="cpu"):
|
|
"""Create a streaming dataloader for training.
|
|
|
|
Yields tensors of shape (batch_size, seq_len) from concatenated text.
|
|
"""
|
|
dataset = load_dataset(dataset_name, split="train", streaming=True)
|
|
|
|
def tokenize_and_batch():
|
|
buffer = []
|
|
target_size = seq_len * batch_size
|
|
|
|
for sample in dataset:
|
|
tokens = tokenizer(sample["text"], truncation=False)["input_ids"]
|
|
buffer.extend(tokens)
|
|
|
|
while len(buffer) >= target_size:
|
|
chunk = buffer[:target_size]
|
|
buffer = buffer[target_size:]
|
|
# Reshape to (batch_size, seq_len)
|
|
x = torch.tensor(chunk, dtype=torch.long, device=device).reshape(batch_size, seq_len)
|
|
yield x
|
|
|
|
return tokenize_and_batch()
|
|
|
|
|
|
def load_eval_data(eval_path):
|
|
"""Load pre-tokenized eval data."""
|
|
with open(eval_path, "r") as f:
|
|
input_ids = json.load(f)
|
|
return torch.tensor(input_ids, dtype=torch.long)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Evaluation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, eval_ids, tokenizer, seq_len=128, device="cpu"):
|
|
"""Compute perplexity on eval data."""
|
|
model.eval()
|
|
nll = 0.0
|
|
n_tokens = 0
|
|
|
|
# Move eval data to device in chunks
|
|
for i in range(0, len(eval_ids) - seq_len, seq_len):
|
|
chunk = eval_ids[i : i + seq_len + 1].to(device)
|
|
# Shape: (1, seq_len+1) for batch dim
|
|
chunk = chunk.unsqueeze(0)
|
|
|
|
logits = model(chunk[:, :-1]).logits # (1, seq_len, vocab)
|
|
loss = F.cross_entropy(
|
|
logits.reshape(-1, logits.size(-1)),
|
|
chunk[:, 1:].reshape(-1),
|
|
reduction="sum"
|
|
)
|
|
nll += loss.item()
|
|
n_tokens += seq_len
|
|
|
|
model.train()
|
|
avg_nll = nll / n_tokens
|
|
ppl = math.exp(avg_nll)
|
|
bpb = avg_nll / math.log(2) # bits per byte
|
|
return ppl, bpb, avg_nll
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Training Loop
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def train(args):
|
|
"""Main training loop with ternary quantization."""
|
|
device = args.device
|
|
start_time = time.time()
|
|
|
|
# ---- Load model and tokenizer ----
|
|
print(f"\n{'='*60}")
|
|
print(f" TERNARY QUANTIZATION TRAINING")
|
|
print(f"{'='*60}")
|
|
print(f" Model: {args.model_name}")
|
|
print(f" Device: {device}")
|
|
print(f" Steps: {args.steps}")
|
|
print(f" Batch size: {args.batch_size}")
|
|
print(f" Seq length: {args.seq_len}")
|
|
print(f" LR: {args.learning_rate}")
|
|
print(f" Group size: {args.group_size}")
|
|
print(f" Act bits: {args.activation_bits}")
|
|
print(f" Warmup: {args.quant_warmup_steps} steps")
|
|
print(f"{'='*60}\n")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
print(f"Loading model... (this may take a moment)")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_name,
|
|
torch_dtype=DEFAULTS['dtype'],
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
model.to(device)
|
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
print(f"Model loaded: {total_params:,} parameters")
|
|
|
|
# ---- Replace Linear with BitLinear ----
|
|
print(f"\nReplacing Linear layers with BitLinear (group_size={args.group_size})...")
|
|
n_replaced = replace_linears_with_bitlinear(
|
|
model,
|
|
group_size=args.group_size,
|
|
activation_bits=args.activation_bits,
|
|
threshold=args.threshold,
|
|
soft_quant=args.soft_quant,
|
|
)
|
|
print(f"Replaced {n_replaced} Linear layers with BitLinear")
|
|
|
|
# Count ternary params
|
|
ternary_params = sum(
|
|
p.numel() for m in model.modules()
|
|
if isinstance(m, BitLinear) for p in m.parameters()
|
|
)
|
|
print(f"Ternary-quantized parameters: {ternary_params:,}")
|
|
|
|
# ---- Load eval data ----
|
|
eval_ids = load_eval_data(args.eval_data_path)
|
|
print(f"Eval data: {len(eval_ids):,} tokens")
|
|
|
|
# ---- Optimizer ----
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=args.learning_rate,
|
|
weight_decay=0.01,
|
|
)
|
|
|
|
# ---- Training loop ----
|
|
data_iter = load_train_data(
|
|
dataset_name=args.train_dataset,
|
|
tokenizer=tokenizer,
|
|
seq_len=args.seq_len,
|
|
batch_size=args.batch_size,
|
|
device=device,
|
|
)
|
|
|
|
best_ppl = float("inf")
|
|
step_times = []
|
|
|
|
# Open log file
|
|
log_path = Path(args.log_file)
|
|
# Check if we need to write header
|
|
write_header = not log_path.exists() or log_path.stat().st_size == 0
|
|
if write_header:
|
|
with open(log_path, "w") as f:
|
|
f.write("step\tlambda\ttrain_loss\ttrain_ppl\teval_ppl\teval_bpb\tlr\ttime_s\tbest_ppl\tq_neg1\tq_zero\tq_pos1\n")
|
|
|
|
print(f"\n{'─'*60}")
|
|
print(f" STEP LAMBDA LOSS PPL EVAL_PPL LR TIME")
|
|
print(f"{'─'*60}")
|
|
|
|
for step in range(1, args.steps + 1):
|
|
step_start = time.time()
|
|
|
|
# ---- Lambda warmup ----
|
|
if args.warmup_schedule == "plateau":
|
|
# Plateau warmup: hold lambda at each level for plateau_steps
|
|
levels = int(args.plateau_max / 0.05) # 0.05 increments
|
|
plateau_size = args.plateau_steps
|
|
level = min(step // plateau_size, levels)
|
|
lambda_ = min(level * 0.05, args.plateau_max)
|
|
else:
|
|
# Linear warmup
|
|
lambda_ = min(step / args.quant_warmup_steps, 1.0)
|
|
set_lambda(model, lambda_)
|
|
|
|
# ---- LR warmup ----
|
|
if step <= args.warmup_steps:
|
|
lr_scale = step / args.warmup_steps
|
|
else:
|
|
lr_scale = 1.0
|
|
for param_group in optimizer.param_groups:
|
|
param_group["lr"] = args.learning_rate * lr_scale
|
|
current_lr = args.learning_rate * lr_scale
|
|
|
|
# ---- Forward pass ----
|
|
model.train()
|
|
try:
|
|
batch = next(data_iter)
|
|
except StopIteration:
|
|
print("\n Dataset exhausted, restarting...")
|
|
data_iter = load_train_data(
|
|
dataset_name=args.train_dataset,
|
|
tokenizer=tokenizer,
|
|
seq_len=args.seq_len,
|
|
batch_size=args.batch_size,
|
|
device=device,
|
|
)
|
|
batch = next(data_iter)
|
|
|
|
# batch: (batch_size, seq_len)
|
|
# Shift for teacher forcing
|
|
x = batch[:, :-1]
|
|
y = batch[:, 1:]
|
|
|
|
logits = model(x).logits
|
|
loss = F.cross_entropy(
|
|
logits.reshape(-1, logits.size(-1)),
|
|
y.reshape(-1),
|
|
)
|
|
|
|
train_ppl = math.exp(loss.item())
|
|
|
|
# ---- Backward pass ----
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
|
|
step_time = time.time() - step_start
|
|
step_times.append(step_time)
|
|
avg_step_time = sum(step_times[-20:]) / len(step_times[-20:])
|
|
|
|
# ---- Eval ----
|
|
eval_ppl = "-"
|
|
eval_bpb = "-"
|
|
quant_stats = None
|
|
if step % args.eval_every == 0 or step == args.steps:
|
|
ppl, bpb, _ = evaluate(model, eval_ids, tokenizer,
|
|
seq_len=128, device=device)
|
|
eval_ppl = f"{ppl:.1f}"
|
|
eval_bpb = f"{bpb:.2f}"
|
|
if ppl < best_ppl:
|
|
best_ppl = ppl
|
|
|
|
# Quantization stats
|
|
quant_stats = get_quant_stats(model)
|
|
|
|
elapsed = time.time() - start_time
|
|
eta = avg_step_time * (args.steps - step)
|
|
q_str = f"[-1:{quant_stats['-1']:.2f} 0:{quant_stats['0']:.2f} +1:{quant_stats['+1']:.2f}]"
|
|
print(f" {step:5d} {lambda_:.3f} {loss.item():.4f} {train_ppl:7.1f} "
|
|
f"{eval_ppl:>8s} {current_lr:.2e} {elapsed:.0f}s (ETA {eta:.0f}s)")
|
|
print(f" eval_bpb={eval_bpb} best_ppl={best_ppl:.1f} "
|
|
f"step_time={avg_step_time:.2f}s quant={q_str}")
|
|
|
|
# Log to TSV
|
|
with open(log_path, "a") as f:
|
|
q = quant_stats or {"-1": 0, "0": 0, "+1": 0}
|
|
f.write(f"{step}\t{lambda_:.4f}\t{loss.item():.6f}\t{train_ppl:.2f}\t"
|
|
f"{ppl:.2f}\t{bpb:.4f}\t{current_lr:.2e}\t{elapsed:.1f}\t{best_ppl:.2f}\t"
|
|
f"{q['-1']:.4f}\t{q['0']:.4f}\t{q['+1']:.4f}\n")
|
|
else:
|
|
elapsed = time.time() - start_time
|
|
eta = avg_step_time * (args.steps - step)
|
|
if step % 25 == 0 or step == 1:
|
|
print(f" {step:5d} {lambda_:.3f} {loss.item():.4f} {train_ppl:7.1f} "
|
|
f"{'—':>8s} {current_lr:.2e} {elapsed:.0f}s (ETA {eta:.0f}s)")
|
|
|
|
# Log to TSV (no eval)
|
|
with open(log_path, "a") as f:
|
|
f.write(f"{step}\t{lambda_:.4f}\t{loss.item():.6f}\t{train_ppl:.2f}\t"
|
|
f"-\t-\t{current_lr:.2e}\t{elapsed:.1f}\t{best_ppl:.2f}\t-\t-\t-\n")
|
|
|
|
# ---- Summary ----
|
|
total_time = time.time() - start_time
|
|
print(f"\n{'='*60}")
|
|
print(f" TRAINING COMPLETE")
|
|
print(f" Total time: {total_time:.0f}s ({total_time/60:.1f} min)")
|
|
print(f" Best eval PPL: {best_ppl:.1f}")
|
|
print(f" Final lambda: {lambda_:.3f}")
|
|
print(f" Avg step time: {sum(step_times)/len(step_times):.2f}s")
|
|
print(f" Results logged to: {log_path}")
|
|
print(f"{'='*60}\n")
|
|
|
|
return best_ppl
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Ternary quantization training")
|
|
parser.add_argument("--model-name", default=DEFAULTS["model_name"])
|
|
parser.add_argument("--device", default=DEFAULTS["device"])
|
|
parser.add_argument("--steps", type=int, default=DEFAULTS["steps"])
|
|
parser.add_argument("--batch-size", type=int, default=DEFAULTS["batch_size"])
|
|
parser.add_argument("--seq-len", type=int, default=DEFAULTS["seq_len"])
|
|
parser.add_argument("--lr", type=float, default=DEFAULTS["learning_rate"], dest="learning_rate")
|
|
parser.add_argument("--warmup-steps", type=int, default=DEFAULTS["warmup_steps"], dest="warmup_steps")
|
|
parser.add_argument("--eval-every", type=int, default=DEFAULTS["eval_every"], dest="eval_every")
|
|
parser.add_argument("--group-size", type=int, default=DEFAULTS["group_size"], dest="group_size")
|
|
parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps")
|
|
parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits")
|
|
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
|
|
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
|
|
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
|
|
parser.add_argument("--warmup-schedule", default=DEFAULTS["warmup_schedule"], dest="warmup_schedule")
|
|
parser.add_argument("--plateau-steps", type=int, default=DEFAULTS["plateau_steps"], dest="plateau_steps")
|
|
parser.add_argument("--plateau-max", type=float, default=DEFAULTS["plateau_max"], dest="plateau_max")
|
|
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
|
|
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
|
|
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
|
|
args = parser.parse_args()
|
|
|
|
train(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|