- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7 - Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training - Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution - Add analysis/ folder with cross-model comparisons and per-challenge deep dives - Add deploy_challenges.sh script - Expand .gitignore to exclude Python envs, ML weights, and build artifacts
5.7 KiB
Ternary Bonsai: Implementation Notes & Findings
Architecture
The implementation follows the Qwen3-0.6B architecture exactly, replacing all nn.Linear and nn.Embedding layers with ternary equivalents:
- Model: Qwen3-0.6B (28 layers, hidden_size=1024, 16 query heads, 8 KV heads, head_dim=128, intermediate_size=3072, vocab_size=151936)
- Ternary layers: Every linear layer (embeddings, Q/K/V/O projections, SwiGLU gate/up/down, LM head) uses ternary weights
- Full-precision layers: RMSNorm and attention scaling remain in float32
Key Implementation Details
Ternary Weight Projection (group_size=128)
Each weight matrix is divided into groups of 128 along the last dimension. For each group:
s = mean(|W_group|) # FP16 scale factor
W_q = clip(round(W / s), -1, 1) # Ternary indices {-1, 0, +1}
W_ternary = W_q * s # Effective weight
Straight-Through Estimator (STE)
The non-differentiable rounding is handled via:
W_out = W + stop_gradient(W_ternary - W)
- Forward: Uses
W_ternary(quantized weights) - Backward: Gradient passes through
Was identity (dL/dW = dL/dW_ternary)
This was verified to produce non-zero gradients in isolation (Test 1-3 in debugging).
Why group_size=128?
- Powers of 2 align well with GPU/accelerator memory access patterns
- 128 provides a good balance between quantization granularity and statistical stability of the scale factor
- Too small (e.g., 32): noisy scales, unstable training
- Too large (e.g., 256): scales can't adapt to local weight distributions
- PrismML confirmed group_size=128 in their GGUF format discussion
Why mean(|W|) for scale?
mean(|W|)is more robust thanmax(|W|)because it's less sensitive to outliers- With normally distributed weights,
mean(|W|) ≈ 0.8 * std(W), giving a stable scale max(|W|)would compress most weights toward 0, losing expressivity- BitNet b1.58 also uses absmean quantization, confirming this choice
Training Procedure
Setup
- Load Qwen3-0.6B weights from HuggingFace (via mlx_lm)
- Create ternary model with identical architecture (TernaryLinear replacing nn.Linear)
- Copy pre-trained weights as latent float32 weights
- Ternary projection happens on every forward pass
Hyperparameters
- Optimizer: AdamW (betas=0.9, 0.95, weight_decay=0.01)
- Learning rate: 5e-4 constant after 50-step linear warmup
- Batch size: 2 (limited by GPU memory with 0.6B float32 latent weights + optimizer state)
- Sequence length: 512
- Dataset: WikiText-2 (train: 2.5M tokens, val: 262K tokens)
Results
2000-step Training Run
| Metric | Pre-training | Post-training |
|---|---|---|
| Loss | 13.81 | 5.14 |
| Perplexity | 995,563 | 232 |
| Ternary weights | {-1, 0, +1} | {-1, 0, +1} |
Eval perplexity trajectory:
- Step 500: 333
- Step 1000: 264
- Step 1500: 228
The model is still steadily improving. With more training steps (5K-10K), perplexity would likely drop below 100.
Text Generation (after 2000 steps)
Prompt: "The most important thing about"
Output: "...the world . The first two days later , the first two days of
the first two days , the first two days of the first two days..."
The output shows learned patterns (English syntax, punctuation) but is repetitive due to limited training.
Weight Distribution
All ternary layers project correctly to {-1, 0, +1}:
- ~34.7% are -1
- ~30.9% are 0
- ~34.3% are +1
This matches the expected distribution for normally-distributed latent weights.
Key Findings & Observations
1. Weight Copy: MLX Module Structure
Critical finding: MLX's nn.Module extends dict. Sub-modules and parameters are stored as dict entries (model['model'], model['embed_tokens']), NOT as __dict__ attributes. Our initial copy_weights using __dict__ silently failed, leaving all weights at zero. Fixed by iterating over model.keys() instead.
2. Ternarization Destroys Pre-trained Knowledge
When Qwen3-0.6B weights are ternarized, the model's loss jumps from ~2.5 (pre-trained) to ~14 (ternarized). This is expected: ternary weights at ~1.58 bits cannot represent the same information as 16-bit weights. The model must re-learn through the ternary constraint.
3. STE Works Correctly
The Straight-Through Estimator implementation via W + stop_gradient(W_ternary - W) produces correct non-zero gradients. We verified:
- Simple STE: gradient = [-2, 0, 2] (expected)
- W-dependent STE: non-zero gradients
- Full model: non-zero gradients for all layers
4. Training From Scratch vs Fine-tuning
PrismML trains from scratch, not from a pre-trained checkpoint. Our fine-tuning approach is fundamentally harder because:
- Pre-trained latent weights encode full-precision patterns
- The optimizer must simultaneously "unlearn" full-precision structure and learn ternary-friendly patterns
- Training from scratch with random init would likely converge faster to a good ternary solution
5. What Broke and How We Fixed It
| Issue | Cause | Fix |
|---|---|---|
| All-zero logits | copy_weights used __dict__ which misses MLX sub-modules |
Use dict-style iteration (model.keys()) |
| Zero gradients (first attempt) | Weights were never actually loaded (same root cause) | Same fix |
| Slow convergence with cosine decay | LR decays to near-zero too quickly | Use constant LR after warmup |
| Noisy training loss | batch_size=2 gives high variance gradients | Acceptable for demo; gradient accumulation would help |
Files
ternary_model.py— Ternary Bonsai model definition (TernaryLinear, TernaryEmbedding, full Qwen3 architecture)train.py— Training, evaluation, and verification scriptNOTES.md— This document