Files

105 lines
3.5 KiB
Python

"""
Data preparation and evaluation for ternary quantization experiments.
READ-ONLY in the autoresearch loop — train.py is the mutable file.
Usage:
python prepare.py # download wikitext val shard
python prepare.py --num-samples 500 # smaller eval set for fast iteration
"""
import argparse
import json
import os
from pathlib import Path
DATASETS_DIR = Path(__file__).parent / "data"
def prepare_eval_data(num_samples=500):
"""Download and prepare WikiText-2 validation data for perplexity evaluation.
Saves tokenized data as a JSON file for fast loading during training.
"""
from datasets import load_dataset
from transformers import AutoTokenizer
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
# Load wikitext test split (validation is unreliable with streaming)
print("Loading WikiText-2 test split...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True)
# Collect text into a single corpus
texts = []
for i, sample in enumerate(dataset):
texts.append(sample["text"].strip())
if i + 1 >= num_samples:
break
corpus = "\n".join(texts)
print(f"Collected {len(corpus):,} characters from {len(texts)} samples")
# Tokenize
print("Tokenizing...")
tokenized = tokenizer(corpus, truncation=False)
input_ids = tokenized["input_ids"]
print(f"Tokenized to {len(input_ids):,} tokens")
# Save
eval_path = DATASETS_DIR / "wikitext_eval.json"
with open(eval_path, "w") as f:
json.dump(input_ids, f)
print(f"Saved eval data to {eval_path}")
return eval_path
def prepare_train_data(num_samples=None):
"""Prepare TinyStories training data (streaming, no download needed).
Returns the dataset name and config for train.py to load on-the-fly.
"""
# TinyStories is loaded streaming in train.py, nothing to prepare here
# Just verify it's accessible
from datasets import load_dataset
print("Verifying TinyStories dataset access...")
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
sample = next(iter(ds))
print(f" Sample keys: {list(sample.keys())}")
print(f" Sample length: {len(sample['text'])} chars")
print("TinyStories is accessible (loaded streaming, no local storage)")
return "roneneldan/TinyStories"
def get_vocab_size():
"""Return the vocab size for SmolLM-135M."""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
return tokenizer.vocab_size
def main():
parser = argparse.ArgumentParser(description="Prepare data for ternary quantization experiments")
parser.add_argument("--num-eval-samples", type=int, default=500, help="Number of wikitext samples for eval")
parser.add_argument("--train", action="store_true", help="Verify training dataset access")
parser.add_argument("--eval", action="store_true", help="Prepare eval dataset")
parser.add_argument("--vocab", action="store_true", help="Print vocab size")
args = parser.parse_args()
if args.vocab:
print(f"Vocab size: {get_vocab_size()}")
if args.eval:
prepare_eval_data(args.num_eval_samples)
if args.train:
prepare_train_data()
if not any([args.vocab, args.eval, args.train]):
# Default: prepare everything
prepare_eval_data(args.num_eval_samples)
prepare_train_data()
if __name__ == "__main__":
main()