Initial ternary quantization framework: BitLinear, QAT training loop, eval harness

This commit is contained in:
2026-04-24 03:07:55 +02:00
parent 7378d4ef8f
commit 868910b40f
8 changed files with 626 additions and 1038 deletions
Binary file not shown.
-122
View File
@@ -1,122 +0,0 @@
#!/bin/bash
# autoresearch.sh — Autonomous experiment loop
#
# Usage: ./autoresearch.sh [model] [time_budget_seconds]
#
# This script runs the autoresearch loop:
# 1. Agent proposes a change to train.py
# 2. Git commit the change
# 3. Run training for fixed time budget
# 4. Extract val_ppl from results.tsv
# 5. If improved → keep; if worse → git reset
# 6. Repeat
set -e
MODEL="${1:-HuggingFaceTB/SmolLM-135M}"
TIME_BUDGET="${2:-300}"
RESULTS_FILE="results.tsv"
# Ensure git is initialized
if [ ! -d ".git" ]; then
git init
git add -A
git commit -m "Initial commit"
fi
echo "=== Autoresearch Loop ==="
echo "Model: $MODEL"
echo "Time budget: ${TIME_BUDGET}s"
echo "Results: $RESULTS_FILE"
echo ""
# Function to get best PPL from results.tsv
get_best_ppl() {
if [ ! -f "$RESULTS_FILE" ]; then
echo "999999"
return
fi
# Get the best_ppl from the last successful run (column 7)
tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]'
}
# Function to get last status
get_last_status() {
if [ ! -f "$RESULTS_FILE" ]; then
echo "none"
return
fi
tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]'
}
# Initial commit if not committed
git add -A
git commit -m "Initial setup" --allow-empty 2>/dev/null || true
BEST_PPL=$(get_best_ppl)
RUN_NUM=0
while true; do
RUN_NUM=$((RUN_NUM + 1))
echo ""
echo "========================================"
echo "RUN #$RUN_NUM"
echo "Current best PPL: $BEST_PPL"
echo "========================================"
# Save current state
PREV_COMMIT=$(git rev-parse HEAD)
# Prompt the agent to make a change
# In production, this would call the LLM agent
# For now, we just run with current config
echo "Running training..."
# Run training
START_TIME=$(date +%s)
python3 train.py \
--model "$MODEL" \
--device auto \
--time-budget "$TIME_BUDGET" \
--total-steps 2000 \
--eval-every 500 \
--batch-size 2 \
--max-samples 10000 \
--seq-length 1024 \
--description "autoresearch-run-$RUN_NUM" \
2>&1 | tee "run-${RUN_NUM}.log" || true
END_TIME=$(date +%s)
ELAPSED=$((END_TIME - START_TIME))
# Check results
STATUS=$(get_last_status)
NEW_PPL=$(get_best_ppl)
echo ""
echo "Run #$RUN_NUM completed in ${ELAPSED}s"
echo "Status: $STATUS"
echo "Best PPL: $NEW_PPL"
if [ "$STATUS" = "success" ]; then
# Compare with previous best
if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then
echo "IMPROVED! Keeping changes."
BEST_PPL=$NEW_PPL
git add results.tsv
git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL"
else
echo "No improvement. Reverting."
git reset --hard $PREV_COMMIT 2>/dev/null || true
git checkout -- results.tsv 2>/dev/null || true
fi
else
echo "FAILED. Reverting."
git reset --hard $PREV_COMMIT 2>/dev/null || true
git checkout -- results.tsv 2>/dev/null || true
fi
echo ""
echo "Continuing... (Ctrl+C to stop)"
done
File diff suppressed because one or more lines are too long
+74 -294
View File
@@ -1,324 +1,104 @@
#!/usr/bin/env python3
""" """
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop). Data preparation and evaluation for ternary quantization experiments.
READ-ONLY in the autoresearch loop — train.py is the mutable file.
This file is NOT modified by the autoresearch agent. All mutable training logic Usage:
lives in train.py. This file provides: python prepare.py # download wikitext val shard
- Dataset loading and tokenization python prepare.py --num-samples 500 # smaller eval set for fast iteration
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
- Utility functions for the training loop
""" """
import argparse
import json
import os import os
import math from pathlib import Path
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
# --------------------------------------------------------------------------- DATASETS_DIR = Path(__file__).parent / "data"
# Configuration
# ---------------------------------------------------------------------------
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
# Alternative tiny models for ultra-fast iteration:
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
# Default eval dataset
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
# Default training dataset
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
# Sequence length for training
DEFAULT_SEQ_LENGTH = 2048
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
"""Load tokenizer for the given model."""
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
# --------------------------------------------------------------------------- def prepare_eval_data(num_samples=500):
# Dataset Wrappers """Download and prepare WikiText-2 validation data for perplexity evaluation.
# ---------------------------------------------------------------------------
class TokenizedDataset(Dataset): Saves tokenized data as a JSON file for fast loading during training.
"""Simple dataset that yields batches of tokenized text."""
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
self.tokenizer = tokenizer
self.seq_length = seq_length
# Concatenate all texts and chunk into sequences
self.samples = self._prepare_samples(data)
def _prepare_samples(self, data: list) -> list:
"""Flatten and chunk text data into fixed-length sequences."""
all_tokens = []
for item in data:
text = item.get("text", item.get("content", str(item)))
tokens = self.tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
all_tokens.append(self.tokenizer.eos_token_id)
# Chunk into sequences of seq_length
samples = []
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
samples.append(all_tokens[i : i + self.seq_length])
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x = torch.tensor(self.samples[idx], dtype=torch.long)
return {"input_ids": x[:-1], "labels": x[1:]}
def load_eval_dataset(
dataset_name: str = DEFAULT_EVAL_DATASET,
config: str = DEFAULT_EVAL_CONFIG,
split: str = "validation",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
):
"""Load and tokenize the evaluation dataset."""
if dataset_name == "wikitext":
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
elif dataset_name == "ptb":
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
else:
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
if tokenizer is None:
tokenizer = load_tokenizer()
# For eval, we want contiguous text without chunking artifacts
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
all_tokens = []
for text in texts:
tokens = tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
# Remove empty tokens
all_tokens = [t for t in all_tokens if t is not None]
# Create sequences
samples = []
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
samples.append(all_tokens[i : i + seq_length])
return samples, tokenizer
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_ppl(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
""" """
Evaluate perplexity on the given samples. from datasets import load_dataset
Returns perplexity (lower is better). from transformers import AutoTokenizer
"""
model.eval()
total_loss = 0.0
total_tokens = 0
for i in range(0, len(eval_samples), batch_size): DATASETS_DIR.mkdir(parents=True, exist_ok=True)
batch_samples = eval_samples[i : i + batch_size] tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
labels = inputs.clone()
# Shift for causal LM # Load wikitext test split (validation is unreliable with streaming)
input_ids = inputs[:, :-1] print("Loading WikiText-2 test split...")
label_ids = labels[:, 1:] dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True)
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu") # Collect text into a single corpus
if device_type != "cpu":
with torch.amp.autocast(device_type=device_type):
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
else:
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
loss = outputs.loss
total_loss += loss.item() * label_ids.size(0)
total_tokens += label_ids.size(0)
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
model.train()
return perplexity
@torch.no_grad()
def evaluate_bpb(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
"""
Evaluate bits-per-byte (bpb) on the given samples.
Returns bpb (lower is better).
"""
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
bpb = math.log2(ppl)
return bpb
# ---------------------------------------------------------------------------
# Training Dataset Loader
# ---------------------------------------------------------------------------
def load_train_dataset(
dataset_name: str = DEFAULT_TRAIN_DATASET,
split: str = "train",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
max_samples: int = None,
streaming: bool = True,
):
"""
Load and prepare the training dataset.
Returns a TokenizedDataset.
"""
if tokenizer is None:
tokenizer = load_tokenizer()
if dataset_name == "HuggingFaceFW/fineweb-edu":
# Use the edu-only subset for quality
ds = load_dataset(
dataset_name,
name="sample-100BT", # 100B token subset for faster iteration
split=split,
streaming=streaming,
trust_remote_code=True,
)
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
else:
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
# Collect texts (handle streaming by taking first N samples)
texts = [] texts = []
for idx, item in enumerate(ds): for i, sample in enumerate(dataset):
if max_samples and idx >= max_samples: texts.append(sample["text"].strip())
if i + 1 >= num_samples:
break break
text = item.get("text", item.get("content", str(item)))
texts.append(text)
if not texts: corpus = "\n".join(texts)
raise ValueError(f"No texts loaded from dataset {dataset_name}") print(f"Collected {len(corpus):,} characters from {len(texts)} samples")
dataset = TokenizedDataset(texts, tokenizer, seq_length) # Tokenize
return dataset print("Tokenizing...")
tokenized = tokenizer(corpus, truncation=False)
input_ids = tokenized["input_ids"]
print(f"Tokenized to {len(input_ids):,} tokens")
# Save
eval_path = DATASETS_DIR / "wikitext_eval.json"
with open(eval_path, "w") as f:
json.dump(input_ids, f)
print(f"Saved eval data to {eval_path}")
return eval_path
# --------------------------------------------------------------------------- def prepare_train_data(num_samples=None):
# Model Loading """Prepare TinyStories training data (streaming, no download needed).
# ---------------------------------------------------------------------------
def load_model( Returns the dataset name and config for train.py to load on-the-fly.
model_name: str = DEFAULT_MODEL,
device: str = "cpu",
dtype: torch.dtype = torch.float16,
attn_implementation: str = "sdpa",
) -> AutoModelForCausalLM:
""" """
Load a pretrained model from HuggingFace. # TinyStories is loaded streaming in train.py, nothing to prepare here
""" # Just verify it's accessible
model = AutoModelForCausalLM.from_pretrained( from datasets import load_dataset
model_name,
torch_dtype=dtype, print("Verifying TinyStories dataset access...")
attn_implementation=attn_implementation, ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
trust_remote_code=True, sample = next(iter(ds))
device_map=device if device != "cpu" else None, print(f" Sample keys: {list(sample.keys())}")
) print(f" Sample length: {len(sample['text'])} chars")
if device == "cpu": print("TinyStories is accessible (loaded streaming, no local storage)")
model = model.to(device) return "roneneldan/TinyStories"
return model
# --------------------------------------------------------------------------- def get_vocab_size():
# CLI Entry Point """Return the vocab size for SmolLM-135M."""
# --------------------------------------------------------------------------- from transformers import AutoTokenizer
if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
import argparse return tokenizer.vocab_size
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
parser.add_argument("--device", default="cpu", help="Device to run on")
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
def main():
parser = argparse.ArgumentParser(description="Prepare data for ternary quantization experiments")
parser.add_argument("--num-eval-samples", type=int, default=500, help="Number of wikitext samples for eval")
parser.add_argument("--train", action="store_true", help="Verify training dataset access")
parser.add_argument("--eval", action="store_true", help="Prepare eval dataset")
parser.add_argument("--vocab", action="store_true", help="Print vocab size")
args = parser.parse_args() args = parser.parse_args()
if args.dry_run: if args.vocab:
print(f"Model: {args.model}") print(f"Vocab size: {get_vocab_size()}")
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})") if args.eval:
print(f"Seq length: {args.seq_length}") prepare_eval_data(args.num_eval_samples)
print(f"Batch size: {args.batch_size}") if args.train:
print(f"Device: {args.device}") prepare_train_data()
print(f"Train samples: {args.train_samples}") if not any([args.vocab, args.eval, args.train]):
exit(0) # Default: prepare everything
prepare_eval_data(args.num_eval_samples)
prepare_train_data()
print(f"Loading tokenizer for {args.model}...")
tokenizer = load_tokenizer(args.model)
print(f"Loading eval dataset {args.eval_dataset}...") if __name__ == "__main__":
eval_samples, _ = load_eval_dataset( main()
dataset_name=args.eval_dataset,
config=args.eval_config,
tokenizer=tokenizer,
seq_length=args.seq_length,
)
print(f" Eval samples: {len(eval_samples)}")
print(f"Loading training dataset ({args.train_samples} samples)...")
train_dataset = load_train_dataset(
tokenizer=tokenizer,
seq_length=args.seq_length,
max_samples=args.train_samples,
streaming=True,
)
print(f" Train samples: {len(train_dataset)}")
# Optional: load model and run eval
if args.device != "skip-model":
print(f"Loading model {args.model}...")
model = load_model(args.model, device=args.device)
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
print(f"Evaluating perplexity...")
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
print(f" Perplexity: {ppl:.2f}")
bpb = math.log2(ppl)
print(f" BPB: {bpb:.4f}")
+55 -55
View File
@@ -1,78 +1,78 @@
# program.md — Instructions for the Autoresearch Agent # Ternary Quantization — Agent Instructions (program.md)
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs. You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
## Your Goal ## Context
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2. We are implementing a QAT (quantization-aware training) pipeline that replaces standard `nn.Linear` layers with `BitLinear` layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.
## File Boundaries ## File Boundary
- **MUTABLE**: `train.py`You may modify this file. It contains: - **MUTABLE**: `train.py`you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc.
- `BitLinear` layer (quantization logic) - **READ-ONLY**: `prepare.py` — data loading and tokenizer. Do not modify.
- Quantization schedules (lambda warmup) - **READ-ONLY**: `program.md` — these instructions. Do not modify.
- Training loop
- Hyperparameters (LR, batch size, group size, etc.)
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains: ## Current State
- Dataset loading and tokenization
- Evaluation harness (WikiText PPL)
- Model loading utilities
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run. Check `results.tsv` to see previous experiment results. Each row has:
`step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1`
## Experiment Protocol ## What You Can Experiment With
1. Read `results.tsv` to understand what has been tried All experiments should be in `train.py`. Focus on:
2. Read `train.py` to understand current implementation
3. Propose ONE focused change to `train.py`
4. The change will be committed and a training run will execute (~5 minutes)
5. Results are logged to `results.tsv`
6. If improved (lower val_ppl) → change is kept
7. If equal or worse → git reset to previous commit
## What to Explore 1. **Quantization warmup schedule**: linear, cosine, exponential, step-wise
- `lambda_ = min(step / quant_warmup_steps, 1.0)` → try cosine: `0.5 * (1 - cos(pi * step / warmup))`
- Try different warmup lengths: 500, 1000, 2000, 5000 steps
### Priority 1: Quantization Schedule 2. **Learning rate**: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4
- Lambda warmup shape: linear, cosine, exponential - LR warmup steps: 50, 100, 200, 500
- Warmup step counts: 200, 500, 1000, 2000, 5000
- Two-phase warmup (fast initial + slow final)
### Priority 2: Learning Rate 3. **Group size**: 64, 128, 256 (Bonsai uses 128)
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
- LR schedule: constant, linear decay, cosine decay
### Priority 3: Quantization Details 4. **Activation quantization**: 8-bit vs 16-bit (no quant)
- Group size: 32, 64, 128, 256, per-tensor - Try different activation quantization strategies (per-token, per-channel)
- Scale initialization: mean-based vs absmax-based
- Ternary threshold adjustments
### Priority 4: Deadzone Recovery 5. **Weight quantization function**:
- Tequila-style reactivation (learnable lambda for deadzone weights) - Current: `scale = abs_mean(w)` → try `scale = abs_max(w)`
- Bias injection for zero-valued weights - Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
- Gradient scaling for deadzone weights
### Priority 5: Distillation 6. **Deadzone recovery (Tequila-style)**:
- OFF loss (cosine similarity between FP and ternary features) - Track fraction of weights at 0; if > 40%, try reactivation
- Logits distillation weight - Repurpose deadzone weights as dynamic biases
- Feature distillation weight
7. **Gradient clipping**: 0.5, 1.0, 2.0, 5.0
8. **Batch size and seq length**: trade off memory vs gradient quality
- On M4 Pro with 24GB RAM, be conservative
## Constraints ## Constraints
- Keep experiments focused — ONE change per iteration - **Device**: MPS (Apple Silicon). No CUDA.
- Always maintain working code — syntax errors waste time - **Memory**: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
- Use SmolLM-135M or Llama-3.2-1B for fast iteration - **Model**: SmolLM-135M (135M params). Don't change the model.
- Target metric: val_ppl (lower is better) - **Dataset**: TinyStories (streaming). Don't change the dataset.
- Time budget: 5 minutes per experiment - **Eval**: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
- **Keep it simple**: Changes should be reviewable diffs. Don't rewrite the whole file.
## Important Notes ## Evaluation Metric
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization **Single metric**: `eval_ppl` (perplexity on WikiText-2). Lower is better.
- Warmup quantization prevents catastrophic accuracy loss at the start of training Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
- Per-group quantization scales are essential for handling outlier weights
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
## NEVER STOP ## Experiment Protocol
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision. 1. Read `results.tsv` to understand current state
2. Propose ONE focused change to `train.py`
3. Run training: `python train.py --steps 100 --eval-every 25 --activation-bits 16`
4. Check if `eval_ppl` improved
5. If improved → keep the change
6. If worse → revert to previous version
7. Repeat
## Notes
- Start with `activation-bits 16` (no activation quant) to isolate weight quantization effects
- Gradually introduce activation quantization once weight quant works well
- The `quant=[-1:X 0:Y +1:Z]` stat shows the ternary distribution — aim for balanced, not all zeros
- Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit
-1
View File
@@ -1 +0,0 @@
timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
1 timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
+1
View File
@@ -0,0 +1 @@
step lambda train_loss train_ppl eval_ppl eval_bpb lr time_s best_ppl
1 step lambda train_loss train_ppl eval_ppl eval_bpb lr time_s best_ppl
+495 -566
View File
File diff suppressed because it is too large Load Diff