feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
This commit is contained in:
@@ -0,0 +1,385 @@
|
||||
# Fused Softmax + Top-K: Design Document
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Given logits `[B, T, V]` (e.g., batch=64, seq_len=128, vocab=50257), produce:
|
||||
- **indices** `[B, T, K]` — the K highest-probability token indices per row
|
||||
- **probs** `[B, T, K]` — their softmax probabilities
|
||||
|
||||
**Constraint:** Never write the full V-length softmax vector to global memory.
|
||||
|
||||
---
|
||||
|
||||
## 2. Algorithm: Online Softmax + Register Min-Heap
|
||||
|
||||
### 2.1 Core Idea
|
||||
|
||||
We fuse three operations — **softmax computation**, **top-K selection**, and **probability rescaling** — into a single pass over the logits. This is an instance of the *online softmax* algorithm (Milakov & Gimelshein, 2018) extended with a streaming top-K heap.
|
||||
|
||||
### 2.2 Online Softmax Recurrence
|
||||
|
||||
Standard softmax requires two passes: one for the max, one for the sum-of-exps. The online variant maintains running statistics:
|
||||
|
||||
```
|
||||
m_j = max(x_0, ..., x_j) // running maximum
|
||||
d_j = Σ_{i≤j} exp(x_i - m_j) // running sum, always relative to current max
|
||||
```
|
||||
|
||||
Update rule for each new element `x_j`:
|
||||
```
|
||||
m_{j} = max(m_{j-1}, x_j)
|
||||
d_{j} = d_{j-1} * exp(m_{j-1} - m_{j}) + exp(x_j - m_{j})
|
||||
```
|
||||
|
||||
This is **numerically stable** — all exponentials use `x - m_j` where `m_j` is the running max, so no term exceeds `exp(0) = 1`.
|
||||
|
||||
### 2.3 Streaming Top-K Heap
|
||||
|
||||
Simultaneously, each thread maintains a sorted array of size K in registers:
|
||||
|
||||
```
|
||||
insert(value, index):
|
||||
if value <= heap[0]: // heap[0] = K-th largest seen so far
|
||||
return // reject — not in top-K
|
||||
find position via linear scan (K ≤ 32, so ~5 compares average)
|
||||
shift lower elements down
|
||||
place new element
|
||||
```
|
||||
|
||||
For K ≤ 32 this register-resident sorted array outperforms a binary heap because:
|
||||
- No indirection / pointer chasing
|
||||
- The GPU's branch predictor handles the predictable comparison pattern well
|
||||
- Register access is ~0 latency vs. shared memory's ~20 cycle latency
|
||||
|
||||
---
|
||||
|
||||
## 3. Kernel Architecture
|
||||
|
||||
### 3.1 Mapping: One Warp per Row
|
||||
|
||||
```
|
||||
Grid: ceil(B*T / WARPS_PER_BLOCK) blocks
|
||||
Block: WARPS_PER_BLOCK × WARP_SIZE threads (default: 8 × 32 = 256)
|
||||
Warp: one (b,t) row
|
||||
```
|
||||
|
||||
Each warp cooperatively processes one row of length V. Lane `j` (0..31) processes elements at indices `j, j+32, j+64, ...`.
|
||||
|
||||
### 3.2 Three-Phase Pipeline
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Phase 1: Local Pass (per-warp, parallel across lanes) │
|
||||
│ │
|
||||
│ Each lane reads V/32 logits in a coalesced strided pattern │
|
||||
│ Each lane maintains: │
|
||||
│ • local_max, local_sum (online softmax statistics) │
|
||||
│ • TopKHeap<K> (K best logits seen by this lane) │
|
||||
│ │
|
||||
│ Warp reduce → warp_max, warp_sum │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ Phase 2: Cross-Warp Merge (shared memory) │
|
||||
│ │
|
||||
│ Only needed when WARPS_PER_BLOCK > 1 (i.e., multiple warps │
|
||||
│ process different rows — they still need to sync for shared │
|
||||
│ memory reuse). Within a single warp, Phase 2 is trivial. │
|
||||
│ │
|
||||
│ • Warp 0 reduces global max/sum from all warps │
|
||||
│ • Each warp writes its local top-K heap to shared memory │
|
||||
│ • Warp 0 merges WARPS_PER_BLOCK heaps → global top-K │
|
||||
│ • Rescale: prob_i = exp(val_i - global_max) / global_sum │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ Phase 3: Write Output │
|
||||
│ │
|
||||
│ Lane 0 of warp 0 writes K (prob, index) pairs to global mem │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 3.3 Data Flow Diagram
|
||||
|
||||
```
|
||||
Global Memory (logits [B,T,V])
|
||||
│
|
||||
▼ coalesced reads, V/32 per lane
|
||||
┌───────────────┐
|
||||
│ Registers │ Lane 0 Lane 1 ... Lane 31
|
||||
│ │ [heap] [heap] [heap]
|
||||
│ │ [lmax] [lmax] [lmax]
|
||||
│ │ [lsum] [lsum] [lsum]
|
||||
└──────┬────────┘
|
||||
│ warp shuffle (reduce_max, reduce_sum)
|
||||
▼
|
||||
┌───────────────┐
|
||||
│ Warp-level │ warp_max, warp_sum (broadcast)
|
||||
│ consensus │ merged heap via shared memory
|
||||
└──────┬────────┘
|
||||
│
|
||||
▼
|
||||
Global Memory (probs [B,T,K], indices [B,T,K])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Memory Access Pattern
|
||||
|
||||
### 4.1 Global Memory Reads (Logits)
|
||||
|
||||
**Pattern: Strided coalesced access**
|
||||
|
||||
```
|
||||
Warp for row r reads logits[r*V + 0], logits[r*V + 1], ..., logits[r*V + V-1]
|
||||
|
||||
Lane 0: reads indices 0, 32, 64, 96, ...
|
||||
Lane 1: reads indices 1, 33, 65, 97, ...
|
||||
...
|
||||
Lane 31: reads indices 31, 63, 95, 127, ...
|
||||
```
|
||||
|
||||
Consecutive lanes read consecutive addresses → **perfectly coalesced** 128-byte transactions. Each 128-byte cache line is fully utilized by 32 `float` values.
|
||||
|
||||
**Memory efficiency:** V reads per row, 100% coalesced. No redundant loads.
|
||||
|
||||
### 4.2 Global Memory Writes (Output)
|
||||
|
||||
Each row writes exactly `2K` values (K probabilities + K indices). For K=10, that's 80 bytes — negligible compared to reading V×4 bytes (200KB for V=50k).
|
||||
|
||||
**Writes are coalesced within a warp** because consecutive warps write consecutive rows, and lane 0 handles the output for its row.
|
||||
|
||||
### 4.3 Shared Memory
|
||||
|
||||
Used for cross-warp heap merge. Total footprint per block:
|
||||
|
||||
```
|
||||
float warp_max[8] = 32 bytes
|
||||
float warp_sum[8] = 32 bytes
|
||||
float heap_buf[8][32] = 1024 bytes
|
||||
int idx_buf[8][32] = 1024 bytes
|
||||
─────────
|
||||
Total ≈ 2 KB
|
||||
```
|
||||
|
||||
Well within the 48KB shared memory limit. **No bank conflicts** because each warp writes to a different row of `heap_buf[warp_id][...]`, and during the merge phase only warp 0 reads (sequentially, from its own perspective).
|
||||
|
||||
### 4.4 Register Usage
|
||||
|
||||
Per thread:
|
||||
- Online softmax state: 2 floats (8 bytes)
|
||||
- TopKHeap<K=10>: 10 floats + 10 ints (80 bytes)
|
||||
- Loop variables: ~4 floats (16 bytes)
|
||||
- **Total: ~104 bytes/thread**
|
||||
|
||||
For a 256-thread block: ~26 KB of register file usage. Comfortably fits modern GPU register files (64KB–256KB per SM).
|
||||
|
||||
---
|
||||
|
||||
## 5. Warp-Level Optimization Strategy
|
||||
|
||||
### 5.1 Shuffle-Based Reductions
|
||||
|
||||
The max and sum reductions use `__shfl_xor_sync` (butterfly pattern):
|
||||
|
||||
```
|
||||
Step 1: exchange with lane ^ 16 → 16 pairs
|
||||
Step 2: exchange with lane ^ 8 → 8 quads
|
||||
Step 3: exchange with lane ^ 4 → 4 groups of 8
|
||||
Step 4: exchange with lane ^ 2 → 2 groups of 16
|
||||
Step 5: exchange with lane ^ 1 → 1 group of 32
|
||||
```
|
||||
|
||||
5 steps × 2 ops (max + sum) = **10 shuffle instructions total**. No shared memory, no synchronization needed within a warp.
|
||||
|
||||
### 5.2 Why Not One Warp Per Row with Vector Loads?
|
||||
|
||||
Alternative: use a wider type (`float4`) to read 4 values per lane, reducing the loop iterations by 4×. This is beneficial when V is very large:
|
||||
|
||||
```
|
||||
// Vectorized load variant (Phase 1 inner loop)
|
||||
float4 vec = reinterpret_cast<const float4*>(logits_row)[v];
|
||||
float x0 = vec.x, x1 = vec.y, x2 = vec.z, x3 = vec.w;
|
||||
// Process 4 elements per iteration
|
||||
```
|
||||
|
||||
**Trade-off:** Increases register pressure (4× more values live at once) but reduces loop overhead and improves memory throughput via wider transactions. Recommended when V > 10K.
|
||||
|
||||
### 5.3 Occupancy Considerations
|
||||
|
||||
| Parameter | Value |
|
||||
|--------------------|---------|
|
||||
| Threads/block | 256 |
|
||||
| Registers/thread | ~26 |
|
||||
| Shared memory/block| ~2 KB |
|
||||
| Blocks/SM (A100) | 16–20 |
|
||||
| Rows in flight/SM | 128–160 |
|
||||
|
||||
The kernel is **not register-heavy** and uses minimal shared memory, allowing high occupancy and effective latency hiding.
|
||||
|
||||
---
|
||||
|
||||
## 6. Complexity Analysis
|
||||
|
||||
### 6.1 Per-Row Work
|
||||
|
||||
| Operation | Reads | Writes | Compute |
|
||||
|--------------------------|----------|---------|-------------------|
|
||||
| Read logits | V | 0 | 0 |
|
||||
| Online max/sum | 0* | 0 | V × (1 max + 1 exp + 2 FMAs) |
|
||||
| Top-K heap insert | 0* | 0 | V × ~5 compares + ~2.5 moves avg |
|
||||
| Warp reduce | 0 | 0 | 10 shuffles |
|
||||
| Final rescale (K values) | 0* | 2K | K × (1 exp + 1 mul) |
|
||||
| **Total** | **V** | **2K** | **~6V + 10 + 2K FLOPs** |
|
||||
|
||||
*All intermediate values are in registers.
|
||||
|
||||
### 6.2 Bandwidth vs Compute Bound Analysis
|
||||
|
||||
For V = 50,257 and K = 10:
|
||||
|
||||
**Memory traffic per row:**
|
||||
```
|
||||
Reads: V × 4 bytes = 201 KB
|
||||
Writes: K × 8 bytes = 80 bytes
|
||||
Total: ~201 KB
|
||||
```
|
||||
|
||||
**Compute per row:**
|
||||
```
|
||||
~6 × 50,257 = 301,542 FLOPs (approximate)
|
||||
```
|
||||
|
||||
**Arithmetic intensity:**
|
||||
```
|
||||
AI = 301,542 FLOPs / 201,028 bytes ≈ 1.5 FLOP/byte
|
||||
```
|
||||
|
||||
**NVIDIA A100 specs:**
|
||||
```
|
||||
Peak bandwidth: 2039 GB/s → compute/bw ratio = 19.5 TFLOPS / 2039 GB/s ≈ 9.6 FLOP/byte
|
||||
Peak FP32: 19.5 TFLOPS
|
||||
```
|
||||
|
||||
**Conclusion: AI (1.5) << ratio (9.6) → kernel is BANDWIDTH BOUND.**
|
||||
|
||||
This means:
|
||||
1. **The bottleneck is reading V logits from global memory**, not compute.
|
||||
2. Optimizations should focus on memory access patterns (coalescing, caching) not arithmetic.
|
||||
3. The fusion saves one full write+read of the `[B,T,V]` tensor (~201 KB/row), directly translating to ~2× end-to-end speedup vs. separate softmax + top-K.
|
||||
|
||||
### 6.3 Comparison to Naive Implementation
|
||||
|
||||
```
|
||||
Naive (separate kernels):
|
||||
Kernel 1: softmax
|
||||
Read V logits → Write V probabilities (201 KB + 201 KB = 402 KB I/O)
|
||||
Kernel 2: top-k
|
||||
Read V probabilities → Write K results (201 KB + 80 bytes = 201 KB I/O)
|
||||
Total I/O: ~603 KB/row
|
||||
Kernel launch overhead: 2×
|
||||
|
||||
Fused (this kernel):
|
||||
Read V logits → Write K results (201 KB + 80 bytes = 201 KB I/O)
|
||||
Total I/O: ~201 KB/row
|
||||
Kernel launch overhead: 1×
|
||||
|
||||
Savings:
|
||||
Memory I/O: 3× reduction (603 KB → 201 KB per row)
|
||||
Kernel launches: 2× reduction
|
||||
Effective speedup: ~2.5–3× (bandwidth-bound, so I/O directly maps to time)
|
||||
```
|
||||
|
||||
For a real workload (B=64, T=128, V=50257):
|
||||
```
|
||||
Naive: 64 × 128 × 603 KB = 4.7 GB global memory traffic
|
||||
Fused: 64 × 128 × 201 KB = 1.6 GB global memory traffic
|
||||
Savings: 3.1 GB avoided
|
||||
|
||||
At A100 bandwidth (2039 GB/s):
|
||||
Naive time: ~2.3 ms
|
||||
Fused time: ~0.8 ms
|
||||
Speedup: 2.9×
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Advanced Optimizations
|
||||
|
||||
### 7.1 FP16 Input with FP32 Accumulation
|
||||
|
||||
For mixed-precision workloads (logits stored as `__half`):
|
||||
|
||||
```cuda
|
||||
// Read 2 values per load, accumulate in FP32
|
||||
__half2 h2 = reinterpret_cast<const __half2*>(logits_row)[v];
|
||||
float x0 = __half2float(h2.x);
|
||||
float x1 = __half2float(h2.y);
|
||||
```
|
||||
|
||||
This halves memory traffic (V × 2 bytes instead of V × 4 bytes), doubling throughput for bandwidth-bound workloads.
|
||||
|
||||
### 7.2 Multi-Row Per Warp (for Small V)
|
||||
|
||||
When V < 1024, each warp has spare bandwidth. Assign multiple rows per warp:
|
||||
|
||||
```
|
||||
for (int row_offset = 0; row_offset < ROWS_PER_WARP; row_offset++) {
|
||||
int row = base_row + row_offset;
|
||||
// ... process row ...
|
||||
}
|
||||
```
|
||||
|
||||
This amortizes warp-management overhead and improves occupancy for small-V cases.
|
||||
|
||||
### 7.3 Async Copy (Hopper/Ada Lovelace)
|
||||
|
||||
```cuda
|
||||
// Pipeline loads with cp.async to overlap compute and memory
|
||||
cp.async.ca.shared.global [smem_ptr], [gmem_ptr], 16;
|
||||
```
|
||||
|
||||
Overlaps the next chunk's load with the current chunk's heap insertions. Beneficial when V > 10K and the compute path has enough latency to hide.
|
||||
|
||||
### 7.4 Warp-Level Heap Merge for Large WARPS_PER_BLOCK
|
||||
|
||||
When using many warps per block, the serial merge by warp 0 becomes a bottleneck. Alternative:
|
||||
|
||||
```
|
||||
1. Each warp writes its K values to shared memory
|
||||
2. Tournament merge using warp shuffles:
|
||||
- Round 1: warp 0 vs warp 1, warp 2 vs warp 3, ...
|
||||
- Round 2: winners merge
|
||||
- Final: one warp produces global top-K
|
||||
3. Each round uses warp-cooperative merge of two sorted arrays
|
||||
```
|
||||
|
||||
This reduces merge complexity from O(WARPS × K) to O(K × log(WARPS)).
|
||||
|
||||
---
|
||||
|
||||
## 8. Correctness: Numerical Stability
|
||||
|
||||
The algorithm maintains numerical stability through three mechanisms:
|
||||
|
||||
1. **Subtract running max before exp:** All calls to `expf()` use `x - current_max`, ensuring the argument is ≤ 0. No overflow possible.
|
||||
|
||||
2. **Rescaling on max update:** When `current_max` increases, we multiply the running sum by `exp(old_max - new_max)`, which is in (0, 1]. No overflow; minimal underflow risk.
|
||||
|
||||
3. **Final rescaling:** `prob_i = exp(val_i - global_max) / global_sum`. Since `global_sum ≥ 1.0` (it includes `exp(global_max - global_max) = 1.0`), division is safe.
|
||||
|
||||
**Comparison with log-sum-exp:**
|
||||
The online algorithm computes exactly `Σ exp(x_i - max(x))` which is equivalent to `exp(logsumexp(x) - max(x))`. The final probabilities are identical to standard numerically-stable softmax to within floating-point rounding.
|
||||
|
||||
---
|
||||
|
||||
## 9. Summary Table
|
||||
|
||||
| Metric | Naive (separate) | Fused (this work) | Improvement |
|
||||
|---------------------------|--------------------|---------------------|----------------|
|
||||
| Global memory reads | 2V per row | V per row | 2× |
|
||||
| Global memory writes | V + 2K per row | 2K per row | ~V/(2K)× |
|
||||
| Total I/O per row | ~3V | ~V | 3× |
|
||||
| Kernel launches | 2 | 1 | 2× |
|
||||
| Intermediate tensor | V floats/row | 0 (registers) | ∞ |
|
||||
| Numerically stable | Yes | Yes | — |
|
||||
| Register pressure | Low | Moderate (~26 regs) | Acceptable |
|
||||
| Shared memory | None | ~2 KB | Minimal |
|
||||
| Bandwidth utilization | Wastes BW on intermediate | Optimal | — |
|
||||
| Effective speedup | Baseline | **2.5–3×** | — |
|
||||
@@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate an ASCII architecture diagram for the fused softmax + top-k kernel.
|
||||
Run: python3 diagram.py
|
||||
"""
|
||||
|
||||
print("""
|
||||
╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
║ FUSED SOFTMAX + TOP-K KERNEL ARCHITECTURE ║
|
||||
╠═══════════════════════════════════════════════════════════════════════════════╣
|
||||
║ ║
|
||||
║ Global Memory Layout ║
|
||||
║ ┌──────────────────────────────────────────────┐ ║
|
||||
║ │ logits [B, T, V] (read-only) │ ║
|
||||
║ │ indices [B, T, K] (write-only) │ ║
|
||||
║ │ probs [B, T, K] (write-only) │ ║
|
||||
║ └──────────────────────────────────────────────┘ ║
|
||||
║ ║
|
||||
║ Thread Mapping: 1 warp = 1 row (b, t) ║
|
||||
║ ┌────────────────────────────────────────────────────────────────────────┐ ║
|
||||
║ │ Block (256 threads = 8 warps) │ ║
|
||||
║ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ ║
|
||||
║ │ │ Warp 0 │ │ Warp 1 │ ... │ Warp 7 │ │ ║
|
||||
║ │ │ row=0 │ │ row=1 │ │ row=7 │ │ ║
|
||||
║ │ │ 32 lanes │ │ 32 lanes │ │ 32 lanes │ │ ║
|
||||
║ │ └────┬─────┘ └────┬─────┘ └────┬─────┘ │ ║
|
||||
║ │ │ │ │ │ ║
|
||||
║ │ ┌────▼─────────────▼──────────────────▼──────────────────────────┐ │ ║
|
||||
║ │ │ Shared Memory (~2 KB) │ │ ║
|
||||
║ │ │ • warp_max[8], warp_sum[8] (32+32 bytes) │ │ ║
|
||||
║ │ │ • heap_buf[8][K], idx_buf[8][K] (2×8×K × 4 bytes) │ │ ║
|
||||
║ │ └───────────────────────────────────────────────────────────────┘ │ ║
|
||||
║ └────────────────────────────────────────────────────────────────────────┘ ║
|
||||
║ ║
|
||||
║ Single Warp Detail (processing row r, V=50257): ║
|
||||
║ ║
|
||||
║ ┌─────────────────────────────────────────────────────────────────────┐ ║
|
||||
║ │ Lane 0 Lane 1 Lane 2 ... Lane 31 │ ║
|
||||
║ │ │ ║
|
||||
║ │ READ: logits[r*V + {0,1,2,...,31}] ← 1 coalesced 128B load │ ║
|
||||
║ │ logits[r*V + {32,33,...,63}] ← next coalesced load │ ║
|
||||
║ │ ... │ ║
|
||||
║ │ logits[r*V + {50224,...,50255}] ← last load │ ║
|
||||
║ │ │ ║
|
||||
║ │ Each lane processes ~V/32 ≈ 1571 elements: │ ║
|
||||
║ │ │ ║
|
||||
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
|
||||
║ │ │ Per-Lane Computation (in REGISTERS): │ │ ║
|
||||
║ │ │ │ │ ║
|
||||
║ │ │ local_max = -∞, local_sum = 0 │ │ ║
|
||||
║ │ │ heap = {(-∞, 0), ..., (-∞, 0)} // K entries │ │ ║
|
||||
║ │ │ │ │ ║
|
||||
║ │ │ for each element x_j at index j: │ │ ║
|
||||
║ │ │ old_max = local_max │ │ ║
|
||||
║ │ │ local_max = max(local_max, x_j) │ │ ║
|
||||
║ │ │ local_sum *= exp(old_max - local_max) // rescale │ │ ║
|
||||
║ │ │ local_sum += exp(x_j - local_max) // add new │ │ ║
|
||||
║ │ │ heap.insert(x_j, j) // O(K) compare+shift │ │ ║
|
||||
║ │ └─────────────────────────────────────────────────────────┘ │ ║
|
||||
║ │ │ │ ║
|
||||
║ │ ▼ Warp Shuffle Reduction │ ║
|
||||
║ │ │ ║
|
||||
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
|
||||
║ │ │ warp_max = reduce_max(local_max) across 32 lanes │ │ ║
|
||||
║ │ │ warp_sum = reduce_sum(local_sum * exp(local_max - │ │ ║
|
||||
║ │ │ warp_max)) across 32 lanes │ │ ║
|
||||
║ │ │ │ │ ║
|
||||
║ │ │ 5 butterfly steps using __shfl_xor_sync: │ │ ║
|
||||
║ │ │ Step 1: ⊕ 16 ── 16↔16 pairs merge │ │ ║
|
||||
║ │ │ Step 2: ⊕ 8 ── 8 groups of 4 merge │ │ ║
|
||||
║ │ │ Step 3: ⊕ 4 ── 4 groups of 8 merge │ │ ║
|
||||
║ │ │ Step 4: ⊕ 2 ── 2 groups of 16 merge │ │ ║
|
||||
║ │ │ Step 5: ⊕ 1 ── final 32-lane consensus │ │ ║
|
||||
║ │ └─────────────────────────────────────────────────────────┘ │ ║
|
||||
║ │ │ │ ║
|
||||
║ │ ▼ Cross-Warp Merge (Phase 2) │ ║
|
||||
║ │ │ ║
|
||||
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
|
||||
║ │ │ 1. Each warp writes its K heap entries → shared memory │ │ ║
|
||||
║ │ │ 2. __syncthreads() │ │ ║
|
||||
║ │ │ 3. Warp 0 merges 8 heaps → global top-K: │ │ ║
|
||||
║ │ │ • Scan 8×K=80 candidates │ │ ║
|
||||
║ │ │ • Keep top K=10 via sorted insertion │ │ ║
|
||||
║ │ │ 4. Rescale to probabilities: │ │ ║
|
||||
║ │ │ prob_i = exp(val_i - global_max) / global_sum │ │ ║
|
||||
║ │ │ 5. Write K × (prob, index) to global memory │ │ ║
|
||||
║ │ └─────────────────────────────────────────────────────────┘ │ ║
|
||||
║ └─────────────────────────────────────────────────────────────────────┘ ║
|
||||
║ ║
|
||||
╠═══════════════════════════════════════════════════════════════════════════════╣
|
||||
║ MEMORY TRAFFIC SUMMARY (per row, V=50257, K=10) ║
|
||||
╠═══════════════════════════════════════════════════════════════════════════════╣
|
||||
║ ║
|
||||
║ ┌───────────────────┬──────────────┬──────────────┬─────────────┐ ║
|
||||
║ │ Implementation │ Reads │ Writes │ Total I/O │ ║
|
||||
║ ├───────────────────┼──────────────┼──────────────┼─────────────┤ ║
|
||||
║ │ Naive (separate) │ 2V = 402 KB │ V+2K = 201KB │ 603 KB │ ║
|
||||
║ │ Fused (this work) │ V = 201 KB │ 2K = 80 B │ 201 KB │ ║
|
||||
║ │ Speedup │ 2× │ 2500× │ 3× │ ║
|
||||
║ └───────────────────┴──────────────┴──────────────┴─────────────┘ ║
|
||||
║ ║
|
||||
║ Full workload (B=64, T=128, V=50257, K=10): ║
|
||||
║ ┌───────────────────┬─────────────┬──────────┬────────────────────┐ ║
|
||||
║ │ Implementation │ Total I/O │ Time* │ Intermediate Tensor│ ║
|
||||
║ ├───────────────────┼─────────────┼──────────┼────────────────────┤ ║
|
||||
║ │ Naive │ 4.7 GB │ ~2.3 ms │ 201 MB (full sm) │ ║
|
||||
║ │ Fused │ 1.6 GB │ ~0.8 ms │ 0 MB (registers) │ ║
|
||||
║ └───────────────────┴─────────────┴──────────┴────────────────────┘ ║
|
||||
║ * Estimated on NVIDIA A100 at peak bandwidth (2039 GB/s) ║
|
||||
║ ║
|
||||
╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
""")
|
||||
@@ -0,0 +1,345 @@
|
||||
// =============================================================================
|
||||
// Fused Softmax + Top-K Kernel
|
||||
// =============================================================================
|
||||
// Input: logits [B, T, V] (row-major, fp32 or fp16)
|
||||
// Output: indices [B, T, K] (int32)
|
||||
// probs [B, T, K] (fp32, softmax probabilities of top-K)
|
||||
//
|
||||
// Key insight: we never materialize the full V-length softmax vector.
|
||||
// We compute the softmax in a single forward pass (online softmax) while
|
||||
// simultaneously maintaining a min-heap of the top-K values seen so far.
|
||||
// =============================================================================
|
||||
|
||||
#pragma once
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// --------------- tunable parameters ---------------
|
||||
#ifndef WARP_SIZE
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
#ifndef HEAP_K
|
||||
// Max K we support; must be a power of 2 for warp-reduce simplicity.
|
||||
// For K <= 32 we keep the heap entirely in registers per warp.
|
||||
#define HEAP_K 32
|
||||
#endif
|
||||
|
||||
// We launch one warp per (b, t) row. Each warp processes V/WARP_SIZE
|
||||
// elements, accumulating partial softmax statistics and a local top-K heap,
|
||||
// then merges heaps across warps in shared memory.
|
||||
//
|
||||
// Block layout: WARPS_PER_BLOCK warps, each handling one row.
|
||||
// Grid layout: ceil(B*T / WARPS_PER_BLOCK) blocks.
|
||||
// Total threads: B * T * WARP_SIZE (every thread in a warp works on one row)
|
||||
|
||||
#ifndef WARPS_PER_BLOCK
|
||||
#define WARPS_PER_BLOCK 8
|
||||
#endif
|
||||
#define BLOCK_SIZE (WARPS_PER_BLOCK * WARP_SIZE)
|
||||
|
||||
// =============================================================================
|
||||
// Min-heap utilities (keeps top-K largest values)
|
||||
// =============================================================================
|
||||
// We store a small sorted array of size K (K <= 32) in registers.
|
||||
// This is faster than a tree-based heap for small K because:
|
||||
// - Insertion is just a single compare + conditional shift
|
||||
// - Cache/coherence is trivial (all registers)
|
||||
// - No pointer chasing
|
||||
|
||||
template <int K>
|
||||
struct TopKHeap {
|
||||
float vals[K]; // sorted ascending (vals[0] is the minimum)
|
||||
int idxs[K];
|
||||
|
||||
__device__ __forceinline__
|
||||
void init() {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < K; i++) {
|
||||
vals[i] = -FLT_MAX;
|
||||
idxs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert if value > current minimum (the K-th largest so far).
|
||||
__device__ __forceinline__
|
||||
void insert(float val, int idx) {
|
||||
if (val <= vals[0]) return; // not in top-K, skip
|
||||
// Linear scan to find insertion point (small K → branch predictor loves it).
|
||||
// For K=32 this is ~5 compares on average, cheaper than binary search overhead.
|
||||
int pos = 0;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < K; i++) {
|
||||
if (val > vals[i]) pos = i; // find last position where val > vals[i]
|
||||
else break;
|
||||
}
|
||||
// Shift elements down: vals[0..pos-1] ← vals[1..pos]
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pos; i++) {
|
||||
vals[i] = vals[i + 1];
|
||||
idxs[i] = idxs[i + 1];
|
||||
}
|
||||
vals[pos] = val;
|
||||
idxs[pos] = idx;
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// Warp-level primitives
|
||||
// =============================================================================
|
||||
|
||||
__device__ __forceinline__
|
||||
float warp_reduce_max(float val) {
|
||||
// Butterfly reduction across 32 lanes
|
||||
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 16));
|
||||
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 4));
|
||||
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 2));
|
||||
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 1));
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
float warp_reduce_sum(float val) {
|
||||
val += __shfl_xor_sync(0xFFFFFFFF, val, 16);
|
||||
val += __shfl_xor_sync(0xFFFFFFFF, val, 8);
|
||||
val += __shfl_xor_sync(0xFFFFFFFF, val, 4);
|
||||
val += __shfl_xor_sync(0xFFFFFFFF, val, 2);
|
||||
val += __shfl_xor_sync(0xFFFFFFFF, val, 1);
|
||||
return val;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Shared memory layout (one per block)
|
||||
// =============================================================================
|
||||
|
||||
struct SharedStorage {
|
||||
// Per-warp partial results for cross-warp merge
|
||||
float warp_max[WARPS_PER_BLOCK]; // partial max
|
||||
float warp_sum[WARPS_PER_BLOCK]; // partial sum of exps
|
||||
// Heap merge buffer: each warp writes its local top-K here
|
||||
float heap_buf[WARPS_PER_BLOCK][HEAP_K];
|
||||
int idx_buf [WARPS_PER_BLOCK][HEAP_K];
|
||||
// Synchronization
|
||||
int barrier_count;
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// Phase 1 — Per-warp local pass over V/WARPS_PER_BLOCK chunks
|
||||
// =============================================================================
|
||||
// Each lane j in warp w processes logits at indices:
|
||||
// j, j + WARP_SIZE, j + 2*WARP_SIZE, ...
|
||||
// covering a strided subset of the V-dimension.
|
||||
//
|
||||
// Online softmax recurrence (per lane):
|
||||
// m_j ← max(m_j, x_j) (local max)
|
||||
// d_j ← d_j * exp(m_old - m_j) + exp(x_j - m_j)
|
||||
//
|
||||
// After the loop we do a warp-all-reduce to get the global max m and sum d
|
||||
// for this row. Then each lane rescales its accumulated exp-sum and
|
||||
// inserts its local top-K candidates into a heap scaled by 1/d.
|
||||
|
||||
template <int K>
|
||||
__device__ __forceinline__
|
||||
void local_pass(
|
||||
const float* __restrict__ logits_row, // pointer to row of length V
|
||||
int V,
|
||||
float& out_max, // warp-reduced max
|
||||
float& out_sum, // warp-reduced sum of exps
|
||||
TopKHeap<K>& heap) // per-lane local top-K
|
||||
{
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
const int warp = threadIdx.x / WARP_SIZE;
|
||||
|
||||
float local_max = -FLT_MAX;
|
||||
float local_sum = 0.0f;
|
||||
|
||||
// Strided loop: lane i processes indices i, i+32, i+64, ...
|
||||
// This gives coalesced global reads because consecutive lanes read
|
||||
// consecutive addresses.
|
||||
for (int v = lane; v < V; v += WARP_SIZE) {
|
||||
float x = logits_row[v];
|
||||
float old_max = local_max;
|
||||
local_max = fmaxf(local_max, x);
|
||||
// Rescale running sum to new max
|
||||
local_sum *= expf(old_max - local_max);
|
||||
local_sum += expf(x - local_max);
|
||||
|
||||
// Track top-K in the original logit space (before exp).
|
||||
// We will rescale to probabilities later using the final max & sum.
|
||||
heap.insert(x, v);
|
||||
}
|
||||
|
||||
// ---- Warp-level reduction for max and sum ----
|
||||
float warp_max = warp_reduce_max(local_max);
|
||||
// Rescale all lane sums to the common warp_max
|
||||
local_sum *= expf(local_max - warp_max);
|
||||
float warp_sum = warp_reduce_sum(local_sum);
|
||||
|
||||
out_max = warp_max;
|
||||
out_sum = warp_sum;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Phase 2 — Cross-warp heap merge in shared memory
|
||||
// =============================================================================
|
||||
// When WARPS_PER_BLOCK > 1, each warp has its own local top-K heap.
|
||||
// We merge by:
|
||||
// 1. Each warp writes its heap to shared memory
|
||||
// 2. __syncthreads()
|
||||
// 3. Lane 0 of warp 0 does a serial K-way merge (K is small, typically 5-50)
|
||||
// over WARPS_PER_BLOCK heaps → global top-K
|
||||
// 4. Rescale values: prob_i = exp(val_i - global_max) / global_sum
|
||||
//
|
||||
// For WARPS_PER_BLOCK == 1 this phase is a no-op (single warp = single row).
|
||||
|
||||
template <int K>
|
||||
__device__ __forceinline__
|
||||
void cross_warp_merge(
|
||||
SharedStorage& smem,
|
||||
float global_max,
|
||||
float global_sum,
|
||||
TopKHeap<K>& heap,
|
||||
int warp_id,
|
||||
int lane_id,
|
||||
float* out_probs, // [K] output
|
||||
int* out_idxs) // [K] output
|
||||
{
|
||||
// Each warp writes its local heap to shared memory
|
||||
if (lane_id < K) {
|
||||
smem.heap_buf[warp_id][lane_id] = heap.vals[K - 1 - lane_id]; // descending
|
||||
smem.idx_buf [warp_id][lane_id] = heap.idxs[K - 1 - lane_id];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Warp 0 merges all heaps
|
||||
if (warp_id == 0) {
|
||||
// Build the global top-K by scanning all warp heaps
|
||||
TopKHeap<K> global_heap;
|
||||
global_heap.init();
|
||||
|
||||
#pragma unroll
|
||||
for (int w = 0; w < WARPS_PER_BLOCK; w++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < K; i++) {
|
||||
float v = smem.heap_buf[w][i];
|
||||
int j = smem.idx_buf [w][i];
|
||||
global_heap.insert(v, j);
|
||||
}
|
||||
}
|
||||
|
||||
// Lane 0 writes the final result (rescaled to probabilities)
|
||||
if (lane_id == 0) {
|
||||
float inv_sum = 1.0f / global_sum;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < K; i++) {
|
||||
// vals are sorted ascending; reverse for output (descending prob)
|
||||
int ki = K - 1 - i;
|
||||
out_probs[i] = expf(global_heap.vals[ki] - global_max) * inv_sum;
|
||||
out_idxs [i] = global_heap.idxs[ki];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Main kernel
|
||||
// =============================================================================
|
||||
|
||||
template <int K>
|
||||
__global__ void fused_softmax_topk_kernel(
|
||||
const float* __restrict__ logits, // [B, T, V]
|
||||
int* __restrict__ out_indices, // [B, T, K]
|
||||
float* __restrict__ out_probs, // [B, T, K]
|
||||
int B, int T, int V)
|
||||
{
|
||||
// One block processes WARPS_PER_BLOCK rows.
|
||||
// Each warp handles one row.
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Map warp → (b, t) row index
|
||||
int row = blockIdx.x * WARPS_PER_BLOCK + warp_id;
|
||||
if (row >= B * T) return;
|
||||
|
||||
int b = row / T;
|
||||
int t = row % T;
|
||||
|
||||
// Pointers for this row
|
||||
const float* logits_row = logits + (size_t)row * V;
|
||||
int* row_out_indices = out_indices + (size_t)row * K;
|
||||
float* row_out_probs = out_probs + (size_t)row * K;
|
||||
|
||||
// Shared memory
|
||||
__shared__ __align__(16) SharedStorage smem;
|
||||
|
||||
// Phase 1: local pass over logits
|
||||
TopKHeap<K> heap;
|
||||
heap.init();
|
||||
|
||||
float warp_max, warp_sum;
|
||||
local_pass<K>(logits_row, V, warp_max, warp_sum, heap);
|
||||
|
||||
// Store partials in shared memory for cross-warp merge
|
||||
if (lane_id == 0) {
|
||||
smem.warp_max[warp_id] = warp_max;
|
||||
smem.warp_sum[warp_id] = warp_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute global max and sum across warps (lane 0 does it)
|
||||
float global_max = -FLT_MAX;
|
||||
float global_sum = 0.0f;
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < WARPS_PER_BLOCK; w++) {
|
||||
if (blockIdx.x * WARPS_PER_BLOCK + w < B * T) {
|
||||
float wm = smem.warp_max[w];
|
||||
float ws = smem.warp_sum[w];
|
||||
float old_max = global_max;
|
||||
global_max = fmaxf(global_max, wm);
|
||||
global_sum *= expf(old_max - global_max);
|
||||
global_sum += ws * expf(wm - global_max);
|
||||
}
|
||||
}
|
||||
smem.warp_max[0] = global_max; // reuse for broadcast
|
||||
smem.warp_sum[0] = global_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
global_max = smem.warp_max[0];
|
||||
global_sum = smem.warp_sum[0];
|
||||
|
||||
// Phase 2: cross-warp heap merge + write output
|
||||
cross_warp_merge<K>(smem, global_max, global_sum,
|
||||
heap, warp_id, lane_id,
|
||||
row_out_probs, row_out_indices);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Host launch wrapper
|
||||
// =============================================================================
|
||||
|
||||
template <int K>
|
||||
void launch_fused_softmax_topk(
|
||||
const float* d_logits,
|
||||
int* d_indices,
|
||||
float* d_probs,
|
||||
int B, int T, int V,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
int total_rows = B * T;
|
||||
int grid = (total_rows + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK;
|
||||
|
||||
size_t smem_bytes = sizeof(SharedStorage);
|
||||
|
||||
fused_softmax_topk_kernel<K>
|
||||
<<<grid, BLOCK_SIZE, smem_bytes, stream>>>(
|
||||
d_logits, d_indices, d_probs, B, T, V);
|
||||
}
|
||||
|
||||
// Explicit instantiation for common K values
|
||||
template void launch_fused_softmax_topk<5>(const float*, int*, float*, int, int, int, cudaStream_t);
|
||||
template void launch_fused_softmax_topk<10>(const float*, int*, float*, int, int, int, cudaStream_t);
|
||||
template void launch_fused_softmax_topk<20>(const float*, int*, float*, int, int, int, cudaStream_t);
|
||||
template void launch_fused_softmax_topk<32>(const float*, int*, float*, int, int, int, cudaStream_t);
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,198 @@
|
||||
// =============================================================================
|
||||
// Test / Benchmark: Fused Softmax + Top-K
|
||||
// =============================================================================
|
||||
// Compile:
|
||||
// nvcc -O3 -arch=sm_80 -o test_fused test_fused.cu fused_softmax_topk.cuh
|
||||
//
|
||||
// Run:
|
||||
// ./test_fused
|
||||
// =============================================================================
|
||||
|
||||
#include "fused_softmax_topk.cuh"
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
// ---------- CPU reference implementation ----------
|
||||
|
||||
void cpu_softmax_topk(const float* logits, int* indices, float* probs,
|
||||
int B, int T, int V, int K) {
|
||||
for (int bt = 0; bt < B * T; bt++) {
|
||||
const float* row = logits + bt * V;
|
||||
int* out_idx = indices + bt * K;
|
||||
float* out_prob = probs + bt * K;
|
||||
|
||||
// Numerically stable softmax
|
||||
float max_val = *std::max_element(row, row + V);
|
||||
float sum = 0.0f;
|
||||
std::vector<float> exp_vals(V);
|
||||
for (int v = 0; v < V; v++) {
|
||||
exp_vals[v] = expf(row[v] - max_val);
|
||||
sum += exp_vals[v];
|
||||
}
|
||||
float inv_sum = 1.0f / sum;
|
||||
for (int v = 0; v < V; v++) {
|
||||
exp_vals[v] *= inv_sum;
|
||||
}
|
||||
|
||||
// Top-K by sorting (simple but correct)
|
||||
std::vector<int> idx(V);
|
||||
std::iota(idx.begin(), idx.end(), 0);
|
||||
std::partial_sort(idx.begin(), idx.begin() + K, idx.end(),
|
||||
[&](int a, int b) { return exp_vals[a] > exp_vals[b]; });
|
||||
|
||||
for (int k = 0; k < K; k++) {
|
||||
out_idx[k] = idx[k];
|
||||
out_prob[k] = exp_vals[idx[k]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Verification ----------
|
||||
|
||||
bool verify(const float* ref_probs, const int* ref_idx,
|
||||
const float* gpu_probs, const int* gpu_idx,
|
||||
int B, int T, int K, float tol = 1e-4f) {
|
||||
bool ok = true;
|
||||
int failures = 0;
|
||||
for (int bt = 0; bt < B * T && failures < 10; bt++) {
|
||||
for (int k = 0; k < K; k++) {
|
||||
int ri = ref_idx[bt * K + k];
|
||||
int gi = gpu_idx[bt * K + k];
|
||||
float rp = ref_probs[bt * K + k];
|
||||
float gp = gpu_probs[bt * K + k];
|
||||
|
||||
// Index must match (probabilities might have ties, but for random data they won't)
|
||||
if (ri != gi) {
|
||||
// Check if probability is close (might be a tie)
|
||||
if (fabsf(rp - gp) > tol) {
|
||||
printf("FAIL [bt=%d, k=%d]: ref_idx=%d gpu_idx=%d ref_prob=%.8f gpu_prob=%.8f\n",
|
||||
bt, k, ri, gi, rp, gp);
|
||||
ok = false;
|
||||
failures++;
|
||||
}
|
||||
}
|
||||
|
||||
// Probability must match
|
||||
if (fabsf(rp - gp) > tol) {
|
||||
printf("FAIL [bt=%d, k=%d]: idx=%d ref_prob=%.8f gpu_prob=%.8f diff=%.2e\n",
|
||||
bt, k, gi, rp, gp, fabsf(rp - gp));
|
||||
ok = false;
|
||||
failures++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ok;
|
||||
}
|
||||
|
||||
// ---------- Main ----------
|
||||
|
||||
int main() {
|
||||
constexpr int B = 4;
|
||||
constexpr int T = 8;
|
||||
constexpr int V = 1024; // manageable for CPU verification
|
||||
constexpr int K = 10;
|
||||
constexpr int N = B * T;
|
||||
|
||||
printf("=== Fused Softmax + Top-K Test ===\n");
|
||||
printf("Shape: [B=%d, T=%d, V=%d], K=%d\n\n", B, T, V, K);
|
||||
|
||||
// Allocate and initialize
|
||||
size_t logits_bytes = (size_t)N * V * sizeof(float);
|
||||
size_t idx_bytes = (size_t)N * K * sizeof(int);
|
||||
size_t prob_bytes = (size_t)N * K * sizeof(float);
|
||||
|
||||
std::vector<float> h_logits(N * V);
|
||||
std::vector<int> h_idx_gpu(N * K);
|
||||
std::vector<float> h_prob_gpu(N * K);
|
||||
std::vector<int> h_idx_ref(N * K);
|
||||
std::vector<float> h_prob_ref(N * K);
|
||||
|
||||
// Random logits with large range to stress numerical stability
|
||||
srand(42);
|
||||
for (auto& x : h_logits) {
|
||||
x = ((float)rand() / RAND_MAX - 0.5f) * 40.0f; // range [-20, 20]
|
||||
}
|
||||
|
||||
// GPU allocation
|
||||
float *d_logits, *d_probs;
|
||||
int *d_indices;
|
||||
cudaMalloc(&d_logits, logits_bytes);
|
||||
cudaMalloc(&d_indices, idx_bytes);
|
||||
cudaMalloc(&d_probs, prob_bytes);
|
||||
|
||||
cudaMemcpy(d_logits, h_logits.data(), logits_bytes, cudaMemcpyHostToDevice);
|
||||
|
||||
// Launch kernel
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
printf("Launching fused kernel...\n");
|
||||
cudaEventRecord(start);
|
||||
launch_fused_softmax_topk<K>(d_logits, d_indices, d_probs, B, T, V);
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
float ms = 0;
|
||||
cudaEventElapsedTime(&ms, start, stop);
|
||||
printf("Kernel time: %.3f ms\n\n", ms);
|
||||
|
||||
// Copy results back
|
||||
cudaMemcpy(h_idx_gpu.data(), d_indices, idx_bytes, cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(h_prob_gpu.data(), d_probs, prob_bytes, cudaMemcpyDeviceToHost);
|
||||
|
||||
// CPU reference
|
||||
printf("Running CPU reference...\n");
|
||||
cpu_softmax_topk(h_logits.data(), h_idx_ref.data(), h_prob_ref.data(),
|
||||
B, T, V, K);
|
||||
|
||||
// Verify
|
||||
printf("Verifying...\n");
|
||||
bool pass = verify(h_prob_ref.data(), h_idx_ref.data(),
|
||||
h_prob_gpu.data(), h_idx_gpu.data(),
|
||||
B, T, K);
|
||||
|
||||
printf("\n%s\n", pass ? "✓ ALL TESTS PASSED" : "✗ TESTS FAILED");
|
||||
|
||||
// Print a sample row
|
||||
int row = 0;
|
||||
printf("\nSample output (row %d):\n", row);
|
||||
printf(" %-6s %-12s %-12s %-12s\n", "k", "Index", "GPU Prob", "Ref Prob");
|
||||
printf(" %-6s %-12s %-12s %-12s\n", "---", "-----", "--------", "--------");
|
||||
for (int k = 0; k < K; k++) {
|
||||
printf(" %-6d %-12d %-12.8f %-12.8f\n", k,
|
||||
h_idx_gpu[row * K + k],
|
||||
h_prob_gpu[row * K + k],
|
||||
h_prob_ref[row * K + k]);
|
||||
}
|
||||
|
||||
// Check probability sums
|
||||
float sum_gpu = 0, sum_ref = 0;
|
||||
for (int k = 0; k < K; k++) {
|
||||
sum_gpu += h_prob_gpu[row * K + k];
|
||||
sum_ref += h_prob_ref[row * K + k];
|
||||
}
|
||||
printf("\n Sum of top-%d probs: GPU=%.8f Ref=%.8f\n", K, sum_gpu, sum_ref);
|
||||
printf(" (Note: sum < 1.0 because K << V; these should match)\n");
|
||||
|
||||
// Bandwidth estimate
|
||||
size_t total_read = logits_bytes;
|
||||
size_t total_write = idx_bytes + prob_bytes;
|
||||
double bw = (total_read + total_write) / (ms * 1e-3) / 1e9;
|
||||
printf("\nEstimated effective bandwidth: %.1f GB/s\n", bw);
|
||||
printf(" Reads: %zu bytes (%.1f KB)\n", total_read, total_read / 1024.0);
|
||||
printf(" Writes: %zu bytes (%.1f KB)\n", total_write, total_write / 1024.0);
|
||||
|
||||
// Cleanup
|
||||
cudaFree(d_logits);
|
||||
cudaFree(d_indices);
|
||||
cudaFree(d_probs);
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
Reference in New Issue
Block a user