Files
ternary_tests/prepare.py
T

325 lines
11 KiB
Python

#!/usr/bin/env python3
"""
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
This file is NOT modified by the autoresearch agent. All mutable training logic
lives in train.py. This file provides:
- Dataset loading and tokenization
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
- Utility functions for the training loop
"""
import os
import math
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
# Alternative tiny models for ultra-fast iteration:
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
# Default eval dataset
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
# Default training dataset
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
# Sequence length for training
DEFAULT_SEQ_LENGTH = 2048
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
"""Load tokenizer for the given model."""
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
# ---------------------------------------------------------------------------
# Dataset Wrappers
# ---------------------------------------------------------------------------
class TokenizedDataset(Dataset):
"""Simple dataset that yields batches of tokenized text."""
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
self.tokenizer = tokenizer
self.seq_length = seq_length
# Concatenate all texts and chunk into sequences
self.samples = self._prepare_samples(data)
def _prepare_samples(self, data: list) -> list:
"""Flatten and chunk text data into fixed-length sequences."""
all_tokens = []
for item in data:
text = item.get("text", item.get("content", str(item)))
tokens = self.tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
all_tokens.append(self.tokenizer.eos_token_id)
# Chunk into sequences of seq_length
samples = []
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
samples.append(all_tokens[i : i + self.seq_length])
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x = torch.tensor(self.samples[idx], dtype=torch.long)
return {"input_ids": x[:-1], "labels": x[1:]}
def load_eval_dataset(
dataset_name: str = DEFAULT_EVAL_DATASET,
config: str = DEFAULT_EVAL_CONFIG,
split: str = "validation",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
):
"""Load and tokenize the evaluation dataset."""
if dataset_name == "wikitext":
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
elif dataset_name == "ptb":
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
else:
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
if tokenizer is None:
tokenizer = load_tokenizer()
# For eval, we want contiguous text without chunking artifacts
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
all_tokens = []
for text in texts:
tokens = tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
# Remove empty tokens
all_tokens = [t for t in all_tokens if t is not None]
# Create sequences
samples = []
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
samples.append(all_tokens[i : i + seq_length])
return samples, tokenizer
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_ppl(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
"""
Evaluate perplexity on the given samples.
Returns perplexity (lower is better).
"""
model.eval()
total_loss = 0.0
total_tokens = 0
for i in range(0, len(eval_samples), batch_size):
batch_samples = eval_samples[i : i + batch_size]
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
labels = inputs.clone()
# Shift for causal LM
input_ids = inputs[:, :-1]
label_ids = labels[:, 1:]
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
if device_type != "cpu":
with torch.amp.autocast(device_type=device_type):
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
else:
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
loss = outputs.loss
total_loss += loss.item() * label_ids.size(0)
total_tokens += label_ids.size(0)
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
model.train()
return perplexity
@torch.no_grad()
def evaluate_bpb(
model: AutoModelForCausalLM,
eval_samples: list,
tokenizer: AutoTokenizer,
seq_length: int = DEFAULT_SEQ_LENGTH,
batch_size: int = 4,
device: str = "cpu",
) -> float:
"""
Evaluate bits-per-byte (bpb) on the given samples.
Returns bpb (lower is better).
"""
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
bpb = math.log2(ppl)
return bpb
# ---------------------------------------------------------------------------
# Training Dataset Loader
# ---------------------------------------------------------------------------
def load_train_dataset(
dataset_name: str = DEFAULT_TRAIN_DATASET,
split: str = "train",
tokenizer: AutoTokenizer = None,
seq_length: int = DEFAULT_SEQ_LENGTH,
max_samples: int = None,
streaming: bool = True,
):
"""
Load and prepare the training dataset.
Returns a TokenizedDataset.
"""
if tokenizer is None:
tokenizer = load_tokenizer()
if dataset_name == "HuggingFaceFW/fineweb-edu":
# Use the edu-only subset for quality
ds = load_dataset(
dataset_name,
name="sample-100BT", # 100B token subset for faster iteration
split=split,
streaming=streaming,
trust_remote_code=True,
)
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
else:
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
# Collect texts (handle streaming by taking first N samples)
texts = []
for idx, item in enumerate(ds):
if max_samples and idx >= max_samples:
break
text = item.get("text", item.get("content", str(item)))
texts.append(text)
if not texts:
raise ValueError(f"No texts loaded from dataset {dataset_name}")
dataset = TokenizedDataset(texts, tokenizer, seq_length)
return dataset
# ---------------------------------------------------------------------------
# Model Loading
# ---------------------------------------------------------------------------
def load_model(
model_name: str = DEFAULT_MODEL,
device: str = "cpu",
dtype: torch.dtype = torch.float16,
attn_implementation: str = "sdpa",
) -> AutoModelForCausalLM:
"""
Load a pretrained model from HuggingFace.
"""
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
attn_implementation=attn_implementation,
trust_remote_code=True,
device_map=device if device != "cpu" else None,
)
if device == "cpu":
model = model.to(device)
return model
# ---------------------------------------------------------------------------
# CLI Entry Point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
parser.add_argument("--device", default="cpu", help="Device to run on")
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
args = parser.parse_args()
if args.dry_run:
print(f"Model: {args.model}")
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
print(f"Seq length: {args.seq_length}")
print(f"Batch size: {args.batch_size}")
print(f"Device: {args.device}")
print(f"Train samples: {args.train_samples}")
exit(0)
print(f"Loading tokenizer for {args.model}...")
tokenizer = load_tokenizer(args.model)
print(f"Loading eval dataset {args.eval_dataset}...")
eval_samples, _ = load_eval_dataset(
dataset_name=args.eval_dataset,
config=args.eval_config,
tokenizer=tokenizer,
seq_length=args.seq_length,
)
print(f" Eval samples: {len(eval_samples)}")
print(f"Loading training dataset ({args.train_samples} samples)...")
train_dataset = load_train_dataset(
tokenizer=tokenizer,
seq_length=args.seq_length,
max_samples=args.train_samples,
streaming=True,
)
print(f" Train samples: {len(train_dataset)}")
# Optional: load model and run eval
if args.device != "skip-model":
print(f"Loading model {args.model}...")
model = load_model(args.model, device=args.device)
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
print(f"Evaluating perplexity...")
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
print(f" Perplexity: {ppl:.2f}")
bpb = math.log2(ppl)
print(f" BPB: {bpb:.4f}")