Files

3.3 KiB

Ternary Quantization — Agent Instructions (program.md)

You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.

Context

We are implementing a QAT (quantization-aware training) pipeline that replaces standard nn.Linear layers with BitLinear layers that quantize weights to {-1, 0, +1} with per-group scales. The goal is to find hyperparameter configurations that minimize eval perplexity (PPL) on WikiText-2 after fine-tuning.

File Boundary

  • MUTABLE: train.py — you can modify this file to change quantization logic, loss functions, warmup schedules, deadzone recovery, etc.
  • READ-ONLY: prepare.py — data loading and tokenizer. Do not modify.
  • READ-ONLY: program.md — these instructions. Do not modify.

Current State

Check results.tsv to see previous experiment results. Each row has: step, lambda, train_loss, train_ppl, eval_ppl, eval_bpb, lr, time_s, best_ppl, q_neg1, q_zero, q_pos1

What You Can Experiment With

All experiments should be in train.py. Focus on:

  1. Quantization warmup schedule: linear, cosine, exponential, step-wise

    • lambda_ = min(step / quant_warmup_steps, 1.0) → try cosine: 0.5 * (1 - cos(pi * step / warmup))
    • Try different warmup lengths: 500, 1000, 2000, 5000 steps
  2. Learning rate: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4

    • LR warmup steps: 50, 100, 200, 500
  3. Group size: 64, 128, 256 (Bonsai uses 128)

  4. Activation quantization: 8-bit vs 16-bit (no quant)

    • Try different activation quantization strategies (per-token, per-channel)
  5. Weight quantization function:

    • Current: scale = abs_mean(w) → try scale = abs_max(w)
    • Try different deadzone thresholds (e.g., |w_norm| < 0.5 → 0)
  6. Deadzone recovery (Tequila-style):

    • Track fraction of weights at 0; if > 40%, try reactivation
    • Repurpose deadzone weights as dynamic biases
  7. Gradient clipping: 0.5, 1.0, 2.0, 5.0

  8. Batch size and seq length: trade off memory vs gradient quality

    • On M4 Pro with 24GB RAM, be conservative

Constraints

  • Device: MPS (Apple Silicon). No CUDA.
  • Memory: 24GB RAM. Use float32 (float16 breaks on MPS with cross-entropy).
  • Model: SmolLM-135M (135M params). Don't change the model.
  • Dataset: TinyStories (streaming). Don't change the dataset.
  • Eval: WikiText-2 test split (pre-tokenized in data/wikitext_eval.json).
  • Keep it simple: Changes should be reviewable diffs. Don't rewrite the whole file.

Evaluation Metric

Single metric: eval_ppl (perplexity on WikiText-2). Lower is better. Baseline: SmolLM-135M FP32 should be around 30-40 PPL on WikiText-2.

Experiment Protocol

  1. Read results.tsv to understand current state
  2. Propose ONE focused change to train.py
  3. Run training: python train.py --steps 100 --eval-every 25 --activation-bits 16
  4. Check if eval_ppl improved
  5. If improved → keep the change
  6. If worse → revert to previous version
  7. Repeat

Notes

  • Start with activation-bits 16 (no activation quant) to isolate weight quantization effects
  • Gradually introduce activation quantization once weight quant works well
  • The quant=[-1:X 0:Y +1:Z] stat shows the ternary distribution — aim for balanced, not all zeros
  • Lambda warmup is critical — too fast = catastrophic accuracy drop, too slow = no quantization benefit