Initial ternary quantization framework: BitLinear, QAT training loop, eval harness
This commit is contained in:
Binary file not shown.
-122
@@ -1,122 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# autoresearch.sh — Autonomous experiment loop
|
|
||||||
#
|
|
||||||
# Usage: ./autoresearch.sh [model] [time_budget_seconds]
|
|
||||||
#
|
|
||||||
# This script runs the autoresearch loop:
|
|
||||||
# 1. Agent proposes a change to train.py
|
|
||||||
# 2. Git commit the change
|
|
||||||
# 3. Run training for fixed time budget
|
|
||||||
# 4. Extract val_ppl from results.tsv
|
|
||||||
# 5. If improved → keep; if worse → git reset
|
|
||||||
# 6. Repeat
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
MODEL="${1:-HuggingFaceTB/SmolLM-135M}"
|
|
||||||
TIME_BUDGET="${2:-300}"
|
|
||||||
RESULTS_FILE="results.tsv"
|
|
||||||
|
|
||||||
# Ensure git is initialized
|
|
||||||
if [ ! -d ".git" ]; then
|
|
||||||
git init
|
|
||||||
git add -A
|
|
||||||
git commit -m "Initial commit"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "=== Autoresearch Loop ==="
|
|
||||||
echo "Model: $MODEL"
|
|
||||||
echo "Time budget: ${TIME_BUDGET}s"
|
|
||||||
echo "Results: $RESULTS_FILE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Function to get best PPL from results.tsv
|
|
||||||
get_best_ppl() {
|
|
||||||
if [ ! -f "$RESULTS_FILE" ]; then
|
|
||||||
echo "999999"
|
|
||||||
return
|
|
||||||
fi
|
|
||||||
# Get the best_ppl from the last successful run (column 7)
|
|
||||||
tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Function to get last status
|
|
||||||
get_last_status() {
|
|
||||||
if [ ! -f "$RESULTS_FILE" ]; then
|
|
||||||
echo "none"
|
|
||||||
return
|
|
||||||
fi
|
|
||||||
tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initial commit if not committed
|
|
||||||
git add -A
|
|
||||||
git commit -m "Initial setup" --allow-empty 2>/dev/null || true
|
|
||||||
|
|
||||||
BEST_PPL=$(get_best_ppl)
|
|
||||||
RUN_NUM=0
|
|
||||||
|
|
||||||
while true; do
|
|
||||||
RUN_NUM=$((RUN_NUM + 1))
|
|
||||||
echo ""
|
|
||||||
echo "========================================"
|
|
||||||
echo "RUN #$RUN_NUM"
|
|
||||||
echo "Current best PPL: $BEST_PPL"
|
|
||||||
echo "========================================"
|
|
||||||
|
|
||||||
# Save current state
|
|
||||||
PREV_COMMIT=$(git rev-parse HEAD)
|
|
||||||
|
|
||||||
# Prompt the agent to make a change
|
|
||||||
# In production, this would call the LLM agent
|
|
||||||
# For now, we just run with current config
|
|
||||||
echo "Running training..."
|
|
||||||
|
|
||||||
# Run training
|
|
||||||
START_TIME=$(date +%s)
|
|
||||||
|
|
||||||
python3 train.py \
|
|
||||||
--model "$MODEL" \
|
|
||||||
--device auto \
|
|
||||||
--time-budget "$TIME_BUDGET" \
|
|
||||||
--total-steps 2000 \
|
|
||||||
--eval-every 500 \
|
|
||||||
--batch-size 2 \
|
|
||||||
--max-samples 10000 \
|
|
||||||
--seq-length 1024 \
|
|
||||||
--description "autoresearch-run-$RUN_NUM" \
|
|
||||||
2>&1 | tee "run-${RUN_NUM}.log" || true
|
|
||||||
|
|
||||||
END_TIME=$(date +%s)
|
|
||||||
ELAPSED=$((END_TIME - START_TIME))
|
|
||||||
|
|
||||||
# Check results
|
|
||||||
STATUS=$(get_last_status)
|
|
||||||
NEW_PPL=$(get_best_ppl)
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Run #$RUN_NUM completed in ${ELAPSED}s"
|
|
||||||
echo "Status: $STATUS"
|
|
||||||
echo "Best PPL: $NEW_PPL"
|
|
||||||
|
|
||||||
if [ "$STATUS" = "success" ]; then
|
|
||||||
# Compare with previous best
|
|
||||||
if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then
|
|
||||||
echo "IMPROVED! Keeping changes."
|
|
||||||
BEST_PPL=$NEW_PPL
|
|
||||||
git add results.tsv
|
|
||||||
git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL"
|
|
||||||
else
|
|
||||||
echo "No improvement. Reverting."
|
|
||||||
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
|
||||||
git checkout -- results.tsv 2>/dev/null || true
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "FAILED. Reverting."
|
|
||||||
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
|
||||||
git checkout -- results.tsv 2>/dev/null || true
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Continuing... (Ctrl+C to stop)"
|
|
||||||
done
|
|
||||||
File diff suppressed because one or more lines are too long
+88
-308
@@ -1,324 +1,104 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
"""
|
||||||
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
|
Data preparation and evaluation for ternary quantization experiments.
|
||||||
|
READ-ONLY in the autoresearch loop — train.py is the mutable file.
|
||||||
|
|
||||||
This file is NOT modified by the autoresearch agent. All mutable training logic
|
Usage:
|
||||||
lives in train.py. This file provides:
|
python prepare.py # download wikitext val shard
|
||||||
- Dataset loading and tokenization
|
python prepare.py --num-samples 500 # smaller eval set for fast iteration
|
||||||
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
|
|
||||||
- Utility functions for the training loop
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import math
|
from pathlib import Path
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
DATASETS_DIR = Path(__file__).parent / "data"
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
|
|
||||||
|
def prepare_eval_data(num_samples=500):
|
||||||
|
"""Download and prepare WikiText-2 validation data for perplexity evaluation.
|
||||||
|
|
||||||
|
Saves tokenized data as a JSON file for fast loading during training.
|
||||||
|
"""
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
||||||
|
|
||||||
|
# Load wikitext test split (validation is unreliable with streaming)
|
||||||
|
print("Loading WikiText-2 test split...")
|
||||||
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True)
|
||||||
|
|
||||||
|
# Collect text into a single corpus
|
||||||
|
texts = []
|
||||||
|
for i, sample in enumerate(dataset):
|
||||||
|
texts.append(sample["text"].strip())
|
||||||
|
if i + 1 >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
corpus = "\n".join(texts)
|
||||||
|
print(f"Collected {len(corpus):,} characters from {len(texts)} samples")
|
||||||
|
|
||||||
|
# Tokenize
|
||||||
|
print("Tokenizing...")
|
||||||
|
tokenized = tokenizer(corpus, truncation=False)
|
||||||
|
input_ids = tokenized["input_ids"]
|
||||||
|
print(f"Tokenized to {len(input_ids):,} tokens")
|
||||||
|
|
||||||
|
# Save
|
||||||
|
eval_path = DATASETS_DIR / "wikitext_eval.json"
|
||||||
|
with open(eval_path, "w") as f:
|
||||||
|
json.dump(input_ids, f)
|
||||||
|
print(f"Saved eval data to {eval_path}")
|
||||||
|
return eval_path
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_train_data(num_samples=None):
|
||||||
|
"""Prepare TinyStories training data (streaming, no download needed).
|
||||||
|
|
||||||
|
Returns the dataset name and config for train.py to load on-the-fly.
|
||||||
|
"""
|
||||||
|
# TinyStories is loaded streaming in train.py, nothing to prepare here
|
||||||
|
# Just verify it's accessible
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
print("Verifying TinyStories dataset access...")
|
||||||
# Configuration
|
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
|
||||||
# ---------------------------------------------------------------------------
|
sample = next(iter(ds))
|
||||||
|
print(f" Sample keys: {list(sample.keys())}")
|
||||||
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
|
print(f" Sample length: {len(sample['text'])} chars")
|
||||||
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
|
print("TinyStories is accessible (loaded streaming, no local storage)")
|
||||||
# Alternative tiny models for ultra-fast iteration:
|
return "roneneldan/TinyStories"
|
||||||
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
|
|
||||||
|
|
||||||
# Default eval dataset
|
|
||||||
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
|
|
||||||
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
|
|
||||||
|
|
||||||
# Default training dataset
|
|
||||||
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
|
|
||||||
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
|
|
||||||
|
|
||||||
# Sequence length for training
|
|
||||||
DEFAULT_SEQ_LENGTH = 2048
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Tokenizer
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
|
|
||||||
"""Load tokenizer for the given model."""
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
if tokenizer.pad_token is None:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
def get_vocab_size():
|
||||||
# Dataset Wrappers
|
"""Return the vocab size for SmolLM-135M."""
|
||||||
# ---------------------------------------------------------------------------
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
class TokenizedDataset(Dataset):
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
||||||
"""Simple dataset that yields batches of tokenized text."""
|
return tokenizer.vocab_size
|
||||||
|
|
||||||
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.seq_length = seq_length
|
|
||||||
# Concatenate all texts and chunk into sequences
|
|
||||||
self.samples = self._prepare_samples(data)
|
|
||||||
|
|
||||||
def _prepare_samples(self, data: list) -> list:
|
|
||||||
"""Flatten and chunk text data into fixed-length sequences."""
|
|
||||||
all_tokens = []
|
|
||||||
for item in data:
|
|
||||||
text = item.get("text", item.get("content", str(item)))
|
|
||||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
all_tokens.extend(tokens)
|
|
||||||
all_tokens.append(self.tokenizer.eos_token_id)
|
|
||||||
|
|
||||||
# Chunk into sequences of seq_length
|
|
||||||
samples = []
|
|
||||||
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
|
|
||||||
samples.append(all_tokens[i : i + self.seq_length])
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.samples)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
x = torch.tensor(self.samples[idx], dtype=torch.long)
|
|
||||||
return {"input_ids": x[:-1], "labels": x[1:]}
|
|
||||||
|
|
||||||
|
|
||||||
def load_eval_dataset(
|
def main():
|
||||||
dataset_name: str = DEFAULT_EVAL_DATASET,
|
parser = argparse.ArgumentParser(description="Prepare data for ternary quantization experiments")
|
||||||
config: str = DEFAULT_EVAL_CONFIG,
|
parser.add_argument("--num-eval-samples", type=int, default=500, help="Number of wikitext samples for eval")
|
||||||
split: str = "validation",
|
parser.add_argument("--train", action="store_true", help="Verify training dataset access")
|
||||||
tokenizer: AutoTokenizer = None,
|
parser.add_argument("--eval", action="store_true", help="Prepare eval dataset")
|
||||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
parser.add_argument("--vocab", action="store_true", help="Print vocab size")
|
||||||
):
|
|
||||||
"""Load and tokenize the evaluation dataset."""
|
|
||||||
if dataset_name == "wikitext":
|
|
||||||
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
|
|
||||||
elif dataset_name == "ptb":
|
|
||||||
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
|
|
||||||
else:
|
|
||||||
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
|
|
||||||
|
|
||||||
if tokenizer is None:
|
|
||||||
tokenizer = load_tokenizer()
|
|
||||||
|
|
||||||
# For eval, we want contiguous text without chunking artifacts
|
|
||||||
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
|
|
||||||
all_tokens = []
|
|
||||||
for text in texts:
|
|
||||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
all_tokens.extend(tokens)
|
|
||||||
|
|
||||||
# Remove empty tokens
|
|
||||||
all_tokens = [t for t in all_tokens if t is not None]
|
|
||||||
|
|
||||||
# Create sequences
|
|
||||||
samples = []
|
|
||||||
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
|
|
||||||
samples.append(all_tokens[i : i + seq_length])
|
|
||||||
|
|
||||||
return samples, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Evaluation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def evaluate_ppl(
|
|
||||||
model: AutoModelForCausalLM,
|
|
||||||
eval_samples: list,
|
|
||||||
tokenizer: AutoTokenizer,
|
|
||||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
|
||||||
batch_size: int = 4,
|
|
||||||
device: str = "cpu",
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Evaluate perplexity on the given samples.
|
|
||||||
Returns perplexity (lower is better).
|
|
||||||
"""
|
|
||||||
model.eval()
|
|
||||||
total_loss = 0.0
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
for i in range(0, len(eval_samples), batch_size):
|
|
||||||
batch_samples = eval_samples[i : i + batch_size]
|
|
||||||
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
|
|
||||||
labels = inputs.clone()
|
|
||||||
|
|
||||||
# Shift for causal LM
|
|
||||||
input_ids = inputs[:, :-1]
|
|
||||||
label_ids = labels[:, 1:]
|
|
||||||
|
|
||||||
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
|
|
||||||
if device_type != "cpu":
|
|
||||||
with torch.amp.autocast(device_type=device_type):
|
|
||||||
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
|
||||||
else:
|
|
||||||
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
|
||||||
|
|
||||||
loss = outputs.loss
|
|
||||||
total_loss += loss.item() * label_ids.size(0)
|
|
||||||
total_tokens += label_ids.size(0)
|
|
||||||
|
|
||||||
avg_loss = total_loss / total_tokens
|
|
||||||
perplexity = math.exp(avg_loss)
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
return perplexity
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def evaluate_bpb(
|
|
||||||
model: AutoModelForCausalLM,
|
|
||||||
eval_samples: list,
|
|
||||||
tokenizer: AutoTokenizer,
|
|
||||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
|
||||||
batch_size: int = 4,
|
|
||||||
device: str = "cpu",
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Evaluate bits-per-byte (bpb) on the given samples.
|
|
||||||
Returns bpb (lower is better).
|
|
||||||
"""
|
|
||||||
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
|
||||||
bpb = math.log2(ppl)
|
|
||||||
return bpb
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Training Dataset Loader
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def load_train_dataset(
|
|
||||||
dataset_name: str = DEFAULT_TRAIN_DATASET,
|
|
||||||
split: str = "train",
|
|
||||||
tokenizer: AutoTokenizer = None,
|
|
||||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
|
||||||
max_samples: int = None,
|
|
||||||
streaming: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Load and prepare the training dataset.
|
|
||||||
Returns a TokenizedDataset.
|
|
||||||
"""
|
|
||||||
if tokenizer is None:
|
|
||||||
tokenizer = load_tokenizer()
|
|
||||||
|
|
||||||
if dataset_name == "HuggingFaceFW/fineweb-edu":
|
|
||||||
# Use the edu-only subset for quality
|
|
||||||
ds = load_dataset(
|
|
||||||
dataset_name,
|
|
||||||
name="sample-100BT", # 100B token subset for faster iteration
|
|
||||||
split=split,
|
|
||||||
streaming=streaming,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
|
|
||||||
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
|
|
||||||
else:
|
|
||||||
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
|
|
||||||
|
|
||||||
# Collect texts (handle streaming by taking first N samples)
|
|
||||||
texts = []
|
|
||||||
for idx, item in enumerate(ds):
|
|
||||||
if max_samples and idx >= max_samples:
|
|
||||||
break
|
|
||||||
text = item.get("text", item.get("content", str(item)))
|
|
||||||
texts.append(text)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
raise ValueError(f"No texts loaded from dataset {dataset_name}")
|
|
||||||
|
|
||||||
dataset = TokenizedDataset(texts, tokenizer, seq_length)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Model Loading
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
model_name: str = DEFAULT_MODEL,
|
|
||||||
device: str = "cpu",
|
|
||||||
dtype: torch.dtype = torch.float16,
|
|
||||||
attn_implementation: str = "sdpa",
|
|
||||||
) -> AutoModelForCausalLM:
|
|
||||||
"""
|
|
||||||
Load a pretrained model from HuggingFace.
|
|
||||||
"""
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
trust_remote_code=True,
|
|
||||||
device_map=device if device != "cpu" else None,
|
|
||||||
)
|
|
||||||
if device == "cpu":
|
|
||||||
model = model.to(device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# CLI Entry Point
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
|
|
||||||
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
|
|
||||||
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
|
|
||||||
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
|
|
||||||
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
|
|
||||||
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
|
|
||||||
parser.add_argument("--device", default="cpu", help="Device to run on")
|
|
||||||
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
|
|
||||||
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.dry_run:
|
if args.vocab:
|
||||||
print(f"Model: {args.model}")
|
print(f"Vocab size: {get_vocab_size()}")
|
||||||
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
|
if args.eval:
|
||||||
print(f"Seq length: {args.seq_length}")
|
prepare_eval_data(args.num_eval_samples)
|
||||||
print(f"Batch size: {args.batch_size}")
|
if args.train:
|
||||||
print(f"Device: {args.device}")
|
prepare_train_data()
|
||||||
print(f"Train samples: {args.train_samples}")
|
if not any([args.vocab, args.eval, args.train]):
|
||||||
exit(0)
|
# Default: prepare everything
|
||||||
|
prepare_eval_data(args.num_eval_samples)
|
||||||
|
prepare_train_data()
|
||||||
|
|
||||||
print(f"Loading tokenizer for {args.model}...")
|
|
||||||
tokenizer = load_tokenizer(args.model)
|
|
||||||
|
|
||||||
print(f"Loading eval dataset {args.eval_dataset}...")
|
if __name__ == "__main__":
|
||||||
eval_samples, _ = load_eval_dataset(
|
main()
|
||||||
dataset_name=args.eval_dataset,
|
|
||||||
config=args.eval_config,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
)
|
|
||||||
print(f" Eval samples: {len(eval_samples)}")
|
|
||||||
|
|
||||||
print(f"Loading training dataset ({args.train_samples} samples)...")
|
|
||||||
train_dataset = load_train_dataset(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
max_samples=args.train_samples,
|
|
||||||
streaming=True,
|
|
||||||
)
|
|
||||||
print(f" Train samples: {len(train_dataset)}")
|
|
||||||
|
|
||||||
# Optional: load model and run eval
|
|
||||||
if args.device != "skip-model":
|
|
||||||
print(f"Loading model {args.model}...")
|
|
||||||
model = load_model(args.model, device=args.device)
|
|
||||||
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
|
|
||||||
|
|
||||||
print(f"Evaluating perplexity...")
|
|
||||||
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
|
|
||||||
print(f" Perplexity: {ppl:.2f}")
|
|
||||||
bpb = math.log2(ppl)
|
|
||||||
print(f" BPB: {bpb:.4f}")
|
|
||||||
|
|||||||
+55
-55
@@ -1,78 +1,78 @@
|
|||||||
# program.md — Instructions for the Autoresearch Agent
|
# Ternary Quantization — Agent Instructions (program.md)
|
||||||
|
|
||||||
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
|
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
|
||||||
|
|
||||||
## Your Goal
|
## Context
|
||||||
|
|
||||||
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2.
|
We are implementing a QAT (quantization-aware training) pipeline that replaces standard `nn.Linear` layers with `BitLinear` layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.
|
||||||
|
|
||||||
## File Boundaries
|
## File Boundary
|
||||||
|
|
||||||
- **MUTABLE**: `train.py` — You may modify this file. It contains:
|
- **MUTABLE**: `train.py` — you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc.
|
||||||
- `BitLinear` layer (quantization logic)
|
- **READ-ONLY**: `prepare.py` — data loading and tokenizer. Do not modify.
|
||||||
- Quantization schedules (lambda warmup)
|
- **READ-ONLY**: `program.md` — these instructions. Do not modify.
|
||||||
- Training loop
|
|
||||||
- Hyperparameters (LR, batch size, group size, etc.)
|
|
||||||
|
|
||||||
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains:
|
## Current State
|
||||||
- Dataset loading and tokenization
|
|
||||||
- Evaluation harness (WikiText PPL)
|
|
||||||
- Model loading utilities
|
|
||||||
|
|
||||||
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run.
|
Check `results.tsv` to see previous experiment results. Each row has:
|
||||||
|
`step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1`
|
||||||
|
|
||||||
## Experiment Protocol
|
## What You Can Experiment With
|
||||||
|
|
||||||
1. Read `results.tsv` to understand what has been tried
|
All experiments should be in `train.py`. Focus on:
|
||||||
2. Read `train.py` to understand current implementation
|
|
||||||
3. Propose ONE focused change to `train.py`
|
|
||||||
4. The change will be committed and a training run will execute (~5 minutes)
|
|
||||||
5. Results are logged to `results.tsv`
|
|
||||||
6. If improved (lower val_ppl) → change is kept
|
|
||||||
7. If equal or worse → git reset to previous commit
|
|
||||||
|
|
||||||
## What to Explore
|
1. **Quantization warmup schedule**: linear, cosine, exponential, step-wise
|
||||||
|
- `lambda_ = min(step / quant_warmup_steps, 1.0)` → try cosine: `0.5 * (1 - cos(pi * step / warmup))`
|
||||||
|
- Try different warmup lengths: 500, 1000, 2000, 5000 steps
|
||||||
|
|
||||||
### Priority 1: Quantization Schedule
|
2. **Learning rate**: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4
|
||||||
- Lambda warmup shape: linear, cosine, exponential
|
- LR warmup steps: 50, 100, 200, 500
|
||||||
- Warmup step counts: 200, 500, 1000, 2000, 5000
|
|
||||||
- Two-phase warmup (fast initial + slow final)
|
|
||||||
|
|
||||||
### Priority 2: Learning Rate
|
3. **Group size**: 64, 128, 256 (Bonsai uses 128)
|
||||||
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
|
|
||||||
- LR schedule: constant, linear decay, cosine decay
|
|
||||||
|
|
||||||
### Priority 3: Quantization Details
|
4. **Activation quantization**: 8-bit vs 16-bit (no quant)
|
||||||
- Group size: 32, 64, 128, 256, per-tensor
|
- Try different activation quantization strategies (per-token, per-channel)
|
||||||
- Scale initialization: mean-based vs absmax-based
|
|
||||||
- Ternary threshold adjustments
|
|
||||||
|
|
||||||
### Priority 4: Deadzone Recovery
|
5. **Weight quantization function**:
|
||||||
- Tequila-style reactivation (learnable lambda for deadzone weights)
|
- Current: `scale = abs_mean(w)` → try `scale = abs_max(w)`
|
||||||
- Bias injection for zero-valued weights
|
- Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
|
||||||
- Gradient scaling for deadzone weights
|
|
||||||
|
|
||||||
### Priority 5: Distillation
|
6. **Deadzone recovery (Tequila-style)**:
|
||||||
- OFF loss (cosine similarity between FP and ternary features)
|
- Track fraction of weights at 0; if > 40%, try reactivation
|
||||||
- Logits distillation weight
|
- Repurpose deadzone weights as dynamic biases
|
||||||
- Feature distillation weight
|
|
||||||
|
7. **Gradient clipping**: 0.5, 1.0, 2.0, 5.0
|
||||||
|
|
||||||
|
8. **Batch size and seq length**: trade off memory vs gradient quality
|
||||||
|
- On M4 Pro with 24GB RAM, be conservative
|
||||||
|
|
||||||
## Constraints
|
## Constraints
|
||||||
|
|
||||||
- Keep experiments focused — ONE change per iteration
|
- **Device**: MPS (Apple Silicon). No CUDA.
|
||||||
- Always maintain working code — syntax errors waste time
|
- **Memory**: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
|
||||||
- Use SmolLM-135M or Llama-3.2-1B for fast iteration
|
- **Model**: SmolLM-135M (135M params). Don't change the model.
|
||||||
- Target metric: val_ppl (lower is better)
|
- **Dataset**: TinyStories (streaming). Don't change the dataset.
|
||||||
- Time budget: 5 minutes per experiment
|
- **Eval**: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
|
||||||
|
- **Keep it simple**: Changes should be reviewable diffs. Don't rewrite the whole file.
|
||||||
|
|
||||||
## Important Notes
|
## Evaluation Metric
|
||||||
|
|
||||||
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization
|
**Single metric**: `eval_ppl` (perplexity on WikiText-2). Lower is better.
|
||||||
- Warmup quantization prevents catastrophic accuracy loss at the start of training
|
Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.
|
||||||
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
|
|
||||||
- Per-group quantization scales are essential for handling outlier weights
|
|
||||||
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
|
|
||||||
|
|
||||||
## NEVER STOP
|
## Experiment Protocol
|
||||||
|
|
||||||
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision.
|
1. Read `results.tsv` to understand current state
|
||||||
|
2. Propose ONE focused change to `train.py`
|
||||||
|
3. Run training: `python train.py --steps 100 --eval-every 25 --activation-bits 16`
|
||||||
|
4. Check if `eval_ppl` improved
|
||||||
|
5. If improved → keep the change
|
||||||
|
6. If worse → revert to previous version
|
||||||
|
7. Repeat
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Start with `activation-bits 16` (no activation quant) to isolate weight quantization effects
|
||||||
|
- Gradually introduce activation quantization once weight quant works well
|
||||||
|
- The `quant=[-1:X 0:Y +1:Z]` stat shows the ternary distribution — aim for balanced, not all zeros
|
||||||
|
- Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
|
|
||||||
|
@@ -0,0 +1 @@
|
|||||||
|
step lambda train_loss train_ppl eval_ppl eval_bpb lr time_s best_ppl
|
||||||
|
Reference in New Issue
Block a user