Initial ternary quantization framework: BitLinear, QAT training loop, eval harness
This commit is contained in:
Binary file not shown.
-122
@@ -1,122 +0,0 @@
|
||||
#!/bin/bash
|
||||
# autoresearch.sh — Autonomous experiment loop
|
||||
#
|
||||
# Usage: ./autoresearch.sh [model] [time_budget_seconds]
|
||||
#
|
||||
# This script runs the autoresearch loop:
|
||||
# 1. Agent proposes a change to train.py
|
||||
# 2. Git commit the change
|
||||
# 3. Run training for fixed time budget
|
||||
# 4. Extract val_ppl from results.tsv
|
||||
# 5. If improved → keep; if worse → git reset
|
||||
# 6. Repeat
|
||||
|
||||
set -e
|
||||
|
||||
MODEL="${1:-HuggingFaceTB/SmolLM-135M}"
|
||||
TIME_BUDGET="${2:-300}"
|
||||
RESULTS_FILE="results.tsv"
|
||||
|
||||
# Ensure git is initialized
|
||||
if [ ! -d ".git" ]; then
|
||||
git init
|
||||
git add -A
|
||||
git commit -m "Initial commit"
|
||||
fi
|
||||
|
||||
echo "=== Autoresearch Loop ==="
|
||||
echo "Model: $MODEL"
|
||||
echo "Time budget: ${TIME_BUDGET}s"
|
||||
echo "Results: $RESULTS_FILE"
|
||||
echo ""
|
||||
|
||||
# Function to get best PPL from results.tsv
|
||||
get_best_ppl() {
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "999999"
|
||||
return
|
||||
fi
|
||||
# Get the best_ppl from the last successful run (column 7)
|
||||
tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]'
|
||||
}
|
||||
|
||||
# Function to get last status
|
||||
get_last_status() {
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "none"
|
||||
return
|
||||
fi
|
||||
tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]'
|
||||
}
|
||||
|
||||
# Initial commit if not committed
|
||||
git add -A
|
||||
git commit -m "Initial setup" --allow-empty 2>/dev/null || true
|
||||
|
||||
BEST_PPL=$(get_best_ppl)
|
||||
RUN_NUM=0
|
||||
|
||||
while true; do
|
||||
RUN_NUM=$((RUN_NUM + 1))
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo "RUN #$RUN_NUM"
|
||||
echo "Current best PPL: $BEST_PPL"
|
||||
echo "========================================"
|
||||
|
||||
# Save current state
|
||||
PREV_COMMIT=$(git rev-parse HEAD)
|
||||
|
||||
# Prompt the agent to make a change
|
||||
# In production, this would call the LLM agent
|
||||
# For now, we just run with current config
|
||||
echo "Running training..."
|
||||
|
||||
# Run training
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
python3 train.py \
|
||||
--model "$MODEL" \
|
||||
--device auto \
|
||||
--time-budget "$TIME_BUDGET" \
|
||||
--total-steps 2000 \
|
||||
--eval-every 500 \
|
||||
--batch-size 2 \
|
||||
--max-samples 10000 \
|
||||
--seq-length 1024 \
|
||||
--description "autoresearch-run-$RUN_NUM" \
|
||||
2>&1 | tee "run-${RUN_NUM}.log" || true
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
ELAPSED=$((END_TIME - START_TIME))
|
||||
|
||||
# Check results
|
||||
STATUS=$(get_last_status)
|
||||
NEW_PPL=$(get_best_ppl)
|
||||
|
||||
echo ""
|
||||
echo "Run #$RUN_NUM completed in ${ELAPSED}s"
|
||||
echo "Status: $STATUS"
|
||||
echo "Best PPL: $NEW_PPL"
|
||||
|
||||
if [ "$STATUS" = "success" ]; then
|
||||
# Compare with previous best
|
||||
if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then
|
||||
echo "IMPROVED! Keeping changes."
|
||||
BEST_PPL=$NEW_PPL
|
||||
git add results.tsv
|
||||
git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL"
|
||||
else
|
||||
echo "No improvement. Reverting."
|
||||
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
||||
git checkout -- results.tsv 2>/dev/null || true
|
||||
fi
|
||||
else
|
||||
echo "FAILED. Reverting."
|
||||
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
||||
git checkout -- results.tsv 2>/dev/null || true
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Continuing... (Ctrl+C to stop)"
|
||||
done
|
||||
File diff suppressed because one or more lines are too long
+74
-294
@@ -1,324 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
|
||||
Data preparation and evaluation for ternary quantization experiments.
|
||||
READ-ONLY in the autoresearch loop — train.py is the mutable file.
|
||||
|
||||
This file is NOT modified by the autoresearch agent. All mutable training logic
|
||||
lives in train.py. This file provides:
|
||||
- Dataset loading and tokenization
|
||||
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
|
||||
- Utility functions for the training loop
|
||||
Usage:
|
||||
python prepare.py # download wikitext val shard
|
||||
python prepare.py --num-samples 500 # smaller eval set for fast iteration
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
|
||||
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
|
||||
# Alternative tiny models for ultra-fast iteration:
|
||||
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
|
||||
|
||||
# Default eval dataset
|
||||
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
|
||||
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
|
||||
|
||||
# Default training dataset
|
||||
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
|
||||
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
|
||||
|
||||
# Sequence length for training
|
||||
DEFAULT_SEQ_LENGTH = 2048
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
|
||||
"""Load tokenizer for the given model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
DATASETS_DIR = Path(__file__).parent / "data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset Wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
def prepare_eval_data(num_samples=500):
|
||||
"""Download and prepare WikiText-2 validation data for perplexity evaluation.
|
||||
|
||||
class TokenizedDataset(Dataset):
|
||||
"""Simple dataset that yields batches of tokenized text."""
|
||||
|
||||
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
|
||||
self.tokenizer = tokenizer
|
||||
self.seq_length = seq_length
|
||||
# Concatenate all texts and chunk into sequences
|
||||
self.samples = self._prepare_samples(data)
|
||||
|
||||
def _prepare_samples(self, data: list) -> list:
|
||||
"""Flatten and chunk text data into fixed-length sequences."""
|
||||
all_tokens = []
|
||||
for item in data:
|
||||
text = item.get("text", item.get("content", str(item)))
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
all_tokens.extend(tokens)
|
||||
all_tokens.append(self.tokenizer.eos_token_id)
|
||||
|
||||
# Chunk into sequences of seq_length
|
||||
samples = []
|
||||
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
|
||||
samples.append(all_tokens[i : i + self.seq_length])
|
||||
return samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = torch.tensor(self.samples[idx], dtype=torch.long)
|
||||
return {"input_ids": x[:-1], "labels": x[1:]}
|
||||
|
||||
|
||||
def load_eval_dataset(
|
||||
dataset_name: str = DEFAULT_EVAL_DATASET,
|
||||
config: str = DEFAULT_EVAL_CONFIG,
|
||||
split: str = "validation",
|
||||
tokenizer: AutoTokenizer = None,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
):
|
||||
"""Load and tokenize the evaluation dataset."""
|
||||
if dataset_name == "wikitext":
|
||||
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
|
||||
elif dataset_name == "ptb":
|
||||
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
|
||||
else:
|
||||
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = load_tokenizer()
|
||||
|
||||
# For eval, we want contiguous text without chunking artifacts
|
||||
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
|
||||
all_tokens = []
|
||||
for text in texts:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
# Remove empty tokens
|
||||
all_tokens = [t for t in all_tokens if t is not None]
|
||||
|
||||
# Create sequences
|
||||
samples = []
|
||||
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
|
||||
samples.append(all_tokens[i : i + seq_length])
|
||||
|
||||
return samples, tokenizer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evaluation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_ppl(
|
||||
model: AutoModelForCausalLM,
|
||||
eval_samples: list,
|
||||
tokenizer: AutoTokenizer,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
batch_size: int = 4,
|
||||
device: str = "cpu",
|
||||
) -> float:
|
||||
Saves tokenized data as a JSON file for fast loading during training.
|
||||
"""
|
||||
Evaluate perplexity on the given samples.
|
||||
Returns perplexity (lower is better).
|
||||
"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
for i in range(0, len(eval_samples), batch_size):
|
||||
batch_samples = eval_samples[i : i + batch_size]
|
||||
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
|
||||
labels = inputs.clone()
|
||||
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
||||
|
||||
# Shift for causal LM
|
||||
input_ids = inputs[:, :-1]
|
||||
label_ids = labels[:, 1:]
|
||||
# Load wikitext test split (validation is unreliable with streaming)
|
||||
print("Loading WikiText-2 test split...")
|
||||
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True)
|
||||
|
||||
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
|
||||
if device_type != "cpu":
|
||||
with torch.amp.autocast(device_type=device_type):
|
||||
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
||||
else:
|
||||
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
||||
|
||||
loss = outputs.loss
|
||||
total_loss += loss.item() * label_ids.size(0)
|
||||
total_tokens += label_ids.size(0)
|
||||
|
||||
avg_loss = total_loss / total_tokens
|
||||
perplexity = math.exp(avg_loss)
|
||||
|
||||
model.train()
|
||||
return perplexity
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(
|
||||
model: AutoModelForCausalLM,
|
||||
eval_samples: list,
|
||||
tokenizer: AutoTokenizer,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
batch_size: int = 4,
|
||||
device: str = "cpu",
|
||||
) -> float:
|
||||
"""
|
||||
Evaluate bits-per-byte (bpb) on the given samples.
|
||||
Returns bpb (lower is better).
|
||||
"""
|
||||
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||
bpb = math.log2(ppl)
|
||||
return bpb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training Dataset Loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_train_dataset(
|
||||
dataset_name: str = DEFAULT_TRAIN_DATASET,
|
||||
split: str = "train",
|
||||
tokenizer: AutoTokenizer = None,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
max_samples: int = None,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
Load and prepare the training dataset.
|
||||
Returns a TokenizedDataset.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = load_tokenizer()
|
||||
|
||||
if dataset_name == "HuggingFaceFW/fineweb-edu":
|
||||
# Use the edu-only subset for quality
|
||||
ds = load_dataset(
|
||||
dataset_name,
|
||||
name="sample-100BT", # 100B token subset for faster iteration
|
||||
split=split,
|
||||
streaming=streaming,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
|
||||
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
|
||||
else:
|
||||
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
|
||||
|
||||
# Collect texts (handle streaming by taking first N samples)
|
||||
# Collect text into a single corpus
|
||||
texts = []
|
||||
for idx, item in enumerate(ds):
|
||||
if max_samples and idx >= max_samples:
|
||||
for i, sample in enumerate(dataset):
|
||||
texts.append(sample["text"].strip())
|
||||
if i + 1 >= num_samples:
|
||||
break
|
||||
text = item.get("text", item.get("content", str(item)))
|
||||
texts.append(text)
|
||||
|
||||
if not texts:
|
||||
raise ValueError(f"No texts loaded from dataset {dataset_name}")
|
||||
corpus = "\n".join(texts)
|
||||
print(f"Collected {len(corpus):,} characters from {len(texts)} samples")
|
||||
|
||||
dataset = TokenizedDataset(texts, tokenizer, seq_length)
|
||||
return dataset
|
||||
# Tokenize
|
||||
print("Tokenizing...")
|
||||
tokenized = tokenizer(corpus, truncation=False)
|
||||
input_ids = tokenized["input_ids"]
|
||||
print(f"Tokenized to {len(input_ids):,} tokens")
|
||||
|
||||
# Save
|
||||
eval_path = DATASETS_DIR / "wikitext_eval.json"
|
||||
with open(eval_path, "w") as f:
|
||||
json.dump(input_ids, f)
|
||||
print(f"Saved eval data to {eval_path}")
|
||||
return eval_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model Loading
|
||||
# ---------------------------------------------------------------------------
|
||||
def prepare_train_data(num_samples=None):
|
||||
"""Prepare TinyStories training data (streaming, no download needed).
|
||||
|
||||
def load_model(
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
device: str = "cpu",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
attn_implementation: str = "sdpa",
|
||||
) -> AutoModelForCausalLM:
|
||||
Returns the dataset name and config for train.py to load on-the-fly.
|
||||
"""
|
||||
Load a pretrained model from HuggingFace.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=attn_implementation,
|
||||
trust_remote_code=True,
|
||||
device_map=device if device != "cpu" else None,
|
||||
)
|
||||
if device == "cpu":
|
||||
model = model.to(device)
|
||||
return model
|
||||
# TinyStories is loaded streaming in train.py, nothing to prepare here
|
||||
# Just verify it's accessible
|
||||
from datasets import load_dataset
|
||||
|
||||
print("Verifying TinyStories dataset access...")
|
||||
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
print(f" Sample keys: {list(sample.keys())}")
|
||||
print(f" Sample length: {len(sample['text'])} chars")
|
||||
print("TinyStories is accessible (loaded streaming, no local storage)")
|
||||
return "roneneldan/TinyStories"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI Entry Point
|
||||
# ---------------------------------------------------------------------------
|
||||
def get_vocab_size():
|
||||
"""Return the vocab size for SmolLM-135M."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
||||
return tokenizer.vocab_size
|
||||
|
||||
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
|
||||
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
|
||||
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
|
||||
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
|
||||
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
|
||||
parser.add_argument("--device", default="cpu", help="Device to run on")
|
||||
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Prepare data for ternary quantization experiments")
|
||||
parser.add_argument("--num-eval-samples", type=int, default=500, help="Number of wikitext samples for eval")
|
||||
parser.add_argument("--train", action="store_true", help="Verify training dataset access")
|
||||
parser.add_argument("--eval", action="store_true", help="Prepare eval dataset")
|
||||
parser.add_argument("--vocab", action="store_true", help="Print vocab size")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
|
||||
print(f"Seq length: {args.seq_length}")
|
||||
print(f"Batch size: {args.batch_size}")
|
||||
print(f"Device: {args.device}")
|
||||
print(f"Train samples: {args.train_samples}")
|
||||
exit(0)
|
||||
if args.vocab:
|
||||
print(f"Vocab size: {get_vocab_size()}")
|
||||
if args.eval:
|
||||
prepare_eval_data(args.num_eval_samples)
|
||||
if args.train:
|
||||
prepare_train_data()
|
||||
if not any([args.vocab, args.eval, args.train]):
|
||||
# Default: prepare everything
|
||||
prepare_eval_data(args.num_eval_samples)
|
||||
prepare_train_data()
|
||||
|
||||
print(f"Loading tokenizer for {args.model}...")
|
||||
tokenizer = load_tokenizer(args.model)
|
||||
|
||||
print(f"Loading eval dataset {args.eval_dataset}...")
|
||||
eval_samples, _ = load_eval_dataset(
|
||||
dataset_name=args.eval_dataset,
|
||||
config=args.eval_config,
|
||||
tokenizer=tokenizer,
|
||||
seq_length=args.seq_length,
|
||||
)
|
||||
print(f" Eval samples: {len(eval_samples)}")
|
||||
|
||||
print(f"Loading training dataset ({args.train_samples} samples)...")
|
||||
train_dataset = load_train_dataset(
|
||||
tokenizer=tokenizer,
|
||||
seq_length=args.seq_length,
|
||||
max_samples=args.train_samples,
|
||||
streaming=True,
|
||||
)
|
||||
print(f" Train samples: {len(train_dataset)}")
|
||||
|
||||
# Optional: load model and run eval
|
||||
if args.device != "skip-model":
|
||||
print(f"Loading model {args.model}...")
|
||||
model = load_model(args.model, device=args.device)
|
||||
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
print(f"Evaluating perplexity...")
|
||||
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
|
||||
print(f" Perplexity: {ppl:.2f}")
|
||||
bpb = math.log2(ppl)
|
||||
print(f" BPB: {bpb:.4f}")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+55
-55
@@ -1,78 +1,78 @@
|
||||
# program.md — Instructions for the Autoresearch Agent
|
||||
# Ternary Quantization — Agent Instructions (program.md)
|
||||
|
||||
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
|
||||
|
||||
## Your Goal
|
||||
## Context
|
||||
|
||||
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2.
|
||||
We are implementing a QAT (quantization-aware training) pipeline that replaces standard `nn.Linear` layers with `BitLinear` layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.
|
||||
|
||||
## File Boundaries
|
||||
## File Boundary
|
||||
|
||||
- **MUTABLE**: `train.py` — You may modify this file. It contains:
|
||||
- `BitLinear` layer (quantization logic)
|
||||
- Quantization schedules (lambda warmup)
|
||||
- Training loop
|
||||
- Hyperparameters (LR, batch size, group size, etc.)
|
||||
- **MUTABLE**: `train.py` — you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc.
|
||||
- **READ-ONLY**: `prepare.py` — data loading and tokenizer. Do not modify.
|
||||
- **READ-ONLY**: `program.md` — these instructions. Do not modify.
|
||||
|
||||
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains:
|
||||
- Dataset loading and tokenization
|
||||
- Evaluation harness (WikiText PPL)
|
||||
- Model loading utilities
|
||||
## Current State
|
||||
|
||||
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run.
|
||||
Check `results.tsv` to see previous experiment results. Each row has:
|
||||
`step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1`
|
||||
|
||||
## Experiment Protocol
|
||||
## What You Can Experiment With
|
||||
|
||||
1. Read `results.tsv` to understand what has been tried
|
||||
2. Read `train.py` to understand current implementation
|
||||
3. Propose ONE focused change to `train.py`
|
||||
4. The change will be committed and a training run will execute (~5 minutes)
|
||||
5. Results are logged to `results.tsv`
|
||||
6. If improved (lower val_ppl) → change is kept
|
||||
7. If equal or worse → git reset to previous commit
|
||||
All experiments should be in `train.py`. Focus on:
|
||||
|
||||
## What to Explore
|
||||
1. **Quantization warmup schedule**: linear, cosine, exponential, step-wise
|
||||
- `lambda_ = min(step / quant_warmup_steps, 1.0)` → try cosine: `0.5 * (1 - cos(pi * step / warmup))`
|
||||
- Try different warmup lengths: 500, 1000, 2000, 5000 steps
|
||||
|
||||
### Priority 1: Quantization Schedule
|
||||
- Lambda warmup shape: linear, cosine, exponential
|
||||
- Warmup step counts: 200, 500, 1000, 2000, 5000
|
||||
- Two-phase warmup (fast initial + slow final)
|
||||
2. **Learning rate**: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4
|
||||
- LR warmup steps: 50, 100, 200, 500
|
||||
|
||||
### Priority 2: Learning Rate
|
||||
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
|
||||
- LR schedule: constant, linear decay, cosine decay
|
||||
3. **Group size**: 64, 128, 256 (Bonsai uses 128)
|
||||
|
||||
### Priority 3: Quantization Details
|
||||
- Group size: 32, 64, 128, 256, per-tensor
|
||||
- Scale initialization: mean-based vs absmax-based
|
||||
- Ternary threshold adjustments
|
||||
4. **Activation quantization**: 8-bit vs 16-bit (no quant)
|
||||
- Try different activation quantization strategies (per-token, per-channel)
|
||||
|
||||
### Priority 4: Deadzone Recovery
|
||||
- Tequila-style reactivation (learnable lambda for deadzone weights)
|
||||
- Bias injection for zero-valued weights
|
||||
- Gradient scaling for deadzone weights
|
||||
5. **Weight quantization function**:
|
||||
- Current: `scale = abs_mean(w)` → try `scale = abs_max(w)`
|
||||
- Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
|
||||
|
||||
### Priority 5: Distillation
|
||||
- OFF loss (cosine similarity between FP and ternary features)
|
||||
- Logits distillation weight
|
||||
- Feature distillation weight
|
||||
6. **Deadzone recovery (Tequila-style)**:
|
||||
- Track fraction of weights at 0; if > 40%, try reactivation
|
||||
- Repurpose deadzone weights as dynamic biases
|
||||
|
||||
7. **Gradient clipping**: 0.5, 1.0, 2.0, 5.0
|
||||
|
||||
8. **Batch size and seq length**: trade off memory vs gradient quality
|
||||
- On M4 Pro with 24GB RAM, be conservative
|
||||
|
||||
## Constraints
|
||||
|
||||
- Keep experiments focused — ONE change per iteration
|
||||
- Always maintain working code — syntax errors waste time
|
||||
- Use SmolLM-135M or Llama-3.2-1B for fast iteration
|
||||
- Target metric: val_ppl (lower is better)
|
||||
- Time budget: 5 minutes per experiment
|
||||
- **Device**: MPS (Apple Silicon). No CUDA.
|
||||
- **Memory**: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
|
||||
- **Model**: SmolLM-135M (135M params). Don't change the model.
|
||||
- **Dataset**: TinyStories (streaming). Don't change the dataset.
|
||||
- **Eval**: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
|
||||
- **Keep it simple**: Changes should be reviewable diffs. Don't rewrite the whole file.
|
||||
|
||||
## Important Notes
|
||||
## Evaluation Metric
|
||||
|
||||
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization
|
||||
- Warmup quantization prevents catastrophic accuracy loss at the start of training
|
||||
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
|
||||
- Per-group quantization scales are essential for handling outlier weights
|
||||
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
|
||||
**Single metric**: `eval_ppl` (perplexity on WikiText-2). Lower is better.
|
||||
Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.
|
||||
|
||||
## NEVER STOP
|
||||
## Experiment Protocol
|
||||
|
||||
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision.
|
||||
1. Read `results.tsv` to understand current state
|
||||
2. Propose ONE focused change to `train.py`
|
||||
3. Run training: `python train.py --steps 100 --eval-every 25 --activation-bits 16`
|
||||
4. Check if `eval_ppl` improved
|
||||
5. If improved → keep the change
|
||||
6. If worse → revert to previous version
|
||||
7. Repeat
|
||||
|
||||
## Notes
|
||||
|
||||
- Start with `activation-bits 16` (no activation quant) to isolate weight quantization effects
|
||||
- Gradually introduce activation quantization once weight quant works well
|
||||
- The `quant=[-1:X 0:Y +1:Z]` stat shows the ternary distribution — aim for balanced, not all zeros
|
||||
- Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
|
||||
|
@@ -0,0 +1 @@
|
||||
step lambda train_loss train_ppl eval_ppl eval_bpb lr time_s best_ppl
|
||||
|
Reference in New Issue
Block a user