Add ternary QAT training pipeline: prepare.py (data/eval), train.py (quantization/training), program.md (agent instructions), autoresearch.sh (loop)
This commit is contained in:
Executable
+122
@@ -0,0 +1,122 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# autoresearch.sh — Autonomous experiment loop
|
||||||
|
#
|
||||||
|
# Usage: ./autoresearch.sh [model] [time_budget_seconds]
|
||||||
|
#
|
||||||
|
# This script runs the autoresearch loop:
|
||||||
|
# 1. Agent proposes a change to train.py
|
||||||
|
# 2. Git commit the change
|
||||||
|
# 3. Run training for fixed time budget
|
||||||
|
# 4. Extract val_ppl from results.tsv
|
||||||
|
# 5. If improved → keep; if worse → git reset
|
||||||
|
# 6. Repeat
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
MODEL="${1:-HuggingFaceTB/SmolLM-135M}"
|
||||||
|
TIME_BUDGET="${2:-300}"
|
||||||
|
RESULTS_FILE="results.tsv"
|
||||||
|
|
||||||
|
# Ensure git is initialized
|
||||||
|
if [ ! -d ".git" ]; then
|
||||||
|
git init
|
||||||
|
git add -A
|
||||||
|
git commit -m "Initial commit"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "=== Autoresearch Loop ==="
|
||||||
|
echo "Model: $MODEL"
|
||||||
|
echo "Time budget: ${TIME_BUDGET}s"
|
||||||
|
echo "Results: $RESULTS_FILE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Function to get best PPL from results.tsv
|
||||||
|
get_best_ppl() {
|
||||||
|
if [ ! -f "$RESULTS_FILE" ]; then
|
||||||
|
echo "999999"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
# Get the best_ppl from the last successful run (column 7)
|
||||||
|
tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to get last status
|
||||||
|
get_last_status() {
|
||||||
|
if [ ! -f "$RESULTS_FILE" ]; then
|
||||||
|
echo "none"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initial commit if not committed
|
||||||
|
git add -A
|
||||||
|
git commit -m "Initial setup" --allow-empty 2>/dev/null || true
|
||||||
|
|
||||||
|
BEST_PPL=$(get_best_ppl)
|
||||||
|
RUN_NUM=0
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
RUN_NUM=$((RUN_NUM + 1))
|
||||||
|
echo ""
|
||||||
|
echo "========================================"
|
||||||
|
echo "RUN #$RUN_NUM"
|
||||||
|
echo "Current best PPL: $BEST_PPL"
|
||||||
|
echo "========================================"
|
||||||
|
|
||||||
|
# Save current state
|
||||||
|
PREV_COMMIT=$(git rev-parse HEAD)
|
||||||
|
|
||||||
|
# Prompt the agent to make a change
|
||||||
|
# In production, this would call the LLM agent
|
||||||
|
# For now, we just run with current config
|
||||||
|
echo "Running training..."
|
||||||
|
|
||||||
|
# Run training
|
||||||
|
START_TIME=$(date +%s)
|
||||||
|
|
||||||
|
python3 train.py \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--device auto \
|
||||||
|
--time-budget "$TIME_BUDGET" \
|
||||||
|
--total-steps 2000 \
|
||||||
|
--eval-every 500 \
|
||||||
|
--batch-size 2 \
|
||||||
|
--max-samples 10000 \
|
||||||
|
--seq-length 1024 \
|
||||||
|
--description "autoresearch-run-$RUN_NUM" \
|
||||||
|
2>&1 | tee "run-${RUN_NUM}.log" || true
|
||||||
|
|
||||||
|
END_TIME=$(date +%s)
|
||||||
|
ELAPSED=$((END_TIME - START_TIME))
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
STATUS=$(get_last_status)
|
||||||
|
NEW_PPL=$(get_best_ppl)
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Run #$RUN_NUM completed in ${ELAPSED}s"
|
||||||
|
echo "Status: $STATUS"
|
||||||
|
echo "Best PPL: $NEW_PPL"
|
||||||
|
|
||||||
|
if [ "$STATUS" = "success" ]; then
|
||||||
|
# Compare with previous best
|
||||||
|
if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then
|
||||||
|
echo "IMPROVED! Keeping changes."
|
||||||
|
BEST_PPL=$NEW_PPL
|
||||||
|
git add results.tsv
|
||||||
|
git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL"
|
||||||
|
else
|
||||||
|
echo "No improvement. Reverting."
|
||||||
|
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
||||||
|
git checkout -- results.tsv 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "FAILED. Reverting."
|
||||||
|
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
||||||
|
git checkout -- results.tsv 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Continuing... (Ctrl+C to stop)"
|
||||||
|
done
|
||||||
+324
@@ -0,0 +1,324 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
|
||||||
|
|
||||||
|
This file is NOT modified by the autoresearch agent. All mutable training logic
|
||||||
|
lives in train.py. This file provides:
|
||||||
|
- Dataset loading and tokenization
|
||||||
|
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
|
||||||
|
- Utility functions for the training loop
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Configuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
|
||||||
|
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
|
||||||
|
# Alternative tiny models for ultra-fast iteration:
|
||||||
|
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
|
||||||
|
|
||||||
|
# Default eval dataset
|
||||||
|
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
|
||||||
|
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
|
||||||
|
|
||||||
|
# Default training dataset
|
||||||
|
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
|
||||||
|
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
|
||||||
|
|
||||||
|
# Sequence length for training
|
||||||
|
DEFAULT_SEQ_LENGTH = 2048
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tokenizer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
|
||||||
|
"""Load tokenizer for the given model."""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dataset Wrappers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TokenizedDataset(Dataset):
|
||||||
|
"""Simple dataset that yields batches of tokenized text."""
|
||||||
|
|
||||||
|
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.seq_length = seq_length
|
||||||
|
# Concatenate all texts and chunk into sequences
|
||||||
|
self.samples = self._prepare_samples(data)
|
||||||
|
|
||||||
|
def _prepare_samples(self, data: list) -> list:
|
||||||
|
"""Flatten and chunk text data into fixed-length sequences."""
|
||||||
|
all_tokens = []
|
||||||
|
for item in data:
|
||||||
|
text = item.get("text", item.get("content", str(item)))
|
||||||
|
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
all_tokens.append(self.tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
# Chunk into sequences of seq_length
|
||||||
|
samples = []
|
||||||
|
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
|
||||||
|
samples.append(all_tokens[i : i + self.seq_length])
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
x = torch.tensor(self.samples[idx], dtype=torch.long)
|
||||||
|
return {"input_ids": x[:-1], "labels": x[1:]}
|
||||||
|
|
||||||
|
|
||||||
|
def load_eval_dataset(
|
||||||
|
dataset_name: str = DEFAULT_EVAL_DATASET,
|
||||||
|
config: str = DEFAULT_EVAL_CONFIG,
|
||||||
|
split: str = "validation",
|
||||||
|
tokenizer: AutoTokenizer = None,
|
||||||
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||||
|
):
|
||||||
|
"""Load and tokenize the evaluation dataset."""
|
||||||
|
if dataset_name == "wikitext":
|
||||||
|
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
|
||||||
|
elif dataset_name == "ptb":
|
||||||
|
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
|
||||||
|
else:
|
||||||
|
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
|
||||||
|
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = load_tokenizer()
|
||||||
|
|
||||||
|
# For eval, we want contiguous text without chunking artifacts
|
||||||
|
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
|
||||||
|
all_tokens = []
|
||||||
|
for text in texts:
|
||||||
|
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
|
||||||
|
# Remove empty tokens
|
||||||
|
all_tokens = [t for t in all_tokens if t is not None]
|
||||||
|
|
||||||
|
# Create sequences
|
||||||
|
samples = []
|
||||||
|
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
|
||||||
|
samples.append(all_tokens[i : i + seq_length])
|
||||||
|
|
||||||
|
return samples, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Evaluation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_ppl(
|
||||||
|
model: AutoModelForCausalLM,
|
||||||
|
eval_samples: list,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||||
|
batch_size: int = 4,
|
||||||
|
device: str = "cpu",
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Evaluate perplexity on the given samples.
|
||||||
|
Returns perplexity (lower is better).
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
for i in range(0, len(eval_samples), batch_size):
|
||||||
|
batch_samples = eval_samples[i : i + batch_size]
|
||||||
|
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
|
||||||
|
labels = inputs.clone()
|
||||||
|
|
||||||
|
# Shift for causal LM
|
||||||
|
input_ids = inputs[:, :-1]
|
||||||
|
label_ids = labels[:, 1:]
|
||||||
|
|
||||||
|
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
|
||||||
|
if device_type != "cpu":
|
||||||
|
with torch.amp.autocast(device_type=device_type):
|
||||||
|
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
||||||
|
else:
|
||||||
|
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
||||||
|
|
||||||
|
loss = outputs.loss
|
||||||
|
total_loss += loss.item() * label_ids.size(0)
|
||||||
|
total_tokens += label_ids.size(0)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_tokens
|
||||||
|
perplexity = math.exp(avg_loss)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
return perplexity
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_bpb(
|
||||||
|
model: AutoModelForCausalLM,
|
||||||
|
eval_samples: list,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||||
|
batch_size: int = 4,
|
||||||
|
device: str = "cpu",
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Evaluate bits-per-byte (bpb) on the given samples.
|
||||||
|
Returns bpb (lower is better).
|
||||||
|
"""
|
||||||
|
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||||
|
bpb = math.log2(ppl)
|
||||||
|
return bpb
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Training Dataset Loader
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_train_dataset(
|
||||||
|
dataset_name: str = DEFAULT_TRAIN_DATASET,
|
||||||
|
split: str = "train",
|
||||||
|
tokenizer: AutoTokenizer = None,
|
||||||
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||||
|
max_samples: int = None,
|
||||||
|
streaming: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load and prepare the training dataset.
|
||||||
|
Returns a TokenizedDataset.
|
||||||
|
"""
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = load_tokenizer()
|
||||||
|
|
||||||
|
if dataset_name == "HuggingFaceFW/fineweb-edu":
|
||||||
|
# Use the edu-only subset for quality
|
||||||
|
ds = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
name="sample-100BT", # 100B token subset for faster iteration
|
||||||
|
split=split,
|
||||||
|
streaming=streaming,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
|
||||||
|
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
|
||||||
|
else:
|
||||||
|
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Collect texts (handle streaming by taking first N samples)
|
||||||
|
texts = []
|
||||||
|
for idx, item in enumerate(ds):
|
||||||
|
if max_samples and idx >= max_samples:
|
||||||
|
break
|
||||||
|
text = item.get("text", item.get("content", str(item)))
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
raise ValueError(f"No texts loaded from dataset {dataset_name}")
|
||||||
|
|
||||||
|
dataset = TokenizedDataset(texts, tokenizer, seq_length)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model Loading
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_name: str = DEFAULT_MODEL,
|
||||||
|
device: str = "cpu",
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
attn_implementation: str = "sdpa",
|
||||||
|
) -> AutoModelForCausalLM:
|
||||||
|
"""
|
||||||
|
Load a pretrained model from HuggingFace.
|
||||||
|
"""
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map=device if device != "cpu" else None,
|
||||||
|
)
|
||||||
|
if device == "cpu":
|
||||||
|
model = model.to(device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI Entry Point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
|
||||||
|
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
|
||||||
|
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
|
||||||
|
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
|
||||||
|
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
|
||||||
|
parser.add_argument("--device", default="cpu", help="Device to run on")
|
||||||
|
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print(f"Model: {args.model}")
|
||||||
|
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
|
||||||
|
print(f"Seq length: {args.seq_length}")
|
||||||
|
print(f"Batch size: {args.batch_size}")
|
||||||
|
print(f"Device: {args.device}")
|
||||||
|
print(f"Train samples: {args.train_samples}")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
print(f"Loading tokenizer for {args.model}...")
|
||||||
|
tokenizer = load_tokenizer(args.model)
|
||||||
|
|
||||||
|
print(f"Loading eval dataset {args.eval_dataset}...")
|
||||||
|
eval_samples, _ = load_eval_dataset(
|
||||||
|
dataset_name=args.eval_dataset,
|
||||||
|
config=args.eval_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
seq_length=args.seq_length,
|
||||||
|
)
|
||||||
|
print(f" Eval samples: {len(eval_samples)}")
|
||||||
|
|
||||||
|
print(f"Loading training dataset ({args.train_samples} samples)...")
|
||||||
|
train_dataset = load_train_dataset(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
seq_length=args.seq_length,
|
||||||
|
max_samples=args.train_samples,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
print(f" Train samples: {len(train_dataset)}")
|
||||||
|
|
||||||
|
# Optional: load model and run eval
|
||||||
|
if args.device != "skip-model":
|
||||||
|
print(f"Loading model {args.model}...")
|
||||||
|
model = load_model(args.model, device=args.device)
|
||||||
|
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
print(f"Evaluating perplexity...")
|
||||||
|
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
|
||||||
|
print(f" Perplexity: {ppl:.2f}")
|
||||||
|
bpb = math.log2(ppl)
|
||||||
|
print(f" BPB: {bpb:.4f}")
|
||||||
+78
@@ -0,0 +1,78 @@
|
|||||||
|
# program.md — Instructions for the Autoresearch Agent
|
||||||
|
|
||||||
|
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
|
||||||
|
|
||||||
|
## Your Goal
|
||||||
|
|
||||||
|
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2.
|
||||||
|
|
||||||
|
## File Boundaries
|
||||||
|
|
||||||
|
- **MUTABLE**: `train.py` — You may modify this file. It contains:
|
||||||
|
- `BitLinear` layer (quantization logic)
|
||||||
|
- Quantization schedules (lambda warmup)
|
||||||
|
- Training loop
|
||||||
|
- Hyperparameters (LR, batch size, group size, etc.)
|
||||||
|
|
||||||
|
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains:
|
||||||
|
- Dataset loading and tokenization
|
||||||
|
- Evaluation harness (WikiText PPL)
|
||||||
|
- Model loading utilities
|
||||||
|
|
||||||
|
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run.
|
||||||
|
|
||||||
|
## Experiment Protocol
|
||||||
|
|
||||||
|
1. Read `results.tsv` to understand what has been tried
|
||||||
|
2. Read `train.py` to understand current implementation
|
||||||
|
3. Propose ONE focused change to `train.py`
|
||||||
|
4. The change will be committed and a training run will execute (~5 minutes)
|
||||||
|
5. Results are logged to `results.tsv`
|
||||||
|
6. If improved (lower val_ppl) → change is kept
|
||||||
|
7. If equal or worse → git reset to previous commit
|
||||||
|
|
||||||
|
## What to Explore
|
||||||
|
|
||||||
|
### Priority 1: Quantization Schedule
|
||||||
|
- Lambda warmup shape: linear, cosine, exponential
|
||||||
|
- Warmup step counts: 200, 500, 1000, 2000, 5000
|
||||||
|
- Two-phase warmup (fast initial + slow final)
|
||||||
|
|
||||||
|
### Priority 2: Learning Rate
|
||||||
|
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
|
||||||
|
- LR schedule: constant, linear decay, cosine decay
|
||||||
|
|
||||||
|
### Priority 3: Quantization Details
|
||||||
|
- Group size: 32, 64, 128, 256, per-tensor
|
||||||
|
- Scale initialization: mean-based vs absmax-based
|
||||||
|
- Ternary threshold adjustments
|
||||||
|
|
||||||
|
### Priority 4: Deadzone Recovery
|
||||||
|
- Tequila-style reactivation (learnable lambda for deadzone weights)
|
||||||
|
- Bias injection for zero-valued weights
|
||||||
|
- Gradient scaling for deadzone weights
|
||||||
|
|
||||||
|
### Priority 5: Distillation
|
||||||
|
- OFF loss (cosine similarity between FP and ternary features)
|
||||||
|
- Logits distillation weight
|
||||||
|
- Feature distillation weight
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
|
||||||
|
- Keep experiments focused — ONE change per iteration
|
||||||
|
- Always maintain working code — syntax errors waste time
|
||||||
|
- Use SmolLM-135M or Llama-3.2-1B for fast iteration
|
||||||
|
- Target metric: val_ppl (lower is better)
|
||||||
|
- Time budget: 5 minutes per experiment
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization
|
||||||
|
- Warmup quantization prevents catastrophic accuracy loss at the start of training
|
||||||
|
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
|
||||||
|
- Per-group quantization scales are essential for handling outlier weights
|
||||||
|
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
|
||||||
|
|
||||||
|
## NEVER STOP
|
||||||
|
|
||||||
|
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision.
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
|
||||||
|
@@ -0,0 +1,664 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
train.py — Ternary quantization training loop (MUTABLE in autoresearch loop).
|
||||||
|
|
||||||
|
This is the single mutable file for the autoresearch agent. The agent may modify:
|
||||||
|
- Quantization schedule (lambda warmup shape, length)
|
||||||
|
- Learning rate and schedule
|
||||||
|
- BitLinear implementation details (group size, deadzone recovery)
|
||||||
|
- Loss function (distillation weights, OFF loss)
|
||||||
|
- Training hyperparameters
|
||||||
|
|
||||||
|
DO NOT modify prepare.py — all data loading and evaluation lives there.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
import git
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
# Import from prepare.py (read-only)
|
||||||
|
from prepare import (
|
||||||
|
load_tokenizer,
|
||||||
|
load_train_dataset,
|
||||||
|
load_eval_dataset,
|
||||||
|
load_model,
|
||||||
|
evaluate_ppl,
|
||||||
|
evaluate_bpb,
|
||||||
|
DEFAULT_MODEL,
|
||||||
|
DEFAULT_SEQ_LENGTH,
|
||||||
|
DEFAULT_TRAIN_DATASET,
|
||||||
|
DEFAULT_EVAL_DATASET,
|
||||||
|
DEFAULT_EVAL_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quantization Utilities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def ternary_quantize_weight(w: torch.Tensor, group_size: int = 128) -> tuple:
|
||||||
|
"""
|
||||||
|
Ternary quantize weights to {-1, 0, +1} with per-group scales.
|
||||||
|
|
||||||
|
Quantization formula (from HF blog):
|
||||||
|
scale = 1.0 / w.abs().mean()
|
||||||
|
w_scaled = w * scale
|
||||||
|
w_clamped = clamp(w_scaled, -1, 1)
|
||||||
|
w_rounded = round(w_clamped) # -> {-1, 0, +1}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
w_quant: ternary quantized weights
|
||||||
|
scales: per-group scales for dequantization
|
||||||
|
"""
|
||||||
|
original_shape = w.shape
|
||||||
|
# Flatten for group-wise processing
|
||||||
|
w_flat = w.flatten().float()
|
||||||
|
|
||||||
|
num_groups = (len(w_flat) + group_size - 1) // group_size
|
||||||
|
scales = []
|
||||||
|
quant_parts = []
|
||||||
|
|
||||||
|
for i in range(num_groups):
|
||||||
|
start = i * group_size
|
||||||
|
end = min(start + group_size, len(w_flat))
|
||||||
|
group = w_flat[start:end]
|
||||||
|
|
||||||
|
# Compute scale: 1 / mean(|w|)
|
||||||
|
scale = 1.0 / (group.abs().mean() + 1e-6)
|
||||||
|
scales.append(scale)
|
||||||
|
|
||||||
|
# Scale, clamp, round
|
||||||
|
w_scaled = group * scale
|
||||||
|
w_clamped = w_scaled.clamp(-1.0, 1.0)
|
||||||
|
w_rounded = w_clamped.round().sign() # -> {-1, 0, +1}
|
||||||
|
|
||||||
|
quant_parts.append(w_rounded)
|
||||||
|
|
||||||
|
w_quant = torch.cat(quant_parts, dim=0).reshape(original_shape)
|
||||||
|
scales = torch.tensor(scales, device=w.device, dtype=w.dtype)
|
||||||
|
|
||||||
|
return w_quant, scales
|
||||||
|
|
||||||
|
|
||||||
|
def ternary_dequantize(w_quant: torch.Tensor, scales: torch.Tensor, group_size: int = 128) -> torch.Tensor:
|
||||||
|
"""Dequantize ternary weights back to original shape."""
|
||||||
|
original_shape = w_quant.shape
|
||||||
|
w_flat = w_quant.flatten().float()
|
||||||
|
|
||||||
|
num_groups = (len(w_flat) + group_size - 1) // group_size
|
||||||
|
dequant_parts = []
|
||||||
|
|
||||||
|
for i in range(num_groups):
|
||||||
|
start = i * group_size
|
||||||
|
end = min(start + group_size, len(w_flat))
|
||||||
|
group = w_flat[start:end]
|
||||||
|
scale = scales[i]
|
||||||
|
dequant_parts.append(group / (scale + 1e-6))
|
||||||
|
|
||||||
|
return torch.cat(dequant_parts, dim=0).reshape(original_shape).to(w_quant.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BitLinear Layer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class BitLinear(nn.Module):
|
||||||
|
"""
|
||||||
|
Linear layer with ternary weight quantization using STE (Straight-Through Estimator).
|
||||||
|
|
||||||
|
Forward pass uses quantized weights; backward pass uses STE to pass gradients
|
||||||
|
through the quantization operation.
|
||||||
|
|
||||||
|
Supports warmup quantization: gradually introduce quantization during training
|
||||||
|
using lambda scheduling to avoid catastrophic accuracy loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
group_size: int = 128,
|
||||||
|
init_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.group_size = group_size
|
||||||
|
self.init_scale = init_scale
|
||||||
|
|
||||||
|
# FP16 weights (the "learned" weights)
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
# Quantization state
|
||||||
|
self.register_buffer("w_quant", torch.zeros(out_features, in_features))
|
||||||
|
self.register_buffer("scales", torch.zeros((out_features * in_features + group_size - 1) // group_size))
|
||||||
|
self.register_buffer("lambda", torch.tensor(0.0)) # warmup lambda
|
||||||
|
|
||||||
|
def _quantize(self, w: torch.Tensor) -> tuple:
|
||||||
|
"""Quantize weights to ternary."""
|
||||||
|
return ternary_quantize_weight(w, self.group_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass with warmup quantization.
|
||||||
|
|
||||||
|
Lambda is set externally via self.lambda_val before each forward pass.
|
||||||
|
|
||||||
|
Quantization warmup formula:
|
||||||
|
w_eff = w + lambda * (w_dequant - w).detach()
|
||||||
|
|
||||||
|
When lambda=0: output uses full-precision weights
|
||||||
|
When lambda=1: output uses quantized weights (with STE gradients)
|
||||||
|
"""
|
||||||
|
# Get lambda (default 1.0 if not set)
|
||||||
|
lambda_ = getattr(self, 'lambda_val', 1.0)
|
||||||
|
|
||||||
|
# Quantize weights
|
||||||
|
w_quant, scales = self._quantize(self.weight)
|
||||||
|
|
||||||
|
# Dequantize for computation
|
||||||
|
w_dequant = ternary_dequantize(w_quant, scales, self.group_size)
|
||||||
|
|
||||||
|
# STE with warmup: interpolate between FP and quantized
|
||||||
|
# w_eff = w + lambda * (w_dequant - w)
|
||||||
|
# Gradient flows through w (the FP weights), not through quantization
|
||||||
|
if lambda_ > 0:
|
||||||
|
w_eff = self.weight + lambda_ * (w_dequant - self.weight).detach()
|
||||||
|
else:
|
||||||
|
w_eff = self.weight
|
||||||
|
|
||||||
|
# Linear forward
|
||||||
|
x = x.to(w_eff.dtype)
|
||||||
|
out = F.linear(x, w_eff, self.bias)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return (
|
||||||
|
f"in_features={self.in_features}, "
|
||||||
|
f"out_features={self.out_features}, "
|
||||||
|
f"group_size={self.group_size}, "
|
||||||
|
f"bias={self.bias is not None}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model Conversion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def replace_linears_with_bitlinear(
|
||||||
|
model: AutoModelForCausalLM,
|
||||||
|
group_size: int = 128,
|
||||||
|
target_modules: list = None,
|
||||||
|
) -> AutoModelForCausalLM:
|
||||||
|
"""
|
||||||
|
Replace all nn.Linear layers in the model with BitLinear layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: pretrained model
|
||||||
|
group_size: quantization group size
|
||||||
|
target_modules: list of module names to replace. If None, replace all.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model with BitLinear layers
|
||||||
|
"""
|
||||||
|
if target_modules is None:
|
||||||
|
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
parent_name = ".".join(name.split(".")[:-1])
|
||||||
|
module_name = name.split(".")[-1]
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear) and module_name in target_modules:
|
||||||
|
bit_linear = BitLinear(
|
||||||
|
in_features=module.in_features,
|
||||||
|
out_features=module.out_features,
|
||||||
|
bias=module.bias is not None,
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
# Initialize from pretrained weights
|
||||||
|
bit_linear.weight.data = module.weight.data.clone()
|
||||||
|
if module.bias is not None:
|
||||||
|
bit_linear.bias.data = module.bias.data.clone()
|
||||||
|
|
||||||
|
# Replace in parent
|
||||||
|
parent = model
|
||||||
|
for attr in parent_name.split("."):
|
||||||
|
parent = getattr(parent, attr)
|
||||||
|
setattr(parent, module_name, bit_linear)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_bitlinear_modules(model: AutoModelForCausalLM) -> list:
|
||||||
|
"""Get all BitLinear modules in the model."""
|
||||||
|
return [m for m in model.modules() if isinstance(m, BitLinear)]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lambda Warmup Schedules
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def linear_warmup_lambda(step: int, warmup_steps: int) -> float:
|
||||||
|
"""Linear warmup: lambda goes from 0 to 1 over warmup_steps."""
|
||||||
|
return min(step / warmup_steps, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_warmup_lambda(step: int, warmup_steps: int) -> float:
|
||||||
|
"""Cosine warmup: lambda follows cosine schedule over warmup_steps."""
|
||||||
|
if step >= warmup_steps:
|
||||||
|
return 1.0
|
||||||
|
return 0.5 * (1 - math.cos(math.pi * step / warmup_steps))
|
||||||
|
|
||||||
|
|
||||||
|
def exponential_warmup_lambda(step: int, warmup_steps: int) -> float:
|
||||||
|
"""Exponential warmup: lambda grows exponentially over warmup_steps."""
|
||||||
|
if step >= warmup_steps:
|
||||||
|
return 1.0
|
||||||
|
# Exponential from 0 to 1
|
||||||
|
return 1.0 - math.exp(-3.0 * step / warmup_steps)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lambda(step: int, warmup_steps: int, schedule: str = "linear") -> float:
|
||||||
|
"""Get lambda value for the given step and schedule."""
|
||||||
|
if schedule == "linear":
|
||||||
|
return linear_warmup_lambda(step, warmup_steps)
|
||||||
|
elif schedule == "cosine":
|
||||||
|
return cosine_warmup_lambda(step, warmup_steps)
|
||||||
|
elif schedule == "exponential":
|
||||||
|
return exponential_warmup_lambda(step, warmup_steps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown lambda schedule: {schedule}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Training Loop
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def train(
|
||||||
|
model_name: str = DEFAULT_MODEL,
|
||||||
|
train_dataset_name: str = DEFAULT_TRAIN_DATASET,
|
||||||
|
eval_dataset_name: str = DEFAULT_EVAL_DATASET,
|
||||||
|
eval_config: str = DEFAULT_EVAL_CONFIG,
|
||||||
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||||
|
batch_size: int = 4,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
warmup_steps: int = 1000,
|
||||||
|
total_steps: int = 5000,
|
||||||
|
eval_every: int = 500,
|
||||||
|
lambda_schedule: str = "linear",
|
||||||
|
group_size: int = 128,
|
||||||
|
max_train_samples: int = 50000,
|
||||||
|
device: str = "cpu",
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
time_budget: float = 300.0, # seconds (5 minutes default)
|
||||||
|
seed: int = 42,
|
||||||
|
output_dir: str = "./results",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main training loop for ternary quantization-aware training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: HuggingFace model name or path
|
||||||
|
train_dataset_name: Training dataset name
|
||||||
|
eval_dataset_name: Evaluation dataset name
|
||||||
|
eval_config: Evaluation dataset config
|
||||||
|
seq_length: Sequence length for training
|
||||||
|
batch_size: Training batch size
|
||||||
|
learning_rate: Base learning rate
|
||||||
|
warmup_steps: Number of steps for lambda warmup
|
||||||
|
total_steps: Total training steps
|
||||||
|
eval_every: Evaluate every N steps
|
||||||
|
lambda_schedule: Lambda warmup schedule ("linear", "cosine", "exponential")
|
||||||
|
group_size: Quantization group size
|
||||||
|
max_train_samples: Max training samples to load
|
||||||
|
device: Device to train on
|
||||||
|
dtype: Data type
|
||||||
|
time_budget: Max training time in seconds
|
||||||
|
seed: Random seed
|
||||||
|
output_dir: Output directory for results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with training metrics
|
||||||
|
"""
|
||||||
|
# Auto-detect device if not specified
|
||||||
|
if device == "auto":
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
print("Using MPS (Apple Silicon)")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
print("Using CUDA")
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
print("Using CPU")
|
||||||
|
|
||||||
|
# Set seeds
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Track time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
print(f"=== Ternary QAT Training ===")
|
||||||
|
print(f"Model: {model_name}")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
print(f"LR: {learning_rate}, Warmup: {warmup_steps} steps")
|
||||||
|
print(f"Lambda schedule: {lambda_schedule}")
|
||||||
|
print(f"Group size: {group_size}")
|
||||||
|
print(f"Time budget: {time_budget}s")
|
||||||
|
print(f"Total steps: {total_steps}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
print("Loading tokenizer...")
|
||||||
|
tokenizer = load_tokenizer(model_name)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
print(f"Loading model {model_name}...")
|
||||||
|
model = load_model(model_name, device=device, dtype=dtype, attn_implementation="sdpa")
|
||||||
|
model = model.to(dtype)
|
||||||
|
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
print(f" Linear layers replaced: ", end="")
|
||||||
|
|
||||||
|
# Replace Linear with BitLinear
|
||||||
|
model = replace_linears_with_bitlinear(model, group_size=group_size)
|
||||||
|
bitlinears = get_bitlinear_modules(model)
|
||||||
|
print(f"{len(bitlinears)}")
|
||||||
|
|
||||||
|
# Load datasets
|
||||||
|
print(f"Loading training dataset ({max_train_samples} samples)...")
|
||||||
|
train_dataset = load_train_dataset(
|
||||||
|
dataset_name=train_dataset_name,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
seq_length=seq_length,
|
||||||
|
max_samples=max_train_samples,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
print(f" Train samples: {len(train_dataset)}")
|
||||||
|
|
||||||
|
print(f"Loading eval dataset...")
|
||||||
|
eval_samples, _ = load_eval_dataset(
|
||||||
|
dataset_name=eval_dataset_name,
|
||||||
|
config=eval_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
seq_length=seq_length,
|
||||||
|
)
|
||||||
|
print(f" Eval samples: {len(eval_samples)}")
|
||||||
|
|
||||||
|
# Initial evaluation
|
||||||
|
print("\n--- Initial Evaluation ---")
|
||||||
|
initial_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||||
|
initial_bpb = math.log2(initial_ppl)
|
||||||
|
print(f" Initial PPL: {initial_ppl:.2f}")
|
||||||
|
print(f" Initial BPB: {initial_bpb:.4f}")
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
# Only optimize BitLinear weights and biases (not the quantization buffers)
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
[p for p in model.parameters() if p.requires_grad],
|
||||||
|
lr=learning_rate,
|
||||||
|
weight_decay=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Learning rate scheduler
|
||||||
|
lr_scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=int(total_steps * 0.1),
|
||||||
|
num_training_steps=total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# DataLoader
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
|
loader_iter = iter(train_loader)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
results = {
|
||||||
|
"initial_ppl": initial_ppl,
|
||||||
|
"initial_bpb": initial_bpb,
|
||||||
|
"final_ppl": None,
|
||||||
|
"final_bpb": None,
|
||||||
|
"best_ppl": initial_ppl,
|
||||||
|
"best_step": 0,
|
||||||
|
"steps_completed": 0,
|
||||||
|
"training_time": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
train_loss = 0.0
|
||||||
|
steps_since_last_eval = 0
|
||||||
|
|
||||||
|
print(f"\n--- Training ---")
|
||||||
|
|
||||||
|
while step < total_steps:
|
||||||
|
# Check time budget
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed >= time_budget - 30: # Reserve 30s for final eval
|
||||||
|
print(f"\n Time budget nearly reached ({elapsed:.0f}/{time_budget}s). Stopping.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get lambda for this step
|
||||||
|
lambda_ = get_lambda(step, warmup_steps, lambda_schedule)
|
||||||
|
|
||||||
|
# Get batch
|
||||||
|
try:
|
||||||
|
batch = next(loader_iter)
|
||||||
|
except StopIteration:
|
||||||
|
loader_iter = iter(train_loader)
|
||||||
|
batch = next(loader_iter)
|
||||||
|
|
||||||
|
input_ids = batch["input_ids"].to(device)
|
||||||
|
labels = batch["labels"].to(device)
|
||||||
|
|
||||||
|
# Set lambda on all BitLinear layers for this step
|
||||||
|
for bl in get_bitlinear_modules(model):
|
||||||
|
bl.lambda_val = lambda_
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = model(input_ids=input_ids, labels=labels, use_cache=False)
|
||||||
|
|
||||||
|
loss = outputs.loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
train_loss += loss.item()
|
||||||
|
step += 1
|
||||||
|
steps_since_last_eval += 1
|
||||||
|
|
||||||
|
# Periodic logging
|
||||||
|
if step % 100 == 0:
|
||||||
|
avg_loss = train_loss / 100
|
||||||
|
avg_ppl = math.exp(avg_loss)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(f" Step {step}/{total_steps} | Loss: {avg_loss:.4f} | PPL: {avg_ppl:.2f} | λ: {lambda_:.3f} | LR: {lr_scheduler.get_last_lr()[0]:.2e} | Time: {elapsed:.0f}s")
|
||||||
|
train_loss = 0.0
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
if steps_since_last_eval >= eval_every or step == total_steps:
|
||||||
|
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||||
|
bpb = math.log2(ppl)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"\n [Eval] Step {step} | PPL: {ppl:.2f} | BPB: {bpb:.4f} | λ: {lambda_:.3f} | Time: {elapsed:.0f}s")
|
||||||
|
|
||||||
|
if ppl < results["best_ppl"]:
|
||||||
|
results["best_ppl"] = ppl
|
||||||
|
results["best_step"] = step
|
||||||
|
print(f" *** New best PPL! ***")
|
||||||
|
|
||||||
|
steps_since_last_eval = 0
|
||||||
|
|
||||||
|
# Final evaluation
|
||||||
|
print(f"\n--- Final Evaluation ---")
|
||||||
|
final_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||||
|
final_bpb = math.log2(final_ppl)
|
||||||
|
results["final_ppl"] = final_ppl
|
||||||
|
results["final_bpb"] = final_bpb
|
||||||
|
results["steps_completed"] = step
|
||||||
|
results["training_time"] = time.time() - start_time
|
||||||
|
|
||||||
|
print(f" Final PPL: {final_ppl:.2f}")
|
||||||
|
print(f" Final BPB: {final_bpb:.4f}")
|
||||||
|
print(f" Best PPL: {results['best_ppl']:.2f} (step {results['best_step']})")
|
||||||
|
print(f" Steps completed: {step}")
|
||||||
|
print(f" Training time: {results['training_time']:.0f}s")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Results Logging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def log_result(results: dict, config: dict, results_file: str = "results.tsv"):
|
||||||
|
"""Log results to TSV file."""
|
||||||
|
file_exists = os.path.exists(results_file)
|
||||||
|
|
||||||
|
with open(results_file, "a") as f:
|
||||||
|
if not file_exists:
|
||||||
|
# Write header
|
||||||
|
headers = ["timestamp", "commit", "status", "val_ppl", "val_bpb",
|
||||||
|
"initial_ppl", "best_ppl", "best_step", "steps", "time_s",
|
||||||
|
"model", "lr", "warmup_steps", "total_steps", "lambda_schedule",
|
||||||
|
"group_size", "batch_size", "seq_length", "description"]
|
||||||
|
f.write("\t".join(headers) + "\n")
|
||||||
|
|
||||||
|
row = [
|
||||||
|
time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
config.get("commit", ""),
|
||||||
|
config.get("status", "success"),
|
||||||
|
f"{results.get('final_ppl', 0):.4f}",
|
||||||
|
f"{results.get('final_bpb', 0):.4f}",
|
||||||
|
f"{results.get('initial_ppl', 0):.4f}",
|
||||||
|
f"{results.get('best_ppl', 0):.4f}",
|
||||||
|
str(results.get('best_step', 0)),
|
||||||
|
str(results.get('steps_completed', 0)),
|
||||||
|
f"{results.get('training_time', 0):.1f}",
|
||||||
|
config.get("model", ""),
|
||||||
|
str(config.get("lr", "")),
|
||||||
|
str(config.get("warmup_steps", "")),
|
||||||
|
str(config.get("total_steps", "")),
|
||||||
|
config.get("lambda_schedule", ""),
|
||||||
|
str(config.get("group_size", "")),
|
||||||
|
str(config.get("batch_size", "")),
|
||||||
|
str(config.get("seq_length", "")),
|
||||||
|
config.get("description", ""),
|
||||||
|
]
|
||||||
|
f.write("\t".join(row) + "\n")
|
||||||
|
|
||||||
|
print(f"\nResult logged to {results_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_commit() -> str:
|
||||||
|
"""Get current git commit hash."""
|
||||||
|
try:
|
||||||
|
repo = git.Repo(search_parent_directories=True)
|
||||||
|
return repo.head.object.hexsha[:8]
|
||||||
|
except:
|
||||||
|
return "nogit"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI Entry Point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Ternary QAT Training")
|
||||||
|
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name")
|
||||||
|
parser.add_argument("--train-dataset", default=DEFAULT_TRAIN_DATASET, help="Training dataset")
|
||||||
|
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Eval dataset")
|
||||||
|
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Eval dataset config")
|
||||||
|
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
||||||
|
parser.add_argument("--warmup-steps", type=int, default=1000, help="Lambda warmup steps")
|
||||||
|
parser.add_argument("--total-steps", type=int, default=5000, help="Total training steps")
|
||||||
|
parser.add_argument("--eval-every", type=int, default=500, help="Eval every N steps")
|
||||||
|
parser.add_argument("--lambda-schedule", default="linear", choices=["linear", "cosine", "exponential"], help="Lambda schedule")
|
||||||
|
parser.add_argument("--group-size", type=int, default=128, help="Quantization group size")
|
||||||
|
parser.add_argument("--max-samples", type=int, default=50000, help="Max training samples")
|
||||||
|
parser.add_argument("--device", default="auto", help="Device (auto=cuda/mps/cpu)")
|
||||||
|
parser.add_argument("--time-budget", type=float, default=300.0, help="Time budget in seconds")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||||
|
parser.add_argument("--description", default="", help="Experiment description")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Print config without training")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"commit": get_git_commit(),
|
||||||
|
"model": args.model,
|
||||||
|
"lr": args.lr,
|
||||||
|
"warmup_steps": args.warmup_steps,
|
||||||
|
"total_steps": args.total_steps,
|
||||||
|
"lambda_schedule": args.lambda_schedule,
|
||||||
|
"group_size": args.group_size,
|
||||||
|
"batch_size": args.batch_size,
|
||||||
|
"seq_length": args.seq_length,
|
||||||
|
"description": args.description or f"λ={args.lambda_schedule}-{args.warmup_steps} lr={args.lr} gs={args.group_size}",
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print("=== Configuration (dry run) ===")
|
||||||
|
for k, v in config.items():
|
||||||
|
print(f" {k}: {v}")
|
||||||
|
print(f" time_budget: {args.time_budget}s")
|
||||||
|
print(f" device: {args.device}")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = train(
|
||||||
|
model_name=args.model,
|
||||||
|
train_dataset_name=args.train_dataset,
|
||||||
|
eval_dataset_name=args.eval_dataset,
|
||||||
|
eval_config=args.eval_config,
|
||||||
|
seq_length=args.seq_length,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
learning_rate=args.lr,
|
||||||
|
warmup_steps=args.warmup_steps,
|
||||||
|
total_steps=args.total_steps,
|
||||||
|
eval_every=args.eval_every,
|
||||||
|
lambda_schedule=args.lambda_schedule,
|
||||||
|
group_size=args.group_size,
|
||||||
|
max_train_samples=args.max_samples,
|
||||||
|
device=args.device,
|
||||||
|
time_budget=args.time_budget,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
config["status"] = "success"
|
||||||
|
log_result(results, config)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n!!! Training failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
config["status"] = "failed"
|
||||||
|
config["error"] = str(e)
|
||||||
|
log_result({}, config)
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user