# Ternary Quantization — Agent Instructions (program.md) You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs. ## Context We are implementing a QAT (quantization-aware training) pipeline that replaces standard `nn.Linear` layers with `BitLinear` layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning. ## File Boundary - **MUTABLE**: `train.py` — you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc. - **READ-ONLY**: `prepare.py` — data loading and tokenizer. Do not modify. - **READ-ONLY**: `program.md` — these instructions. Do not modify. ## Current State Check `results.tsv` to see previous experiment results. Each row has: `step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1` ## What You Can Experiment With All experiments should be in `train.py`. Focus on: 1. **Quantization warmup schedule**: linear, cosine, exponential, step-wise - `lambda_ = min(step / quant_warmup_steps, 1.0)` → try cosine: `0.5 * (1 - cos(pi * step / warmup))` - Try different warmup lengths: 500, 1000, 2000, 5000 steps 2. **Learning rate**: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4 - LR warmup steps: 50, 100, 200, 500 3. **Group size**: 64, 128, 256 (Bonsai uses 128) 4. **Activation quantization**: 8-bit vs 16-bit (no quant) - Try different activation quantization strategies (per-token, per-channel) 5. **Weight quantization function**: - Current: `scale = abs_mean(w)` → try `scale = abs_max(w)` - Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0) 6. **Deadzone recovery (Tequila-style)**: - Track fraction of weights at 0; if > 40%, try reactivation - Repurpose deadzone weights as dynamic biases 7. **Gradient clipping**: 0.5, 1.0, 2.0, 5.0 8. **Batch size and seq length**: trade off memory vs gradient quality - On M4 Pro with 24GB RAM, be conservative ## Constraints - **Device**: MPS (Apple Silicon). No CUDA. - **Memory**: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy). - **Model**: SmolLM-135M (135M params). Don't change the model. - **Dataset**: TinyStories (streaming). Don't change the dataset. - **Eval**: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json). - **Keep it simple**: Changes should be reviewable diffs. Don't rewrite the whole file. ## Evaluation Metric **Single metric**: `eval_ppl` (perplexity on WikiText-2). Lower is better. Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2. ## Experiment Protocol 1. Read `results.tsv` to understand current state 2. Propose ONE focused change to `train.py` 3. Run training: `python train.py --steps 100 --eval-every 25 --activation-bits 16` 4. Check if `eval_ppl` improved 5. If improved → keep the change 6. If worse → revert to previous version 7. Repeat ## Notes - Start with `activation-bits 16` (no activation quant) to isolate weight quantization effects - Gradually introduce activation quantization once weight quant works well - The `quant=[-1:X 0:Y +1:Z]` stat shows the ternary distribution — aim for balanced, not all zeros - Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit