325 lines
11 KiB
Python
325 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
|
|
|
|
This file is NOT modified by the autoresearch agent. All mutable training logic
|
|
lives in train.py. This file provides:
|
|
- Dataset loading and tokenization
|
|
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
|
|
- Utility functions for the training loop
|
|
"""
|
|
|
|
import os
|
|
import math
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
from datasets import load_dataset
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
|
|
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
|
|
# Alternative tiny models for ultra-fast iteration:
|
|
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
|
|
|
|
# Default eval dataset
|
|
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
|
|
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
|
|
|
|
# Default training dataset
|
|
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
|
|
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
|
|
|
|
# Sequence length for training
|
|
DEFAULT_SEQ_LENGTH = 2048
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tokenizer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
|
|
"""Load tokenizer for the given model."""
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name,
|
|
trust_remote_code=True,
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
return tokenizer
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dataset Wrappers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TokenizedDataset(Dataset):
|
|
"""Simple dataset that yields batches of tokenized text."""
|
|
|
|
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
|
|
self.tokenizer = tokenizer
|
|
self.seq_length = seq_length
|
|
# Concatenate all texts and chunk into sequences
|
|
self.samples = self._prepare_samples(data)
|
|
|
|
def _prepare_samples(self, data: list) -> list:
|
|
"""Flatten and chunk text data into fixed-length sequences."""
|
|
all_tokens = []
|
|
for item in data:
|
|
text = item.get("text", item.get("content", str(item)))
|
|
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
|
all_tokens.extend(tokens)
|
|
all_tokens.append(self.tokenizer.eos_token_id)
|
|
|
|
# Chunk into sequences of seq_length
|
|
samples = []
|
|
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
|
|
samples.append(all_tokens[i : i + self.seq_length])
|
|
return samples
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
x = torch.tensor(self.samples[idx], dtype=torch.long)
|
|
return {"input_ids": x[:-1], "labels": x[1:]}
|
|
|
|
|
|
def load_eval_dataset(
|
|
dataset_name: str = DEFAULT_EVAL_DATASET,
|
|
config: str = DEFAULT_EVAL_CONFIG,
|
|
split: str = "validation",
|
|
tokenizer: AutoTokenizer = None,
|
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
|
):
|
|
"""Load and tokenize the evaluation dataset."""
|
|
if dataset_name == "wikitext":
|
|
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
|
|
elif dataset_name == "ptb":
|
|
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
|
|
else:
|
|
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
|
|
|
|
if tokenizer is None:
|
|
tokenizer = load_tokenizer()
|
|
|
|
# For eval, we want contiguous text without chunking artifacts
|
|
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
|
|
all_tokens = []
|
|
for text in texts:
|
|
tokens = tokenizer.encode(text, add_special_tokens=False)
|
|
all_tokens.extend(tokens)
|
|
|
|
# Remove empty tokens
|
|
all_tokens = [t for t in all_tokens if t is not None]
|
|
|
|
# Create sequences
|
|
samples = []
|
|
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
|
|
samples.append(all_tokens[i : i + seq_length])
|
|
|
|
return samples, tokenizer
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Evaluation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@torch.no_grad()
|
|
def evaluate_ppl(
|
|
model: AutoModelForCausalLM,
|
|
eval_samples: list,
|
|
tokenizer: AutoTokenizer,
|
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
|
batch_size: int = 4,
|
|
device: str = "cpu",
|
|
) -> float:
|
|
"""
|
|
Evaluate perplexity on the given samples.
|
|
Returns perplexity (lower is better).
|
|
"""
|
|
model.eval()
|
|
total_loss = 0.0
|
|
total_tokens = 0
|
|
|
|
for i in range(0, len(eval_samples), batch_size):
|
|
batch_samples = eval_samples[i : i + batch_size]
|
|
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
|
|
labels = inputs.clone()
|
|
|
|
# Shift for causal LM
|
|
input_ids = inputs[:, :-1]
|
|
label_ids = labels[:, 1:]
|
|
|
|
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
|
|
if device_type != "cpu":
|
|
with torch.amp.autocast(device_type=device_type):
|
|
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
|
else:
|
|
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
|
|
|
loss = outputs.loss
|
|
total_loss += loss.item() * label_ids.size(0)
|
|
total_tokens += label_ids.size(0)
|
|
|
|
avg_loss = total_loss / total_tokens
|
|
perplexity = math.exp(avg_loss)
|
|
|
|
model.train()
|
|
return perplexity
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate_bpb(
|
|
model: AutoModelForCausalLM,
|
|
eval_samples: list,
|
|
tokenizer: AutoTokenizer,
|
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
|
batch_size: int = 4,
|
|
device: str = "cpu",
|
|
) -> float:
|
|
"""
|
|
Evaluate bits-per-byte (bpb) on the given samples.
|
|
Returns bpb (lower is better).
|
|
"""
|
|
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
|
bpb = math.log2(ppl)
|
|
return bpb
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Training Dataset Loader
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_train_dataset(
|
|
dataset_name: str = DEFAULT_TRAIN_DATASET,
|
|
split: str = "train",
|
|
tokenizer: AutoTokenizer = None,
|
|
seq_length: int = DEFAULT_SEQ_LENGTH,
|
|
max_samples: int = None,
|
|
streaming: bool = True,
|
|
):
|
|
"""
|
|
Load and prepare the training dataset.
|
|
Returns a TokenizedDataset.
|
|
"""
|
|
if tokenizer is None:
|
|
tokenizer = load_tokenizer()
|
|
|
|
if dataset_name == "HuggingFaceFW/fineweb-edu":
|
|
# Use the edu-only subset for quality
|
|
ds = load_dataset(
|
|
dataset_name,
|
|
name="sample-100BT", # 100B token subset for faster iteration
|
|
split=split,
|
|
streaming=streaming,
|
|
trust_remote_code=True,
|
|
)
|
|
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
|
|
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
|
|
else:
|
|
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
|
|
|
|
# Collect texts (handle streaming by taking first N samples)
|
|
texts = []
|
|
for idx, item in enumerate(ds):
|
|
if max_samples and idx >= max_samples:
|
|
break
|
|
text = item.get("text", item.get("content", str(item)))
|
|
texts.append(text)
|
|
|
|
if not texts:
|
|
raise ValueError(f"No texts loaded from dataset {dataset_name}")
|
|
|
|
dataset = TokenizedDataset(texts, tokenizer, seq_length)
|
|
return dataset
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model Loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_model(
|
|
model_name: str = DEFAULT_MODEL,
|
|
device: str = "cpu",
|
|
dtype: torch.dtype = torch.float16,
|
|
attn_implementation: str = "sdpa",
|
|
) -> AutoModelForCausalLM:
|
|
"""
|
|
Load a pretrained model from HuggingFace.
|
|
"""
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype=dtype,
|
|
attn_implementation=attn_implementation,
|
|
trust_remote_code=True,
|
|
device_map=device if device != "cpu" else None,
|
|
)
|
|
if device == "cpu":
|
|
model = model.to(device)
|
|
return model
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI Entry Point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
|
|
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
|
|
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
|
|
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
|
|
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
|
|
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
|
|
parser.add_argument("--device", default="cpu", help="Device to run on")
|
|
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
|
|
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.dry_run:
|
|
print(f"Model: {args.model}")
|
|
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
|
|
print(f"Seq length: {args.seq_length}")
|
|
print(f"Batch size: {args.batch_size}")
|
|
print(f"Device: {args.device}")
|
|
print(f"Train samples: {args.train_samples}")
|
|
exit(0)
|
|
|
|
print(f"Loading tokenizer for {args.model}...")
|
|
tokenizer = load_tokenizer(args.model)
|
|
|
|
print(f"Loading eval dataset {args.eval_dataset}...")
|
|
eval_samples, _ = load_eval_dataset(
|
|
dataset_name=args.eval_dataset,
|
|
config=args.eval_config,
|
|
tokenizer=tokenizer,
|
|
seq_length=args.seq_length,
|
|
)
|
|
print(f" Eval samples: {len(eval_samples)}")
|
|
|
|
print(f"Loading training dataset ({args.train_samples} samples)...")
|
|
train_dataset = load_train_dataset(
|
|
tokenizer=tokenizer,
|
|
seq_length=args.seq_length,
|
|
max_samples=args.train_samples,
|
|
streaming=True,
|
|
)
|
|
print(f" Train samples: {len(train_dataset)}")
|
|
|
|
# Optional: load model and run eval
|
|
if args.device != "skip-model":
|
|
print(f"Loading model {args.model}...")
|
|
model = load_model(args.model, device=args.device)
|
|
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
print(f"Evaluating perplexity...")
|
|
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
|
|
print(f" Perplexity: {ppl:.2f}")
|
|
bpb = math.log2(ppl)
|
|
print(f" BPB: {bpb:.4f}")
|