105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
"""
|
|
Data preparation and evaluation for ternary quantization experiments.
|
|
READ-ONLY in the autoresearch loop — train.py is the mutable file.
|
|
|
|
Usage:
|
|
python prepare.py # download wikitext val shard
|
|
python prepare.py --num-samples 500 # smaller eval set for fast iteration
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
DATASETS_DIR = Path(__file__).parent / "data"
|
|
|
|
|
|
def prepare_eval_data(num_samples=500):
|
|
"""Download and prepare WikiText-2 validation data for perplexity evaluation.
|
|
|
|
Saves tokenized data as a JSON file for fast loading during training.
|
|
"""
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer
|
|
|
|
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
|
|
|
# Load wikitext test split (validation is unreliable with streaming)
|
|
print("Loading WikiText-2 test split...")
|
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True)
|
|
|
|
# Collect text into a single corpus
|
|
texts = []
|
|
for i, sample in enumerate(dataset):
|
|
texts.append(sample["text"].strip())
|
|
if i + 1 >= num_samples:
|
|
break
|
|
|
|
corpus = "\n".join(texts)
|
|
print(f"Collected {len(corpus):,} characters from {len(texts)} samples")
|
|
|
|
# Tokenize
|
|
print("Tokenizing...")
|
|
tokenized = tokenizer(corpus, truncation=False)
|
|
input_ids = tokenized["input_ids"]
|
|
print(f"Tokenized to {len(input_ids):,} tokens")
|
|
|
|
# Save
|
|
eval_path = DATASETS_DIR / "wikitext_eval.json"
|
|
with open(eval_path, "w") as f:
|
|
json.dump(input_ids, f)
|
|
print(f"Saved eval data to {eval_path}")
|
|
return eval_path
|
|
|
|
|
|
def prepare_train_data(num_samples=None):
|
|
"""Prepare TinyStories training data (streaming, no download needed).
|
|
|
|
Returns the dataset name and config for train.py to load on-the-fly.
|
|
"""
|
|
# TinyStories is loaded streaming in train.py, nothing to prepare here
|
|
# Just verify it's accessible
|
|
from datasets import load_dataset
|
|
|
|
print("Verifying TinyStories dataset access...")
|
|
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
|
|
sample = next(iter(ds))
|
|
print(f" Sample keys: {list(sample.keys())}")
|
|
print(f" Sample length: {len(sample['text'])} chars")
|
|
print("TinyStories is accessible (loaded streaming, no local storage)")
|
|
return "roneneldan/TinyStories"
|
|
|
|
|
|
def get_vocab_size():
|
|
"""Return the vocab size for SmolLM-135M."""
|
|
from transformers import AutoTokenizer
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
|
return tokenizer.vocab_size
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Prepare data for ternary quantization experiments")
|
|
parser.add_argument("--num-eval-samples", type=int, default=500, help="Number of wikitext samples for eval")
|
|
parser.add_argument("--train", action="store_true", help="Verify training dataset access")
|
|
parser.add_argument("--eval", action="store_true", help="Prepare eval dataset")
|
|
parser.add_argument("--vocab", action="store_true", help="Print vocab size")
|
|
args = parser.parse_args()
|
|
|
|
if args.vocab:
|
|
print(f"Vocab size: {get_vocab_size()}")
|
|
if args.eval:
|
|
prepare_eval_data(args.num_eval_samples)
|
|
if args.train:
|
|
prepare_train_data()
|
|
if not any([args.vocab, args.eval, args.train]):
|
|
# Default: prepare everything
|
|
prepare_eval_data(args.num_eval_samples)
|
|
prepare_train_data()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|