3.3 KiB
Ternary Quantization — Agent Instructions (program.md)
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
Context
We are implementing a QAT (quantization-aware training) pipeline that replaces standard nn.Linear layers with BitLinear layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.
File Boundary
- MUTABLE:
train.py— you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc. - READ-ONLY:
prepare.py— data loading and tokenizer. Do not modify. - READ-ONLY:
program.md— these instructions. Do not modify.
Current State
Check results.tsv to see previous experiment results. Each row has:
step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1
What You Can Experiment With
All experiments should be in train.py. Focus on:
-
Quantization warmup schedule: linear, cosine, exponential, step-wise
lambda_ = min(step / quant_warmup_steps, 1.0)→ try cosine:0.5 * (1 - cos(pi * step / warmup))- Try different warmup lengths: 500, 1000, 2000, 5000 steps
-
Learning rate: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4
- LR warmup steps: 50, 100, 200, 500
-
Group size: 64, 128, 256 (Bonsai uses 128)
-
Activation quantization: 8-bit vs 16-bit (no quant)
- Try different activation quantization strategies (per-token, per-channel)
-
Weight quantization function:
- Current:
scale = abs_mean(w)→ tryscale = abs_max(w) - Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
- Current:
-
Deadzone recovery (Tequila-style):
- Track fraction of weights at 0; if > 40%, try reactivation
- Repurpose deadzone weights as dynamic biases
-
Gradient clipping: 0.5, 1.0, 2.0, 5.0
-
Batch size and seq length: trade off memory vs gradient quality
- On M4 Pro with 24GB RAM, be conservative
Constraints
- Device: MPS (Apple Silicon). No CUDA.
- Memory: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
- Model: SmolLM-135M (135M params). Don't change the model.
- Dataset: TinyStories (streaming). Don't change the dataset.
- Eval: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
- Keep it simple: Changes should be reviewable diffs. Don't rewrite the whole file.
Evaluation Metric
Single metric: eval_ppl (perplexity on WikiText-2). Lower is better.
Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.
Experiment Protocol
- Read
results.tsvto understand current state - Propose ONE focused change to
train.py - Run training:
python train.py --steps 100 --eval-every 25 --activation-bits 16 - Check if
eval_pplimproved - If improved → keep the change
- If worse → revert to previous version
- Repeat
Notes
- Start with
activation-bits 16(no activation quant) to isolate weight quantization effects - Gradually introduce activation quantization once weight quant works well
- The
quant=[-1:X 0:Y +1:Z]stat shows the ternary distribution — aim for balanced, not all zeros - Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit