Files
ternary_tests/program.md
T

79 lines
3.3 KiB
Markdown

# Ternary Quantization — Agent Instructions (program.md)
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
## Context
We are implementing a QAT (quantization-aware training) pipeline that replaces standard `nn.Linear` layers with `BitLinear` layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.
## File Boundary
- **MUTABLE**: `train.py` — you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc.
- **READ-ONLY**: `prepare.py` — data loading and tokenizer. Do not modify.
- **READ-ONLY**: `program.md` — these instructions. Do not modify.
## Current State
Check `results.tsv` to see previous experiment results. Each row has:
`step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1`
## What You Can Experiment With
All experiments should be in `train.py`. Focus on:
1. **Quantization warmup schedule**: linear, cosine, exponential, step-wise
- `lambda_ = min(step / quant_warmup_steps, 1.0)` → try cosine: `0.5 * (1 - cos(pi * step / warmup))`
- Try different warmup lengths: 500, 1000, 2000, 5000 steps
2. **Learning rate**: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4
- LR warmup steps: 50, 100, 200, 500
3. **Group size**: 64, 128, 256 (Bonsai uses 128)
4. **Activation quantization**: 8-bit vs 16-bit (no quant)
- Try different activation quantization strategies (per-token, per-channel)
5. **Weight quantization function**:
- Current: `scale = abs_mean(w)` → try `scale = abs_max(w)`
- Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
6. **Deadzone recovery (Tequila-style)**:
- Track fraction of weights at 0; if > 40%, try reactivation
- Repurpose deadzone weights as dynamic biases
7. **Gradient clipping**: 0.5, 1.0, 2.0, 5.0
8. **Batch size and seq length**: trade off memory vs gradient quality
- On M4 Pro with 24GB RAM, be conservative
## Constraints
- **Device**: MPS (Apple Silicon). No CUDA.
- **Memory**: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
- **Model**: SmolLM-135M (135M params). Don't change the model.
- **Dataset**: TinyStories (streaming). Don't change the dataset.
- **Eval**: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
- **Keep it simple**: Changes should be reviewable diffs. Don't rewrite the whole file.
## Evaluation Metric
**Single metric**: `eval_ppl` (perplexity on WikiText-2). Lower is better.
Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.
## Experiment Protocol
1. Read `results.tsv` to understand current state
2. Propose ONE focused change to `train.py`
3. Run training: `python train.py --steps 100 --eval-every 25 --activation-bits 16`
4. Check if `eval_ppl` improved
5. If improved → keep the change
6. If worse → revert to previous version
7. Repeat
## Notes
- Start with `activation-bits 16` (no activation quant) to isolate weight quantization effects
- Gradually introduce activation quantization once weight quant works well
- The `quant=[-1:X 0:Y +1:Z]` stat shows the ternary distribution — aim for balanced, not all zeros
- Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit