Add ternary QAT training pipeline: prepare.py (data/eval), train.py (quantization/training), program.md (agent instructions), autoresearch.sh (loop)
This commit is contained in:
+78
@@ -0,0 +1,78 @@
|
||||
# program.md — Instructions for the Autoresearch Agent
|
||||
|
||||
You are an autonomous research agent exploring ternary (1.58-bit) quantization for LLMs.
|
||||
|
||||
## Your Goal
|
||||
|
||||
Iteratively improve the ternary quantization training in `train.py` to achieve the **lowest validation perplexity (val_ppl)** on WikiText-2.
|
||||
|
||||
## File Boundaries
|
||||
|
||||
- **MUTABLE**: `train.py` — You may modify this file. It contains:
|
||||
- `BitLinear` layer (quantization logic)
|
||||
- Quantization schedules (lambda warmup)
|
||||
- Training loop
|
||||
- Hyperparameters (LR, batch size, group size, etc.)
|
||||
|
||||
- **READ-ONLY**: `prepare.py` — DO NOT modify this file. It contains:
|
||||
- Dataset loading and tokenization
|
||||
- Evaluation harness (WikiText PPL)
|
||||
- Model loading utilities
|
||||
|
||||
- **OUTPUT**: `results.tsv` — Results are automatically logged here after each run.
|
||||
|
||||
## Experiment Protocol
|
||||
|
||||
1. Read `results.tsv` to understand what has been tried
|
||||
2. Read `train.py` to understand current implementation
|
||||
3. Propose ONE focused change to `train.py`
|
||||
4. The change will be committed and a training run will execute (~5 minutes)
|
||||
5. Results are logged to `results.tsv`
|
||||
6. If improved (lower val_ppl) → change is kept
|
||||
7. If equal or worse → git reset to previous commit
|
||||
|
||||
## What to Explore
|
||||
|
||||
### Priority 1: Quantization Schedule
|
||||
- Lambda warmup shape: linear, cosine, exponential
|
||||
- Warmup step counts: 200, 500, 1000, 2000, 5000
|
||||
- Two-phase warmup (fast initial + slow final)
|
||||
|
||||
### Priority 2: Learning Rate
|
||||
- LR values: 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4
|
||||
- LR schedule: constant, linear decay, cosine decay
|
||||
|
||||
### Priority 3: Quantization Details
|
||||
- Group size: 32, 64, 128, 256, per-tensor
|
||||
- Scale initialization: mean-based vs absmax-based
|
||||
- Ternary threshold adjustments
|
||||
|
||||
### Priority 4: Deadzone Recovery
|
||||
- Tequila-style reactivation (learnable lambda for deadzone weights)
|
||||
- Bias injection for zero-valued weights
|
||||
- Gradient scaling for deadzone weights
|
||||
|
||||
### Priority 5: Distillation
|
||||
- OFF loss (cosine similarity between FP and ternary features)
|
||||
- Logits distillation weight
|
||||
- Feature distillation weight
|
||||
|
||||
## Constraints
|
||||
|
||||
- Keep experiments focused — ONE change per iteration
|
||||
- Always maintain working code — syntax errors waste time
|
||||
- Use SmolLM-135M or Llama-3.2-1B for fast iteration
|
||||
- Target metric: val_ppl (lower is better)
|
||||
- Time budget: 5 minutes per experiment
|
||||
|
||||
## Important Notes
|
||||
|
||||
- The STE (Straight-Through Estimator) is critical for gradients to flow through quantization
|
||||
- Warmup quantization prevents catastrophic accuracy loss at the start of training
|
||||
- Deadzone trapping (weights stuck at 0) is a known problem — explore solutions
|
||||
- Per-group quantization scales are essential for handling outlier weights
|
||||
- The quantization formula: `scale = 1.0 / w.abs().mean(); round(clamp(-1, 1))`
|
||||
|
||||
## NEVER STOP
|
||||
|
||||
Run experiments continuously until manually interrupted. Each experiment should be a small, focused change. Review results.tsv between runs to inform your next decision.
|
||||
Reference in New Issue
Block a user