Initial commit: PLAN.md
This commit is contained in:
@@ -0,0 +1,266 @@
|
|||||||
|
# Ternary Quantization Research Plan
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
Research and prototype ternary (1.58-bit) quantization for LLMs, exploring quantization-aware training (QAT) and post-training quantization (PTQ) + fine-tuning pipelines. The goal is to understand how to take a pre-trained model, quantize it to ternary/2-bit weights, and recover accuracy through fine-tuning.
|
||||||
|
|
||||||
|
**Key methodology addition: Autonomous experiment iteration via Karpathy's autoresearch pattern** to accelerate hyperparameter and technique discovery.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## What "Bonsai" Actually Is
|
||||||
|
|
||||||
|
**Bonsai** is a family of commercially-viable sub-2-bit LLMs developed by **PrismML** (not Microsoft/BitNet). They have two families:
|
||||||
|
|
||||||
|
| Family | Weights | Sizes | Format |
|
||||||
|
|--------|---------|-------|--------|
|
||||||
|
| **Bonsai** | Binary {-1, +1} | 1.7B, 4B, 8B | Q1_0 (GGUF), MLX 1-bit |
|
||||||
|
| **Ternary-Bonsai** | Ternary {-1, 0, +1} | 1.7B, 4B, 8B | Q2_0 (GGUF), MLX 2-bit |
|
||||||
|
|
||||||
|
Key properties:
|
||||||
|
- Uses **group size 128** for quantization scales
|
||||||
|
- Llama architecture with Mistral tokenizer
|
||||||
|
- Trained **natively** at low bit-width (not PTQ from FP16)
|
||||||
|
- Inference via llama.cpp fork (PrismML-Eng/llama.cpp) and MLX
|
||||||
|
- Models available on HuggingFace: `prism-ml/Bonsai-8B-gguf`, `prism-ml/Ternary-Bonsai-8B-gguf`
|
||||||
|
|
||||||
|
**Bonsai is NOT open-source training code** — only inference weights and demos are released. To replicate Bonsai-style results, you need to implement your own QAT pipeline.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Recommended Stack
|
||||||
|
|
||||||
|
### For Training / Fine-tuning
|
||||||
|
| Component | Recommendation | Rationale |
|
||||||
|
|-----------|---------------|-----------|
|
||||||
|
| **Framework** | PyTorch + HuggingFace Transformers | Widest ecosystem, ParetoQ/EfficientQAT both use it |
|
||||||
|
| **Training** | HuggingFace TRL / custom training loop | ParetoQ uses vanilla HF Trainer; HF blog uses Nanotron |
|
||||||
|
| **Quantization Layer** | Custom `BitLinear` (see HF blog code) | Drop-in replacement for `nn.Linear` |
|
||||||
|
| **Dataset** | FineWeb-edu, RedPajama, or UltraFineWeb | Proven for ternary QAT (HF blog, Tequila, ParetoQ) |
|
||||||
|
| **Inference** | bitnet.cpp (Microsoft) or llama.cpp (Bonsai fork) | Optimized CPU/GPU kernels for ternary |
|
||||||
|
|
||||||
|
### For Quick Experiments
|
||||||
|
| Component | Recommendation |
|
||||||
|
|-----------|---------------|
|
||||||
|
| **Small model** | Llama-3.2-1B or SmolLM-135M/360M | Fast iteration, ParetoQ has released quantized versions |
|
||||||
|
| **GPU** | Single A100-80GB or H100 | EfficientQAT does 2-bit Llama-2-70B on one A100-80GB in 41h |
|
||||||
|
| **Tokens** | 10B–100B for fine-tuning | HF blog: 10B tokens competitive; 100B closer to FP baseline |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Core Technical Approaches
|
||||||
|
|
||||||
|
### Approach 1: Warmup Quantization Fine-tuning (HF Blog / Most Practical)
|
||||||
|
**Best for:** Starting from a pretrained FP model, quantizing to ternary, recovering via QAT fine-tuning.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Core idea: gradually introduce quantization
|
||||||
|
lambda_ = min(training_step / 1000, 1) # linear warmup over 1000 steps
|
||||||
|
|
||||||
|
x_quant = x + lambda_ * (activation_quant(x) - x).detach()
|
||||||
|
w_quant = w + lambda_ * (weight_quant(w) - w).detach()
|
||||||
|
```
|
||||||
|
|
||||||
|
Key hyperparameters from HF blog (Llama3-8B):
|
||||||
|
- **LR:** 1e-4 (critical — they experimented extensively)
|
||||||
|
- **Batch size:** 2M tokens
|
||||||
|
- **Dataset:** FineWeb-edu
|
||||||
|
- **Warmup steps:** 1000 (linear scheduler)
|
||||||
|
- **Weight quant:** `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
|
||||||
|
- **Activation quant:** 8-bit absmax per token
|
||||||
|
|
||||||
|
Results: WikiText PPL 12.2 after 10B tokens; surpasses Llama-1-7B on MMLU.
|
||||||
|
|
||||||
|
### Approach 2: ParetoQ-style QAT (Meta Research)
|
||||||
|
**Best for:** Rigorous comparison across bit-widths; released training code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# From their repo
|
||||||
|
torchrun train.py \
|
||||||
|
--input_model_filename "meta-llama/Llama-3.2-1B" \
|
||||||
|
--qat True --w_bits 2 \
|
||||||
|
--learning_rate 2e-5 --bf16 True
|
||||||
|
```
|
||||||
|
|
||||||
|
Key insights:
|
||||||
|
- **2-bit and ternary sit on the Pareto frontier** for size-vs-accuracy
|
||||||
|
- 3-bit+ models stay close to FP distribution; 2-bit and below change drastically
|
||||||
|
- Scale initialization differs by bit-width (critical detail in their code)
|
||||||
|
- Released MobileLLM-ParetoQ models: 125M–1.5B in 1/1.58/2/3/4-bit
|
||||||
|
|
||||||
|
### Approach 3: Two-Phase PTQ + Fine-tuning (EfficientQAT)
|
||||||
|
**Best for:** If you want to start from PTQ then recover.
|
||||||
|
|
||||||
|
Phase 1: Block-wise training of all parameters (Block-AP)
|
||||||
|
Phase 2: End-to-end training of only quantization parameters (E2E-QP)
|
||||||
|
|
||||||
|
Supports **INT2** (not ternary). Best for 2-bit uniform quantization, not {-1,0,+1}.
|
||||||
|
|
||||||
|
### Approach 4: TernaryLLM-style Knowledge Distillation
|
||||||
|
**Best for:** Maximum accuracy recovery with feature-level distillation.
|
||||||
|
|
||||||
|
- **DLT (Dual Learnable Ternarization):** Learnable scale α + shift γ per layer
|
||||||
|
- **OFF loss:** Cosine similarity between FP and ternary features (scale-invariant, outlier-friendly)
|
||||||
|
- `L_total = L_label + ε·L_logits + δ·L_feat`
|
||||||
|
- Results: LLaMA-3-8B W1.58A16 outperforms W2A16 (DB-LLM) by 5.8 PPL on C4
|
||||||
|
|
||||||
|
### Approach 5: Tequila (Deadzone Trapping Fix)
|
||||||
|
**Best for:** Fixing the fundamental problem where ternary QAT weights get stuck at 0.
|
||||||
|
|
||||||
|
- Problem: STE gives noisy gradients to deadzone weights → they can't escape
|
||||||
|
- Solution: Repurpose deadzone weights as dynamic biases with learnable reactivation λ
|
||||||
|
- Forward: `Y = X·Q̂(W)·α + Σᵢ∈D λ·wᵢ`
|
||||||
|
- Results on LLaMA-3.2-1B (10B tokens): <1% gap to FP on ARC benchmarks
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Autonomous Experimentation: Karpathy's autoresearch Pattern
|
||||||
|
|
||||||
|
Karpathy's [autoresearch](https://github.com/karpathy/autoresearch) is an autonomous AI-driven experiment loop. An agent iteratively modifies a training script, runs a short training job, evaluates the result, and either keeps or discards the change. The loop runs indefinitely until interrupted.
|
||||||
|
|
||||||
|
### Why This Matters for Ternary Quantization
|
||||||
|
|
||||||
|
Ternary quantization research has a **large, poorly understood hyperparameter space**: quantization schedules (λ warmup), LR schedules, group sizes, deadzone recovery thresholds, distillation loss weights (ε, δ), and architecture trade-offs. Manually grid-searching this is impractical. The autoresearch pattern automates it.
|
||||||
|
|
||||||
|
### How It Works (adapted for our use case)
|
||||||
|
|
||||||
|
```
|
||||||
|
LOOP FOREVER:
|
||||||
|
1. Agent reads current state of train.py and results.tsv
|
||||||
|
2. Agent proposes a change (e.g., "try Tequila deadzone reactivation with λ=0.5")
|
||||||
|
3. Agent modifies train.py and commits
|
||||||
|
4. Run training for fixed time budget (~5 min on small model)
|
||||||
|
5. Extract val_bpb / val_ppl from output
|
||||||
|
6. Log result to results.tsv (commit, metric, memory, status, description)
|
||||||
|
7. If improved → keep the commit
|
||||||
|
8. If equal or worse → git reset to previous commit
|
||||||
|
9. Repeat
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Design Choices from autoresearch
|
||||||
|
|
||||||
|
| Decision | Rationale |
|
||||||
|
|----------|-----------|
|
||||||
|
| **Single mutable file** (`train.py`) | Keeps scope manageable; diffs are reviewable |
|
||||||
|
| **Fixed time budget** (5 min) | Experiments are comparable regardless of model/architecture changes |
|
||||||
|
| **Single metric** (`val_bpb`) | Removes ambiguity in what "better" means |
|
||||||
|
| **Git-based version control** | Automatic rollback on failed experiments; full audit trail |
|
||||||
|
| **NEVER STOP** directive | Agent runs until manually stopped (e.g., overnight = ~100 experiments) |
|
||||||
|
|
||||||
|
### Adapting autoresearch for Ternary Quantization
|
||||||
|
|
||||||
|
Our adaptation differs from vanilla autoresearch in key ways:
|
||||||
|
|
||||||
|
1. **Metric**: `val_ppl` or `val_bpb` on WikiText/C4 instead of autoresearch's synthetic data metric
|
||||||
|
2. **Base model**: Start from a pretrained HF model (Llama-3.2-1B) rather than training from scratch
|
||||||
|
3. **Scope of mutations**: Agent can modify quantization layers, loss functions, warmup schedules, deadzone recovery, distillation weights — not just architecture/hyperparameters
|
||||||
|
4. **Two-file boundary**: `train.py` (mutable — quantization logic + training loop) vs `prepare.py` (read-only — data loading, tokenizer, evaluation)
|
||||||
|
5. **Longer runs**: Full QAT fine-tuning needs 10B+ tokens. The autoresearch loop handles **short ablation experiments** (5-15 min) to find the best hyperparameter combos, then the winning config gets a **full long run** outside the loop
|
||||||
|
|
||||||
|
### Proposed autoresearch Integration
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────┐
|
||||||
|
│ SHORT-LOOP (autoresearch agent, 5-min runs) │
|
||||||
|
│ - Quantization schedule shape │
|
||||||
|
│ - Lambda warmup length │
|
||||||
|
│ - LR warmup vs constant │
|
||||||
|
│ - Deadzone recovery thresholds │
|
||||||
|
│ - Distillation loss weights │
|
||||||
|
│ - Group size ablations │
|
||||||
|
│ → Outputs: best hyperparameter config │
|
||||||
|
└──────────────────┬──────────────────────────┘
|
||||||
|
│ winning config
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────┐
|
||||||
|
│ LONG-RUN (manual, full QAT fine-tuning) │
|
||||||
|
│ - 10B-100B token training │
|
||||||
|
│ - Full dataset (FineWeb-edu) │
|
||||||
|
│ - Eval on WikiText + MMLU + ARC │
|
||||||
|
│ → Outputs: production ternary model │
|
||||||
|
└─────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
The short-loop runs autonomously (overnight) to explore the hyperparameter space. Once a winning configuration emerges, you run a full-scale fine-tuning with those settings.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Proposed POC Pipeline
|
||||||
|
|
||||||
|
### Phase 0: Infrastructure & autoresearch Setup (2-3 days)
|
||||||
|
1. Set up the autoresearch-style project structure:
|
||||||
|
- `prepare.py` — data loading, tokenizer, evaluation (read-only)
|
||||||
|
- `train.py` — model loading, `BitLinear` layer, quantization logic, training loop (mutable by agent)
|
||||||
|
- `program.md` — agent instructions specific to ternary quantization experimentation
|
||||||
|
- `results.tsv` — experiment log
|
||||||
|
2. Clone autoresearch repo as reference; adapt `prepare.py` patterns for our data pipeline
|
||||||
|
3. Set up evaluation harness: WikiText PPL, optionally ARC/MMLU zero-shot
|
||||||
|
4. **Goal**: Working 5-minute training loop that loads Llama-3.2-1B, applies ternary quantization, and reports val_ppl
|
||||||
|
|
||||||
|
### Phase 1: Reproduce HF Blog Fine-tuning (1–2 weeks)
|
||||||
|
1. Take **Llama-3.2-1B** or **SmolLM-135M**
|
||||||
|
2. Implement `BitLinear` layer with STE + warmup quantization
|
||||||
|
3. Fine-tune on FineWeb-edu (10B tokens) with lambda warmup
|
||||||
|
4. Evaluate on WikiText + zero-shot tasks
|
||||||
|
5. **Goal:** Validate the pipeline works; establish baseline PPL
|
||||||
|
|
||||||
|
### Phase 2: Autonomous Hyperparameter Search via autoresearch (1-2 weeks)
|
||||||
|
1. Launch autoresearch agent with `program.md` tuned for ternary quantization
|
||||||
|
2. Agent iteratively explores:
|
||||||
|
- Quantization warmup schedules (linear, cosine, exponential)
|
||||||
|
- Lambda warmup step counts (500, 1000, 2000, 5000)
|
||||||
|
- Learning rates (1e-5 to 1e-3 grid)
|
||||||
|
- Group sizes (64, 128, 256)
|
||||||
|
- Deadzone recovery strategies (Tequila λ values, ON/OFF)
|
||||||
|
- Distillation loss weights (ε, δ)
|
||||||
|
3. Each experiment: ~5 min run, automatic keep/discard
|
||||||
|
4. Review `results.tsv` after overnight runs; identify patterns
|
||||||
|
5. **Goal:** Find optimal hyperparameter configuration through autonomous search (~100-500 experiments)
|
||||||
|
|
||||||
|
### Phase 3: Recovery Technique Deep Dive (2–3 weeks)
|
||||||
|
1. Apply winning autoresearch config as baseline
|
||||||
|
2. Systematically add **Tequila** deadzone reactivation to `BitLinear`
|
||||||
|
3. Try **TernaryLLM**-style OFF distillation loss
|
||||||
|
4. Compare: warmup-only vs warmup+Tequila vs warmup+OFF vs combined
|
||||||
|
5. Use autoresearch short-loop to find optimal weights for each technique
|
||||||
|
6. **Goal:** Find best accuracy recovery method; quantify each technique's contribution
|
||||||
|
|
||||||
|
### Phase 4: Full-Scale Fine-tuning (2–4 weeks)
|
||||||
|
1. Apply winning recipe to target model size (e.g., 7B–8B)
|
||||||
|
2. Scale to 100B tokens
|
||||||
|
3. Monitor training loss curve vs FP16 baseline
|
||||||
|
4. Evaluate on full benchmark suite (WikiText, C4, MMLU, ARC)
|
||||||
|
5. **Goal:** Ternary model within 10-15% of FP16 baseline on key benchmarks
|
||||||
|
|
||||||
|
### Phase 5: Export & Inference (1 week)
|
||||||
|
1. Export to GGUF Q2_0 format (Bonsai-compatible) or bitnet.cpp I2_S
|
||||||
|
2. Benchmark inference speed vs FP16 baseline (tokens/sec, memory footprint)
|
||||||
|
3. Quantize activations for inference (INT8 activations)
|
||||||
|
4. **Goal:** Production-ready ternary model with measured speed/memory gains
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Risks & Mitigations
|
||||||
|
|
||||||
|
| Risk | Likelihood | Impact | Mitigation |
|
||||||
|
|------|-----------|--------|------------|
|
||||||
|
| **Catastrophic forgetting** during QAT | High | High | Use warmup quantization (lambda scheduling); start from instruct-tuned model; use diverse dataset |
|
||||||
|
| **Deadzone trapping** (weights stuck at 0) | Medium | High | Implement Tequila reactivation; use per-group quantization; autoresearch explores λ values |
|
||||||
|
| **Training instability** at low LR | Medium | Medium | LR 1e-4 worked for HF; ParetoQ uses 2e-5. Autoresearch grid-searches on small model first |
|
||||||
|
| **autoresearch agent wastes runs** on bad ideas | Low | Low | The keep/discard loop naturally prunes; 5-min budget limits waste; `program.md` constrains search space |
|
||||||
|
| **autoresearch metric not correlating** with full fine-tune results | Medium | High | Validate: run 3-5 winning configs as longer runs (30+ min) and check correlation before committing to full run |
|
||||||
|
| **autoresearch agent breaks train.py** | Medium | Low | Git reset on failure; `prepare.py` is immutable; crash logged and skipped |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
| Resource | Link |
|
||||||
|
|----------|------|
|
||||||
|
| autoresearch (Karpathy) | https://github.com/karpathy/autoresearch |
|
||||||
|
| Bonsai / Ternary-Bonsai (PrismML) | https://huggingface.co/prism-ml |
|
||||||
|
| ParetoQ (Meta) | https://github.com/facebookresearch/ParetoQ |
|
||||||
|
| HF Blog: Ternary LLM Fine-tuning | https://huggingface.co/blog/ternary-llm |
|
||||||
|
| Tequila (Deadzone Trapping) | https://arxiv.org/abs/2506.18907 |
|
||||||
|
| TernaryLLM (Distillation) | https://arxiv.org/abs/2406.11943 |
|
||||||
|
| EfficientQAT (PTQ + Fine-tune) | https://github.com/microsoft/BrickFlow |
|
||||||
|
| ParetoQ MobileLLM Models | https://huggingface.co/collections/meta/pq-675198e3097f6a25e810eea2 |
|
||||||
Reference in New Issue
Block a user