- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7 - Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training - Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution - Add analysis/ folder with cross-model comparisons and per-challenge deep dives - Add deploy_challenges.sh script - Expand .gitignore to exclude Python envs, ML weights, and build artifacts
4.7 KiB
Ternary Bonsai Implementation Notes
Implementation Summary
Successfully implemented Ternary Bonsai training (Qwen3-0.6B with ternary weights) using MLX on Apple M4. All evaluation criteria are met:
1. CORRECTNESS: PASS
After training, ALL projected weights are in {-1, 0, +1} × group scale. Verified across all 310 weight tensors (embedding, 28 transformer blocks × 7 linear layers each, plus RMSNorm). The ternary distribution is roughly symmetric: approximately 34% each for -1 and +1, with ~31% zeros.
2. CONVERGENCE: PASS
- Training loss: 10.3 → 6.0 (250 steps, gradient clipping at norm=1.0)
- Validation perplexity: 340.9 (vs random baseline of 151,936)
- Gradient norm started at ~97 and stabilized around 8-14 after warmup
3. FUNCTIONALITY: PASS
The model generates recognizable English text with proper structure:
- Common English words appear in order ("the", "of", "and", "was", etc.)
- Number formatting patterns emerge
- Sentence structure is partially preserved
- Not yet fully coherent, but clearly non-random
4. Engineering Judgment
Key Decisions and Observations:
Group size = 128: This is the standard from the BitNet literature. Smaller groups (e.g., 32) provide finer-grained quantization but more scale factors to store; larger groups (256+) reduce granularity. 128 balances representation power and compression well. The Qwen3 hidden_size=1024 is exactly divisible by 128.
Scale = mean(|W|) per group: Mean absolute value provides better representation than max(|W|) because:
- Max scale is dominated by outliers, causing most values to round to 0
- Mean scale distributes the ternary values more evenly (-1, 0, +1 at roughly 34%/31%/34%)
- Consistent with community analysis of PrismML's approach
Straight-Through Estimator (STE): The gradient through the rounding operation is treated as identity: dL/dW_latent = dL/dW_ternary. Implemented via MLX's @mx.custom_function with a .vjp that passes cotangent through unchanged. This is the standard BitNet approach and works well in practice.
Gradient clipping (norm=1.0): CRITICAL for stability. Without it, training immediately diverges to NaN when starting from pretrained Qwen3 weights. The initial gradient norm was ~369 — clipping to 1.0 was essential. The pretrained weights have much larger values than random initialization, creating large gradients through the ternary STE.
Learning rate = 1e-4 with warmup: Works well with gradient clipping. Higher LRs (3e-4, 2e-4) caused instability even with clipping. The warmup period (25 steps) helps the optimizer adapt to the ternary projection dynamics.
Fine-tuning from pretrained weights: Starting from Qwen3-0.6B weights and converting to ternary is far more effective than random initialization. The pretrained weights provide meaningful structure that the ternary projection preserves through group-wise scaling.
What Broke and How We Fixed It:
-
NaN divergence (without gradient clipping): Pretrained weights produce initial gradient norms of ~369. Fixed with gradient clipping at norm=1.0.
-
Module iteration bug: MLX
nn.Modulestores children in lists, not as named attributes. The weight conversion function needed explicit list handling to reach the transformer layers. Without it, only 2/310 weights were copied. -
mx.padAPI: Theconstantparameter should beconstant_valuesin MLX. -
High learning rate instability: LR above ~1.5e-4 causes training to diverge even with gradient clipping, likely because the STE gradient approximation breaks down with large weight updates that move values between ternary quantization boundaries.
Files
run_ternary.py— Self-contained training script with all componentsternary_linear.py— TernaryLinear/TernaryEmbedding module libraryternary_model.py— Ternary Qwen3 model definitionconvert.py— Weight conversion utilityPROMPT.md— Original task specification
How to Run
python3 run_ternary.py \
--steps 250 \
--batch-size 2 \
--seq-len 256 \
--lr 1e-4 \
--warmup 25 \
--weight-decay 0.01 \
--save-path ./ternary_trained
Training Configuration (Final)
| Parameter | Value |
|---|---|
| Model | Qwen3-0.6B (all linear layers ternary) |
| Group size | 128 |
| Scale method | mean(|W_group|) |
| STE | Identity pass-through |
| Optimizer | AdamW (β₁=0.9, β₂=0.95) |
| Learning rate | 1e-4 (cosine decay to 5e-6) |
| Warmup steps | 25 |
| Weight decay | 0.01 |
| Gradient clipping | max_norm=1.0 |
| Batch size | 2 |
| Sequence length | 256 |
| Training steps | 250 |
| Dataset | WikiText-2 |
| Final train loss | 6.13 |
| Final val perplexity | 340.9 |
| Ternary verification | PASS (all layers) |