Files
llm_programming_tests/qwen36/fuse/ANALYSIS.md
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

13 KiB
Raw Blame History

Fused Softmax + Top-K Kernel — Design Analysis

Table of Contents

  1. Architecture Overview
  2. Memory Access Pattern
  3. Warp-Level Optimization Strategy
  4. Complexity Analysis
  5. Comparison to Naive Implementation
  6. Further Optimizations

1. Architecture Overview

Block Assignment

Grid:  B × T blocks (one block per (b, t) position)
Block: 256 threads per block

Three-Phase Pipeline (per block)

Phase 1:  Find max(logits[b,t,:])          → numerical stability anchor
Phase 2:  Compute Σexp(xᵢ - max)           → log-sum-exp denominator
Phase 3:  Compute softmax + collect top-K   → register-local buffers
Phase 4:  Merge local buffers → shared heap → global top-K
Phase 5:  Sort + write-back                 → output [B,T,K]

Why Three Passes Over V?

You might wonder why we don't do this in one pass. The answer is numerical stability:

softmax(xᵢ) = exp(xᵢ) / Σⱼ exp(xⱼ)

Without knowing the max first, exp(xᵢ) can overflow for large logits. The standard trick is:

softmax(xᵢ) = exp(xᵢ - max) / Σⱼ exp(xⱼ - max)

This requires knowing max before computing any softmax values, hence two passes (max reduction, then softmax computation).

Could we do it in one pass? Yes, with an online algorithm that tracks a running max and re-normalizes, but this adds complexity and potential numerical issues. The two-pass approach is simpler, correct, and the extra V reads are coalesced.


2. Memory Access Pattern

Global Memory Reads

Phase Access Pattern Bytes Read Coalesced?
Phase 1 row[tid], row[tid+256], ... 4V First iteration
Phase 2 row[tid], row[tid+256], ... 4V First iteration
Phase 3 row[tid], row[tid+256], ... 4V First iteration
Total 12V

For V=50257: 12 × 50257 × 4B ≈ 2.4 MB read per (b,t).

Coalescing analysis:

  • First iteration: threads 0-255 read row[0] through row[255] → perfectly coalesced into ~8-16 128-byte transactions (depending on alignment).
  • Subsequent iterations: threads read row[256] through row[511], etc. → also coalesced.
  • Stride within a thread (256 elements apart) doesn't affect coalescing — coalescing is about consecutive threads accessing consecutive addresses.

Global Memory Writes

Output Bytes Written
top_idx[B,T,K] 4BK
top_prob[B,T,K] 4BK
Total 8BK

For B=1, T=1, K=256: 8 × 256 = 2048 B (negligible).

Shared Memory Usage

Buffer Size (K=256) Access Pattern
s_warp_max[8] 32 B Write: 8 threads, Read: warp 0
s_warp_sum[8] 32 B Write: 8 threads, Read: warp 0
s_heap_vals[256] 1024 B Write: all (init), Read/Write: thread 0
s_heap_idxs[256] 1024 B Write: all (init), Read/Write: thread 0
s_stage_vals[512] 2048 B Write: active warp, Read: thread 0
s_stage_idxs[512] 2048 B Write: active warp, Read: thread 0
Total 6208 B

Well within the 48 KB shared memory limit per SM.

Register Usage (per thread)

Variable Count
LocalTopK<16>::vals 16 floats = 64 B
LocalTopK<16>::idxs 16 ints = 64 B
Loop counters, temporaries ~10 registers
Total ~40 registers

With 256 threads/block and 40 registers/thread: 10,240 registers per block. On Ampere (64K registers/SM): fits 6 blocks → 1536 threads → good occupancy.


3. Warp-Level Optimization Strategy

3.1 Shuffle-Based Reductions

Problem: Traditional reductions use shared memory + sync barriers.

Our approach: __shfl_xor_sync (warp shuffle) — data moves directly between thread registers within a warp, zero shared memory, zero global memory.

warp_max(val):
    for offset in [16, 8, 4, 2, 1]:
        other = __shfl_xor_sync(mask, val, offset)
        val = max(val, other)
    return val

Latency: 5 shuffle operations × ~3 cycles = ~15 cycles per reduction. vs. shared memory: ~5 cycles per access + barrier overhead = ~30+ cycles.

3.2 Warp-Level Merge Strategy

The merge of local top-K buffers into the shared heap uses a warp-by-warp strategy:

for each warp w in [0, 7]:
    if warp_id == w:
        write LOCAL_K entries to staging buffer
    __syncthreads()
    if tid == 0:
        merge staging into shared heap
    __syncthreads()

Why not all threads merge concurrently? Concurrent heap mutations require atomics or locks, which serialize anyway and add overhead. The warp-by-warp approach:

  • Uses only 2 barriers per warp (16 total)
  • Thread 0 does all heap operations (no contention)
  • Other threads are idle during merge (but this is a small fraction of total work)

Alternative: warp-level merge within each warp. Each warp could merge its 32 threads' LOCAL_K entries into a warp-local top-K using shuffle operations, then only 8 warp leaders contribute to the shared heap. This reduces heap insertions from 4096 to 8×K = 2048. This is a valid optimization (see §6).

3.3 Grid-Stride Loop for Large V

for (int v = tid; v < V; v += BLOCK_THREADS) {
    // process row[v]
}

For V=50257, BLOCK_THREADS=256: each thread processes ⌈50257/256⌉ = 197 elements.

Benefits:

  • Works for any V (no template parameter needed)
  • Good load balancing (threads process nearly equal elements)
  • First iteration is coalesced; subsequent iterations are also coalesced

Trade-off: Strided access within a thread means poor L2 cache reuse. However, for V=50K, the entire row fits in L2 (200 KB on Ampere), so re-reading across phases benefits from L2 cache.


4. Complexity Analysis

4.1 Bandwidth vs. Compute Bound

Parameters: B=1, T=1, V=50257, K=256

Metric Value
Global memory reads 12 × 50257 × 4B = 2.41 MB
Global memory writes 8 × 256 = 2.05 KB
Shared memory ops ~32K (heap) + ~4K (staging) = ~36K
expf() calls 2 × 50257 = 100,514
Comparisons 50257 × LOCAL_K × 256 ≈ 163M (local top-K inserts)
Heap sifts 4096 × log₂(256) = 32,768

Bandwidth requirement: 2.41 MB per (b,t). On H100 (3.35 TB/s): 2.41 MB / 3.35 TB/s = 0.72 μs (theoretical minimum).

Compute requirement: 100,514 expf() calls. On H100 (194 TFLOPS FP32): expf ≈ 50 cycles → 5.0M cycles / 1.5 GHz = 3.3 μs.

Verdict: COMPUTE-BOUND. The kernel is limited by expf() throughput, not memory bandwidth.

4.2 Scaling with V

V Global Reads expf() calls Bandwidth (μs) Compute (μs) Bound
10K 480 KB 20K 0.14 0.67 Compute
50K 2.41 MB 100K 0.72 3.3 Compute
100K 4.82 MB 200K 1.44 6.6 Compute
500K 24.1 MB 1M 7.2 33 Compute
1M 48.2 MB 2M 14.4 66 Compute

The kernel remains compute-bound across all practical V values.

4.3 Scaling with K

K Heap ops Sort ops Impact
16 512 × 4 = 2K 256 Negligible
64 4096 × 6 = 25K 4K Small
256 4096 × 8 = 33K 66K Moderate
1024 4096 × 10 = 41K 1M Significant

For K > 256, the heap operations and sort become noticeable. Consider:

  • Increasing LOCAL_K to maintain oversampling ratio
  • Using a more efficient merge (warp-level top-K within each warp)
  • Parallel sort (bitonic sort across threads)

5. Comparison to Naive Implementation

Naive Approach

# Python pseudocode
probs = softmax(logits)           # Materialize [B, T, V] in global memory
top_idx, top_prob = topk(probs, K)  # Read [B, T, V], write [B, T, K]

Comparison Table

Metric Naive Fused Kernel Speedup
Global reads 4V (logits) + 4V (probs) = 8V 12V (logits × 3) 0.67×
Global writes 4V (probs) + 8K (output) 8K (output only) V/K ×
Peak memory 4V + 8K 8K V/K ×
expf() calls V (softmax) 2V (phase 2 + 3) 0.5×
Numerical stability Depends on softmax impl Guaranteed (max subtraction)

Key Insight: Memory Savings Dominate

For V=50257, K=256:

  • Naive: writes 4 × 50257 = 201 KB of softmax probabilities to global memory
  • Fused: writes only 8 × 256 = 2 KB of output

The fused kernel reads 50% more (12V vs 8V) but avoids writing the entire softmax matrix. For large V, the write savings dominate:

Naive bandwidth:  8V + 8K = 8V(1 + K/V) ≈ 8V
Fused bandwidth:  12V + 8K = 12V(1 + K/(3V)) ≈ 12V

Ratio: 12V / 8V = 1.5× more reads, but 0 writes vs 4V writes.
Net: fused saves 4V - 8K = 4V(1 - 2K/V) bytes.

For V=50257, K=256: saves 4 × 50257 - 8 × 256 = 192 KB per (b,t).

When Naive Wins

The naive approach can be faster when:

  1. V is small (V < 1024): the overhead of 3 passes isn't worth it
  2. You need the full softmax for other operations (e.g., KL divergence)
  3. Hardware has very high bandwidth relative to compute (e.g., HBM3)

When Fused Wins

The fused kernel dominates when:

  1. V is large (V > 10K): memory savings are significant
  2. Memory is the bottleneck (e.g., mobile, edge devices)
  3. You only need top-K (common in LLM sampling)
  4. Batch size is small (B=1): one block per (b,t) means no inter-block sync

6. Further Optimizations

Instead of merging all 4096 candidates through a single thread, each warp merges its 32 threads' LOCAL_K entries into a warp-local top-K using shuffle:

// Each warp: 32 threads × LOCAL_K = 512 entries → top-K within warp
// Use warp shuffle to find top-K in O(K × WARP_SIZE) operations
// Then only 8 warp leaders contribute to shared heap

Benefit: Reduces heap insertions from 4096 to 8 × K = 2048. Complexity: Moderate — requires warp-level selection algorithm.

6.2 Float16/BFloat16 Support

For LLM workloads, logits are often in FP16/BF16:

// Use __hexp2() for half-precision exp
// Use __shfl_xor_sync with half-precision values
// Promote to FP32 only for final softmax computation

Benefit: 2× less global memory bandwidth, 2× more throughput. Trade-off: Slight numerical precision loss (acceptable for top-K).

6.3 Vectorized Memory Access

// Read 4 floats at once (128-bit load)
float4 val = reinterpret_cast<const float4*>(&row[v])[0];

Benefit: 4× fewer memory instructions, better utilization of memory bandwidth. Constraint: V must be divisible by 4, BLOCK_THREADS must be divisible by 4.

6.4 Persistent Blocks for Large B×T

For large B×T, launch fewer blocks and have each block process multiple (b,t):

int bid = blockIdx.x * GRID_STRIDE + threadIdx.x;
while (bid < B * T) {
    process(bid);
    bid += GRID_STRIDE * BLOCK_THREADS;
}

Benefit: Better occupancy, hides memory latency.

6.5 Asynchronous Copy (Hopper+)

On H100+, use ld.global.nc.v4.f32 (non-coherent load) for the logits reads:

// Compiler hint: these values won't be modified
#pragma unroll
for (int v = tid; v < V; v += BLOCK_THREADS) {
    float val = __ldg(&row[v]);  // cacheable load
    // ...
}

Benefit: Better L2 cache utilization across the three passes.

6.6 Single-Pass Online Algorithm

Track a running max and re-normalize:

float local_max = -FLT_MAX;
float local_sum = 0.0f;
LocalTopK<LOCAL_K> local_topk;

for (int v = tid; v < V; v += BLOCK_THREADS) {
    float x = row[v];
    if (x > local_max) {
        // Re-normalize all previous values
        float old_max = local_max;
        local_max = x;
        local_sum = 0.0f;
        // Re-insert all local_topk entries with new normalization
        // ... (complex)
    }
    float prob = expf(x - local_max);
    local_sum += prob;
    // ...
}

Benefit: Single pass over V (4V reads instead of 12V). Trade-off: Complex, potential numerical issues, re-normalization overhead. Verdict: Not recommended unless V is extremely large (>1M).


Appendix: Kernel Instantiation

// Launch for LLaMA (V=50257, K=256)
launch_fused_softmax_topk<256>(d_logits, d_top_idx, d_top_prob, 1, 1, 50257);

// Launch for GPT-2 (V=50257, K=50)
launch_fused_softmax_topk<50>(d_logits, d_top_idx, d_top_prob, 1, 1, 50257);

// Launch for batched inference (B=32, T=128, V=32000, K=128)
launch_fused_softmax_topk<128>(d_logits, d_top_idx, d_top_prob, 32, 128, 32000);