feat: add model comparisons and sanitize session files

- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
This commit is contained in:
2026-04-23 11:16:01 +02:00
commit 8e72eef09c
62 changed files with 18469 additions and 0 deletions
+394
View File
@@ -0,0 +1,394 @@
# Fused Softmax + Top-K Kernel — Design Analysis
## Table of Contents
1. [Architecture Overview](#1-architecture-overview)
2. [Memory Access Pattern](#2-memory-access-pattern)
3. [Warp-Level Optimization Strategy](#3-warp-level-optimization-strategy)
4. [Complexity Analysis](#4-complexity-analysis)
5. [Comparison to Naive Implementation](#5-comparison-to-naive-implementation)
6. [Further Optimizations](#6-further-optimizations)
---
## 1. Architecture Overview
### Block Assignment
```
Grid: B × T blocks (one block per (b, t) position)
Block: 256 threads per block
```
### Three-Phase Pipeline (per block)
```
Phase 1: Find max(logits[b,t,:]) → numerical stability anchor
Phase 2: Compute Σexp(xᵢ - max) → log-sum-exp denominator
Phase 3: Compute softmax + collect top-K → register-local buffers
Phase 4: Merge local buffers → shared heap → global top-K
Phase 5: Sort + write-back → output [B,T,K]
```
### Why Three Passes Over V?
You might wonder why we don't do this in one pass. The answer is **numerical stability**:
```
softmax(xᵢ) = exp(xᵢ) / Σⱼ exp(xⱼ)
```
Without knowing the max first, `exp(xᵢ)` can overflow for large logits. The standard
trick is:
```
softmax(xᵢ) = exp(xᵢ - max) / Σⱼ exp(xⱼ - max)
```
This requires knowing `max` before computing any softmax values, hence two passes
(max reduction, then softmax computation).
**Could we do it in one pass?** Yes, with an online algorithm that tracks a running
max and re-normalizes, but this adds complexity and potential numerical issues. The
two-pass approach is simpler, correct, and the extra V reads are coalesced.
---
## 2. Memory Access Pattern
### Global Memory Reads
| Phase | Access Pattern | Bytes Read | Coalesced? |
|-------|---------------|------------|------------|
| Phase 1 | `row[tid], row[tid+256], ...` | 4V | ✅ First iteration |
| Phase 2 | `row[tid], row[tid+256], ...` | 4V | ✅ First iteration |
| Phase 3 | `row[tid], row[tid+256], ...` | 4V | ✅ First iteration |
| **Total** | | **12V** | |
For V=50257: **12 × 50257 × 4B ≈ 2.4 MB read per (b,t)**.
**Coalescing analysis:**
- First iteration: threads 0-255 read `row[0]` through `row[255]` → perfectly coalesced
into ~8-16 128-byte transactions (depending on alignment).
- Subsequent iterations: threads read `row[256]` through `row[511]`, etc. → also coalesced.
- Stride within a thread (256 elements apart) doesn't affect coalescing — coalescing
is about **consecutive threads accessing consecutive addresses**.
### Global Memory Writes
| Output | Bytes Written |
|--------|--------------|
| `top_idx[B,T,K]` | 4BK |
| `top_prob[B,T,K]` | 4BK |
| **Total** | **8BK** |
For B=1, T=1, K=256: **8 × 256 = 2048 B** (negligible).
### Shared Memory Usage
| Buffer | Size (K=256) | Access Pattern |
|--------|-------------|----------------|
| `s_warp_max[8]` | 32 B | Write: 8 threads, Read: warp 0 |
| `s_warp_sum[8]` | 32 B | Write: 8 threads, Read: warp 0 |
| `s_heap_vals[256]` | 1024 B | Write: all (init), Read/Write: thread 0 |
| `s_heap_idxs[256]` | 1024 B | Write: all (init), Read/Write: thread 0 |
| `s_stage_vals[512]` | 2048 B | Write: active warp, Read: thread 0 |
| `s_stage_idxs[512]` | 2048 B | Write: active warp, Read: thread 0 |
| **Total** | **6208 B** | |
Well within the 48 KB shared memory limit per SM.
### Register Usage (per thread)
| Variable | Count |
|----------|-------|
| `LocalTopK<16>::vals` | 16 floats = 64 B |
| `LocalTopK<16>::idxs` | 16 ints = 64 B |
| Loop counters, temporaries | ~10 registers |
| **Total** | **~40 registers** |
With 256 threads/block and 40 registers/thread: 10,240 registers per block.
On Ampere (64K registers/SM): fits 6 blocks → 1536 threads → good occupancy.
---
## 3. Warp-Level Optimization Strategy
### 3.1 Shuffle-Based Reductions
**Problem:** Traditional reductions use shared memory + sync barriers.
**Our approach:** `__shfl_xor_sync` (warp shuffle) — data moves directly between
thread registers within a warp, zero shared memory, zero global memory.
```
warp_max(val):
for offset in [16, 8, 4, 2, 1]:
other = __shfl_xor_sync(mask, val, offset)
val = max(val, other)
return val
```
**Latency:** 5 shuffle operations × ~3 cycles = ~15 cycles per reduction.
**vs. shared memory:** ~5 cycles per access + barrier overhead = ~30+ cycles.
### 3.2 Warp-Level Merge Strategy
The merge of local top-K buffers into the shared heap uses a **warp-by-warp** strategy:
```
for each warp w in [0, 7]:
if warp_id == w:
write LOCAL_K entries to staging buffer
__syncthreads()
if tid == 0:
merge staging into shared heap
__syncthreads()
```
**Why not all threads merge concurrently?** Concurrent heap mutations require
atomics or locks, which serialize anyway and add overhead. The warp-by-warp
approach:
- Uses only 2 barriers per warp (16 total)
- Thread 0 does all heap operations (no contention)
- Other threads are idle during merge (but this is a small fraction of total work)
**Alternative: warp-level merge within each warp.** Each warp could merge its 32
threads' LOCAL_K entries into a warp-local top-K using shuffle operations, then
only 8 warp leaders contribute to the shared heap. This reduces heap insertions
from 4096 to 8×K = 2048. **This is a valid optimization** (see §6).
### 3.3 Grid-Stride Loop for Large V
```cuda
for (int v = tid; v < V; v += BLOCK_THREADS) {
// process row[v]
}
```
For V=50257, BLOCK_THREADS=256: each thread processes ⌈50257/256⌉ = 197 elements.
**Benefits:**
- Works for any V (no template parameter needed)
- Good load balancing (threads process nearly equal elements)
- First iteration is coalesced; subsequent iterations are also coalesced
**Trade-off:** Strided access within a thread means poor L2 cache reuse.
However, for V=50K, the entire row fits in L2 (200 KB on Ampere), so
re-reading across phases benefits from L2 cache.
---
## 4. Complexity Analysis
### 4.1 Bandwidth vs. Compute Bound
**Parameters:** B=1, T=1, V=50257, K=256
| Metric | Value |
|--------|-------|
| Global memory reads | 12 × 50257 × 4B = **2.41 MB** |
| Global memory writes | 8 × 256 = **2.05 KB** |
| Shared memory ops | ~32K (heap) + ~4K (staging) = **~36K** |
| expf() calls | 2 × 50257 = **100,514** |
| Comparisons | 50257 × LOCAL_K × 256 ≈ **163M** (local top-K inserts) |
| Heap sifts | 4096 × log₂(256) = **32,768** |
**Bandwidth requirement:** 2.41 MB per (b,t).
On H100 (3.35 TB/s): 2.41 MB / 3.35 TB/s = **0.72 μs** (theoretical minimum).
**Compute requirement:** 100,514 expf() calls.
On H100 (194 TFLOPS FP32): expf ≈ 50 cycles → 5.0M cycles / 1.5 GHz = **3.3 μs**.
**Verdict: COMPUTE-BOUND.** The kernel is limited by expf() throughput, not memory bandwidth.
### 4.2 Scaling with V
| V | Global Reads | expf() calls | Bandwidth (μs) | Compute (μs) | Bound |
|---|-------------|-------------|----------------|---------------|-------|
| 10K | 480 KB | 20K | 0.14 | 0.67 | Compute |
| 50K | 2.41 MB | 100K | 0.72 | 3.3 | Compute |
| 100K | 4.82 MB | 200K | 1.44 | 6.6 | Compute |
| 500K | 24.1 MB | 1M | 7.2 | 33 | Compute |
| 1M | 48.2 MB | 2M | 14.4 | 66 | Compute |
The kernel remains compute-bound across all practical V values.
### 4.3 Scaling with K
| K | Heap ops | Sort ops | Impact |
|---|----------|----------|--------|
| 16 | 512 × 4 = 2K | 256 | Negligible |
| 64 | 4096 × 6 = 25K | 4K | Small |
| 256 | 4096 × 8 = 33K | 66K | Moderate |
| 1024 | 4096 × 10 = 41K | 1M | Significant |
For K > 256, the heap operations and sort become noticeable. Consider:
- Increasing LOCAL_K to maintain oversampling ratio
- Using a more efficient merge (warp-level top-K within each warp)
- Parallel sort (bitonic sort across threads)
---
## 5. Comparison to Naive Implementation
### Naive Approach
```python
# Python pseudocode
probs = softmax(logits) # Materialize [B, T, V] in global memory
top_idx, top_prob = topk(probs, K) # Read [B, T, V], write [B, T, K]
```
### Comparison Table
| Metric | Naive | Fused Kernel | Speedup |
|--------|-------|-------------|---------|
| **Global reads** | 4V (logits) + 4V (probs) = **8V** | **12V** (logits × 3) | 0.67× |
| **Global writes** | 4V (probs) + 8K (output) | **8K** (output only) | **V/K ×** |
| **Peak memory** | 4V + 8K | 8K | **V/K ×** |
| **expf() calls** | V (softmax) | 2V (phase 2 + 3) | 0.5× |
| **Numerical stability** | Depends on softmax impl | Guaranteed (max subtraction) | — |
### Key Insight: Memory Savings Dominate
For V=50257, K=256:
- **Naive:** writes 4 × 50257 = **201 KB** of softmax probabilities to global memory
- **Fused:** writes only 8 × 256 = **2 KB** of output
The fused kernel reads 50% more (12V vs 8V) but **avoids writing the entire softmax
matrix**. For large V, the write savings dominate:
```
Naive bandwidth: 8V + 8K = 8V(1 + K/V) ≈ 8V
Fused bandwidth: 12V + 8K = 12V(1 + K/(3V)) ≈ 12V
Ratio: 12V / 8V = 1.5× more reads, but 0 writes vs 4V writes.
Net: fused saves 4V - 8K = 4V(1 - 2K/V) bytes.
```
For V=50257, K=256: saves **4 × 50257 - 8 × 256 = 192 KB** per (b,t).
### When Naive Wins
The naive approach can be faster when:
1. **V is small** (V < 1024): the overhead of 3 passes isn't worth it
2. **You need the full softmax** for other operations (e.g., KL divergence)
3. **Hardware has very high bandwidth** relative to compute (e.g., HBM3)
### When Fused Wins
The fused kernel dominates when:
1. **V is large** (V > 10K): memory savings are significant
2. **Memory is the bottleneck** (e.g., mobile, edge devices)
3. **You only need top-K** (common in LLM sampling)
4. **Batch size is small** (B=1): one block per (b,t) means no inter-block sync
---
## 6. Further Optimizations
### 6.1 Warp-Level Top-K Merge (Recommended)
Instead of merging all 4096 candidates through a single thread, each warp
merges its 32 threads' LOCAL_K entries into a warp-local top-K using shuffle:
```cuda
// Each warp: 32 threads × LOCAL_K = 512 entries → top-K within warp
// Use warp shuffle to find top-K in O(K × WARP_SIZE) operations
// Then only 8 warp leaders contribute to shared heap
```
**Benefit:** Reduces heap insertions from 4096 to 8 × K = 2048.
**Complexity:** Moderate — requires warp-level selection algorithm.
### 6.2 Float16/BFloat16 Support
For LLM workloads, logits are often in FP16/BF16:
```cuda
// Use __hexp2() for half-precision exp
// Use __shfl_xor_sync with half-precision values
// Promote to FP32 only for final softmax computation
```
**Benefit:** 2× less global memory bandwidth, 2× more throughput.
**Trade-off:** Slight numerical precision loss (acceptable for top-K).
### 6.3 Vectorized Memory Access
```cuda
// Read 4 floats at once (128-bit load)
float4 val = reinterpret_cast<const float4*>(&row[v])[0];
```
**Benefit:** 4× fewer memory instructions, better utilization of memory bandwidth.
**Constraint:** V must be divisible by 4, BLOCK_THREADS must be divisible by 4.
### 6.4 Persistent Blocks for Large B×T
For large B×T, launch fewer blocks and have each block process multiple (b,t):
```cuda
int bid = blockIdx.x * GRID_STRIDE + threadIdx.x;
while (bid < B * T) {
process(bid);
bid += GRID_STRIDE * BLOCK_THREADS;
}
```
**Benefit:** Better occupancy, hides memory latency.
### 6.5 Asynchronous Copy (Hopper+)
On H100+, use `ld.global.nc.v4.f32` (non-coherent load) for the logits reads:
```cuda
// Compiler hint: these values won't be modified
#pragma unroll
for (int v = tid; v < V; v += BLOCK_THREADS) {
float val = __ldg(&row[v]); // cacheable load
// ...
}
```
**Benefit:** Better L2 cache utilization across the three passes.
### 6.6 Single-Pass Online Algorithm
Track a running max and re-normalize:
```cuda
float local_max = -FLT_MAX;
float local_sum = 0.0f;
LocalTopK<LOCAL_K> local_topk;
for (int v = tid; v < V; v += BLOCK_THREADS) {
float x = row[v];
if (x > local_max) {
// Re-normalize all previous values
float old_max = local_max;
local_max = x;
local_sum = 0.0f;
// Re-insert all local_topk entries with new normalization
// ... (complex)
}
float prob = expf(x - local_max);
local_sum += prob;
// ...
}
```
**Benefit:** Single pass over V (4V reads instead of 12V).
**Trade-off:** Complex, potential numerical issues, re-normalization overhead.
**Verdict:** Not recommended unless V is extremely large (>1M).
---
## Appendix: Kernel Instantiation
```cuda
// Launch for LLaMA (V=50257, K=256)
launch_fused_softmax_topk<256>(d_logits, d_top_idx, d_top_prob, 1, 1, 50257);
// Launch for GPT-2 (V=50257, K=50)
launch_fused_softmax_topk<50>(d_logits, d_top_idx, d_top_prob, 1, 1, 50257);
// Launch for batched inference (B=32, T=128, V=32000, K=128)
launch_fused_softmax_topk<128>(d_logits, d_top_idx, d_top_prob, 32, 128, 32000);
```