Initial ternary quantization framework: BitLinear, QAT training loop, eval harness

This commit is contained in:
2026-04-24 03:07:55 +02:00
parent 7378d4ef8f
commit 868910b40f
8 changed files with 626 additions and 1038 deletions
Binary file not shown.
-122
View File
@@ -1,122 +0,0 @@
#!/bin/bash
# autoresearch.sh — Autonomous experiment loop
#
# Usage: ./autoresearch.sh [model] [time_budget_seconds]
#
# This script runs the autoresearch loop:
# 1. Agent proposes a change to train.py
# 2. Git commit the change
# 3. Run training for fixed time budget
# 4. Extract val_ppl from results.tsv
# 5. If improved → keep; if worse → git reset
# 6. Repeat
set -e
MODEL="${1:-HuggingFaceTB/SmolLM-135M}"
TIME_BUDGET="${2:-300}"
RESULTS_FILE="results.tsv"
# Ensure git is initialized
if [ ! -d ".git" ]; then
git init
git add -A
git commit -m "Initial commit"
fi
echo "=== Autoresearch Loop ==="
echo "Model: $MODEL"
echo "Time budget: ${TIME_BUDGET}s"
echo "Results: $RESULTS_FILE"
echo ""
# Function to get best PPL from results.tsv
get_best_ppl() {
if [ ! -f "$RESULTS_FILE" ]; then
echo "999999"
return
fi
# Get the best_ppl from the last successful run (column 7)
tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]'
}
# Function to get last status
get_last_status() {
if [ ! -f "$RESULTS_FILE" ]; then
echo "none"
return
fi
tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]'
}
# Initial commit if not committed
git add -A
git commit -m "Initial setup" --allow-empty 2>/dev/null || true
BEST_PPL=$(get_best_ppl)
RUN_NUM=0
while true; do
RUN_NUM=$((RUN_NUM + 1))
echo ""
echo "========================================"
echo "RUN #$RUN_NUM"
echo "Current best PPL: $BEST_PPL"
echo "========================================"
# Save current state
PREV_COMMIT=$(git rev-parse HEAD)
# Prompt the agent to make a change
# In production, this would call the LLM agent
# For now, we just run with current config
echo "Running training..."
# Run training
START_TIME=$(date +%s)
python3 train.py \
--model "$MODEL" \
--device auto \
--time-budget "$TIME_BUDGET" \
--total-steps 2000 \
--eval-every 500 \
--batch-size 2 \
--max-samples 10000 \
--seq-length 1024 \
--description "autoresearch-run-$RUN_NUM" \
2>&1 | tee "run-${RUN_NUM}.log" || true
END_TIME=$(date +%s)
ELAPSED=$((END_TIME - START_TIME))
# Check results
STATUS=$(get_last_status)
NEW_PPL=$(get_best_ppl)
echo ""
echo "Run #$RUN_NUM completed in ${ELAPSED}s"
echo "Status: $STATUS"
echo "Best PPL: $NEW_PPL"
if [ "$STATUS" = "success" ]; then
# Compare with previous best
if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then
echo "IMPROVED! Keeping changes."
BEST_PPL=$NEW_PPL
git add results.tsv
git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL"
else
echo "No improvement. Reverting."
git reset --hard $PREV_COMMIT 2>/dev/null || true
git checkout -- results.tsv 2>/dev/null || true
fi
else
echo "FAILED. Reverting."
git reset --hard $PREV_COMMIT 2>/dev/null || true
git checkout -- results.tsv 2>/dev/null || true
fi
echo ""
echo "Continuing... (Ctrl+C to stop)"
done
File diff suppressed because one or more lines are too long
+74 -294
View File
@@ -1,324 +1,104 @@
#!/usr/bin/env python3
"""
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
Data preparation and evaluation for ternary quantization experiments.
READ-ONLY in the autoresearch loop — train.py is the mutable file.
This file is NOT modified by the autoresearch agent. All mutable training logic
lives in train.py. This file provides:
- Dataset loading and tokenization
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
- Utility functions for the training loop
Usage:
python prepare.py # download wikitext val shard
python prepare.py --num-samples 500 # smaller eval set for fast iteration
"""
import argparse
import json
import os
import math
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from pathlib import Path
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
# Alternative tiny models for ultra-fast iteration:
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
# Default eval dataset
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
# Default training dataset
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
# Sequence length for training
DEFAULT_SEQ_LENGTH = 2048
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
"""Load tokenizer for the given model."""
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
DATASETS_DIR = Path(__file__).parent / "data"
# ---------------------------------------------------------------------------
# Dataset Wrappers
# ---------------------------------------------------------------------------
def prepare_eval_data(num_samples=500):
"""Download and prepare WikiText-2 validation data for perplexity evaluation.
class TokenizedDataset(Dataset):
"""Simple dataset that yields batches of tokenized text."""
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
self.tokenizer = tokenizer
self.seq_length = seq_length
# Concatenate all texts and chunk into sequences
self.samples = self._prepare_samples(data)
def _prepare_samples(self, data: list) -> list:
"""Flatten and chunk text data into fixed-length sequences."""
all_tokens = []
for item in data:
text = item.get("text", item.get("content", str(item)))
tokens = self.tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
all_tokens.append(self.tokenizer.eos_token_id)
# Chunk into sequences of seq_length
samples = []
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
samples.append(all_tokens[i : i + self.seq_length])
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x = torch.tensor(self.samples[idx], dtype=torch.long)
return {"input_ids": x[:-1], "labels": x[1:]}
def load_eval_dataset(
dataset_name: str = DEFAULT_EVAL_DATASET,
config: str = DEFAULT_EVAL_CONFIG,
split: str = "validation",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
):
"""Load and tokenize the evaluation dataset."""
if dataset_name == "wikitext":
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
elif dataset_name == "ptb":
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
else:
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
if tokenizer is None:
tokenizer = load_tokenizer()
# For eval, we want contiguous text without chunking artifacts
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
all_tokens = []
for text in texts:
tokens = tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
# Remove empty tokens
all_tokens = [t for t in all_tokens if t is not None]
# Create sequences
samples = []
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
samples.append(all_tokens[i : i + seq_length])
return samples, tokenizer
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_ppl(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
Saves tokenized data as a JSON file for fast loading during training.
"""
Evaluate perplexity on the given samples.
Returns perplexity (lower is better).
"""
model.eval()
total_loss = 0.0
total_tokens = 0
from datasets import load_dataset
from transformers import AutoTokenizer
for i in range(0, len(eval_samples), batch_size):
batch_samples = eval_samples[i : i + batch_size]
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
labels = inputs.clone()
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
# Shift for causal LM
input_ids = inputs[:, :-1]
label_ids = labels[:, 1:]
# Load wikitext test split (validation is unreliable with streaming)
print("Loading WikiText-2 test split...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True)
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
if device_type != "cpu":
with torch.amp.autocast(device_type=device_type):
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
else:
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
loss = outputs.loss
total_loss += loss.item() * label_ids.size(0)
total_tokens += label_ids.size(0)
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
model.train()
return perplexity
@torch.no_grad()
def evaluate_bpb(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
"""
Evaluate bits-per-byte (bpb) on the given samples.
Returns bpb (lower is better).
"""
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
bpb = math.log2(ppl)
return bpb
# ---------------------------------------------------------------------------
# Training Dataset Loader
# ---------------------------------------------------------------------------
def load_train_dataset(
dataset_name: str = DEFAULT_TRAIN_DATASET,
split: str = "train",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
max_samples: int = None,
streaming: bool = True,
):
"""
Load and prepare the training dataset.
Returns a TokenizedDataset.
"""
if tokenizer is None:
tokenizer = load_tokenizer()
if dataset_name == "HuggingFaceFW/fineweb-edu":
# Use the edu-only subset for quality
ds = load_dataset(
dataset_name,
name="sample-100BT", # 100B token subset for faster iteration
split=split,
streaming=streaming,
trust_remote_code=True,
)
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
else:
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
# Collect texts (handle streaming by taking first N samples)
# Collect text into a single corpus
texts = []
for idx, item in enumerate(ds):
if max_samples and idx >= max_samples:
for i, sample in enumerate(dataset):
texts.append(sample["text"].strip())
if i + 1 >= num_samples:
break
text = item.get("text", item.get("content", str(item)))
texts.append(text)
if not texts:
raise ValueError(f"No texts loaded from dataset {dataset_name}")
corpus = "\n".join(texts)
print(f"Collected {len(corpus):,} characters from {len(texts)} samples")
dataset = TokenizedDataset(texts, tokenizer, seq_length)
return dataset
# Tokenize
print("Tokenizing...")
tokenized = tokenizer(corpus, truncation=False)
input_ids = tokenized["input_ids"]
print(f"Tokenized to {len(input_ids):,} tokens")
# Save
eval_path = DATASETS_DIR / "wikitext_eval.json"
with open(eval_path, "w") as f:
json.dump(input_ids, f)
print(f"Saved eval data to {eval_path}")
return eval_path
# ---------------------------------------------------------------------------
# Model Loading
# ---------------------------------------------------------------------------
def prepare_train_data(num_samples=None):
"""Prepare TinyStories training data (streaming, no download needed).
def load_model(
model_name: str = DEFAULT_MODEL,
device: str = "cpu",
dtype: torch.dtype = torch.float16,
attn_implementation: str = "sdpa",
) -> AutoModelForCausalLM:
Returns the dataset name and config for train.py to load on-the-fly.
"""
Load a pretrained model from HuggingFace.
"""
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
attn_implementation=attn_implementation,
trust_remote_code=True,
device_map=device if device != "cpu" else None,
)
if device == "cpu":
model = model.to(device)
return model
# TinyStories is loaded streaming in train.py, nothing to prepare here
# Just verify it's accessible
from datasets import load_dataset
print("Verifying TinyStories dataset access...")
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
sample = next(iter(ds))
print(f" Sample keys: {list(sample.keys())}")
print(f" Sample length: {len(sample['text'])} chars")
print("TinyStories is accessible (loaded streaming, no local storage)")
return "roneneldan/TinyStories"
# ---------------------------------------------------------------------------
# CLI Entry Point
# ---------------------------------------------------------------------------
def get_vocab_size():
"""Return the vocab size for SmolLM-135M."""
from transformers import AutoTokenizer
if __name__ == "__main__":
import argparse
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
return tokenizer.vocab_size
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
parser.add_argument("--device", default="cpu", help="Device to run on")
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
def main():
parser = argparse.ArgumentParser(description="Prepare data for ternary quantization experiments")
parser.add_argument("--num-eval-samples", type=int, default=500, help="Number of wikitext samples for eval")
parser.add_argument("--train", action="store_true", help="Verify training dataset access")
parser.add_argument("--eval", action="store_true", help="Prepare eval dataset")
parser.add_argument("--vocab", action="store_true", help="Print vocab size")
args = parser.parse_args()
if args.dry_run:
print(f"Model: {args.model}")
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
print(f"Seq length: {args.seq_length}")
print(f"Batch size: {args.batch_size}")
print(f"Device: {args.device}")
print(f"Train samples: {args.train_samples}")
exit(0)
if args.vocab:
print(f"Vocab size: {get_vocab_size()}")
if args.eval:
prepare_eval_data(args.num_eval_samples)
if args.train:
prepare_train_data()
if not any([args.vocab, args.eval, args.train]):
# Default: prepare everything
prepare_eval_data(args.num_eval_samples)
prepare_train_data()
print(f"Loading tokenizer for {args.model}...")
tokenizer = load_tokenizer(args.model)
print(f"Loading eval dataset {args.eval_dataset}...")
eval_samples, _ = load_eval_dataset(
dataset_name=args.eval_dataset,
config=args.eval_config,
tokenizer=tokenizer,
seq_length=args.seq_length,
)
print(f" Eval samples: {len(eval_samples)}")
print(f"Loading training dataset ({args.train_samples} samples)...")
train_dataset = load_train_dataset(
tokenizer=tokenizer,
seq_length=args.seq_length,
max_samples=args.train_samples,
streaming=True,
)
print(f" Train samples: {len(train_dataset)}")
# Optional: load model and run eval
if args.device != "skip-model":
print(f"Loading model {args.model}...")
model = load_model(args.model, device=args.device)
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
print(f"Evaluating perplexity...")
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
print(f" Perplexity: {ppl:.2f}")
bpb = math.log2(ppl)
print(f" BPB: {bpb:.4f}")
if __name__ == "__main__":
main()
+55 -55
View File
@@ -1,78 +1,78 @@
# program.md — Instructions for the Autoresearch Agent
# Ternary Quantization — Agent Instructions (program.md)
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
## Your Goal
## Context
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2.
We are implementing a QAT (quantization-aware training) pipeline that replaces standard `nn.Linear` layers with `BitLinear` layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.
## File Boundaries
## File Boundary
- **MUTABLE**: `train.py`You may modify this file. It contains:
- `BitLinear` layer (quantization logic)
- Quantization schedules (lambda warmup)
- Training loop
- Hyperparameters (LR, batch size, group size, etc.)
- **MUTABLE**: `train.py`you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc.
- **READ-ONLY**: `prepare.py` — data loading and tokenizer. Do not modify.
- **READ-ONLY**: `program.md` — these instructions. Do not modify.
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains:
- Dataset loading and tokenization
- Evaluation harness (WikiText PPL)
- Model loading utilities
## Current State
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run.
Check `results.tsv` to see previous experiment results. Each row has:
`step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1`
## Experiment Protocol
## What You Can Experiment With
1. Read `results.tsv` to understand what has been tried
2. Read `train.py` to understand current implementation
3. Propose ONE focused change to `train.py`
4. The change will be committed and a training run will execute (~5 minutes)
5. Results are logged to `results.tsv`
6. If improved (lower val_ppl) → change is kept
7. If equal or worse → git reset to previous commit
All experiments should be in `train.py`. Focus on:
## What to Explore
1. **Quantization warmup schedule**: linear, cosine, exponential, step-wise
- `lambda_ = min(step / quant_warmup_steps, 1.0)` → try cosine: `0.5 * (1 - cos(pi * step / warmup))`
- Try different warmup lengths: 500, 1000, 2000, 5000 steps
### Priority 1: Quantization Schedule
- Lambda warmup shape: linear, cosine, exponential
- Warmup step counts: 200, 500, 1000, 2000, 5000
- Two-phase warmup (fast initial + slow final)
2. **Learning rate**: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4
- LR warmup steps: 50, 100, 200, 500
### Priority 2: Learning Rate
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
- LR schedule: constant, linear decay, cosine decay
3. **Group size**: 64, 128, 256 (Bonsai uses 128)
### Priority 3: Quantization Details
- Group size: 32, 64, 128, 256, per-tensor
- Scale initialization: mean-based vs absmax-based
- Ternary threshold adjustments
4. **Activation quantization**: 8-bit vs 16-bit (no quant)
- Try different activation quantization strategies (per-token, per-channel)
### Priority 4: Deadzone Recovery
- Tequila-style reactivation (learnable lambda for deadzone weights)
- Bias injection for zero-valued weights
- Gradient scaling for deadzone weights
5. **Weight quantization function**:
- Current: `scale = abs_mean(w)` → try `scale = abs_max(w)`
- Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
### Priority 5: Distillation
- OFF loss (cosine similarity between FP and ternary features)
- Logits distillation weight
- Feature distillation weight
6. **Deadzone recovery (Tequila-style)**:
- Track fraction of weights at 0; if > 40%, try reactivation
- Repurpose deadzone weights as dynamic biases
7. **Gradient clipping**: 0.5, 1.0, 2.0, 5.0
8. **Batch size and seq length**: trade off memory vs gradient quality
- On M4 Pro with 24GB RAM, be conservative
## Constraints
- Keep experiments focused — ONE change per iteration
- Always maintain working code — syntax errors waste time
- Use SmolLM-135M or Llama-3.2-1B for fast iteration
- Target metric: val_ppl (lower is better)
- Time budget: 5 minutes per experiment
- **Device**: MPS (Apple Silicon). No CUDA.
- **Memory**: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
- **Model**: SmolLM-135M (135M params). Don't change the model.
- **Dataset**: TinyStories (streaming). Don't change the dataset.
- **Eval**: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
- **Keep it simple**: Changes should be reviewable diffs. Don't rewrite the whole file.
## Important Notes
## Evaluation Metric
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization
- Warmup quantization prevents catastrophic accuracy loss at the start of training
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
- Per-group quantization scales are essential for handling outlier weights
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
**Single metric**: `eval_ppl` (perplexity on WikiText-2). Lower is better.
Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.
## NEVER STOP
## Experiment Protocol
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision.
1. Read `results.tsv` to understand current state
2. Propose ONE focused change to `train.py`
3. Run training: `python train.py --steps 100 --eval-every 25 --activation-bits 16`
4. Check if `eval_ppl` improved
5. If improved → keep the change
6. If worse → revert to previous version
7. Repeat
## Notes
- Start with `activation-bits 16` (no activation quant) to isolate weight quantization effects
- Gradually introduce activation quantization once weight quant works well
- The `quant=[-1:X 0:Y +1:Z]` stat shows the ternary distribution — aim for balanced, not all zeros
- Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit
-1
View File
@@ -1 +0,0 @@
timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
1 timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
+1
View File
@@ -0,0 +1 @@
step lambda train_loss train_ppl eval_ppl eval_bpb lr time_s best_ppl
1 step lambda train_loss train_ppl eval_ppl eval_bpb lr time_s best_ppl
+464 -535
View File
File diff suppressed because it is too large Load Diff