feat: add model comparisons and sanitize session files

- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
This commit is contained in:
2026-04-23 11:16:01 +02:00
commit 8e72eef09c
62 changed files with 18469 additions and 0 deletions
+68
View File
@@ -0,0 +1,68 @@
Here's a summary of everything delivered:
---
## Deliverables
### Files
| File | Purpose |
|------|---------|
| `fused_softmax_topk.cu` | **Production kernel** — complete, correct, well-commented |
| `fused_softmax_topk_v2.cu` | **Optimized kernel** — vectorized loads, warp-level merge |
| `ANALYSIS.md` | **Full design analysis** — memory, complexity, comparisons |
| `benchmark.cu` | **Correctness + perf harness** — CPU reference, timing, scaling |
---
### Architecture at a Glance
```
Grid: B×T blocks (one per position)
Block: 256 threads
Phase 1: max-reduction over V → warp shuffle → shared mem
Phase 2: Σexp(xᵢ - max) → warp shuffle → shared mem
Phase 3: softmax + local top-K (regs) → LOCAL_K=16 per thread
Phase 4: merge → shared min-heap → warp-by-warp staging
Phase 5: sort + write-back → thread 0, O(K²)
```
### Key Design Decisions
| Constraint | Solution |
|---|---|
| **No full softmax in global mem** | Softmax values computed in registers, only top-K written out |
| **Numerical stability** | Two-pass: find max first, then `exp(x - max)` |
| **Large V (50k+)** | Grid-stride loops; each thread processes ~196 elements |
| **Shared memory** | Min-heap (2KB) + warp staging (4KB) = 6.2KB total |
| **Warp optimization** | `__shfl_xor_sync` for reductions (zero shared mem, ~15 cycles) |
### Complexity (V=50257, K=256)
| Metric | Value |
|---|---|
| Global reads | **12V = 2.4 MB** (3 passes, coalesced) |
| Global writes | **8K = 2 KB** (output only) |
| expf() calls | **2V = 100K** |
| Bound | **Compute-bound** (expf throughput, not bandwidth) |
### vs. Naive (softmax → topk)
| Metric | Naive | Fused | Win |
|---|---|---|---|
| Global writes | 4V + 8K | **8K** | **~200× less** |
| Peak memory | 4V + 8K | **8K** | **~200× less** |
| Global reads | 8V | 12V | 0.67× (reads more) |
| expf() calls | V | 2V | 0.5× (computes more) |
**Net: The fused kernel trades 50% more reads for ~200× fewer writes.** For V=50K, this saves ~192 KB of global memory traffic per position — the write savings dominate.
### Further Optimizations (documented in ANALYSIS.md §6)
1. **Warp-level top-K merge** — reduce heap insertions from 4096 → 2048
2. **FP16/BF16** — 2× bandwidth savings, acceptable precision loss
3. **Vectorized loads (float4)** — already in v2, 4× fewer memory instructions
4. **Persistent blocks** — for large B×T, better occupancy
5. **Async copy (Hopper+)**`__ldg()` for L2 cache reuse across passes
6. **Single-pass online** — tracks running max, re-normalizes (complex, for V > 1M)