79 lines
3.3 KiB
Markdown
79 lines
3.3 KiB
Markdown
# Ternary Quantization — Agent Instructions (program.md)
|
|
|
|
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
|
|
|
|
## Context
|
|
|
|
We are implementing a QAT (quantization-aware training) pipeline that replaces standard `nn.Linear` layers with `BitLinear` layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.
|
|
|
|
## File Boundary
|
|
|
|
- **MUTABLE**: `train.py` — you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc.
|
|
- **READ-ONLY**: `prepare.py` — data loading and tokenizer. Do not modify.
|
|
- **READ-ONLY**: `program.md` — these instructions. Do not modify.
|
|
|
|
## Current State
|
|
|
|
Check `results.tsv` to see previous experiment results. Each row has:
|
|
`step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1`
|
|
|
|
## What You Can Experiment With
|
|
|
|
All experiments should be in `train.py`. Focus on:
|
|
|
|
1. **Quantization warmup schedule**: linear, cosine, exponential, step-wise
|
|
- `lambda_ = min(step / quant_warmup_steps, 1.0)` → try cosine: `0.5 * (1 - cos(pi * step / warmup))`
|
|
- Try different warmup lengths: 500, 1000, 2000, 5000 steps
|
|
|
|
2. **Learning rate**: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4
|
|
- LR warmup steps: 50, 100, 200, 500
|
|
|
|
3. **Group size**: 64, 128, 256 (Bonsai uses 128)
|
|
|
|
4. **Activation quantization**: 8-bit vs 16-bit (no quant)
|
|
- Try different activation quantization strategies (per-token, per-channel)
|
|
|
|
5. **Weight quantization function**:
|
|
- Current: `scale = abs_mean(w)` → try `scale = abs_max(w)`
|
|
- Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
|
|
|
|
6. **Deadzone recovery (Tequila-style)**:
|
|
- Track fraction of weights at 0; if > 40%, try reactivation
|
|
- Repurpose deadzone weights as dynamic biases
|
|
|
|
7. **Gradient clipping**: 0.5, 1.0, 2.0, 5.0
|
|
|
|
8. **Batch size and seq length**: trade off memory vs gradient quality
|
|
- On M4 Pro with 24GB RAM, be conservative
|
|
|
|
## Constraints
|
|
|
|
- **Device**: MPS (Apple Silicon). No CUDA.
|
|
- **Memory**: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
|
|
- **Model**: SmolLM-135M (135M params). Don't change the model.
|
|
- **Dataset**: TinyStories (streaming). Don't change the dataset.
|
|
- **Eval**: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
|
|
- **Keep it simple**: Changes should be reviewable diffs. Don't rewrite the whole file.
|
|
|
|
## Evaluation Metric
|
|
|
|
**Single metric**: `eval_ppl` (perplexity on WikiText-2). Lower is better.
|
|
Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.
|
|
|
|
## Experiment Protocol
|
|
|
|
1. Read `results.tsv` to understand current state
|
|
2. Propose ONE focused change to `train.py`
|
|
3. Run training: `python train.py --steps 100 --eval-every 25 --activation-bits 16`
|
|
4. Check if `eval_ppl` improved
|
|
5. If improved → keep the change
|
|
6. If worse → revert to previous version
|
|
7. Repeat
|
|
|
|
## Notes
|
|
|
|
- Start with `activation-bits 16` (no activation quant) to isolate weight quantization effects
|
|
- Gradually introduce activation quantization once weight quant works well
|
|
- The `quant=[-1:X 0:Y +1:Z]` stat shows the ternary distribution — aim for balanced, not all zeros
|
|
- Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit
|