- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
13 KiB
Fused Softmax + Top-K Kernel — Design Analysis
Table of Contents
- Architecture Overview
- Memory Access Pattern
- Warp-Level Optimization Strategy
- Complexity Analysis
- Comparison to Naive Implementation
- Further Optimizations
1. Architecture Overview
Block Assignment
Grid: B × T blocks (one block per (b, t) position)
Block: 256 threads per block
Three-Phase Pipeline (per block)
Phase 1: Find max(logits[b,t,:]) → numerical stability anchor
Phase 2: Compute Σexp(xᵢ - max) → log-sum-exp denominator
Phase 3: Compute softmax + collect top-K → register-local buffers
Phase 4: Merge local buffers → shared heap → global top-K
Phase 5: Sort + write-back → output [B,T,K]
Why Three Passes Over V?
You might wonder why we don't do this in one pass. The answer is numerical stability:
softmax(xᵢ) = exp(xᵢ) / Σⱼ exp(xⱼ)
Without knowing the max first, exp(xᵢ) can overflow for large logits. The standard
trick is:
softmax(xᵢ) = exp(xᵢ - max) / Σⱼ exp(xⱼ - max)
This requires knowing max before computing any softmax values, hence two passes
(max reduction, then softmax computation).
Could we do it in one pass? Yes, with an online algorithm that tracks a running max and re-normalizes, but this adds complexity and potential numerical issues. The two-pass approach is simpler, correct, and the extra V reads are coalesced.
2. Memory Access Pattern
Global Memory Reads
| Phase | Access Pattern | Bytes Read | Coalesced? |
|---|---|---|---|
| Phase 1 | row[tid], row[tid+256], ... |
4V | ✅ First iteration |
| Phase 2 | row[tid], row[tid+256], ... |
4V | ✅ First iteration |
| Phase 3 | row[tid], row[tid+256], ... |
4V | ✅ First iteration |
| Total | 12V |
For V=50257: 12 × 50257 × 4B ≈ 2.4 MB read per (b,t).
Coalescing analysis:
- First iteration: threads 0-255 read
row[0]throughrow[255]→ perfectly coalesced into ~8-16 128-byte transactions (depending on alignment). - Subsequent iterations: threads read
row[256]throughrow[511], etc. → also coalesced. - Stride within a thread (256 elements apart) doesn't affect coalescing — coalescing is about consecutive threads accessing consecutive addresses.
Global Memory Writes
| Output | Bytes Written |
|---|---|
top_idx[B,T,K] |
4BK |
top_prob[B,T,K] |
4BK |
| Total | 8BK |
For B=1, T=1, K=256: 8 × 256 = 2048 B (negligible).
Shared Memory Usage
| Buffer | Size (K=256) | Access Pattern |
|---|---|---|
s_warp_max[8] |
32 B | Write: 8 threads, Read: warp 0 |
s_warp_sum[8] |
32 B | Write: 8 threads, Read: warp 0 |
s_heap_vals[256] |
1024 B | Write: all (init), Read/Write: thread 0 |
s_heap_idxs[256] |
1024 B | Write: all (init), Read/Write: thread 0 |
s_stage_vals[512] |
2048 B | Write: active warp, Read: thread 0 |
s_stage_idxs[512] |
2048 B | Write: active warp, Read: thread 0 |
| Total | 6208 B |
Well within the 48 KB shared memory limit per SM.
Register Usage (per thread)
| Variable | Count |
|---|---|
LocalTopK<16>::vals |
16 floats = 64 B |
LocalTopK<16>::idxs |
16 ints = 64 B |
| Loop counters, temporaries | ~10 registers |
| Total | ~40 registers |
With 256 threads/block and 40 registers/thread: 10,240 registers per block. On Ampere (64K registers/SM): fits 6 blocks → 1536 threads → good occupancy.
3. Warp-Level Optimization Strategy
3.1 Shuffle-Based Reductions
Problem: Traditional reductions use shared memory + sync barriers.
Our approach: __shfl_xor_sync (warp shuffle) — data moves directly between
thread registers within a warp, zero shared memory, zero global memory.
warp_max(val):
for offset in [16, 8, 4, 2, 1]:
other = __shfl_xor_sync(mask, val, offset)
val = max(val, other)
return val
Latency: 5 shuffle operations × ~3 cycles = ~15 cycles per reduction. vs. shared memory: ~5 cycles per access + barrier overhead = ~30+ cycles.
3.2 Warp-Level Merge Strategy
The merge of local top-K buffers into the shared heap uses a warp-by-warp strategy:
for each warp w in [0, 7]:
if warp_id == w:
write LOCAL_K entries to staging buffer
__syncthreads()
if tid == 0:
merge staging into shared heap
__syncthreads()
Why not all threads merge concurrently? Concurrent heap mutations require atomics or locks, which serialize anyway and add overhead. The warp-by-warp approach:
- Uses only 2 barriers per warp (16 total)
- Thread 0 does all heap operations (no contention)
- Other threads are idle during merge (but this is a small fraction of total work)
Alternative: warp-level merge within each warp. Each warp could merge its 32 threads' LOCAL_K entries into a warp-local top-K using shuffle operations, then only 8 warp leaders contribute to the shared heap. This reduces heap insertions from 4096 to 8×K = 2048. This is a valid optimization (see §6).
3.3 Grid-Stride Loop for Large V
for (int v = tid; v < V; v += BLOCK_THREADS) {
// process row[v]
}
For V=50257, BLOCK_THREADS=256: each thread processes ⌈50257/256⌉ = 197 elements.
Benefits:
- Works for any V (no template parameter needed)
- Good load balancing (threads process nearly equal elements)
- First iteration is coalesced; subsequent iterations are also coalesced
Trade-off: Strided access within a thread means poor L2 cache reuse. However, for V=50K, the entire row fits in L2 (200 KB on Ampere), so re-reading across phases benefits from L2 cache.
4. Complexity Analysis
4.1 Bandwidth vs. Compute Bound
Parameters: B=1, T=1, V=50257, K=256
| Metric | Value |
|---|---|
| Global memory reads | 12 × 50257 × 4B = 2.41 MB |
| Global memory writes | 8 × 256 = 2.05 KB |
| Shared memory ops | ~32K (heap) + ~4K (staging) = ~36K |
| expf() calls | 2 × 50257 = 100,514 |
| Comparisons | 50257 × LOCAL_K × 256 ≈ 163M (local top-K inserts) |
| Heap sifts | 4096 × log₂(256) = 32,768 |
Bandwidth requirement: 2.41 MB per (b,t). On H100 (3.35 TB/s): 2.41 MB / 3.35 TB/s = 0.72 μs (theoretical minimum).
Compute requirement: 100,514 expf() calls. On H100 (194 TFLOPS FP32): expf ≈ 50 cycles → 5.0M cycles / 1.5 GHz = 3.3 μs.
Verdict: COMPUTE-BOUND. The kernel is limited by expf() throughput, not memory bandwidth.
4.2 Scaling with V
| V | Global Reads | expf() calls | Bandwidth (μs) | Compute (μs) | Bound |
|---|---|---|---|---|---|
| 10K | 480 KB | 20K | 0.14 | 0.67 | Compute |
| 50K | 2.41 MB | 100K | 0.72 | 3.3 | Compute |
| 100K | 4.82 MB | 200K | 1.44 | 6.6 | Compute |
| 500K | 24.1 MB | 1M | 7.2 | 33 | Compute |
| 1M | 48.2 MB | 2M | 14.4 | 66 | Compute |
The kernel remains compute-bound across all practical V values.
4.3 Scaling with K
| K | Heap ops | Sort ops | Impact |
|---|---|---|---|
| 16 | 512 × 4 = 2K | 256 | Negligible |
| 64 | 4096 × 6 = 25K | 4K | Small |
| 256 | 4096 × 8 = 33K | 66K | Moderate |
| 1024 | 4096 × 10 = 41K | 1M | Significant |
For K > 256, the heap operations and sort become noticeable. Consider:
- Increasing LOCAL_K to maintain oversampling ratio
- Using a more efficient merge (warp-level top-K within each warp)
- Parallel sort (bitonic sort across threads)
5. Comparison to Naive Implementation
Naive Approach
# Python pseudocode
probs = softmax(logits) # Materialize [B, T, V] in global memory
top_idx, top_prob = topk(probs, K) # Read [B, T, V], write [B, T, K]
Comparison Table
| Metric | Naive | Fused Kernel | Speedup |
|---|---|---|---|
| Global reads | 4V (logits) + 4V (probs) = 8V | 12V (logits × 3) | 0.67× |
| Global writes | 4V (probs) + 8K (output) | 8K (output only) | V/K × |
| Peak memory | 4V + 8K | 8K | V/K × |
| expf() calls | V (softmax) | 2V (phase 2 + 3) | 0.5× |
| Numerical stability | Depends on softmax impl | Guaranteed (max subtraction) | — |
Key Insight: Memory Savings Dominate
For V=50257, K=256:
- Naive: writes 4 × 50257 = 201 KB of softmax probabilities to global memory
- Fused: writes only 8 × 256 = 2 KB of output
The fused kernel reads 50% more (12V vs 8V) but avoids writing the entire softmax matrix. For large V, the write savings dominate:
Naive bandwidth: 8V + 8K = 8V(1 + K/V) ≈ 8V
Fused bandwidth: 12V + 8K = 12V(1 + K/(3V)) ≈ 12V
Ratio: 12V / 8V = 1.5× more reads, but 0 writes vs 4V writes.
Net: fused saves 4V - 8K = 4V(1 - 2K/V) bytes.
For V=50257, K=256: saves 4 × 50257 - 8 × 256 = 192 KB per (b,t).
When Naive Wins
The naive approach can be faster when:
- V is small (V < 1024): the overhead of 3 passes isn't worth it
- You need the full softmax for other operations (e.g., KL divergence)
- Hardware has very high bandwidth relative to compute (e.g., HBM3)
When Fused Wins
The fused kernel dominates when:
- V is large (V > 10K): memory savings are significant
- Memory is the bottleneck (e.g., mobile, edge devices)
- You only need top-K (common in LLM sampling)
- Batch size is small (B=1): one block per (b,t) means no inter-block sync
6. Further Optimizations
6.1 Warp-Level Top-K Merge (Recommended)
Instead of merging all 4096 candidates through a single thread, each warp merges its 32 threads' LOCAL_K entries into a warp-local top-K using shuffle:
// Each warp: 32 threads × LOCAL_K = 512 entries → top-K within warp
// Use warp shuffle to find top-K in O(K × WARP_SIZE) operations
// Then only 8 warp leaders contribute to shared heap
Benefit: Reduces heap insertions from 4096 to 8 × K = 2048. Complexity: Moderate — requires warp-level selection algorithm.
6.2 Float16/BFloat16 Support
For LLM workloads, logits are often in FP16/BF16:
// Use __hexp2() for half-precision exp
// Use __shfl_xor_sync with half-precision values
// Promote to FP32 only for final softmax computation
Benefit: 2× less global memory bandwidth, 2× more throughput. Trade-off: Slight numerical precision loss (acceptable for top-K).
6.3 Vectorized Memory Access
// Read 4 floats at once (128-bit load)
float4 val = reinterpret_cast<const float4*>(&row[v])[0];
Benefit: 4× fewer memory instructions, better utilization of memory bandwidth. Constraint: V must be divisible by 4, BLOCK_THREADS must be divisible by 4.
6.4 Persistent Blocks for Large B×T
For large B×T, launch fewer blocks and have each block process multiple (b,t):
int bid = blockIdx.x * GRID_STRIDE + threadIdx.x;
while (bid < B * T) {
process(bid);
bid += GRID_STRIDE * BLOCK_THREADS;
}
Benefit: Better occupancy, hides memory latency.
6.5 Asynchronous Copy (Hopper+)
On H100+, use ld.global.nc.v4.f32 (non-coherent load) for the logits reads:
// Compiler hint: these values won't be modified
#pragma unroll
for (int v = tid; v < V; v += BLOCK_THREADS) {
float val = __ldg(&row[v]); // cacheable load
// ...
}
Benefit: Better L2 cache utilization across the three passes.
6.6 Single-Pass Online Algorithm
Track a running max and re-normalize:
float local_max = -FLT_MAX;
float local_sum = 0.0f;
LocalTopK<LOCAL_K> local_topk;
for (int v = tid; v < V; v += BLOCK_THREADS) {
float x = row[v];
if (x > local_max) {
// Re-normalize all previous values
float old_max = local_max;
local_max = x;
local_sum = 0.0f;
// Re-insert all local_topk entries with new normalization
// ... (complex)
}
float prob = expf(x - local_max);
local_sum += prob;
// ...
}
Benefit: Single pass over V (4V reads instead of 12V). Trade-off: Complex, potential numerical issues, re-normalization overhead. Verdict: Not recommended unless V is extremely large (>1M).
Appendix: Kernel Instantiation
// Launch for LLaMA (V=50257, K=256)
launch_fused_softmax_topk<256>(d_logits, d_top_idx, d_top_prob, 1, 1, 50257);
// Launch for GPT-2 (V=50257, K=50)
launch_fused_softmax_topk<50>(d_logits, d_top_idx, d_top_prob, 1, 1, 50257);
// Launch for batched inference (B=32, T=128, V=32000, K=128)
launch_fused_softmax_topk<128>(d_logits, d_top_idx, d_top_prob, 32, 128, 32000);