Add ternary QAT training pipeline: prepare.py (data/eval), train.py (quantization/training), program.md (agent instructions), autoresearch.sh (loop)
This commit is contained in:
Executable
+122
@@ -0,0 +1,122 @@
|
||||
#!/bin/bash
|
||||
# autoresearch.sh — Autonomous experiment loop
|
||||
#
|
||||
# Usage: ./autoresearch.sh [model] [time_budget_seconds]
|
||||
#
|
||||
# This script runs the autoresearch loop:
|
||||
# 1. Agent proposes a change to train.py
|
||||
# 2. Git commit the change
|
||||
# 3. Run training for fixed time budget
|
||||
# 4. Extract val_ppl from results.tsv
|
||||
# 5. If improved → keep; if worse → git reset
|
||||
# 6. Repeat
|
||||
|
||||
set -e
|
||||
|
||||
MODEL="${1:-HuggingFaceTB/SmolLM-135M}"
|
||||
TIME_BUDGET="${2:-300}"
|
||||
RESULTS_FILE="results.tsv"
|
||||
|
||||
# Ensure git is initialized
|
||||
if [ ! -d ".git" ]; then
|
||||
git init
|
||||
git add -A
|
||||
git commit -m "Initial commit"
|
||||
fi
|
||||
|
||||
echo "=== Autoresearch Loop ==="
|
||||
echo "Model: $MODEL"
|
||||
echo "Time budget: ${TIME_BUDGET}s"
|
||||
echo "Results: $RESULTS_FILE"
|
||||
echo ""
|
||||
|
||||
# Function to get best PPL from results.tsv
|
||||
get_best_ppl() {
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "999999"
|
||||
return
|
||||
fi
|
||||
# Get the best_ppl from the last successful run (column 7)
|
||||
tail -1 "$RESULTS_FILE" | cut -f7 | tr -d '[:space:]'
|
||||
}
|
||||
|
||||
# Function to get last status
|
||||
get_last_status() {
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "none"
|
||||
return
|
||||
fi
|
||||
tail -1 "$RESULTS_FILE" | cut -f3 | tr -d '[:space:]'
|
||||
}
|
||||
|
||||
# Initial commit if not committed
|
||||
git add -A
|
||||
git commit -m "Initial setup" --allow-empty 2>/dev/null || true
|
||||
|
||||
BEST_PPL=$(get_best_ppl)
|
||||
RUN_NUM=0
|
||||
|
||||
while true; do
|
||||
RUN_NUM=$((RUN_NUM + 1))
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo "RUN #$RUN_NUM"
|
||||
echo "Current best PPL: $BEST_PPL"
|
||||
echo "========================================"
|
||||
|
||||
# Save current state
|
||||
PREV_COMMIT=$(git rev-parse HEAD)
|
||||
|
||||
# Prompt the agent to make a change
|
||||
# In production, this would call the LLM agent
|
||||
# For now, we just run with current config
|
||||
echo "Running training..."
|
||||
|
||||
# Run training
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
python3 train.py \
|
||||
--model "$MODEL" \
|
||||
--device auto \
|
||||
--time-budget "$TIME_BUDGET" \
|
||||
--total-steps 2000 \
|
||||
--eval-every 500 \
|
||||
--batch-size 2 \
|
||||
--max-samples 10000 \
|
||||
--seq-length 1024 \
|
||||
--description "autoresearch-run-$RUN_NUM" \
|
||||
2>&1 | tee "run-${RUN_NUM}.log" || true
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
ELAPSED=$((END_TIME - START_TIME))
|
||||
|
||||
# Check results
|
||||
STATUS=$(get_last_status)
|
||||
NEW_PPL=$(get_best_ppl)
|
||||
|
||||
echo ""
|
||||
echo "Run #$RUN_NUM completed in ${ELAPSED}s"
|
||||
echo "Status: $STATUS"
|
||||
echo "Best PPL: $NEW_PPL"
|
||||
|
||||
if [ "$STATUS" = "success" ]; then
|
||||
# Compare with previous best
|
||||
if echo "$NEW_PPL $BEST_PPL" | awk '{exit !($1 < $2)}'; then
|
||||
echo "IMPROVED! Keeping changes."
|
||||
BEST_PPL=$NEW_PPL
|
||||
git add results.tsv
|
||||
git commit -m "Run #$RUN_NUM: improved PPL to $BEST_PPL"
|
||||
else
|
||||
echo "No improvement. Reverting."
|
||||
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
||||
git checkout -- results.tsv 2>/dev/null || true
|
||||
fi
|
||||
else
|
||||
echo "FAILED. Reverting."
|
||||
git reset --hard $PREV_COMMIT 2>/dev/null || true
|
||||
git checkout -- results.tsv 2>/dev/null || true
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Continuing... (Ctrl+C to stop)"
|
||||
done
|
||||
+324
@@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
prepare.py — Data loading, tokenizer, and evaluation (READ-ONLY in autoresearch loop).
|
||||
|
||||
This file is NOT modified by the autoresearch agent. All mutable training logic
|
||||
lives in train.py. This file provides:
|
||||
- Dataset loading and tokenization
|
||||
- Evaluation harness (WikiText PPL, optional zero-shot tasks)
|
||||
- Utility functions for the training loop
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default model — small enough for fast iteration on M4 Pro / consumer GPU
|
||||
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
|
||||
# Alternative tiny models for ultra-fast iteration:
|
||||
# DEFAULT_MODEL = "HuggingFaceTB/SmolLM-135M"
|
||||
|
||||
# Default eval dataset
|
||||
DEFAULT_EVAL_DATASET = "wikitext" # "wikitext" or "ptb"
|
||||
DEFAULT_EVAL_CONFIG = "wikitext-2-raw-v1"
|
||||
|
||||
# Default training dataset
|
||||
DEFAULT_TRAIN_DATASET = "HuggingFaceFW/fineweb-edu" # FineWeb-edu
|
||||
# Alternative: "monology/patrickstar-core-1000" for smaller local testing
|
||||
|
||||
# Sequence length for training
|
||||
DEFAULT_SEQ_LENGTH = 2048
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_tokenizer(model_name: str = DEFAULT_MODEL) -> AutoTokenizer:
|
||||
"""Load tokenizer for the given model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset Wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TokenizedDataset(Dataset):
|
||||
"""Simple dataset that yields batches of tokenized text."""
|
||||
|
||||
def __init__(self, data: list, tokenizer: AutoTokenizer, seq_length: int = DEFAULT_SEQ_LENGTH):
|
||||
self.tokenizer = tokenizer
|
||||
self.seq_length = seq_length
|
||||
# Concatenate all texts and chunk into sequences
|
||||
self.samples = self._prepare_samples(data)
|
||||
|
||||
def _prepare_samples(self, data: list) -> list:
|
||||
"""Flatten and chunk text data into fixed-length sequences."""
|
||||
all_tokens = []
|
||||
for item in data:
|
||||
text = item.get("text", item.get("content", str(item)))
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
all_tokens.extend(tokens)
|
||||
all_tokens.append(self.tokenizer.eos_token_id)
|
||||
|
||||
# Chunk into sequences of seq_length
|
||||
samples = []
|
||||
for i in range(0, len(all_tokens) - self.seq_length + 1, self.seq_length):
|
||||
samples.append(all_tokens[i : i + self.seq_length])
|
||||
return samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = torch.tensor(self.samples[idx], dtype=torch.long)
|
||||
return {"input_ids": x[:-1], "labels": x[1:]}
|
||||
|
||||
|
||||
def load_eval_dataset(
|
||||
dataset_name: str = DEFAULT_EVAL_DATASET,
|
||||
config: str = DEFAULT_EVAL_CONFIG,
|
||||
split: str = "validation",
|
||||
tokenizer: AutoTokenizer = None,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
):
|
||||
"""Load and tokenize the evaluation dataset."""
|
||||
if dataset_name == "wikitext":
|
||||
ds = load_dataset(dataset_name, config, trust_remote_code=True)[split]
|
||||
elif dataset_name == "ptb":
|
||||
ds = load_dataset(dataset_name, trust_remote_code=True)[split]
|
||||
else:
|
||||
ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = load_tokenizer()
|
||||
|
||||
# For eval, we want contiguous text without chunking artifacts
|
||||
texts = ds["text"] if "text" in ds.column_names else [str(x) for x in ds]
|
||||
all_tokens = []
|
||||
for text in texts:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
# Remove empty tokens
|
||||
all_tokens = [t for t in all_tokens if t is not None]
|
||||
|
||||
# Create sequences
|
||||
samples = []
|
||||
for i in range(0, len(all_tokens) - seq_length + 1, seq_length):
|
||||
samples.append(all_tokens[i : i + seq_length])
|
||||
|
||||
return samples, tokenizer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evaluation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_ppl(
|
||||
model: AutoModelForCausalLM,
|
||||
eval_samples: list,
|
||||
tokenizer: AutoTokenizer,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
batch_size: int = 4,
|
||||
device: str = "cpu",
|
||||
) -> float:
|
||||
"""
|
||||
Evaluate perplexity on the given samples.
|
||||
Returns perplexity (lower is better).
|
||||
"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
|
||||
for i in range(0, len(eval_samples), batch_size):
|
||||
batch_samples = eval_samples[i : i + batch_size]
|
||||
inputs = torch.tensor(batch_samples, dtype=torch.long, device=device)
|
||||
labels = inputs.clone()
|
||||
|
||||
# Shift for causal LM
|
||||
input_ids = inputs[:, :-1]
|
||||
label_ids = labels[:, 1:]
|
||||
|
||||
device_type = "mps" if device == "mps" else ("cuda" if device == "cuda" else "cpu")
|
||||
if device_type != "cpu":
|
||||
with torch.amp.autocast(device_type=device_type):
|
||||
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
||||
else:
|
||||
outputs = model(input_ids=input_ids, labels=label_ids, use_cache=False)
|
||||
|
||||
loss = outputs.loss
|
||||
total_loss += loss.item() * label_ids.size(0)
|
||||
total_tokens += label_ids.size(0)
|
||||
|
||||
avg_loss = total_loss / total_tokens
|
||||
perplexity = math.exp(avg_loss)
|
||||
|
||||
model.train()
|
||||
return perplexity
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(
|
||||
model: AutoModelForCausalLM,
|
||||
eval_samples: list,
|
||||
tokenizer: AutoTokenizer,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
batch_size: int = 4,
|
||||
device: str = "cpu",
|
||||
) -> float:
|
||||
"""
|
||||
Evaluate bits-per-byte (bpb) on the given samples.
|
||||
Returns bpb (lower is better).
|
||||
"""
|
||||
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||
bpb = math.log2(ppl)
|
||||
return bpb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training Dataset Loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_train_dataset(
|
||||
dataset_name: str = DEFAULT_TRAIN_DATASET,
|
||||
split: str = "train",
|
||||
tokenizer: AutoTokenizer = None,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
max_samples: int = None,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
Load and prepare the training dataset.
|
||||
Returns a TokenizedDataset.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = load_tokenizer()
|
||||
|
||||
if dataset_name == "HuggingFaceFW/fineweb-edu":
|
||||
# Use the edu-only subset for quality
|
||||
ds = load_dataset(
|
||||
dataset_name,
|
||||
name="sample-100BT", # 100B token subset for faster iteration
|
||||
split=split,
|
||||
streaming=streaming,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
elif dataset_name.endswith(".jsonl") or dataset_name.endswith(".json"):
|
||||
ds = load_dataset("json", data_files=dataset_name, split=split, streaming=streaming)
|
||||
else:
|
||||
ds = load_dataset(dataset_name, split=split, streaming=streaming, trust_remote_code=True)
|
||||
|
||||
# Collect texts (handle streaming by taking first N samples)
|
||||
texts = []
|
||||
for idx, item in enumerate(ds):
|
||||
if max_samples and idx >= max_samples:
|
||||
break
|
||||
text = item.get("text", item.get("content", str(item)))
|
||||
texts.append(text)
|
||||
|
||||
if not texts:
|
||||
raise ValueError(f"No texts loaded from dataset {dataset_name}")
|
||||
|
||||
dataset = TokenizedDataset(texts, tokenizer, seq_length)
|
||||
return dataset
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model Loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_model(
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
device: str = "cpu",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
attn_implementation: str = "sdpa",
|
||||
) -> AutoModelForCausalLM:
|
||||
"""
|
||||
Load a pretrained model from HuggingFace.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=attn_implementation,
|
||||
trust_remote_code=True,
|
||||
device_map=device if device != "cpu" else None,
|
||||
)
|
||||
if device == "cpu":
|
||||
model = model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI Entry Point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Prepare and evaluate ternary quantization datasets")
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
|
||||
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Evaluation dataset")
|
||||
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Evaluation dataset config")
|
||||
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
|
||||
parser.add_argument("--batch-size", type=int, default=4, help="Evaluation batch size")
|
||||
parser.add_argument("--device", default="cpu", help="Device to run on")
|
||||
parser.add_argument("--train-samples", type=int, default=1000, help="Max training samples to load")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Just print config without loading")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Eval dataset: {args.eval_dataset} ({args.eval_config})")
|
||||
print(f"Seq length: {args.seq_length}")
|
||||
print(f"Batch size: {args.batch_size}")
|
||||
print(f"Device: {args.device}")
|
||||
print(f"Train samples: {args.train_samples}")
|
||||
exit(0)
|
||||
|
||||
print(f"Loading tokenizer for {args.model}...")
|
||||
tokenizer = load_tokenizer(args.model)
|
||||
|
||||
print(f"Loading eval dataset {args.eval_dataset}...")
|
||||
eval_samples, _ = load_eval_dataset(
|
||||
dataset_name=args.eval_dataset,
|
||||
config=args.eval_config,
|
||||
tokenizer=tokenizer,
|
||||
seq_length=args.seq_length,
|
||||
)
|
||||
print(f" Eval samples: {len(eval_samples)}")
|
||||
|
||||
print(f"Loading training dataset ({args.train_samples} samples)...")
|
||||
train_dataset = load_train_dataset(
|
||||
tokenizer=tokenizer,
|
||||
seq_length=args.seq_length,
|
||||
max_samples=args.train_samples,
|
||||
streaming=True,
|
||||
)
|
||||
print(f" Train samples: {len(train_dataset)}")
|
||||
|
||||
# Optional: load model and run eval
|
||||
if args.device != "skip-model":
|
||||
print(f"Loading model {args.model}...")
|
||||
model = load_model(args.model, device=args.device)
|
||||
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
print(f"Evaluating perplexity...")
|
||||
ppl = evaluate_ppl(model, eval_samples, tokenizer, args.seq_length, args.batch_size, args.device)
|
||||
print(f" Perplexity: {ppl:.2f}")
|
||||
bpb = math.log2(ppl)
|
||||
print(f" BPB: {bpb:.4f}")
|
||||
+78
@@ -0,0 +1,78 @@
|
||||
# program.md — Instructions for the Autoresearch Agent
|
||||
|
||||
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
|
||||
|
||||
## Your Goal
|
||||
|
||||
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2.
|
||||
|
||||
## File Boundaries
|
||||
|
||||
- **MUTABLE**: `train.py` — You may modify this file. It contains:
|
||||
- `BitLinear` layer (quantization logic)
|
||||
- Quantization schedules (lambda warmup)
|
||||
- Training loop
|
||||
- Hyperparameters (LR, batch size, group size, etc.)
|
||||
|
||||
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains:
|
||||
- Dataset loading and tokenization
|
||||
- Evaluation harness (WikiText PPL)
|
||||
- Model loading utilities
|
||||
|
||||
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run.
|
||||
|
||||
## Experiment Protocol
|
||||
|
||||
1. Read `results.tsv` to understand what has been tried
|
||||
2. Read `train.py` to understand current implementation
|
||||
3. Propose ONE focused change to `train.py`
|
||||
4. The change will be committed and a training run will execute (~5 minutes)
|
||||
5. Results are logged to `results.tsv`
|
||||
6. If improved (lower val_ppl) → change is kept
|
||||
7. If equal or worse → git reset to previous commit
|
||||
|
||||
## What to Explore
|
||||
|
||||
### Priority 1: Quantization Schedule
|
||||
- Lambda warmup shape: linear, cosine, exponential
|
||||
- Warmup step counts: 200, 500, 1000, 2000, 5000
|
||||
- Two-phase warmup (fast initial + slow final)
|
||||
|
||||
### Priority 2: Learning Rate
|
||||
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
|
||||
- LR schedule: constant, linear decay, cosine decay
|
||||
|
||||
### Priority 3: Quantization Details
|
||||
- Group size: 32, 64, 128, 256, per-tensor
|
||||
- Scale initialization: mean-based vs absmax-based
|
||||
- Ternary threshold adjustments
|
||||
|
||||
### Priority 4: Deadzone Recovery
|
||||
- Tequila-style reactivation (learnable lambda for deadzone weights)
|
||||
- Bias injection for zero-valued weights
|
||||
- Gradient scaling for deadzone weights
|
||||
|
||||
### Priority 5: Distillation
|
||||
- OFF loss (cosine similarity between FP and ternary features)
|
||||
- Logits distillation weight
|
||||
- Feature distillation weight
|
||||
|
||||
## Constraints
|
||||
|
||||
- Keep experiments focused — ONE change per iteration
|
||||
- Always maintain working code — syntax errors waste time
|
||||
- Use SmolLM-135M or Llama-3.2-1B for fast iteration
|
||||
- Target metric: val_ppl (lower is better)
|
||||
- Time budget: 5 minutes per experiment
|
||||
|
||||
## Important Notes
|
||||
|
||||
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization
|
||||
- Warmup quantization prevents catastrophic accuracy loss at the start of training
|
||||
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
|
||||
- Per-group quantization scales are essential for handling outlier weights
|
||||
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
|
||||
|
||||
## NEVER STOP
|
||||
|
||||
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision.
|
||||
@@ -0,0 +1 @@
|
||||
timestamp commit status val_ppl val_bpb initial_ppl best_ppl best_step steps time_s model lr warmup_steps total_steps lambda_schedule group_size batch_size seq_length description
|
||||
|
@@ -0,0 +1,664 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
train.py — Ternary quantization training loop (MUTABLE in autoresearch loop).
|
||||
|
||||
This is the single mutable file for the autoresearch agent. The agent may modify:
|
||||
- Quantization schedule (lambda warmup shape, length)
|
||||
- Learning rate and schedule
|
||||
- BitLinear implementation details (group size, deadzone recovery)
|
||||
- Loss function (distillation weights, OFF loss)
|
||||
- Training hyperparameters
|
||||
|
||||
DO NOT modify prepare.py — all data loading and evaluation lives there.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import csv
|
||||
import git
|
||||
import random
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
|
||||
|
||||
# Import from prepare.py (read-only)
|
||||
from prepare import (
|
||||
load_tokenizer,
|
||||
load_train_dataset,
|
||||
load_eval_dataset,
|
||||
load_model,
|
||||
evaluate_ppl,
|
||||
evaluate_bpb,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_SEQ_LENGTH,
|
||||
DEFAULT_TRAIN_DATASET,
|
||||
DEFAULT_EVAL_DATASET,
|
||||
DEFAULT_EVAL_CONFIG,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantization Utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def ternary_quantize_weight(w: torch.Tensor, group_size: int = 128) -> tuple:
|
||||
"""
|
||||
Ternary quantize weights to {-1, 0, +1} with per-group scales.
|
||||
|
||||
Quantization formula (from HF blog):
|
||||
scale = 1.0 / w.abs().mean()
|
||||
w_scaled = w * scale
|
||||
w_clamped = clamp(w_scaled, -1, 1)
|
||||
w_rounded = round(w_clamped) # -> {-1, 0, +1}
|
||||
|
||||
Returns:
|
||||
w_quant: ternary quantized weights
|
||||
scales: per-group scales for dequantization
|
||||
"""
|
||||
original_shape = w.shape
|
||||
# Flatten for group-wise processing
|
||||
w_flat = w.flatten().float()
|
||||
|
||||
num_groups = (len(w_flat) + group_size - 1) // group_size
|
||||
scales = []
|
||||
quant_parts = []
|
||||
|
||||
for i in range(num_groups):
|
||||
start = i * group_size
|
||||
end = min(start + group_size, len(w_flat))
|
||||
group = w_flat[start:end]
|
||||
|
||||
# Compute scale: 1 / mean(|w|)
|
||||
scale = 1.0 / (group.abs().mean() + 1e-6)
|
||||
scales.append(scale)
|
||||
|
||||
# Scale, clamp, round
|
||||
w_scaled = group * scale
|
||||
w_clamped = w_scaled.clamp(-1.0, 1.0)
|
||||
w_rounded = w_clamped.round().sign() # -> {-1, 0, +1}
|
||||
|
||||
quant_parts.append(w_rounded)
|
||||
|
||||
w_quant = torch.cat(quant_parts, dim=0).reshape(original_shape)
|
||||
scales = torch.tensor(scales, device=w.device, dtype=w.dtype)
|
||||
|
||||
return w_quant, scales
|
||||
|
||||
|
||||
def ternary_dequantize(w_quant: torch.Tensor, scales: torch.Tensor, group_size: int = 128) -> torch.Tensor:
|
||||
"""Dequantize ternary weights back to original shape."""
|
||||
original_shape = w_quant.shape
|
||||
w_flat = w_quant.flatten().float()
|
||||
|
||||
num_groups = (len(w_flat) + group_size - 1) // group_size
|
||||
dequant_parts = []
|
||||
|
||||
for i in range(num_groups):
|
||||
start = i * group_size
|
||||
end = min(start + group_size, len(w_flat))
|
||||
group = w_flat[start:end]
|
||||
scale = scales[i]
|
||||
dequant_parts.append(group / (scale + 1e-6))
|
||||
|
||||
return torch.cat(dequant_parts, dim=0).reshape(original_shape).to(w_quant.dtype)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BitLinear Layer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BitLinear(nn.Module):
|
||||
"""
|
||||
Linear layer with ternary weight quantization using STE (Straight-Through Estimator).
|
||||
|
||||
Forward pass uses quantized weights; backward pass uses STE to pass gradients
|
||||
through the quantization operation.
|
||||
|
||||
Supports warmup quantization: gradually introduce quantization during training
|
||||
using lambda scheduling to avoid catastrophic accuracy loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
group_size: int = 128,
|
||||
init_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.group_size = group_size
|
||||
self.init_scale = init_scale
|
||||
|
||||
# FP16 weights (the "learned" weights)
|
||||
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# Quantization state
|
||||
self.register_buffer("w_quant", torch.zeros(out_features, in_features))
|
||||
self.register_buffer("scales", torch.zeros((out_features * in_features + group_size - 1) // group_size))
|
||||
self.register_buffer("lambda", torch.tensor(0.0)) # warmup lambda
|
||||
|
||||
def _quantize(self, w: torch.Tensor) -> tuple:
|
||||
"""Quantize weights to ternary."""
|
||||
return ternary_quantize_weight(w, self.group_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with warmup quantization.
|
||||
|
||||
Lambda is set externally via self.lambda_val before each forward pass.
|
||||
|
||||
Quantization warmup formula:
|
||||
w_eff = w + lambda * (w_dequant - w).detach()
|
||||
|
||||
When lambda=0: output uses full-precision weights
|
||||
When lambda=1: output uses quantized weights (with STE gradients)
|
||||
"""
|
||||
# Get lambda (default 1.0 if not set)
|
||||
lambda_ = getattr(self, 'lambda_val', 1.0)
|
||||
|
||||
# Quantize weights
|
||||
w_quant, scales = self._quantize(self.weight)
|
||||
|
||||
# Dequantize for computation
|
||||
w_dequant = ternary_dequantize(w_quant, scales, self.group_size)
|
||||
|
||||
# STE with warmup: interpolate between FP and quantized
|
||||
# w_eff = w + lambda * (w_dequant - w)
|
||||
# Gradient flows through w (the FP weights), not through quantization
|
||||
if lambda_ > 0:
|
||||
w_eff = self.weight + lambda_ * (w_dequant - self.weight).detach()
|
||||
else:
|
||||
w_eff = self.weight
|
||||
|
||||
# Linear forward
|
||||
x = x.to(w_eff.dtype)
|
||||
out = F.linear(x, w_eff, self.bias)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"in_features={self.in_features}, "
|
||||
f"out_features={self.out_features}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"bias={self.bias is not None}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model Conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def replace_linears_with_bitlinear(
|
||||
model: AutoModelForCausalLM,
|
||||
group_size: int = 128,
|
||||
target_modules: list = None,
|
||||
) -> AutoModelForCausalLM:
|
||||
"""
|
||||
Replace all nn.Linear layers in the model with BitLinear layers.
|
||||
|
||||
Args:
|
||||
model: pretrained model
|
||||
group_size: quantization group size
|
||||
target_modules: list of module names to replace. If None, replace all.
|
||||
|
||||
Returns:
|
||||
Model with BitLinear layers
|
||||
"""
|
||||
if target_modules is None:
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
for name, module in model.named_modules():
|
||||
parent_name = ".".join(name.split(".")[:-1])
|
||||
module_name = name.split(".")[-1]
|
||||
|
||||
if isinstance(module, nn.Linear) and module_name in target_modules:
|
||||
bit_linear = BitLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
group_size=group_size,
|
||||
)
|
||||
# Initialize from pretrained weights
|
||||
bit_linear.weight.data = module.weight.data.clone()
|
||||
if module.bias is not None:
|
||||
bit_linear.bias.data = module.bias.data.clone()
|
||||
|
||||
# Replace in parent
|
||||
parent = model
|
||||
for attr in parent_name.split("."):
|
||||
parent = getattr(parent, attr)
|
||||
setattr(parent, module_name, bit_linear)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_bitlinear_modules(model: AutoModelForCausalLM) -> list:
|
||||
"""Get all BitLinear modules in the model."""
|
||||
return [m for m in model.modules() if isinstance(m, BitLinear)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lambda Warmup Schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def linear_warmup_lambda(step: int, warmup_steps: int) -> float:
|
||||
"""Linear warmup: lambda goes from 0 to 1 over warmup_steps."""
|
||||
return min(step / warmup_steps, 1.0)
|
||||
|
||||
|
||||
def cosine_warmup_lambda(step: int, warmup_steps: int) -> float:
|
||||
"""Cosine warmup: lambda follows cosine schedule over warmup_steps."""
|
||||
if step >= warmup_steps:
|
||||
return 1.0
|
||||
return 0.5 * (1 - math.cos(math.pi * step / warmup_steps))
|
||||
|
||||
|
||||
def exponential_warmup_lambda(step: int, warmup_steps: int) -> float:
|
||||
"""Exponential warmup: lambda grows exponentially over warmup_steps."""
|
||||
if step >= warmup_steps:
|
||||
return 1.0
|
||||
# Exponential from 0 to 1
|
||||
return 1.0 - math.exp(-3.0 * step / warmup_steps)
|
||||
|
||||
|
||||
def get_lambda(step: int, warmup_steps: int, schedule: str = "linear") -> float:
|
||||
"""Get lambda value for the given step and schedule."""
|
||||
if schedule == "linear":
|
||||
return linear_warmup_lambda(step, warmup_steps)
|
||||
elif schedule == "cosine":
|
||||
return cosine_warmup_lambda(step, warmup_steps)
|
||||
elif schedule == "exponential":
|
||||
return exponential_warmup_lambda(step, warmup_steps)
|
||||
else:
|
||||
raise ValueError(f"Unknown lambda schedule: {schedule}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training Loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def train(
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
train_dataset_name: str = DEFAULT_TRAIN_DATASET,
|
||||
eval_dataset_name: str = DEFAULT_EVAL_DATASET,
|
||||
eval_config: str = DEFAULT_EVAL_CONFIG,
|
||||
seq_length: int = DEFAULT_SEQ_LENGTH,
|
||||
batch_size: int = 4,
|
||||
learning_rate: float = 1e-4,
|
||||
warmup_steps: int = 1000,
|
||||
total_steps: int = 5000,
|
||||
eval_every: int = 500,
|
||||
lambda_schedule: str = "linear",
|
||||
group_size: int = 128,
|
||||
max_train_samples: int = 50000,
|
||||
device: str = "cpu",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
time_budget: float = 300.0, # seconds (5 minutes default)
|
||||
seed: int = 42,
|
||||
output_dir: str = "./results",
|
||||
):
|
||||
"""
|
||||
Main training loop for ternary quantization-aware training.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model name or path
|
||||
train_dataset_name: Training dataset name
|
||||
eval_dataset_name: Evaluation dataset name
|
||||
eval_config: Evaluation dataset config
|
||||
seq_length: Sequence length for training
|
||||
batch_size: Training batch size
|
||||
learning_rate: Base learning rate
|
||||
warmup_steps: Number of steps for lambda warmup
|
||||
total_steps: Total training steps
|
||||
eval_every: Evaluate every N steps
|
||||
lambda_schedule: Lambda warmup schedule ("linear", "cosine", "exponential")
|
||||
group_size: Quantization group size
|
||||
max_train_samples: Max training samples to load
|
||||
device: Device to train on
|
||||
dtype: Data type
|
||||
time_budget: Max training time in seconds
|
||||
seed: Random seed
|
||||
output_dir: Output directory for results
|
||||
|
||||
Returns:
|
||||
dict with training metrics
|
||||
"""
|
||||
# Auto-detect device if not specified
|
||||
if device == "auto":
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
print("Using MPS (Apple Silicon)")
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
print("Using CUDA")
|
||||
else:
|
||||
device = "cpu"
|
||||
print("Using CPU")
|
||||
|
||||
# Set seeds
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if device == "cuda":
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Track time
|
||||
start_time = time.time()
|
||||
|
||||
print(f"=== Ternary QAT Training ===")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Device: {device}")
|
||||
print(f"LR: {learning_rate}, Warmup: {warmup_steps} steps")
|
||||
print(f"Lambda schedule: {lambda_schedule}")
|
||||
print(f"Group size: {group_size}")
|
||||
print(f"Time budget: {time_budget}s")
|
||||
print(f"Total steps: {total_steps}")
|
||||
print()
|
||||
|
||||
# Load tokenizer
|
||||
print("Loading tokenizer...")
|
||||
tokenizer = load_tokenizer(model_name)
|
||||
|
||||
# Load model
|
||||
print(f"Loading model {model_name}...")
|
||||
model = load_model(model_name, device=device, dtype=dtype, attn_implementation="sdpa")
|
||||
model = model.to(dtype)
|
||||
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
print(f" Linear layers replaced: ", end="")
|
||||
|
||||
# Replace Linear with BitLinear
|
||||
model = replace_linears_with_bitlinear(model, group_size=group_size)
|
||||
bitlinears = get_bitlinear_modules(model)
|
||||
print(f"{len(bitlinears)}")
|
||||
|
||||
# Load datasets
|
||||
print(f"Loading training dataset ({max_train_samples} samples)...")
|
||||
train_dataset = load_train_dataset(
|
||||
dataset_name=train_dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
seq_length=seq_length,
|
||||
max_samples=max_train_samples,
|
||||
streaming=True,
|
||||
)
|
||||
print(f" Train samples: {len(train_dataset)}")
|
||||
|
||||
print(f"Loading eval dataset...")
|
||||
eval_samples, _ = load_eval_dataset(
|
||||
dataset_name=eval_dataset_name,
|
||||
config=eval_config,
|
||||
tokenizer=tokenizer,
|
||||
seq_length=seq_length,
|
||||
)
|
||||
print(f" Eval samples: {len(eval_samples)}")
|
||||
|
||||
# Initial evaluation
|
||||
print("\n--- Initial Evaluation ---")
|
||||
initial_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||
initial_bpb = math.log2(initial_ppl)
|
||||
print(f" Initial PPL: {initial_ppl:.2f}")
|
||||
print(f" Initial BPB: {initial_bpb:.4f}")
|
||||
|
||||
# Optimizer
|
||||
# Only optimize BitLinear weights and biases (not the quantization buffers)
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad],
|
||||
lr=learning_rate,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=int(total_steps * 0.1),
|
||||
num_training_steps=total_steps,
|
||||
)
|
||||
|
||||
# DataLoader
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||
loader_iter = iter(train_loader)
|
||||
|
||||
# Training loop
|
||||
results = {
|
||||
"initial_ppl": initial_ppl,
|
||||
"initial_bpb": initial_bpb,
|
||||
"final_ppl": None,
|
||||
"final_bpb": None,
|
||||
"best_ppl": initial_ppl,
|
||||
"best_step": 0,
|
||||
"steps_completed": 0,
|
||||
"training_time": 0,
|
||||
}
|
||||
|
||||
step = 0
|
||||
train_loss = 0.0
|
||||
steps_since_last_eval = 0
|
||||
|
||||
print(f"\n--- Training ---")
|
||||
|
||||
while step < total_steps:
|
||||
# Check time budget
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed >= time_budget - 30: # Reserve 30s for final eval
|
||||
print(f"\n Time budget nearly reached ({elapsed:.0f}/{time_budget}s). Stopping.")
|
||||
break
|
||||
|
||||
# Get lambda for this step
|
||||
lambda_ = get_lambda(step, warmup_steps, lambda_schedule)
|
||||
|
||||
# Get batch
|
||||
try:
|
||||
batch = next(loader_iter)
|
||||
except StopIteration:
|
||||
loader_iter = iter(train_loader)
|
||||
batch = next(loader_iter)
|
||||
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
|
||||
# Set lambda on all BitLinear layers for this step
|
||||
for bl in get_bitlinear_modules(model):
|
||||
bl.lambda_val = lambda_
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=input_ids, labels=labels, use_cache=False)
|
||||
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
train_loss += loss.item()
|
||||
step += 1
|
||||
steps_since_last_eval += 1
|
||||
|
||||
# Periodic logging
|
||||
if step % 100 == 0:
|
||||
avg_loss = train_loss / 100
|
||||
avg_ppl = math.exp(avg_loss)
|
||||
elapsed = time.time() - start_time
|
||||
print(f" Step {step}/{total_steps} | Loss: {avg_loss:.4f} | PPL: {avg_ppl:.2f} | λ: {lambda_:.3f} | LR: {lr_scheduler.get_last_lr()[0]:.2e} | Time: {elapsed:.0f}s")
|
||||
train_loss = 0.0
|
||||
|
||||
# Evaluation
|
||||
if steps_since_last_eval >= eval_every or step == total_steps:
|
||||
ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||
bpb = math.log2(ppl)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
print(f"\n [Eval] Step {step} | PPL: {ppl:.2f} | BPB: {bpb:.4f} | λ: {lambda_:.3f} | Time: {elapsed:.0f}s")
|
||||
|
||||
if ppl < results["best_ppl"]:
|
||||
results["best_ppl"] = ppl
|
||||
results["best_step"] = step
|
||||
print(f" *** New best PPL! ***")
|
||||
|
||||
steps_since_last_eval = 0
|
||||
|
||||
# Final evaluation
|
||||
print(f"\n--- Final Evaluation ---")
|
||||
final_ppl = evaluate_ppl(model, eval_samples, tokenizer, seq_length, batch_size, device)
|
||||
final_bpb = math.log2(final_ppl)
|
||||
results["final_ppl"] = final_ppl
|
||||
results["final_bpb"] = final_bpb
|
||||
results["steps_completed"] = step
|
||||
results["training_time"] = time.time() - start_time
|
||||
|
||||
print(f" Final PPL: {final_ppl:.2f}")
|
||||
print(f" Final BPB: {final_bpb:.4f}")
|
||||
print(f" Best PPL: {results['best_ppl']:.2f} (step {results['best_step']})")
|
||||
print(f" Steps completed: {step}")
|
||||
print(f" Training time: {results['training_time']:.0f}s")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Results Logging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def log_result(results: dict, config: dict, results_file: str = "results.tsv"):
|
||||
"""Log results to TSV file."""
|
||||
file_exists = os.path.exists(results_file)
|
||||
|
||||
with open(results_file, "a") as f:
|
||||
if not file_exists:
|
||||
# Write header
|
||||
headers = ["timestamp", "commit", "status", "val_ppl", "val_bpb",
|
||||
"initial_ppl", "best_ppl", "best_step", "steps", "time_s",
|
||||
"model", "lr", "warmup_steps", "total_steps", "lambda_schedule",
|
||||
"group_size", "batch_size", "seq_length", "description"]
|
||||
f.write("\t".join(headers) + "\n")
|
||||
|
||||
row = [
|
||||
time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
config.get("commit", ""),
|
||||
config.get("status", "success"),
|
||||
f"{results.get('final_ppl', 0):.4f}",
|
||||
f"{results.get('final_bpb', 0):.4f}",
|
||||
f"{results.get('initial_ppl', 0):.4f}",
|
||||
f"{results.get('best_ppl', 0):.4f}",
|
||||
str(results.get('best_step', 0)),
|
||||
str(results.get('steps_completed', 0)),
|
||||
f"{results.get('training_time', 0):.1f}",
|
||||
config.get("model", ""),
|
||||
str(config.get("lr", "")),
|
||||
str(config.get("warmup_steps", "")),
|
||||
str(config.get("total_steps", "")),
|
||||
config.get("lambda_schedule", ""),
|
||||
str(config.get("group_size", "")),
|
||||
str(config.get("batch_size", "")),
|
||||
str(config.get("seq_length", "")),
|
||||
config.get("description", ""),
|
||||
]
|
||||
f.write("\t".join(row) + "\n")
|
||||
|
||||
print(f"\nResult logged to {results_file}")
|
||||
|
||||
|
||||
def get_git_commit() -> str:
|
||||
"""Get current git commit hash."""
|
||||
try:
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
return repo.head.object.hexsha[:8]
|
||||
except:
|
||||
return "nogit"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI Entry Point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Ternary QAT Training")
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name")
|
||||
parser.add_argument("--train-dataset", default=DEFAULT_TRAIN_DATASET, help="Training dataset")
|
||||
parser.add_argument("--eval-dataset", default=DEFAULT_EVAL_DATASET, help="Eval dataset")
|
||||
parser.add_argument("--eval-config", default=DEFAULT_EVAL_CONFIG, help="Eval dataset config")
|
||||
parser.add_argument("--seq-length", type=int, default=DEFAULT_SEQ_LENGTH, help="Sequence length")
|
||||
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
||||
parser.add_argument("--warmup-steps", type=int, default=1000, help="Lambda warmup steps")
|
||||
parser.add_argument("--total-steps", type=int, default=5000, help="Total training steps")
|
||||
parser.add_argument("--eval-every", type=int, default=500, help="Eval every N steps")
|
||||
parser.add_argument("--lambda-schedule", default="linear", choices=["linear", "cosine", "exponential"], help="Lambda schedule")
|
||||
parser.add_argument("--group-size", type=int, default=128, help="Quantization group size")
|
||||
parser.add_argument("--max-samples", type=int, default=50000, help="Max training samples")
|
||||
parser.add_argument("--device", default="auto", help="Device (auto=cuda/mps/cpu)")
|
||||
parser.add_argument("--time-budget", type=float, default=300.0, help="Time budget in seconds")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--description", default="", help="Experiment description")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Print config without training")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = {
|
||||
"commit": get_git_commit(),
|
||||
"model": args.model,
|
||||
"lr": args.lr,
|
||||
"warmup_steps": args.warmup_steps,
|
||||
"total_steps": args.total_steps,
|
||||
"lambda_schedule": args.lambda_schedule,
|
||||
"group_size": args.group_size,
|
||||
"batch_size": args.batch_size,
|
||||
"seq_length": args.seq_length,
|
||||
"description": args.description or f"λ={args.lambda_schedule}-{args.warmup_steps} lr={args.lr} gs={args.group_size}",
|
||||
}
|
||||
|
||||
if args.dry_run:
|
||||
print("=== Configuration (dry run) ===")
|
||||
for k, v in config.items():
|
||||
print(f" {k}: {v}")
|
||||
print(f" time_budget: {args.time_budget}s")
|
||||
print(f" device: {args.device}")
|
||||
exit(0)
|
||||
|
||||
try:
|
||||
results = train(
|
||||
model_name=args.model,
|
||||
train_dataset_name=args.train_dataset,
|
||||
eval_dataset_name=args.eval_dataset,
|
||||
eval_config=args.eval_config,
|
||||
seq_length=args.seq_length,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
warmup_steps=args.warmup_steps,
|
||||
total_steps=args.total_steps,
|
||||
eval_every=args.eval_every,
|
||||
lambda_schedule=args.lambda_schedule,
|
||||
group_size=args.group_size,
|
||||
max_train_samples=args.max_samples,
|
||||
device=args.device,
|
||||
time_budget=args.time_budget,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
config["status"] = "success"
|
||||
log_result(results, config)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n!!! Training failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
config["status"] = "failed"
|
||||
config["error"] = str(e)
|
||||
log_result({}, config)
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user