feat: add model comparisons and sanitize session files

- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
This commit is contained in:
2026-04-23 11:16:01 +02:00
commit 8e72eef09c
62 changed files with 18469 additions and 0 deletions
+336
View File
@@ -0,0 +1,336 @@
"""
Numerically stable forward & backward pass for Layer Normalization.
Forward:
y = gamma * (x - mean) / sqrt(var + eps) + beta
where mean and var are computed over the last dimension (D) independently
for each (b, t) position.
Reference derivation
--------------------
Let:
μ = (1/D) Σ_d x_d (mean over D)
σ² = (1/D) Σ_d (x_d - μ)² (variance over D)
σ̂ = sqrt(σ² + ε) (std with epsilon)
x̂ = (x - μ) / σ̂ (normalized x)
y = γ · x̂ + β (output)
Backward (upstream gradient dy arrives):
dβ = Σ_{b,t} dy (sum over batch & time)
dγ = Σ_{b,t} dy · x̂ (sum over batch & time)
For dx we chain through x̂, μ, σ̂:
dx̂ = dy · γ (element-wise)
dσ² = Σ_d [dx̂_d · (x_d - μ)] · (-½)(σ² + ε)^{-3/2}
= Σ_d [dx̂_d · (x_d - μ)] · (-1 / σ̂³)
dμ = Σ_d dx̂_d · (-1/σ̂) + dσ² · (-2/D) Σ_d (x_d - μ)
= Σ_d dx̂_d · (-1/σ̂) (second term = 0)
dx_d = dx̂_d / σ̂ + dσ² · (2/D)(x_d - μ) + dμ / D
After substitution and simplification (see analysis below):
dx = (1/σ̂) · [ dx̂ - (1/D)(x̂ · Σ_d dx̂_d · x̂_d + Σ_d dx̂_d) ]
Time complexity : O(B·T·D) — one pass over all elements
Memory complexity: O(B·T·D) for the output; intermediates σ̂ and x̂
are O(B·T) and O(B·T·D) respectively.
"""
import numpy as np
# ──────────────────────────────────────────────────────────────
# Forward pass — returns cache needed for backward
# ──────────────────────────────────────────────────────────────
def layer_norm_forward(x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
eps: float = 1e-5):
"""
x : (B, T, D)
gamma : (D,)
beta : (D,)
"""
# --- compute statistics over last dim ---
# keepdims=True avoids a broadcast copy later
mean = x.mean(axis=-1, keepdims=True) # (B, T, 1)
xc = x - mean # (B, T, D) centered x
# Numerical stability note: var is always >= 0 by construction
# because xc = x - mean, so (xc)**2 >= 0. The eps guards against
# the degenerate case where all elements in a row are identical.
var = (xc ** 2).mean(axis=-1, keepdims=True) # (B, T, 1)
# rsqrt is more stable than 1/sqrt for very small arguments because
# it avoids the intermediate sqrt → division round-off.
# NumPy has no native rsqrt, so we compute 1/sqrt carefully.
rstd = 1.0 / np.sqrt(var + eps) # (B, T, 1) reciprocal std
xhat = xc * rstd # (B, T, D) normalized
y = gamma * xhat + beta # (B, T, D)
# Cache everything needed for backward
cache = (xhat, rstd, gamma)
return y, cache
# ──────────────────────────────────────────────────────────────
# Backward pass — manually derived
# ──────────────────────────────────────────────────────────────
def layer_norm_backward(dy: np.ndarray, cache: tuple):
"""
dy : (B, T, D) upstream gradient
cache : (xhat, rstd, gamma) from forward
Returns: dx (B,T,D), dgamma (D,), dbeta (D,)
"""
xhat, rstd, gamma = cache
D = xhat.shape[-1]
# ── dgamma, dbeta ──────────────────────────────────────────
# Sum over all batch & time positions; pointwise over D.
dbeta = dy.sum(axis=(0, 1)) # (D,)
dgamma = (dy * xhat).sum(axis=(0, 1)) # (D,)
# ── dx ─────────────────────────────────────────────────────
# Chain through y = gamma * xhat + beta:
# dxhat = dy * gamma
dxhat = dy * gamma # (B, T, D)
# Direct implementation of the simplified formula:
#
# dx = (1/σ̂) [ dxhat - xhat · mean(dxhat · xhat) - mean(dxhat) ]
#
# where the means are over the D dimension.
#
# Derivation:
# dxhat_d / σ̂ + (dvar)(2/D)(x_d-μ) + dμ/D
# = dxhat_d/σ̂ - (x̂_d/D) Σ_j dxhat_j x̂_j - (1/D) Σ_j dxhat_j
# = (1/σ̂)[ dxhat_d - x̂_d · (1/D)Σ_j dxhat_j x̂_j - (1/D)Σ_j dxhat_j ]
#
# This avoids forming σ² + ε separately and reuses xhat directly.
# Inner products over D — these are O(B·T·D) but touch each element once
proj = (dxhat * xhat).sum(axis=-1, keepdims=True) # (B, T, 1)
dxhat_sum = dxhat.sum(axis=-1, keepdims=True) # (B, T, 1)
dx = rstd * (dxhat
- xhat * proj / D
- dxhat_sum / D) # (B, T, D)
return dx, dgamma, dbeta
# ──────────────────────────────────────────────────────────────
# Gradient check via finite differences
# ──────────────────────────────────────────────────────────────
def gradient_check(B=2, T=3, D=8, eps_fd=1e-5, tol=1e-4, seed=42):
"""Central finite-difference check on x, gamma, beta."""
rng = np.random.default_rng(seed)
x = rng.standard_normal((B, T, D))
gamma = rng.standard_normal(D) * 0.5 + 1.0
beta = rng.standard_normal(D) * 0.1
eps = 1e-5
# ── analytic backward ──
y, cache = layer_norm_forward(x, gamma, beta, eps=eps)
dy = rng.standard_normal(y.shape) # random upstream grad
dx, dgamma, dbeta = layer_norm_backward(dy, cache)
# ── helper: scalar loss = sum(dy * y) ──
def loss_fn(x_, g_, b_):
y_, _ = layer_norm_forward(x_, g_, b_, eps=eps)
return np.sum(dy * y_)
# ── finite-difference for x ──
dx_fd = np.zeros_like(x)
for idx in np.ndindex(x.shape):
x_plus = x.copy(); x_plus[idx] += eps_fd
x_minus = x.copy(); x_minus[idx] -= eps_fd
dx_fd[idx] = (loss_fn(x_plus, gamma, beta)
- loss_fn(x_minus, gamma, beta)) / (2 * eps_fd)
err_x = np.max(np.abs(dx - dx_fd))
rel_x = err_x / (np.max(np.abs(dx)) + 1e-12)
# ── finite-difference for gamma ──
dg_fd = np.zeros_like(gamma)
for i in range(D):
g_plus = gamma.copy(); g_plus[i] += eps_fd
g_minus = gamma.copy(); g_minus[i] -= eps_fd
dg_fd[i] = (loss_fn(x, g_plus, beta)
- loss_fn(x, g_minus, beta)) / (2 * eps_fd)
err_g = np.max(np.abs(dgamma - dg_fd))
rel_g = err_g / (np.max(np.abs(dgamma)) + 1e-12)
# ── finite-difference for beta ──
db_fd = np.zeros_like(beta)
for i in range(D):
b_plus = beta.copy(); b_plus[i] += eps_fd
b_minus = beta.copy(); b_minus[i] -= eps_fd
db_fd[i] = (loss_fn(x, gamma, b_plus)
- loss_fn(x, gamma, b_minus)) / (2 * eps_fd)
err_b = np.max(np.abs(dbeta - db_fd))
rel_b = err_b / (np.max(np.abs(dbeta)) + 1e-12)
# ── report ──
print("=" * 60)
print("Gradient check (central finite differences, h={})".format(eps_fd))
print("=" * 60)
for name, abs_err, rel_err in [("dx", err_x, rel_x),
("dgamma", err_g, rel_g),
("dbeta", err_b, rel_b)]:
ok = "PASS" if rel_err < tol else "FAIL"
print(f" {name:>6s}: max|err| = {abs_err:.2e} "
f"rel = {rel_err:.2e} [{ok}]")
print("=" * 60)
# ──────────────────────────────────────────────────────────────
# Complexity analysis & GPU fusion discussion
# ──────────────────────────────────────────────────────────────
def main():
B, T, D = 4, 16, 64
rng = np.random.default_rng(0)
x = rng.standard_normal((B, T, D)).astype(np.float64)
gamma = rng.standard_normal(D).astype(np.float64) * 0.5 + 1.0
beta = rng.standard_normal(D).astype(np.float64) * 0.1
y, cache = layer_norm_forward(x, gamma, beta)
dy = rng.standard_normal(y.shape)
dx, dg, db = layer_norm_backward(dy, cache)
print(f"Forward output shape : {y.shape}")
print(f"Backward dx shape : {dx.shape}")
print(f"Backward dgamma shape : {dg.shape}")
print(f"Backward dbeta shape : {db.shape}")
print()
# Gradient check
gradient_check(B=3, T=5, D=32)
print()
print_complexity_and_fusion(B, T, D)
def print_complexity_and_fusion(B, T, D):
N = B * T # number of independent normalizations
M = N * D # total elements
print("" * 60)
print("COMPLEXITY ANALYSIS")
print("" * 60)
print(f" Problem size: B={B}, T={T}, D={D}{N} vectors of dim {D}")
print()
print(" Forward pass:")
print(f" • mean : O(M) ({N} reductions of size {D})")
print(f" • var : O(M)")
print(f" • rstd : O(N) (one rsqrt per vector)")
print(f" • xhat : O(M)")
print(f" • y : O(M)")
print(f" Total time : O(M) = O(B·T·D)")
print(f" Extra memory: O(M) for xhat + O(N) for rstd")
print()
print(" Backward pass:")
print(f" • dbeta : O(M) (sum reduction)")
print(f" • dgamma : O(M) (sum reduction)")
print(f" • dx : O(M) (two D-wide reductions + elementwise)")
print(f" Total time : O(M) = O(B·T·D)")
print(f" Extra memory: O(M) for dxhat (can be fused in-place)")
print()
print("" * 60)
print("NUMERICAL STABILITY DISCUSSION")
print("" * 60)
print("""
1. Division by near-zero σ̂:
When all elements in a vector are identical, var = 0 and σ̂ = √ε.
The epsilon (typically 1e-5) prevents division by zero. Using
double precision (float64) for the gradient check gives ~1e-10
residual; in float32 the residual is ~1e-4, which is acceptable.
2. Catastrophic cancellation in xc = x - mean:
If x values are large but close together (e.g., x ≈ 1e6 with
σ ≈ 1e-3), then xc = x - mean loses relative precision in float32.
Remedy: the two-pass algorithm (compute mean first, then centered
sum of squares) is already used here, which is the standard
approach. For extreme cases, a compensated (Kahan) summation
or Welford's online algorithm can be used.
3. Overflow in xc² or var:
For very large values, squaring xc can overflow float16 or float32.
The standard fix is to compute in float32 for float16 inputs, or
use a scaled variant.
4. Gradient explosion when σ̂ is very small:
dx ∝ 1/σ̂, so tiny variance → large gradients. This is inherent
to the operation and is typically handled by gradient clipping
upstream. The epsilon also bounds 1/σ̂ ≤ 1/√ε.
5. rstd computation:
We compute 1/sqrt(var + eps) directly rather than forming
sqrt(var + eps) and then dividing. On GPU, the rsqrt instruction
is a single hardware instruction with correct rounding.
""")
print("" * 60)
print("GPU FUSION STRATEGY")
print("" * 60)
print("""
Goal: Fuse the entire backward pass into a single CUDA kernel that
loads each (B,T) row exactly once from global memory.
Kernel design (one thread-block per row of length D):
───────────────────────────────────────────────────────────────
Shared memory (per block, size ≈ 3·D·4 bytes for float32):
smem_xhat[D] — the normalized input
smem_dxhat[D] — dy * gamma
smem_proj[1] — scalar Σ dxhat_d · xhat_d
smem_sum[1] — scalar Σ dxhat_d
Steps inside the kernel (no globalmem round-trips between steps):
1. Each thread loads one (or more) element(s) of dy and xhat.
Compute dxhat_d = dy_d * gamma_d (gamma in constant mem or smem).
Store dxhat_d and xhat_d into shared memory.
2. Cooperative reduction across the D threads of the block:
proj += dxhat_d * xhat_d
sum += dxhat_d
Two warp-level reductions (or one Blelloch scan) give us the
two scalars in O(log D) steps.
3. Each thread computes:
dx_d = rstd * (dxhat_d - xhat_d * proj / D - sum / D)
and writes the result to global memory.
4. Atomic adds to global dgamma[d] += dy_d * xhat_d
dbeta[d] += dy_d
(one per element; can be deferred to a second pass or done
with block-level reduction + single atomic per block).
Memory traffic per row:
Reads : dy (D) + xhat (D) + rstd (1) = 2D + 1 elements
Writes: dx (D) + dgamma accumulator + dbeta accumulator
Total : ≈ 3D elements vs. ≈ 10D+ for an unfused implementation
(which would read/write intermediates to global memory
between each of the 45 separate kernel launches).
Additional optimizations:
• Use warp-level shuffles (__shfl_down_sync) instead of shared
memory for the reductions when D ≤ 32 (or D ≤ warpSize).
• Vectorized loads (float4 / float2) to improve memory throughput.
• For D values that don't divide evenly into warpSize, use a
grid-stride loop with cooperative groups.
• Fuse with the preceding or following elementwise op (GELU,
residual add) to eliminate another global memory round-trip.
""")
if __name__ == "__main__":
main()
File diff suppressed because one or more lines are too long
+385
View File
@@ -0,0 +1,385 @@
# Fused Softmax + Top-K: Design Document
## 1. Problem Statement
Given logits `[B, T, V]` (e.g., batch=64, seq_len=128, vocab=50257), produce:
- **indices** `[B, T, K]` — the K highest-probability token indices per row
- **probs** `[B, T, K]` — their softmax probabilities
**Constraint:** Never write the full V-length softmax vector to global memory.
---
## 2. Algorithm: Online Softmax + Register Min-Heap
### 2.1 Core Idea
We fuse three operations — **softmax computation**, **top-K selection**, and **probability rescaling** — into a single pass over the logits. This is an instance of the *online softmax* algorithm (Milakov & Gimelshein, 2018) extended with a streaming top-K heap.
### 2.2 Online Softmax Recurrence
Standard softmax requires two passes: one for the max, one for the sum-of-exps. The online variant maintains running statistics:
```
m_j = max(x_0, ..., x_j) // running maximum
d_j = Σ_{i≤j} exp(x_i - m_j) // running sum, always relative to current max
```
Update rule for each new element `x_j`:
```
m_{j} = max(m_{j-1}, x_j)
d_{j} = d_{j-1} * exp(m_{j-1} - m_{j}) + exp(x_j - m_{j})
```
This is **numerically stable** — all exponentials use `x - m_j` where `m_j` is the running max, so no term exceeds `exp(0) = 1`.
### 2.3 Streaming Top-K Heap
Simultaneously, each thread maintains a sorted array of size K in registers:
```
insert(value, index):
if value <= heap[0]: // heap[0] = K-th largest seen so far
return // reject — not in top-K
find position via linear scan (K ≤ 32, so ~5 compares average)
shift lower elements down
place new element
```
For K ≤ 32 this register-resident sorted array outperforms a binary heap because:
- No indirection / pointer chasing
- The GPU's branch predictor handles the predictable comparison pattern well
- Register access is ~0 latency vs. shared memory's ~20 cycle latency
---
## 3. Kernel Architecture
### 3.1 Mapping: One Warp per Row
```
Grid: ceil(B*T / WARPS_PER_BLOCK) blocks
Block: WARPS_PER_BLOCK × WARP_SIZE threads (default: 8 × 32 = 256)
Warp: one (b,t) row
```
Each warp cooperatively processes one row of length V. Lane `j` (0..31) processes elements at indices `j, j+32, j+64, ...`.
### 3.2 Three-Phase Pipeline
```
┌─────────────────────────────────────────────────────────────────┐
│ Phase 1: Local Pass (per-warp, parallel across lanes) │
│ │
│ Each lane reads V/32 logits in a coalesced strided pattern │
│ Each lane maintains: │
│ • local_max, local_sum (online softmax statistics) │
│ • TopKHeap<K> (K best logits seen by this lane) │
│ │
│ Warp reduce → warp_max, warp_sum │
├─────────────────────────────────────────────────────────────────┤
│ Phase 2: Cross-Warp Merge (shared memory) │
│ │
│ Only needed when WARPS_PER_BLOCK > 1 (i.e., multiple warps │
│ process different rows — they still need to sync for shared │
│ memory reuse). Within a single warp, Phase 2 is trivial. │
│ │
│ • Warp 0 reduces global max/sum from all warps │
│ • Each warp writes its local top-K heap to shared memory │
│ • Warp 0 merges WARPS_PER_BLOCK heaps → global top-K │
│ • Rescale: prob_i = exp(val_i - global_max) / global_sum │
├─────────────────────────────────────────────────────────────────┤
│ Phase 3: Write Output │
│ │
│ Lane 0 of warp 0 writes K (prob, index) pairs to global mem │
└─────────────────────────────────────────────────────────────────┘
```
### 3.3 Data Flow Diagram
```
Global Memory (logits [B,T,V])
▼ coalesced reads, V/32 per lane
┌───────────────┐
│ Registers │ Lane 0 Lane 1 ... Lane 31
│ │ [heap] [heap] [heap]
│ │ [lmax] [lmax] [lmax]
│ │ [lsum] [lsum] [lsum]
└──────┬────────┘
│ warp shuffle (reduce_max, reduce_sum)
┌───────────────┐
│ Warp-level │ warp_max, warp_sum (broadcast)
│ consensus │ merged heap via shared memory
└──────┬────────┘
Global Memory (probs [B,T,K], indices [B,T,K])
```
---
## 4. Memory Access Pattern
### 4.1 Global Memory Reads (Logits)
**Pattern: Strided coalesced access**
```
Warp for row r reads logits[r*V + 0], logits[r*V + 1], ..., logits[r*V + V-1]
Lane 0: reads indices 0, 32, 64, 96, ...
Lane 1: reads indices 1, 33, 65, 97, ...
...
Lane 31: reads indices 31, 63, 95, 127, ...
```
Consecutive lanes read consecutive addresses → **perfectly coalesced** 128-byte transactions. Each 128-byte cache line is fully utilized by 32 `float` values.
**Memory efficiency:** V reads per row, 100% coalesced. No redundant loads.
### 4.2 Global Memory Writes (Output)
Each row writes exactly `2K` values (K probabilities + K indices). For K=10, that's 80 bytes — negligible compared to reading V×4 bytes (200KB for V=50k).
**Writes are coalesced within a warp** because consecutive warps write consecutive rows, and lane 0 handles the output for its row.
### 4.3 Shared Memory
Used for cross-warp heap merge. Total footprint per block:
```
float warp_max[8] = 32 bytes
float warp_sum[8] = 32 bytes
float heap_buf[8][32] = 1024 bytes
int idx_buf[8][32] = 1024 bytes
─────────
Total ≈ 2 KB
```
Well within the 48KB shared memory limit. **No bank conflicts** because each warp writes to a different row of `heap_buf[warp_id][...]`, and during the merge phase only warp 0 reads (sequentially, from its own perspective).
### 4.4 Register Usage
Per thread:
- Online softmax state: 2 floats (8 bytes)
- TopKHeap<K=10>: 10 floats + 10 ints (80 bytes)
- Loop variables: ~4 floats (16 bytes)
- **Total: ~104 bytes/thread**
For a 256-thread block: ~26 KB of register file usage. Comfortably fits modern GPU register files (64KB256KB per SM).
---
## 5. Warp-Level Optimization Strategy
### 5.1 Shuffle-Based Reductions
The max and sum reductions use `__shfl_xor_sync` (butterfly pattern):
```
Step 1: exchange with lane ^ 16 → 16 pairs
Step 2: exchange with lane ^ 8 → 8 quads
Step 3: exchange with lane ^ 4 → 4 groups of 8
Step 4: exchange with lane ^ 2 → 2 groups of 16
Step 5: exchange with lane ^ 1 → 1 group of 32
```
5 steps × 2 ops (max + sum) = **10 shuffle instructions total**. No shared memory, no synchronization needed within a warp.
### 5.2 Why Not One Warp Per Row with Vector Loads?
Alternative: use a wider type (`float4`) to read 4 values per lane, reducing the loop iterations by 4×. This is beneficial when V is very large:
```
// Vectorized load variant (Phase 1 inner loop)
float4 vec = reinterpret_cast<const float4*>(logits_row)[v];
float x0 = vec.x, x1 = vec.y, x2 = vec.z, x3 = vec.w;
// Process 4 elements per iteration
```
**Trade-off:** Increases register pressure (4× more values live at once) but reduces loop overhead and improves memory throughput via wider transactions. Recommended when V > 10K.
### 5.3 Occupancy Considerations
| Parameter | Value |
|--------------------|---------|
| Threads/block | 256 |
| Registers/thread | ~26 |
| Shared memory/block| ~2 KB |
| Blocks/SM (A100) | 1620 |
| Rows in flight/SM | 128160 |
The kernel is **not register-heavy** and uses minimal shared memory, allowing high occupancy and effective latency hiding.
---
## 6. Complexity Analysis
### 6.1 Per-Row Work
| Operation | Reads | Writes | Compute |
|--------------------------|----------|---------|-------------------|
| Read logits | V | 0 | 0 |
| Online max/sum | 0* | 0 | V × (1 max + 1 exp + 2 FMAs) |
| Top-K heap insert | 0* | 0 | V × ~5 compares + ~2.5 moves avg |
| Warp reduce | 0 | 0 | 10 shuffles |
| Final rescale (K values) | 0* | 2K | K × (1 exp + 1 mul) |
| **Total** | **V** | **2K** | **~6V + 10 + 2K FLOPs** |
*All intermediate values are in registers.
### 6.2 Bandwidth vs Compute Bound Analysis
For V = 50,257 and K = 10:
**Memory traffic per row:**
```
Reads: V × 4 bytes = 201 KB
Writes: K × 8 bytes = 80 bytes
Total: ~201 KB
```
**Compute per row:**
```
~6 × 50,257 = 301,542 FLOPs (approximate)
```
**Arithmetic intensity:**
```
AI = 301,542 FLOPs / 201,028 bytes ≈ 1.5 FLOP/byte
```
**NVIDIA A100 specs:**
```
Peak bandwidth: 2039 GB/s → compute/bw ratio = 19.5 TFLOPS / 2039 GB/s ≈ 9.6 FLOP/byte
Peak FP32: 19.5 TFLOPS
```
**Conclusion: AI (1.5) << ratio (9.6) → kernel is BANDWIDTH BOUND.**
This means:
1. **The bottleneck is reading V logits from global memory**, not compute.
2. Optimizations should focus on memory access patterns (coalescing, caching) not arithmetic.
3. The fusion saves one full write+read of the `[B,T,V]` tensor (~201 KB/row), directly translating to ~2× end-to-end speedup vs. separate softmax + top-K.
### 6.3 Comparison to Naive Implementation
```
Naive (separate kernels):
Kernel 1: softmax
Read V logits → Write V probabilities (201 KB + 201 KB = 402 KB I/O)
Kernel 2: top-k
Read V probabilities → Write K results (201 KB + 80 bytes = 201 KB I/O)
Total I/O: ~603 KB/row
Kernel launch overhead: 2×
Fused (this kernel):
Read V logits → Write K results (201 KB + 80 bytes = 201 KB I/O)
Total I/O: ~201 KB/row
Kernel launch overhead: 1×
Savings:
Memory I/O: 3× reduction (603 KB → 201 KB per row)
Kernel launches: 2× reduction
Effective speedup: ~2.53× (bandwidth-bound, so I/O directly maps to time)
```
For a real workload (B=64, T=128, V=50257):
```
Naive: 64 × 128 × 603 KB = 4.7 GB global memory traffic
Fused: 64 × 128 × 201 KB = 1.6 GB global memory traffic
Savings: 3.1 GB avoided
At A100 bandwidth (2039 GB/s):
Naive time: ~2.3 ms
Fused time: ~0.8 ms
Speedup: 2.9×
```
---
## 7. Advanced Optimizations
### 7.1 FP16 Input with FP32 Accumulation
For mixed-precision workloads (logits stored as `__half`):
```cuda
// Read 2 values per load, accumulate in FP32
__half2 h2 = reinterpret_cast<const __half2*>(logits_row)[v];
float x0 = __half2float(h2.x);
float x1 = __half2float(h2.y);
```
This halves memory traffic (V × 2 bytes instead of V × 4 bytes), doubling throughput for bandwidth-bound workloads.
### 7.2 Multi-Row Per Warp (for Small V)
When V < 1024, each warp has spare bandwidth. Assign multiple rows per warp:
```
for (int row_offset = 0; row_offset < ROWS_PER_WARP; row_offset++) {
int row = base_row + row_offset;
// ... process row ...
}
```
This amortizes warp-management overhead and improves occupancy for small-V cases.
### 7.3 Async Copy (Hopper/Ada Lovelace)
```cuda
// Pipeline loads with cp.async to overlap compute and memory
cp.async.ca.shared.global [smem_ptr], [gmem_ptr], 16;
```
Overlaps the next chunk's load with the current chunk's heap insertions. Beneficial when V > 10K and the compute path has enough latency to hide.
### 7.4 Warp-Level Heap Merge for Large WARPS_PER_BLOCK
When using many warps per block, the serial merge by warp 0 becomes a bottleneck. Alternative:
```
1. Each warp writes its K values to shared memory
2. Tournament merge using warp shuffles:
- Round 1: warp 0 vs warp 1, warp 2 vs warp 3, ...
- Round 2: winners merge
- Final: one warp produces global top-K
3. Each round uses warp-cooperative merge of two sorted arrays
```
This reduces merge complexity from O(WARPS × K) to O(K × log(WARPS)).
---
## 8. Correctness: Numerical Stability
The algorithm maintains numerical stability through three mechanisms:
1. **Subtract running max before exp:** All calls to `expf()` use `x - current_max`, ensuring the argument is ≤ 0. No overflow possible.
2. **Rescaling on max update:** When `current_max` increases, we multiply the running sum by `exp(old_max - new_max)`, which is in (0, 1]. No overflow; minimal underflow risk.
3. **Final rescaling:** `prob_i = exp(val_i - global_max) / global_sum`. Since `global_sum ≥ 1.0` (it includes `exp(global_max - global_max) = 1.0`), division is safe.
**Comparison with log-sum-exp:**
The online algorithm computes exactly `Σ exp(x_i - max(x))` which is equivalent to `exp(logsumexp(x) - max(x))`. The final probabilities are identical to standard numerically-stable softmax to within floating-point rounding.
---
## 9. Summary Table
| Metric | Naive (separate) | Fused (this work) | Improvement |
|---------------------------|--------------------|---------------------|----------------|
| Global memory reads | 2V per row | V per row | 2× |
| Global memory writes | V + 2K per row | 2K per row | ~V/(2K)× |
| Total I/O per row | ~3V | ~V | 3× |
| Kernel launches | 2 | 1 | 2× |
| Intermediate tensor | V floats/row | 0 (registers) | ∞ |
| Numerically stable | Yes | Yes | — |
| Register pressure | Low | Moderate (~26 regs) | Acceptable |
| Shared memory | None | ~2 KB | Minimal |
| Bandwidth utilization | Wastes BW on intermediate | Optimal | — |
| Effective speedup | Baseline | **2.53×** | — |
+112
View File
@@ -0,0 +1,112 @@
#!/usr/bin/env python3
"""
Generate an ASCII architecture diagram for the fused softmax + top-k kernel.
Run: python3 diagram.py
"""
print("""
╔═══════════════════════════════════════════════════════════════════════════════╗
║ FUSED SOFTMAX + TOP-K KERNEL ARCHITECTURE ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Global Memory Layout ║
║ ┌──────────────────────────────────────────────┐ ║
║ │ logits [B, T, V] (read-only) │ ║
║ │ indices [B, T, K] (write-only) │ ║
║ │ probs [B, T, K] (write-only) │ ║
║ └──────────────────────────────────────────────┘ ║
║ ║
║ Thread Mapping: 1 warp = 1 row (b, t) ║
║ ┌────────────────────────────────────────────────────────────────────────┐ ║
║ │ Block (256 threads = 8 warps) │ ║
║ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ ║
║ │ │ Warp 0 │ │ Warp 1 │ ... │ Warp 7 │ │ ║
║ │ │ row=0 │ │ row=1 │ │ row=7 │ │ ║
║ │ │ 32 lanes │ │ 32 lanes │ │ 32 lanes │ │ ║
║ │ └────┬─────┘ └────┬─────┘ └────┬─────┘ │ ║
║ │ │ │ │ │ ║
║ │ ┌────▼─────────────▼──────────────────▼──────────────────────────┐ │ ║
║ │ │ Shared Memory (~2 KB) │ │ ║
║ │ │ • warp_max[8], warp_sum[8] (32+32 bytes) │ │ ║
║ │ │ • heap_buf[8][K], idx_buf[8][K] (2×8×K × 4 bytes) │ │ ║
║ │ └───────────────────────────────────────────────────────────────┘ │ ║
║ └────────────────────────────────────────────────────────────────────────┘ ║
║ ║
║ Single Warp Detail (processing row r, V=50257): ║
║ ║
║ ┌─────────────────────────────────────────────────────────────────────┐ ║
║ │ Lane 0 Lane 1 Lane 2 ... Lane 31 │ ║
║ │ │ ║
║ │ READ: logits[r*V + {0,1,2,...,31}] ← 1 coalesced 128B load │ ║
║ │ logits[r*V + {32,33,...,63}] ← next coalesced load │ ║
║ │ ... │ ║
║ │ logits[r*V + {50224,...,50255}] ← last load │ ║
║ │ │ ║
║ │ Each lane processes ~V/32 ≈ 1571 elements: │ ║
║ │ │ ║
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
║ │ │ Per-Lane Computation (in REGISTERS): │ │ ║
║ │ │ │ │ ║
║ │ │ local_max = -∞, local_sum = 0 │ │ ║
║ │ │ heap = {(-∞, 0), ..., (-∞, 0)} // K entries │ │ ║
║ │ │ │ │ ║
║ │ │ for each element x_j at index j: │ │ ║
║ │ │ old_max = local_max │ │ ║
║ │ │ local_max = max(local_max, x_j) │ │ ║
║ │ │ local_sum *= exp(old_max - local_max) // rescale │ │ ║
║ │ │ local_sum += exp(x_j - local_max) // add new │ │ ║
║ │ │ heap.insert(x_j, j) // O(K) compare+shift │ │ ║
║ │ └─────────────────────────────────────────────────────────┘ │ ║
║ │ │ │ ║
║ │ ▼ Warp Shuffle Reduction │ ║
║ │ │ ║
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
║ │ │ warp_max = reduce_max(local_max) across 32 lanes │ │ ║
║ │ │ warp_sum = reduce_sum(local_sum * exp(local_max - │ │ ║
║ │ │ warp_max)) across 32 lanes │ │ ║
║ │ │ │ │ ║
║ │ │ 5 butterfly steps using __shfl_xor_sync: │ │ ║
║ │ │ Step 1: ⊕ 16 ── 16↔16 pairs merge │ │ ║
║ │ │ Step 2: ⊕ 8 ── 8 groups of 4 merge │ │ ║
║ │ │ Step 3: ⊕ 4 ── 4 groups of 8 merge │ │ ║
║ │ │ Step 4: ⊕ 2 ── 2 groups of 16 merge │ │ ║
║ │ │ Step 5: ⊕ 1 ── final 32-lane consensus │ │ ║
║ │ └─────────────────────────────────────────────────────────┘ │ ║
║ │ │ │ ║
║ │ ▼ Cross-Warp Merge (Phase 2) │ ║
║ │ │ ║
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
║ │ │ 1. Each warp writes its K heap entries → shared memory │ │ ║
║ │ │ 2. __syncthreads() │ │ ║
║ │ │ 3. Warp 0 merges 8 heaps → global top-K: │ │ ║
║ │ │ • Scan 8×K=80 candidates │ │ ║
║ │ │ • Keep top K=10 via sorted insertion │ │ ║
║ │ │ 4. Rescale to probabilities: │ │ ║
║ │ │ prob_i = exp(val_i - global_max) / global_sum │ │ ║
║ │ │ 5. Write K × (prob, index) to global memory │ │ ║
║ │ └─────────────────────────────────────────────────────────┘ │ ║
║ └─────────────────────────────────────────────────────────────────────┘ ║
║ ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║ MEMORY TRAFFIC SUMMARY (per row, V=50257, K=10) ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ ┌───────────────────┬──────────────┬──────────────┬─────────────┐ ║
║ │ Implementation │ Reads │ Writes │ Total I/O │ ║
║ ├───────────────────┼──────────────┼──────────────┼─────────────┤ ║
║ │ Naive (separate) │ 2V = 402 KB │ V+2K = 201KB │ 603 KB │ ║
║ │ Fused (this work) │ V = 201 KB │ 2K = 80 B │ 201 KB │ ║
║ │ Speedup │ 2× │ 2500× │ 3× │ ║
║ └───────────────────┴──────────────┴──────────────┴─────────────┘ ║
║ ║
║ Full workload (B=64, T=128, V=50257, K=10): ║
║ ┌───────────────────┬─────────────┬──────────┬────────────────────┐ ║
║ │ Implementation │ Total I/O │ Time* │ Intermediate Tensor│ ║
║ ├───────────────────┼─────────────┼──────────┼────────────────────┤ ║
║ │ Naive │ 4.7 GB │ ~2.3 ms │ 201 MB (full sm) │ ║
║ │ Fused │ 1.6 GB │ ~0.8 ms │ 0 MB (registers) │ ║
║ └───────────────────┴─────────────┴──────────┴────────────────────┘ ║
║ * Estimated on NVIDIA A100 at peak bandwidth (2039 GB/s) ║
║ ║
╚═══════════════════════════════════════════════════════════════════════════════╝
""")
+345
View File
@@ -0,0 +1,345 @@
// =============================================================================
// Fused Softmax + Top-K Kernel
// =============================================================================
// Input: logits [B, T, V] (row-major, fp32 or fp16)
// Output: indices [B, T, K] (int32)
// probs [B, T, K] (fp32, softmax probabilities of top-K)
//
// Key insight: we never materialize the full V-length softmax vector.
// We compute the softmax in a single forward pass (online softmax) while
// simultaneously maintaining a min-heap of the top-K values seen so far.
// =============================================================================
#pragma once
#include <cuda_fp16.h>
#include <cuda_runtime.h>
// --------------- tunable parameters ---------------
#ifndef WARP_SIZE
#define WARP_SIZE 32
#endif
#ifndef HEAP_K
// Max K we support; must be a power of 2 for warp-reduce simplicity.
// For K <= 32 we keep the heap entirely in registers per warp.
#define HEAP_K 32
#endif
// We launch one warp per (b, t) row. Each warp processes V/WARP_SIZE
// elements, accumulating partial softmax statistics and a local top-K heap,
// then merges heaps across warps in shared memory.
//
// Block layout: WARPS_PER_BLOCK warps, each handling one row.
// Grid layout: ceil(B*T / WARPS_PER_BLOCK) blocks.
// Total threads: B * T * WARP_SIZE (every thread in a warp works on one row)
#ifndef WARPS_PER_BLOCK
#define WARPS_PER_BLOCK 8
#endif
#define BLOCK_SIZE (WARPS_PER_BLOCK * WARP_SIZE)
// =============================================================================
// Min-heap utilities (keeps top-K largest values)
// =============================================================================
// We store a small sorted array of size K (K <= 32) in registers.
// This is faster than a tree-based heap for small K because:
// - Insertion is just a single compare + conditional shift
// - Cache/coherence is trivial (all registers)
// - No pointer chasing
template <int K>
struct TopKHeap {
float vals[K]; // sorted ascending (vals[0] is the minimum)
int idxs[K];
__device__ __forceinline__
void init() {
#pragma unroll
for (int i = 0; i < K; i++) {
vals[i] = -FLT_MAX;
idxs[i] = 0;
}
}
// Insert if value > current minimum (the K-th largest so far).
__device__ __forceinline__
void insert(float val, int idx) {
if (val <= vals[0]) return; // not in top-K, skip
// Linear scan to find insertion point (small K → branch predictor loves it).
// For K=32 this is ~5 compares on average, cheaper than binary search overhead.
int pos = 0;
#pragma unroll
for (int i = 1; i < K; i++) {
if (val > vals[i]) pos = i; // find last position where val > vals[i]
else break;
}
// Shift elements down: vals[0..pos-1] ← vals[1..pos]
#pragma unroll
for (int i = 0; i < pos; i++) {
vals[i] = vals[i + 1];
idxs[i] = idxs[i + 1];
}
vals[pos] = val;
idxs[pos] = idx;
}
};
// =============================================================================
// Warp-level primitives
// =============================================================================
__device__ __forceinline__
float warp_reduce_max(float val) {
// Butterfly reduction across 32 lanes
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 16));
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 8));
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 4));
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 2));
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, 1));
return val;
}
__device__ __forceinline__
float warp_reduce_sum(float val) {
val += __shfl_xor_sync(0xFFFFFFFF, val, 16);
val += __shfl_xor_sync(0xFFFFFFFF, val, 8);
val += __shfl_xor_sync(0xFFFFFFFF, val, 4);
val += __shfl_xor_sync(0xFFFFFFFF, val, 2);
val += __shfl_xor_sync(0xFFFFFFFF, val, 1);
return val;
}
// =============================================================================
// Shared memory layout (one per block)
// =============================================================================
struct SharedStorage {
// Per-warp partial results for cross-warp merge
float warp_max[WARPS_PER_BLOCK]; // partial max
float warp_sum[WARPS_PER_BLOCK]; // partial sum of exps
// Heap merge buffer: each warp writes its local top-K here
float heap_buf[WARPS_PER_BLOCK][HEAP_K];
int idx_buf [WARPS_PER_BLOCK][HEAP_K];
// Synchronization
int barrier_count;
};
// =============================================================================
// Phase 1 — Per-warp local pass over V/WARPS_PER_BLOCK chunks
// =============================================================================
// Each lane j in warp w processes logits at indices:
// j, j + WARP_SIZE, j + 2*WARP_SIZE, ...
// covering a strided subset of the V-dimension.
//
// Online softmax recurrence (per lane):
// m_j ← max(m_j, x_j) (local max)
// d_j ← d_j * exp(m_old - m_j) + exp(x_j - m_j)
//
// After the loop we do a warp-all-reduce to get the global max m and sum d
// for this row. Then each lane rescales its accumulated exp-sum and
// inserts its local top-K candidates into a heap scaled by 1/d.
template <int K>
__device__ __forceinline__
void local_pass(
const float* __restrict__ logits_row, // pointer to row of length V
int V,
float& out_max, // warp-reduced max
float& out_sum, // warp-reduced sum of exps
TopKHeap<K>& heap) // per-lane local top-K
{
const int lane = threadIdx.x % WARP_SIZE;
const int warp = threadIdx.x / WARP_SIZE;
float local_max = -FLT_MAX;
float local_sum = 0.0f;
// Strided loop: lane i processes indices i, i+32, i+64, ...
// This gives coalesced global reads because consecutive lanes read
// consecutive addresses.
for (int v = lane; v < V; v += WARP_SIZE) {
float x = logits_row[v];
float old_max = local_max;
local_max = fmaxf(local_max, x);
// Rescale running sum to new max
local_sum *= expf(old_max - local_max);
local_sum += expf(x - local_max);
// Track top-K in the original logit space (before exp).
// We will rescale to probabilities later using the final max & sum.
heap.insert(x, v);
}
// ---- Warp-level reduction for max and sum ----
float warp_max = warp_reduce_max(local_max);
// Rescale all lane sums to the common warp_max
local_sum *= expf(local_max - warp_max);
float warp_sum = warp_reduce_sum(local_sum);
out_max = warp_max;
out_sum = warp_sum;
}
// =============================================================================
// Phase 2 — Cross-warp heap merge in shared memory
// =============================================================================
// When WARPS_PER_BLOCK > 1, each warp has its own local top-K heap.
// We merge by:
// 1. Each warp writes its heap to shared memory
// 2. __syncthreads()
// 3. Lane 0 of warp 0 does a serial K-way merge (K is small, typically 5-50)
// over WARPS_PER_BLOCK heaps → global top-K
// 4. Rescale values: prob_i = exp(val_i - global_max) / global_sum
//
// For WARPS_PER_BLOCK == 1 this phase is a no-op (single warp = single row).
template <int K>
__device__ __forceinline__
void cross_warp_merge(
SharedStorage& smem,
float global_max,
float global_sum,
TopKHeap<K>& heap,
int warp_id,
int lane_id,
float* out_probs, // [K] output
int* out_idxs) // [K] output
{
// Each warp writes its local heap to shared memory
if (lane_id < K) {
smem.heap_buf[warp_id][lane_id] = heap.vals[K - 1 - lane_id]; // descending
smem.idx_buf [warp_id][lane_id] = heap.idxs[K - 1 - lane_id];
}
__syncthreads();
// Warp 0 merges all heaps
if (warp_id == 0) {
// Build the global top-K by scanning all warp heaps
TopKHeap<K> global_heap;
global_heap.init();
#pragma unroll
for (int w = 0; w < WARPS_PER_BLOCK; w++) {
#pragma unroll
for (int i = 0; i < K; i++) {
float v = smem.heap_buf[w][i];
int j = smem.idx_buf [w][i];
global_heap.insert(v, j);
}
}
// Lane 0 writes the final result (rescaled to probabilities)
if (lane_id == 0) {
float inv_sum = 1.0f / global_sum;
#pragma unroll
for (int i = 0; i < K; i++) {
// vals are sorted ascending; reverse for output (descending prob)
int ki = K - 1 - i;
out_probs[i] = expf(global_heap.vals[ki] - global_max) * inv_sum;
out_idxs [i] = global_heap.idxs[ki];
}
}
}
}
// =============================================================================
// Main kernel
// =============================================================================
template <int K>
__global__ void fused_softmax_topk_kernel(
const float* __restrict__ logits, // [B, T, V]
int* __restrict__ out_indices, // [B, T, K]
float* __restrict__ out_probs, // [B, T, K]
int B, int T, int V)
{
// One block processes WARPS_PER_BLOCK rows.
// Each warp handles one row.
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
// Map warp → (b, t) row index
int row = blockIdx.x * WARPS_PER_BLOCK + warp_id;
if (row >= B * T) return;
int b = row / T;
int t = row % T;
// Pointers for this row
const float* logits_row = logits + (size_t)row * V;
int* row_out_indices = out_indices + (size_t)row * K;
float* row_out_probs = out_probs + (size_t)row * K;
// Shared memory
__shared__ __align__(16) SharedStorage smem;
// Phase 1: local pass over logits
TopKHeap<K> heap;
heap.init();
float warp_max, warp_sum;
local_pass<K>(logits_row, V, warp_max, warp_sum, heap);
// Store partials in shared memory for cross-warp merge
if (lane_id == 0) {
smem.warp_max[warp_id] = warp_max;
smem.warp_sum[warp_id] = warp_sum;
}
__syncthreads();
// Compute global max and sum across warps (lane 0 does it)
float global_max = -FLT_MAX;
float global_sum = 0.0f;
if (lane_id == 0) {
#pragma unroll
for (int w = 0; w < WARPS_PER_BLOCK; w++) {
if (blockIdx.x * WARPS_PER_BLOCK + w < B * T) {
float wm = smem.warp_max[w];
float ws = smem.warp_sum[w];
float old_max = global_max;
global_max = fmaxf(global_max, wm);
global_sum *= expf(old_max - global_max);
global_sum += ws * expf(wm - global_max);
}
}
smem.warp_max[0] = global_max; // reuse for broadcast
smem.warp_sum[0] = global_sum;
}
__syncthreads();
global_max = smem.warp_max[0];
global_sum = smem.warp_sum[0];
// Phase 2: cross-warp heap merge + write output
cross_warp_merge<K>(smem, global_max, global_sum,
heap, warp_id, lane_id,
row_out_probs, row_out_indices);
}
// =============================================================================
// Host launch wrapper
// =============================================================================
template <int K>
void launch_fused_softmax_topk(
const float* d_logits,
int* d_indices,
float* d_probs,
int B, int T, int V,
cudaStream_t stream = 0)
{
int total_rows = B * T;
int grid = (total_rows + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK;
size_t smem_bytes = sizeof(SharedStorage);
fused_softmax_topk_kernel<K>
<<<grid, BLOCK_SIZE, smem_bytes, stream>>>(
d_logits, d_indices, d_probs, B, T, V);
}
// Explicit instantiation for common K values
template void launch_fused_softmax_topk<5>(const float*, int*, float*, int, int, int, cudaStream_t);
template void launch_fused_softmax_topk<10>(const float*, int*, float*, int, int, int, cudaStream_t);
template void launch_fused_softmax_topk<20>(const float*, int*, float*, int, int, int, cudaStream_t);
template void launch_fused_softmax_topk<32>(const float*, int*, float*, int, int, int, cudaStream_t);
File diff suppressed because one or more lines are too long
+198
View File
@@ -0,0 +1,198 @@
// =============================================================================
// Test / Benchmark: Fused Softmax + Top-K
// =============================================================================
// Compile:
// nvcc -O3 -arch=sm_80 -o test_fused test_fused.cu fused_softmax_topk.cuh
//
// Run:
// ./test_fused
// =============================================================================
#include "fused_softmax_topk.cuh"
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <vector>
#include <numeric>
// ---------- CPU reference implementation ----------
void cpu_softmax_topk(const float* logits, int* indices, float* probs,
int B, int T, int V, int K) {
for (int bt = 0; bt < B * T; bt++) {
const float* row = logits + bt * V;
int* out_idx = indices + bt * K;
float* out_prob = probs + bt * K;
// Numerically stable softmax
float max_val = *std::max_element(row, row + V);
float sum = 0.0f;
std::vector<float> exp_vals(V);
for (int v = 0; v < V; v++) {
exp_vals[v] = expf(row[v] - max_val);
sum += exp_vals[v];
}
float inv_sum = 1.0f / sum;
for (int v = 0; v < V; v++) {
exp_vals[v] *= inv_sum;
}
// Top-K by sorting (simple but correct)
std::vector<int> idx(V);
std::iota(idx.begin(), idx.end(), 0);
std::partial_sort(idx.begin(), idx.begin() + K, idx.end(),
[&](int a, int b) { return exp_vals[a] > exp_vals[b]; });
for (int k = 0; k < K; k++) {
out_idx[k] = idx[k];
out_prob[k] = exp_vals[idx[k]];
}
}
}
// ---------- Verification ----------
bool verify(const float* ref_probs, const int* ref_idx,
const float* gpu_probs, const int* gpu_idx,
int B, int T, int K, float tol = 1e-4f) {
bool ok = true;
int failures = 0;
for (int bt = 0; bt < B * T && failures < 10; bt++) {
for (int k = 0; k < K; k++) {
int ri = ref_idx[bt * K + k];
int gi = gpu_idx[bt * K + k];
float rp = ref_probs[bt * K + k];
float gp = gpu_probs[bt * K + k];
// Index must match (probabilities might have ties, but for random data they won't)
if (ri != gi) {
// Check if probability is close (might be a tie)
if (fabsf(rp - gp) > tol) {
printf("FAIL [bt=%d, k=%d]: ref_idx=%d gpu_idx=%d ref_prob=%.8f gpu_prob=%.8f\n",
bt, k, ri, gi, rp, gp);
ok = false;
failures++;
}
}
// Probability must match
if (fabsf(rp - gp) > tol) {
printf("FAIL [bt=%d, k=%d]: idx=%d ref_prob=%.8f gpu_prob=%.8f diff=%.2e\n",
bt, k, gi, rp, gp, fabsf(rp - gp));
ok = false;
failures++;
}
}
}
return ok;
}
// ---------- Main ----------
int main() {
constexpr int B = 4;
constexpr int T = 8;
constexpr int V = 1024; // manageable for CPU verification
constexpr int K = 10;
constexpr int N = B * T;
printf("=== Fused Softmax + Top-K Test ===\n");
printf("Shape: [B=%d, T=%d, V=%d], K=%d\n\n", B, T, V, K);
// Allocate and initialize
size_t logits_bytes = (size_t)N * V * sizeof(float);
size_t idx_bytes = (size_t)N * K * sizeof(int);
size_t prob_bytes = (size_t)N * K * sizeof(float);
std::vector<float> h_logits(N * V);
std::vector<int> h_idx_gpu(N * K);
std::vector<float> h_prob_gpu(N * K);
std::vector<int> h_idx_ref(N * K);
std::vector<float> h_prob_ref(N * K);
// Random logits with large range to stress numerical stability
srand(42);
for (auto& x : h_logits) {
x = ((float)rand() / RAND_MAX - 0.5f) * 40.0f; // range [-20, 20]
}
// GPU allocation
float *d_logits, *d_probs;
int *d_indices;
cudaMalloc(&d_logits, logits_bytes);
cudaMalloc(&d_indices, idx_bytes);
cudaMalloc(&d_probs, prob_bytes);
cudaMemcpy(d_logits, h_logits.data(), logits_bytes, cudaMemcpyHostToDevice);
// Launch kernel
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
printf("Launching fused kernel...\n");
cudaEventRecord(start);
launch_fused_softmax_topk<K>(d_logits, d_indices, d_probs, B, T, V);
cudaEventRecord(stop);
cudaEventSynchronize(stop);
float ms = 0;
cudaEventElapsedTime(&ms, start, stop);
printf("Kernel time: %.3f ms\n\n", ms);
// Copy results back
cudaMemcpy(h_idx_gpu.data(), d_indices, idx_bytes, cudaMemcpyDeviceToHost);
cudaMemcpy(h_prob_gpu.data(), d_probs, prob_bytes, cudaMemcpyDeviceToHost);
// CPU reference
printf("Running CPU reference...\n");
cpu_softmax_topk(h_logits.data(), h_idx_ref.data(), h_prob_ref.data(),
B, T, V, K);
// Verify
printf("Verifying...\n");
bool pass = verify(h_prob_ref.data(), h_idx_ref.data(),
h_prob_gpu.data(), h_idx_gpu.data(),
B, T, K);
printf("\n%s\n", pass ? "✓ ALL TESTS PASSED" : "✗ TESTS FAILED");
// Print a sample row
int row = 0;
printf("\nSample output (row %d):\n", row);
printf(" %-6s %-12s %-12s %-12s\n", "k", "Index", "GPU Prob", "Ref Prob");
printf(" %-6s %-12s %-12s %-12s\n", "---", "-----", "--------", "--------");
for (int k = 0; k < K; k++) {
printf(" %-6d %-12d %-12.8f %-12.8f\n", k,
h_idx_gpu[row * K + k],
h_prob_gpu[row * K + k],
h_prob_ref[row * K + k]);
}
// Check probability sums
float sum_gpu = 0, sum_ref = 0;
for (int k = 0; k < K; k++) {
sum_gpu += h_prob_gpu[row * K + k];
sum_ref += h_prob_ref[row * K + k];
}
printf("\n Sum of top-%d probs: GPU=%.8f Ref=%.8f\n", K, sum_gpu, sum_ref);
printf(" (Note: sum < 1.0 because K << V; these should match)\n");
// Bandwidth estimate
size_t total_read = logits_bytes;
size_t total_write = idx_bytes + prob_bytes;
double bw = (total_read + total_write) / (ms * 1e-3) / 1e9;
printf("\nEstimated effective bandwidth: %.1f GB/s\n", bw);
printf(" Reads: %zu bytes (%.1f KB)\n", total_read, total_read / 1024.0);
printf(" Writes: %zu bytes (%.1f KB)\n", total_write, total_write / 1024.0);
// Cleanup
cudaFree(d_logits);
cudaFree(d_indices);
cudaFree(d_probs);
cudaEventDestroy(start);
cudaEventDestroy(stop);
return pass ? 0 : 1;
}
+395
View File
@@ -0,0 +1,395 @@
# KV-Cache for Autoregressive Transformer Inference
A complete, framework-free implementation of KV-caching for autoregressive
transformer inference, built from scratch in Python/NumPy.
## Architecture Overview
```
┌─────────────────────────────────────────────────────────────┐
│ Transformer Layer │
│ │
│ Token IDs ──► Embedding ──► Q,K,V Projections │
│ │ │
│ ┌─────────────┼──────────────┐ │
│ ▼ ▼ ▼ │
│ Q_new ──► K_new, V_new ──► Cache Write │
│ │ │ │
│ │ ┌─────────────────────┘ │
│ ▼ ▼ │
│ ┌──────────────┐ │
│ │ Attention │ Q_new × (K_cached + K_new) │
│ │ Computation │ ──────────────────────────► │
│ │ (read-only) │ weights × (V_cached + V_new) │
│ └──────┬───────┘ │
│ ▼ │
│ Output Projection ──► LayerNorm ──► next layer │
└─────────────────────────────────────────────────────────────┘
```
## Data Structure Layout
### Memory Format
Each layer maintains two pre-allocated tensors:
```
keys: (B, H, S_max, D) float32
values: (B, H, S_max, D) float32
```
| Symbol | Meaning | Example (GPT-4 class) |
|--------|--------------------------------|----------------------|
| B | Batch size | 164 |
| H | Number of attention heads | 32 |
| S_max | Maximum sequence length | 8192131072 |
| D | Head dimension (d_model / H) | 128 |
**Why BHSD layout?**
The dimensions are ordered so that the sequence axis (S) is stride-D
contiguous. This means:
1. **Append is a simple slice copy**`cache[b, :, pos, :] = new_kv`
writes D×H floats to a contiguous region.
2. **Attention matmul is efficient** — the inner `Q @ K^T` reads K along
the S dimension, which is stride-D contiguous.
3. **GPU-friendly** — maps directly to a CUDA tensor with no transposition
needed between the write and read paths.
### Auxiliary State
```
seq_lens: int[B] — valid prefix length per batch element
```
Positions `[..., :seq_lens[b], :]` contain valid data. Everything beyond
is garbage and must be masked out during attention.
## Update Logic Per Step
### Prefill Phase (processing the full prompt)
```
Input: prompt tokens of length S
Output: cache filled with S key-value pairs
for each layer:
Q, K, V = project(prompt_embeddings) # (B, S, d_model) → 3× (B, S, d_model)
K = reshape(K, (B, H, S, D)) # split into heads
V = reshape(V, (B, H, S, D))
cache.write(positions=[0, 1, ..., S-1], K, V) # bulk write
# Self-attention within the prompt (causal mask)
attn_output = attention(Q, cache.read()) # O(S²) — one-time cost
```
### Decode Phase (one token at a time)
```
Input: single new token
Output: logits for next token prediction
for each layer:
q_new, k_new, v_new = project(token_embedding) # each (B, 1, d_model)
k_new = reshape(k_new, (B, H, 1, D))
v_new = reshape(v_new, (B, H, 1, D))
# ── CACHE UPDATE: O(H·D) — write 1 token ──
cache[pos] = (k_new, v_new) # 2 × H × D floats
# ── ATTENTION: O(S·H·D) — query vs ALL cached keys ──
K_all, V_all = cache.read() # (B, H, S+1, D)
scores = q_new @ K_all.T / √D # (B, H, 1, S+1)
weights = softmax(scores)
output = weights @ V_all # (B, H, 1, D)
```
**Key insight**: Without caching, each decode step would require O(S²) work
(recomputing attention for all S previous tokens). With caching, it's only
O(S) — the new query attends against the cached keys/values.
## Attention Computation Using Cached Keys/Values
```
┌───────────┐ ┌───────────────────────────────────┐
│ Q_new │ │ Cached K (all past tokens) │
│ (1, D) │ × │ (S_valid, D) │
│ │ │ │
│ │ │ [k₀] [k₁] [k₂] ... [k_{S-1}] │
└─────┬─────┘ └───────────────────────────────────┘
│ │
▼ ▼
┌────────────────────────────────────┐
│ scores = Q · K^T / √D │ → (1, S_valid)
│ weights = softmax(scores) │ → (1, S_valid)
│ output = weights · V │ → (1, D)
└────────────────────────────────────┘
```
This is performed independently for each head H and batch element B.
## Memory Growth Analysis
### Linear Growth
The cache grows **linearly** with sequence length:
```
Memory per layer = 2 × B × H × S × D × sizeof(dtype)
= 2 × B × d_model × S × sizeof(dtype)
```
For a GPT-4-class model (32 layers, d_model=4096, FP32):
| Seq Length | Per Layer (MB) | Total (MB) | Total (GB) |
|-----------|---------------|-----------|-----------|
| 128 | 0.67 | 21.47 | 0.021 |
| 1,024 | 5.37 | 171.79 | 0.172 |
| 4,096 | 21.47 | 687.19 | 0.687 |
| 16,384 | 85.89 | 2,748.77 | 2.749 |
| 65,536 | 343.59 | 10,995.08 | 10.995 |
| 131,072 | 687.19 | 21,990.16 | 21.990 |
**Observation**: At 128K context with batch=1, you need **~22 GB** just for
the KV cache — before accounting for model weights, activations, or
gradients.
### FLOPs Savings
| Scenario | Without Cache | With Cache | Speedup |
|----------|--------------|-----------|---------|
| 1024 prompt + 100 decode | 4.2e14 | 2.0e12 | ~200× |
The speedup grows quadratically with sequence length.
## Optimizations
### 1. Paged Attention (Virtual Memory for KV Cache)
**Problem**: Pre-allocating `(B, H, S_max, D)` wastes memory for short
sequences and causes fragmentation when sequences finish at different
times.
**Solution**: Divide the cache into fixed-size blocks (pages):
```
Physical Memory:
┌────────┬────────┬────────┬────────┬────────┬────────┐
│ Block 0│ Block 1│ Block 2│ Block 3│ Block 4│ ... │
│(H,B,D) │(H,B,D) │(H,B,D) │(H,B,D) │(H,B,D) │ │
└────────┴────────┴────────┴────────┴────────┴────────┘
Page Tables:
Seq 0: [0] → [3] → [1] (3 blocks = 3 × BLOCK_SIZE tokens)
Seq 1: [2] → [4] (2 blocks = 2 × BLOCK_SIZE tokens)
Seq 2: [5] (1 block)
Free: [6, 7, 8, ...]
```
**Benefits**:
- Memory allocated only as needed (no S_max pre-allocation)
- Finished sequences free blocks immediately → higher throughput
- No external fragmentation
- Enables sharing of KV blocks across sequences (e.g., prefix caching)
**Implementation**: See `PagedKVCache` in `optimizations.py`.
### 2. Chunked Prefill
**Problem**: Processing a 32K-token prompt requires a 32K×32K attention
matrix (1 billion floats = 4 GB) just for the prefill.
**Solution**: Split the prompt into chunks of C tokens:
```
Prompt: [t₀, t₁, t₂, ..., t_{S-1}] (S = 32K)
Chunk 0: [t₀..t_{C-1}] → cache write → attention vs cache (0..C)
Chunk 1: [t_C..t_{2C-1}] → cache write → attention vs cache (0..2C)
Chunk 2: [t_{2C}..t_{3C-1}] → cache write → attention vs cache (0..3C)
...
```
Peak attention memory: O(C × S) instead of O(S²).
**Benefits**:
- Bounded peak memory regardless of prompt length
- Can interleave prefill chunks with decode steps from other sequences
- Better GPU utilization (uniform work items)
### 3. Cache Quantization (INT8 / INT4)
**Problem**: 22 GB for a 128K context is unsustainable.
**Solution**: Quantize cached K/V to lower precision:
| Precision | Bytes/Element | Memory Savings | Typical Quality Loss |
|-----------|-------------|---------------|---------------------|
| FP32 | 4 | 1× (baseline) | 0% |
| FP16 | 2 | 2× | <0.1% |
| INT8 | 1 | 4× | <0.5% |
| INT4 | 0.5 | 8× | 1-3% |
Quantization is per-token: `scale[b,h,t] = max(|K[b,h,t,:]|) / (2^bits - 1)`.
```
Storage:
k_quant: uint8 (B, H, S, D) or packed uint8 (B, H, S, D/2) for INT4
k_scale: float32 (B, H, S) — one scalar per token per head
Dequantize during attention:
K_float = k_quant * k_scale — in registers before matmul
```
**Benefits**:
- 4-8× memory reduction → longer contexts or larger batches
- Minimal quality loss for most tasks
- Hardware support on modern GPUs (FP8 on Hopper, INT8 on Ampere)
## GPU Execution Mapping
### Memory Hierarchy
```
┌──────────────────────────────────────────────┐
│ HBM (High Bandwidth Memory) │
│ ┌──────────────────────────────────────┐ │
│ │ KV Cache: (B, H, S, D) per layer │ │
│ │ ~10-70 GB for long contexts │ │
│ └──────────────────────────────────────┘ │
│ ┌──────────────────────────────────────┐ │
│ │ Model Weights │ │
│ └──────────────────────────────────────┘ │
└──────────────────────┬───────────────────────┘
│ ~2-3 TB/s bandwidth
┌──────────────────────────────────────────────┐
│ Shared Memory (per SM) │
│ ┌──────────────────────────────────────┐ │
│ │ Q tile: (block_B, H, tile_S, D) │ │
│ │ K tile: (block_B, H, tile_S, D) │ │
│ │ V tile: (block_B, H, tile_S, D) │ │
│ │ Score tile: (block_B, H, tile_S²) │ │
│ └──────────────────────────────────────┘ │
│ ~48-164 KB per SM │
└──────────────────────┬───────────────────────┘
│ ~19 TB/s bandwidth
┌──────────────────────────────────────────────┐
│ Registers (per thread block) │
│ accumulator for QK^T, softmax, etc. │
│ ~255 registers/thread │
└──────────────────────────────────────────────┘
```
### Kernel Mapping
| Operation | CPU (this impl) | GPU Kernel |
|-----------|----------------|------------|
| Cache write | `cache[b,:,pos,:] = new_kv` | `cudaMemcpyAsync` or block-level scatter |
| Q×K^T | `q @ k.T` | Batched GEMM (cuBLAS) or FlashAttention |
| Softmax | `_softmax(scores)` | Online softmax (FlashAttention) |
| Weights×V | `weights @ v` | GEMM (part of FlashAttention fused kernel) |
| Quantize | `_quantize_token()` | Block-reduce + scale + convert |
### FlashAttention Integration
The attention computation in this codebase performs the naive:
```
S = Q × K^T # materialize full (S_q, S_kv) matrix
A = softmax(S) # another (S_q, S_kv) matrix
O = A × V # output
```
On GPU, **FlashAttention** fuses these three operations:
```
for each tile of Q:
init: O = 0, m = -∞, l = 0
for each tile of K, V:
S_tile = Q_tile × K_tile^T # in SRAM
m_new = max(m, max(S_tile))
P_tile = exp(S_tile - m_new) # in SRAM
l_new = l + sum(P_tile)
O = O * (l/l_new) + P_tile × V_tile # accumulate
m, l = m_new, l_new
O = O / l
```
This keeps the O(S²) attention matrix entirely in SRAM, avoiding
HBM reads/writes. The KV cache is read tile-by-tile from HBM.
### Paged Attention on GPU
The `PagedKVCache` page table translates to a GPU indirection:
```cuda
// CUDA pseudocode for paged attention
__global__ void paged_attention(
float* Q, // (B, H, 1, D) — new query
float* K_pool, // (num_blocks, H, BLOCK_SIZE, D)
float* V_pool,
int* page_table, // (B, max_pages_per_seq)
int* seq_lens, // (B,)
float* output // (B, H, 1, D)
) {
int b = blockIdx.y;
int h = blockIdx.x;
int S = seq_lens[b];
// Load query into registers
float q[D];
load_query(q, Q, b, h);
// Iterate over pages
float score[S_MAX_LOCAL];
for (int page = 0; page < ceil(S / BLOCK_SIZE); page++) {
int phys_block = page_table[b * max_pages + page];
// Gather K/V from scattered physical blocks
for (int i = 0; i < BLOCK_SIZE; i++) {
float k = K_pool[phys_block * H * BLOCK_SIZE * D
+ h * BLOCK_SIZE * D + i * D + d];
score[page * BLOCK_SIZE + i] = dot(q, k) / sqrt(D);
}
}
// ... softmax, multiply by V, write output
}
```
## File Structure
```
kv/
├── README.md ← you are here
├── kv_cache.py ← core data structures + attention
├── optimizations.py ← paged attention, chunked prefill, quantization
└── test_kv_cache.py ← comprehensive test suite
```
## Running
```bash
python test_kv_cache.py
```
All tests run without any external dependencies beyond NumPy.
## Key Design Decisions
1. **Pre-allocation**: The base `KVCache` pre-allocates to `S_max` to
avoid GPU memory allocation during inference (malloc is expensive).
The `PagedKVCache` trades this for on-demand block allocation.
2. **No cross-contamination**: Each batch element maintains its own
valid prefix via `seq_lens`. Attention never attends to garbage
positions from other sequences.
3. **Separation of concerns**: Cache update (write) and attention
(read) are decoupled. The caller controls when each happens,
enabling chunked prefill and prefix sharing.
4. **Quantization at cache boundary**: K/V are computed in FP32,
quantized on write, dequantized on read. This keeps the attention
computation unchanged while reducing memory.
+471
View File
@@ -0,0 +1,471 @@
"""
KV-Cache for Autoregressive Transformer Inference
===================================================
Memory layout
-------------
Each layer stores two tensors:
keys: shape (B, H, S_max, D) — float32
values: shape (B, H, S_max, D) — float32
Where:
B = batch size
H = number of attention heads
S_max = pre-allocated max sequence length
D = head dimension (d_model / H)
The layout is BHSD (batch, head, seq, dim) which is contiguous along
the sequence axis — ideal for appending one token at a time and for
the inner attention matmul.
A companion `seq_lens: list[int]` (length B) tracks how many positions
are valid in each batch element. Positions beyond seq_lens[b] contain
garbage and must never participate in attention.
No external frameworks are used. All kernels are pure-NumPy for
correctness; the design maps 1:1 to CUDA kernels (see README).
"""
from __future__ import annotations
import math
from typing import List, Tuple, Optional
import numpy as np
# ──────────────────────────────────────────────────────────────────────
# 1. DATA STRUCTURE
# ──────────────────────────────────────────────────────────────────────
class KVCache:
"""
Pre-allocated KV cache for one transformer layer.
Physical storage
~~~~~~~~~~~~~~~~
Two numpy arrays allocated once at construction:
self.k_cache (B, H, S_max, D) float32
self.v_cache (B, H, S_max, D) float32
An auxiliary array `self.seq_lens` (length B, int) records how many
token positions are live for each sequence in the batch.
On GPU the same layout would be backed by a single cudaMalloc per
layer. The B-H-S-D ordering keeps the S-dimension stride == D,
making the per-token write a simple 3D slice copy.
"""
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
dtype: np.dtype = np.float32,
):
self.B = batch_size
self.S_max = max_seq_len
self.H = num_heads
self.D = head_dim
self.dtype = dtype
shape = (batch_size, num_heads, max_seq_len, head_dim)
self.k_cache = np.zeros(shape, dtype=dtype)
self.v_cache = np.zeros(shape, dtype=dtype)
# seq_lens[b] = number of valid positions for batch element b
self.seq_lens: List[int] = [0] * batch_size
# ── helpers ──────────────────────────────────────────────────────
def _check_batch(self, token_k: np.ndarray) -> None:
"""Validate shape of incoming key/value tensors."""
# token_k expected: (B, H, T, D) where T is the number of new tokens
assert token_k.ndim == 4
assert token_k.shape[0] == self.B
assert token_k.shape[1] == self.H
assert token_k.shape[3] == self.D
# ── core update ──────────────────────────────────────────────────
def update(
self,
new_k: np.ndarray,
new_v: np.ndarray,
positions: Optional[List[int]] = None,
) -> None:
"""
Write new key/value vectors into the cache.
Parameters
----------
new_k, new_v : ndarray, shape (B, H, T, D)
Keys and values for T new tokens. In incremental decoding T=1.
positions : list[int] | None
Explicit write offsets per batch element. When *None* the
tokens are appended right after the current `seq_lens[b]`.
"""
self._check_batch(new_k)
T = new_k.shape[2] # number of new tokens (1 for decode, S for prefill)
for b in range(self.B):
pos = positions[b] if positions is not None else self.seq_lens[b]
assert pos + T <= self.S_max, (
f"batch {b}: pos {pos} + {T} tokens would exceed S_max={self.S_max}"
)
# ---- the actual write: a slice copy into pre-allocated memory ----
self.k_cache[b, :, pos : pos + T, :] = new_k[b]
self.v_cache[b, :, pos : pos + T, :] = new_v[b]
# advance sequence pointers
for b in range(self.B):
base = positions[b] if positions is not None else self.seq_lens[b]
self.seq_lens[b] = base + T
# ── retrieval (used by attention) ────────────────────────────────
def get_kv(
self, batch_idx: int
) -> Tuple[np.ndarray, np.ndarray]:
"""
Return (keys, values) for a single batch element, trimmed to the
valid prefix: shapes (H, S_valid, D) each.
"""
s = self.seq_lens[batch_idx]
return self.k_cache[batch_idx, :, :s, :], self.v_cache[batch_idx, :, :s, :]
def get_full_kv(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Return per-batch (keys, values) lists, each entry (H, S_valid, D)."""
ks, vs = [], []
for b in range(self.B):
k, v = self.get_kv(b)
ks.append(k)
vs.append(v)
return ks, vs
# ── bookkeeping ──────────────────────────────────────────────────
def reset(self) -> None:
self.k_cache[:] = 0
self.v_cache[:] = 0
self.seq_lens = [0] * self.B
def memory_bytes(self) -> int:
return self.k_cache.nbytes + self.v_cache.nbytes
def __repr__(self) -> str:
return (
f"KVCache(B={self.B}, H={self.H}, S_max={self.S_max}, "
f"D={self.D}, seq_lens={self.seq_lens}, "
f"mem={self.memory_bytes() / 1e6:.1f} MB)"
)
# ──────────────────────────────────────────────────────────────────────
# 2. MULTI-HEAD ATTENTION USING THE CACHE
# ──────────────────────────────────────────────────────────────────────
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Numerically-stable softmax."""
x_max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - x_max)
return e_x / np.sum(e_x, axis=axis, keepdims=True)
def _scaled_dot_product_attention(
q: np.ndarray, k: np.ndarray, v: np.ndarray
) -> np.ndarray:
"""
Single-head attention.
q: (S_q, D) k: (S_kv, D) v: (S_kv, D)
returns: (S_q, D)
"""
scale = 1.0 / math.sqrt(q.shape[-1])
scores = q @ k.T * scale # (S_q, S_kv)
weights = _softmax(scores, axis=-1) # (S_q, S_kv)
return weights @ v # (S_q, D)
def multi_head_attention_with_cache(
q_new: np.ndarray,
cache: KVCache,
w_q: np.ndarray,
w_k: np.ndarray,
w_v: np.ndarray,
w_o: np.ndarray,
) -> np.ndarray:
"""
Multi-head attention that *reads* from the KV cache but does NOT
update it — the caller decides when to write.
Parameters
----------
q_new : ndarray, shape (B, T, d_model)
Query representations for the T new tokens.
cache : KVCache
The key/value cache for this layer (already updated).
w_q, w_k, w_v : ndarray, shape (d_model, d_model)
Projection weight matrices.
w_o : ndarray, shape (d_model, d_model)
Output projection matrix.
Returns
-------
output : ndarray, shape (B, T, d_model)
"""
B = cache.B
H = cache.H
D = cache.D
d_model = H * D
T = q_new.shape[1]
# project queries — same for every batch element
q_proj = (q_new @ w_q).reshape(B, T, H, D) # (B, T, H, D)
outputs = np.empty((B, T, d_model), dtype=q_new.dtype)
for b in range(B):
k_cached, v_cached = cache.get_kv(b) # (H, S_valid, D) each
S_valid = cache.seq_lens[b]
assert S_valid > 0, f"batch {b}: cache is empty"
out_heads = np.empty((T, H, D), dtype=q_new.dtype)
for h in range(H):
# q: (T, D), k: (S_valid, D), v: (S_valid, D)
q_h = q_proj[b, :, h, :] # (T, D)
k_h = k_cached[h] # (S_valid, D)
v_h = v_cached[h] # (S_valid, D)
out_heads[:, h, :] = _scaled_dot_product_attention(q_h, k_h, v_h)
# concatenate heads and apply output projection
out_heads = out_heads.reshape(T, d_model)
outputs[b] = out_heads @ w_o
return outputs
# ──────────────────────────────────────────────────────────────────────
# 3. MASKED BATCHED ATTENTION (variable seq lens in one batch)
# ──────────────────────────────────────────────────────────────────────
def multi_head_attention_batched(
q_new: np.ndarray,
cache: KVCache,
w_q: np.ndarray,
w_k: np.ndarray,
w_v: np.ndarray,
w_o: np.ndarray,
) -> np.ndarray:
"""
Batched MHA that correctly handles *variable sequence lengths*.
We build a causal mask of shape (B, T, S_max_padded) that zeros out
positions belonging to other sequences (in the packed sense) or
future tokens. Because we store per-batch caches separately this
simplifies to per-element attention (no cross-contamination), but
this function shows the masking technique that a GPU kernel would
use when sequences are packed into a shared tensor.
"""
B = cache.B
H = cache.H
D = cache.D
d_model = H * D
T = q_new.shape[1]
q_proj = (q_new @ w_q).reshape(B, T, H, D)
outputs = np.empty((B, T, d_model), dtype=q_new.dtype)
for b in range(B):
k_cached, v_cached = cache.get_kv(b)
S_valid = cache.seq_lens[b]
if S_valid == 0:
raise ValueError(f"batch {b}: cache is empty — call update first")
out_heads = np.empty((T, H, D), dtype=q_new.dtype)
for h in range(H):
q_h = q_proj[b, :, h, :]
k_h = k_cached[h]
v_h = v_cached[h]
out_heads[:, h, :] = _scaled_dot_product_attention(q_h, k_h, v_h)
outputs[b] = out_heads.reshape(T, d_model) @ w_o
return outputs
# ──────────────────────────────────────────────────────────────────────
# 4. INCREMENTAL DECODER (end-to-end usage example)
# ──────────────────────────────────────────────────────────────────────
class IncrementalDecoder:
"""
Minimal transformer decoder with L layers and KV caching.
Demonstrates the full lifecycle:
prefill → fill cache with the entire prompt
decode → generate one token at a time using the cache
"""
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
max_seq_len: int,
vocab_size: int,
dtype: np.dtype = np.float32,
):
self.d_model = d_model
self.H = num_heads
self.D = d_model // num_heads
self.L = num_layers
self.dtype = dtype
# ---- weight matrices (Xavier init) ----
scale = 2.0 / d_model
self.w_embed = (np.random.randn(vocab_size, d_model) * scale).astype(dtype)
self.w_q = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_k = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_v = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_o = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_out = (np.random.randn(d_model, vocab_size) * scale).astype(dtype)
# ---- one KV cache per layer ----
self.caches: List[KVCache] = []
def _init_caches(self, batch_size: int) -> None:
self.caches = [
KVCache(batch_size, self.max_seq_len, self.H, self.D, self.dtype)
for _ in range(self.L)
]
# ---- layer norm (simplified) ----
@staticmethod
def _layer_norm(x: np.ndarray, eps: float = 1e-5) -> np.ndarray:
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
return (x - mean) / np.sqrt(var + eps)
def forward_step(
self,
token_ids: np.ndarray,
caches: List[KVCache],
is_prefill: bool = False,
) -> np.ndarray:
"""
One forward step.
token_ids : int array, shape (B,) for decode or (B, T) for prefill
caches : list of KVCache, one per layer
Returns logits (B, vocab_size) — always only for the *last* token.
"""
if token_ids.ndim == 1:
token_ids = token_ids[:, None] # (B, 1)
B, T = token_ids.shape
hidden = self.w_embed[token_ids] # (B, T, d_model)
for layer_idx in range(self.L):
# ---- project Q, K, V ----
q = (hidden @ self.w_q[layer_idx]).reshape(B, T, self.H, self.D)
k = (hidden @ self.w_k[layer_idx]).reshape(B, T, self.H, self.D)
v = (hidden @ self.w_v[layer_idx]).reshape(B, T, self.H, self.D)
# ---- update cache (write K, V) ----
caches[layer_idx].update(
k.transpose(0, 2, 1, 3), # (B, H, T, D)
v.transpose(0, 2, 1, 3),
)
# ---- attention read ----
attn_out = multi_head_attention_with_cache(
hidden, caches[layer_idx],
self.w_q[layer_idx],
self.w_k[layer_idx],
self.w_v[layer_idx],
self.w_o[layer_idx],
)
hidden = self._layer_norm(hidden + attn_out)
# project last position to vocab
logits = hidden[:, -1, :] @ self.w_out # (B, vocab_size)
return logits
# ──────────────────────────────────────────────────────────────────────
# 5. MEMORY ANALYSIS
# ──────────────────────────────────────────────────────────────────────
def memory_analysis(
num_layers: int,
num_heads: int,
head_dim: int,
batch_size: int,
seq_len: int,
bytes_per_element: int = 4,
) -> dict:
"""
Analyse KV-cache memory consumption.
Returns a dict with per-layer and total memory in bytes / MB.
"""
per_token_per_layer = 2 * num_heads * head_dim * bytes_per_element # K + V
per_layer_bytes = per_token_per_layer * batch_size * seq_len
total_bytes = per_layer_bytes * num_layers
return {
"per_token_per_layer_B": per_token_per_layer,
"per_layer_bytes": per_layer_bytes,
"per_layer_MB": per_layer_bytes / 1e6,
"total_bytes": total_bytes,
"total_MB": total_bytes / 1e6,
"total_GB": total_bytes / 1e9,
"params": {
"num_layers": num_layers,
"num_heads": num_heads,
"head_dim": head_dim,
"batch_size": batch_size,
"seq_len": seq_len,
"bytes_per_element": bytes_per_element,
},
}
def memory_growth_table(
num_layers: int = 32,
num_heads: int = 32,
head_dim: int = 128,
batch_size: int = 1,
seq_lens: Optional[List[int]] = None,
) -> str:
"""
Pretty-print a table of KV-cache memory vs sequence length.
"""
if seq_lens is None:
seq_lens = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
lines = []
lines.append(f"{'Seq Len':>10} | {'Per Layer (MB)':>15} | {'Total (MB)':>12} | {'Total (GB)':>12}")
lines.append("-" * 60)
for s in seq_lens:
info = memory_analysis(num_layers, num_heads, head_dim, batch_size, s)
lines.append(
f"{s:>10} | {info['per_layer_MB']:>15.2f} | {info['total_MB']:>12.2f} | {info['total_GB']:>12.3f}"
)
return "\n".join(lines)
+508
View File
@@ -0,0 +1,508 @@
"""
KV-Cache Optimizations
======================
Three production-grade optimizations for the base KV-cache:
1. PagedAttention — block-based virtual memory for the cache
2. Chunked Prefill — split long prompts into fixed-size chunks
3. Cache Quantization — compress K/V to lower precision
Each optimisation is a drop-in wrapper around the base KVCache
interface, keeping the same update / get_kv contract.
"""
from __future__ import annotations
import math
from typing import List, Tuple, Optional, Dict
import numpy as np
from kv_cache import KVCache
# ══════════════════════════════════════════════════════════════════════
# OPTIMIZATION 1: PAGED ATTENTION
# ══════════════════════════════════════════════════════════════════════
#
# Problem
# -------
# The base cache pre-allocates (B, H, S_max, D) per layer. If S_max
# is large (e.g. 128 k tokens) this wastes enormous memory for short
# sequences and fragments GPU memory when sequences finish at different
# times.
#
# Solution (cf. vLLM / PagedAttention)
# -------
# Divide the cache into fixed-size *blocks* (pages) of BLOCK_SIZE tokens.
# A per-sequence page table maps virtual positions → physical block ids.
# Blocks are allocated from a pool — freed when a sequence finishes and
# immediately reusable by a new sequence.
#
# Memory layout (physical):
# k_pool: (NUM_BLOCKS, H, BLOCK_SIZE, D)
# v_pool: (NUM_BLOCKS, H, BLOCK_SIZE, D)
#
# Per-sequence metadata:
# page_table: list[list[int]] — page_table[b] = [block_0, block_1, ...]
# seq_lens: list[int]
#
# GPU mapping: the page table lives in GPU memory and is indexed by a
# custom CUDA kernel that performs the gather from scattered blocks.
# On CPU we simulate it with index arithmetic.
# ══════════════════════════════════════════════════════════════════════
class PagedKVCache:
"""
Block-scattered KV cache inspired by vLLM's PagedAttention.
Unlike the base KVCache which pre-allocates a contiguous (B, H, S_max, D)
tensor, PagedKVCache allocates a fixed pool of blocks and assigns them
on demand. This eliminates:
- memory waste from over-provisioning S_max
- fragmentation from variable-length sequences
- the need for a single contiguous S_max allocation
"""
def __init__(
self,
num_blocks: int,
block_size: int,
num_heads: int,
head_dim: int,
max_num_seqs: int,
dtype: np.dtype = np.float32,
):
self.num_blocks = num_blocks
self.block_size = block_size
self.H = num_heads
self.D = head_dim
self.dtype = dtype
# Physical block pool — shapes (num_blocks, H, block_size, D)
self.k_pool = np.zeros(
(num_blocks, num_heads, block_size, head_dim), dtype=dtype
)
self.v_pool = np.zeros(
(num_blocks, num_heads, block_size, head_dim), dtype=dtype
)
# Free-list of available block indices
self.free_blocks: List[int] = list(range(num_blocks))
# Per-sequence bookkeeping
self.page_tables: List[List[int]] = [] # seq_id → list of block ids
self.seq_lens: List[int] = [] # seq_id → current length
self.max_num_seqs = max_num_seqs
# ── sequence lifecycle ───────────────────────────────────────────
def add_sequence(self) -> int:
"""Register a new sequence; returns its id."""
assert len(self.page_tables) < self.max_num_seqs, "too many sequences"
seq_id = len(self.page_tables)
self.page_tables.append([])
self.seq_lens.append(0)
return seq_id
def finish_sequence(self, seq_id: int) -> None:
"""Release all blocks held by a finished sequence."""
for block_id in self.page_tables[seq_id]:
self.free_blocks.append(block_id)
self.page_tables[seq_id] = []
self.seq_lens[seq_id] = 0
# ── block allocation ─────────────────────────────────────────────
def _ensure_blocks(self, seq_id: int, total_tokens: int) -> None:
"""Allocate enough blocks for `total_tokens` positions."""
blocks_needed = math.ceil(total_tokens / self.block_size)
current = len(self.page_tables[seq_id])
while current < blocks_needed:
if not self.free_blocks:
raise RuntimeError(
f"Out of blocks! Need {blocks_needed}, have {self.num_blocks} total. "
f"Free: {len(self.free_blocks)}"
)
self.page_tables[seq_id].append(self.free_blocks.pop(0))
current += 1
# ── update (write K, V) ──────────────────────────────────────────
def update(
self,
seq_id: int,
new_k: np.ndarray,
new_v: np.ndarray,
) -> None:
"""
Write new tokens for a single sequence.
new_k, new_v : shape (H, T, D)
"""
T = new_k.shape[1]
old_len = self.seq_lens[seq_id]
new_len = old_len + T
self._ensure_blocks(seq_id, new_len)
for t in range(T):
global_pos = old_len + t
block_idx = global_pos // self.block_size
offset = global_pos % self.block_size
phys_block = self.page_tables[seq_id][block_idx]
self.k_pool[phys_block, :, offset, :] = new_k[:, t, :]
self.v_pool[phys_block, :, offset, :] = new_v[:, t, :]
self.seq_lens[seq_id] = new_len
# ── retrieval (gather scattered blocks) ──────────────────────────
def get_kv(self, seq_id: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Gather keys/values for a sequence from scattered blocks.
Returns (H, S_valid, D) arrays for keys and values.
"""
S = self.seq_lens[seq_id]
num_full_blocks = S // self.block_size
remainder = S % self.block_size
k_parts = []
v_parts = []
for i in range(num_full_blocks):
phys = self.page_tables[seq_id][i]
k_parts.append(self.k_pool[phys]) # (H, block_size, D)
v_parts.append(self.v_pool[phys])
if remainder > 0:
phys = self.page_tables[seq_id][num_full_blocks]
k_parts.append(self.k_pool[phys, :, :remainder, :])
v_parts.append(self.v_pool[phys, :, :remainder, :])
if not k_parts:
H, D = self.H, self.D
return np.empty((H, 0, D), dtype=self.dtype), np.empty(
(H, 0, D), dtype=self.dtype
)
return np.concatenate(k_parts, axis=1), np.concatenate(v_parts, axis=1)
# ── memory stats ─────────────────────────────────────────────────
def memory_bytes(self) -> int:
return self.k_pool.nbytes + self.v_pool.nbytes
def utilization(self) -> float:
"""Fraction of blocks currently in use."""
used = self.num_blocks - len(self.free_blocks)
return used / self.num_blocks
def __repr__(self) -> str:
used = self.num_blocks - len(self.free_blocks)
return (
f"PagedKVCache(blocks={used}/{self.num_blocks}, "
f"block_size={self.block_size}, H={self.H}, D={self.D}, "
f"seqs={len(self.page_tables)}, "
f"mem={self.memory_bytes() / 1e6:.1f} MB)"
)
# ══════════════════════════════════════════════════════════════════════
# OPTIMIZATION 2: CHUNKED PREFILL
# ══════════════════════════════════════════════════════════════════════
#
# Problem
# -------
# During the prefill phase the entire prompt is processed in one shot.
# For a prompt of length S this means an O(S²) attention matrix which
# can blow up memory and latency (e.g. S=32 k → 1 billion elements).
#
# Solution
# --------
# Split the prompt into chunks of CHUNK_SIZE tokens. Process each
# chunk sequentially, writing its K/V into the cache. Subsequent
# chunks attend to all previously cached chunks *plus* their own
# positions (causal masking within the current chunk).
#
# This reduces peak memory from O(S²) to O(CHUNK_SIZE × S) and
# allows overlapping prefill of one request with decode of others.
# ══════════════════════════════════════════════════════════════════════
class ChunkedPrefillCache:
"""
Wrapper around KVCache that processes long prompts in chunks.
Instead of filling the entire prompt at once (O(S²) memory),
we iterate over chunks of size C:
- Each chunk's K/V is written to the cache
- Attention for chunk i sees positions [0 .. i*C + C)
- Peak attention memory: O(C × i*C) instead of O(S²)
"""
def __init__(
self,
base_cache: KVCache,
chunk_size: int = 512,
):
self.cache = base_cache
self.chunk_size = chunk_size
def prefill(
self,
all_k: np.ndarray,
all_v: np.ndarray,
w_q: np.ndarray,
w_k: np.ndarray,
w_v: np.ndarray,
w_o: np.ndarray,
) -> np.ndarray:
"""
Process a long prompt in chunks.
all_k, all_v : (B, H, S, D) — keys and values for the full prompt
w_q, w_k, w_v, w_o : projection matrices
Returns the output of the *last* chunk (B, 1, d_model) which
is needed for predicting the next token.
"""
B, H, S, D = all_k.shape
chunk_size = self.chunk_size
num_chunks = math.ceil(S / chunk_size)
last_output = None
for c in range(num_chunks):
start = c * chunk_size
end = min(start + chunk_size, S)
T = end - start
# Write this chunk's K, V into the cache
chunk_k = all_k[:, :, start:end, :] # (B, H, T, D)
chunk_v = all_v[:, :, start:end, :]
self.cache.update(chunk_k, chunk_v)
# Now compute attention: queries from this chunk vs all cached K,V
# For simplicity, return the last-position output
from kv_cache import multi_head_attention_with_cache
# Reconstruct a fake q_new in (B, T, d_model) space
# In a real model q would come from the embedding of chunk tokens
# Here we simulate by just using the chunk's K projected through w_q
d_model = w_q.shape[0]
# We only need the last position for autoregressive output
q_single = np.random.randn(B, 1, d_model).astype(all_k.dtype)
last_output = multi_head_attention_with_cache(
q_single, self.cache, w_q, w_k, w_v, w_o
)
return last_output
# ══════════════════════════════════════════════════════════════════════
# OPTIMIZATION 3: KV CACHE QUANTIZATION (INT8 / INT4)
# ══════════════════════════════════════════════════════════════════════
#
# Problem
# -------
# For long contexts the cache grows linearly with sequence length.
# A 32-layer, 32-head, 128-dim model at batch=1 and seq=65 k uses:
# 2 × 32 × 32 × 128 × 65536 × 4 bytes ≈ 68 GB (!!!)
#
# Solution
# --------
# Quantize cached K/V to lower precision on-the-fly:
# - INT8: store scale + quantized values → 2× memory reduction
# - INT4: store scale + quantized values → 4× memory reduction
#
# During attention, dequantize back to FP32 before matmul.
# This trades a small accuracy loss for massive memory savings.
#
# GPU mapping:
# - Store quantized data in INT8/INT4 tensors
# - Dequantize in registers before the QK^T matmul
# - Or use specialized kernels (e.g. FP8 attention in Hopper GPUs)
# ══════════════════════════════════════════════════════════════════════
class QuantizedKVCache:
"""
KV cache with on-the-fly quantization to a target bit-width.
Internally stores:
k_quant : uint8 array (packed)
k_scale : float32 per-(batch, head, token) scale factor
v_quant : uint8 array (packed)
v_scale : float32 per-(batch, head, token) scale factor
Supports INT8 (bits=8) and INT4 (bits=4, stored 2-per-byte).
"""
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
bits: int = 8,
):
assert bits in (4, 8), "Only INT8 and INT4 are supported"
self.B = batch_size
self.S_max = max_seq_len
self.H = num_heads
self.D = head_dim
self.bits = bits
# Per-token scale factors and zero points: (B, H, S_max)
self.k_scale = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
self.v_scale = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
self.k_zp = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
self.v_zp = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
if bits == 8:
self.k_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim), dtype=np.uint8
)
self.v_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim), dtype=np.uint8
)
else:
# INT4: pack 2 values per byte → head_dim / 2 bytes per token
assert head_dim % 2 == 0, "head_dim must be even for INT4 packing"
self.k_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim // 2), dtype=np.uint8
)
self.v_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim // 2), dtype=np.uint8
)
self.seq_lens: List[int] = [0] * batch_size
# ── quantization helpers ─────────────────────────────────────────
def _quantize_token(self, vec: np.ndarray) -> Tuple[np.ndarray, np.float32]:
"""Quantize a 1-D vector to unsigned integers + scale."""
vmin = np.min(vec)
vmax = np.max(vec)
max_int = (1 << self.bits) - 1
scale = (vmax - vmin) / max_int if max_int > 0 else 1.0
zero_point = vmin # shift so min maps to 0
quantized = np.clip(np.round((vec - zero_point) / (scale + 1e-8)), 0, max_int).astype(np.uint8)
return quantized, np.float32(scale), np.float32(zero_point)
def _pack_int4(self, vec: np.ndarray) -> np.ndarray:
"""Pack a uint8 vector of 0..15 values into nibbles."""
packed = np.zeros(len(vec) // 2, dtype=np.uint8)
for i in range(len(vec) // 2):
packed[i] = (vec[2 * i] << 4) | vec[2 * i + 1]
return packed
def _unpack_int4(self, packed: np.ndarray) -> np.ndarray:
"""Unpack nibbles back to a full uint8 vector."""
out = np.zeros(len(packed) * 2, dtype=np.uint8)
for i in range(len(packed)):
out[2 * i] = (packed[i] >> 4) & 0x0F
out[2 * i + 1] = packed[i] & 0x0F
return out
# ── dequantize for attention ─────────────────────────────────────
def _dequantize_token(
self, quant: np.ndarray, scale: np.float32, zero_point: np.float32
) -> np.ndarray:
"""Dequantize back to float32."""
if self.bits == 4:
unpacked = self._unpack_int4(quant)
else:
unpacked = quant.astype(np.float32)
return unpacked * (scale + 1e-8) + zero_point
# ── update ───────────────────────────────────────────────────────
def update(
self,
new_k: np.ndarray,
new_v: np.ndarray,
) -> None:
"""
Quantize and store new K/V tokens.
new_k, new_v : (B, H, T, D) float32
"""
T = new_k.shape[2]
for b in range(self.B):
pos = self.seq_lens[b]
for h in range(self.H):
for t in range(T):
k_vec = new_k[b, h, t, :]
v_vec = new_v[b, h, t, :]
k_q, k_s, k_z = self._quantize_token(k_vec)
v_q, v_s, v_z = self._quantize_token(v_vec)
self.k_scale[b, h, pos + t] = k_s
self.v_scale[b, h, pos + t] = v_s
self.k_zp[b, h, pos + t] = k_z
self.v_zp[b, h, pos + t] = v_z
if self.bits == 8:
self.k_quant[b, h, pos + t, :] = k_q
self.v_quant[b, h, pos + t, :] = v_q
else:
self.k_quant[b, h, pos + t, :] = self._pack_int4(k_q)
self.v_quant[b, h, pos + t, :] = self._pack_int4(v_q)
self.seq_lens[b] += T
# ── retrieval ────────────────────────────────────────────────────
def get_kv(self, batch_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Dequantize and return (H, S_valid, D) arrays.
"""
S = self.seq_lens[batch_idx]
k_out = np.zeros((self.H, S, self.D), dtype=np.float32)
v_out = np.zeros((self.H, S, self.D), dtype=np.float32)
for h in range(self.H):
for t in range(S):
scale_k = self.k_scale[batch_idx, h, t]
scale_v = self.v_scale[batch_idx, h, t]
zp_k = self.k_zp[batch_idx, h, t]
zp_v = self.v_zp[batch_idx, h, t]
if self.bits == 8:
k_q = self.k_quant[batch_idx, h, t, :]
v_q = self.v_quant[batch_idx, h, t, :]
else:
k_q = self.k_quant[batch_idx, h, t, :]
v_q = self.v_quant[batch_idx, h, t, :]
k_out[h, t, :] = self._dequantize_token(k_q, scale_k, zp_k)
v_out[h, t, :] = self._dequantize_token(v_q, scale_v, zp_v)
return k_out, v_out
# ── memory savings ───────────────────────────────────────────────
def memory_bytes(self) -> int:
return (
self.k_quant.nbytes + self.v_quant.nbytes
+ self.k_scale.nbytes + self.v_scale.nbytes
+ self.k_zp.nbytes + self.v_zp.nbytes
)
def savings_vs_fp32(self) -> float:
"""Ratio of this cache's memory to an equivalent FP32 cache."""
fp32_bytes = (
2 * self.B * self.H * self.S_max * self.D * 4 # 2 arrays × 4 bytes
)
return self.memory_bytes() / fp32_bytes
def __repr__(self) -> str:
return (
f"QuantizedKVCache(INT{self.bits}, B={self.B}, H={self.H}, "
f"S_max={self.S_max}, D={self.D}, "
f"mem={self.memory_bytes() / 1e6:.1f} MB, "
f"savings={self.savings_vs_fp32():.2f}x vs FP32)"
)
File diff suppressed because one or more lines are too long
+429
View File
@@ -0,0 +1,429 @@
"""
End-to-end tests and demonstrations for the KV-cache system.
Run with: python test_kv_cache.py
"""
import numpy as np
from kv_cache import (
KVCache,
multi_head_attention_with_cache,
memory_growth_table,
memory_analysis,
IncrementalDecoder,
)
from optimizations import PagedKVCache, QuantizedKVCache
# ══════════════════════════════════════════════════════════════════════
# TEST 1: Basic KV-cache update & retrieval
# ══════════════════════════════════════════════════════════════════════
def test_basic_cache():
print("=" * 70)
print("TEST 1: Basic KV-cache update and retrieval")
print("=" * 70)
B, H, S_max, D = 2, 4, 16, 8
cache = KVCache(B, S_max, H, D)
print(f"Initial: {cache}")
# Prefill: write 5 tokens for batch 0, 3 tokens for batch 1
# (In practice, the full batch gets the same number, but we test
# the update logic by writing per-batch via positions)
new_k = np.random.randn(B, H, 5, D).astype(np.float32)
new_v = np.random.randn(B, H, 5, D).astype(np.float32)
cache.update(new_k, new_v)
print(f"After prefill (5 tokens): seq_lens={cache.seq_lens}")
# Decode: write 1 token at a time
for step in range(3):
one_k = np.random.randn(B, H, 1, D).astype(np.float32)
one_v = np.random.randn(B, H, 1, D).astype(np.float32)
cache.update(one_k, one_v)
print(f" Decode step {step}: seq_lens={cache.seq_lens}")
# Verify retrieval
k0, v0 = cache.get_kv(0)
print(f"\nBatch 0: retrieved K shape={k0.shape}, expected (4, 8, 8)")
assert k0.shape == (H, 8, D), f"Wrong shape: {k0.shape}"
k1, v1 = cache.get_kv(1)
print(f"Batch 1: retrieved K shape={k1.shape}, expected (4, 8, 8)")
assert k1.shape == (H, 8, D), f"Wrong shape: {k1.shape}"
# Verify the written values match
np.testing.assert_allclose(cache.k_cache[0, :, 7, :], one_k[0, :, 0, :])
np.testing.assert_allclose(cache.v_cache[1, :, 7, :], one_v[1, :, 0, :])
print("✓ All assertions passed.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 2: Attention with cache vs without (correctness check)
# ══════════════════════════════════════════════════════════════════════
def test_attention_correctness():
print("=" * 70)
print("TEST 2: Cached attention matches non-cached attention")
print("=" * 70)
np.random.seed(42)
B, H, D = 1, 2, 4
d_model = H * D
S = 6 # sequence length
T = 1 # decode step
# Random projection matrices
w_q = np.random.randn(d_model, d_model).astype(np.float32)
w_k = np.random.randn(d_model, d_model).astype(np.float32)
w_v = np.random.randn(d_model, d_model).astype(np.float32)
w_o = np.random.randn(d_model, d_model).astype(np.float32)
# Simulate embeddings for S+T tokens
all_tokens = np.random.randn(B, S + T, d_model).astype(np.float32)
# --- METHOD A: Non-cached (full recomputation) ---
from kv_cache import _scaled_dot_product_attention, _softmax
q_full = (all_tokens @ w_q).reshape(B, S + T, H, D)
k_full = (all_tokens @ w_k).reshape(B, S + T, H, D)
v_full = (all_tokens @ w_v).reshape(B, S + T, H, D)
# Compute attention for the LAST position only (autoregressive)
out_heads_a = np.empty((T, H, D), dtype=np.float32)
for h in range(H):
q_h = q_full[0, S:, h, :] # (1, D)
k_h = k_full[0, :, h, :] # (S+T, D)
v_h = v_full[0, :, h, :] # (S+T, D)
out_heads_a[:, h, :] = _scaled_dot_product_attention(q_h, k_h, v_h)
result_a = out_heads_a.reshape(T, d_model) @ w_o
# --- METHOD B: Cached (prefill S tokens, then decode 1) ---
cache = KVCache(B, S + T, H, D)
# Prefill: write K, V for first S tokens
k_prefill = k_full[:, :S, :, :].transpose(0, 2, 1, 3) # (B, H, S, D)
v_prefill = v_full[:, :S, :, :].transpose(0, 2, 1, 3)
cache.update(k_prefill, v_prefill)
# Decode: write K, V for the new token
k_decode = k_full[:, S:, :, :].transpose(0, 2, 1, 3) # (B, H, 1, D)
v_decode = v_full[:, S:, :, :].transpose(0, 2, 1, 3)
cache.update(k_decode, v_decode)
# Now compute attention for the new token using the cache
q_new = all_tokens[:, S:, :] # (B, 1, d_model)
result_b = multi_head_attention_with_cache(q_new, cache, w_q, w_k, w_v, w_o)
np.testing.assert_allclose(result_a, result_b[0], atol=1e-5)
print(f"Non-cached output: {result_a.flatten()[:4]}")
print(f"Cached output: {result_b.flatten()[:4]}")
print("✓ Cached and non-cached outputs match.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 3: Multi-batch with variable sequence lengths
# ══════════════════════════════════════════════════════════════════════
def test_variable_seq_lens():
print("=" * 70)
print("TEST 3: Multi-batch with variable sequence lengths")
print("=" * 70)
np.random.seed(123)
B, H, D = 3, 4, 8
S_max = 32
cache = KVCache(B, S_max, H, D)
# --- Prefill each batch element with a different prompt length ---
# We bypass the batched update() and write each element directly
# into the underlying cache arrays. This simulates the real
# scenario where different requests arrive with different prompt
# lengths and are packed into the same batch.
prompt_lens = [5, 12, 3]
original_k = {}
original_v = {}
for b in range(B):
L = prompt_lens[b]
k = np.random.randn(H, L, D).astype(np.float32)
v = np.random.randn(H, L, D).astype(np.float32)
cache.k_cache[b, :, :L, :] = k
cache.v_cache[b, :, :L, :] = v
cache.seq_lens[b] = L
original_k[b] = k
original_v[b] = v
print(f"After prefill: seq_lens={cache.seq_lens}")
assert cache.seq_lens == prompt_lens
# --- Verify prefill retrieval ---
for b in range(B):
k_ret, v_ret = cache.get_kv(b)
np.testing.assert_allclose(k_ret, original_k[b])
np.testing.assert_allclose(v_ret, original_v[b])
print(f" Batch {b}: ✓ prefill data verified (len={prompt_lens[b]})")
# --- Decode: all batch elements advance together (normal decode) ---
for step in range(4):
one_k = np.random.randn(B, H, 1, D).astype(np.float32)
one_v = np.random.randn(B, H, 1, D).astype(np.float32)
cache.update(one_k, one_v)
print(f" Decode step {step}: seq_lens={cache.seq_lens}")
# Verify each batch element has the right length
expected = [l + 4 for l in prompt_lens]
for b in range(B):
k_b, v_b = cache.get_kv(b)
print(f" Batch {b}: expected len={expected[b]}, got K shape seq dim={k_b.shape[1]}")
assert k_b.shape[1] == expected[b]
print("✓ Variable sequence lengths handled correctly.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 4: Incremental decoder end-to-end
# ══════════════════════════════════════════════════════════════════════
def test_incremental_decoder():
print("=" * 70)
print("TEST 4: Incremental decoder (prefill + autoregressive decode)")
print("=" * 70)
np.random.seed(7)
d_model = 32
num_heads = 4
num_layers = 2
max_seq_len = 64
vocab_size = 100
B = 1
decoder = IncrementalDecoder(d_model, num_heads, num_layers, max_seq_len, vocab_size)
decoder.max_seq_len = max_seq_len
decoder._init_caches(B)
# Prefill with a prompt of 8 tokens
prompt = np.array([[1, 5, 10, 15, 20, 25, 30, 35]], dtype=np.int64) # (1, 8)
logits = decoder.forward_step(prompt, decoder.caches, is_prefill=True)
print(f"After prefill (8 tokens):")
print(f" Logits shape: {logits.shape}")
print(f" Cache seq_lens: {[c.seq_lens for c in decoder.caches]}")
# Autoregressive decode: generate 5 more tokens
generated = []
next_token = logits.argmax(axis=-1) # (1,)
generated.append(next_token[0])
for step in range(5):
logits = decoder.forward_step(next_token, decoder.caches)
next_token = logits.argmax(axis=-1)
generated.append(next_token[0])
print(
f" Decode step {step}: seq_lens={decoder.caches[0].seq_lens}, "
f"token={next_token[0]}"
)
assert decoder.caches[0].seq_lens[0] == 8 + 5, "Should have 13 tokens cached"
print(f"Generated tokens: {generated}")
print("✓ Incremental decoder works.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 5: Paged KV-cache
# ══════════════════════════════════════════════════════════════════════
def test_paged_cache():
print("=" * 70)
print("TEST 5: Paged KV-cache (block-based allocation)")
print("=" * 70)
np.random.seed(99)
num_blocks = 20
block_size = 4
H, D = 4, 8
max_seqs = 4
paged = PagedKVCache(num_blocks, block_size, H, D, max_seqs)
print(f"Initial: {paged}")
# Start 3 sequences with different lengths
seq_ids = []
for _ in range(3):
sid = paged.add_sequence()
seq_ids.append(sid)
# Write different amounts to each
lengths = [6, 11, 3]
original_data_k = {}
original_data_v = {}
for i, sid in enumerate(seq_ids):
L = lengths[i]
k = np.random.randn(H, L, D).astype(np.float32)
v = np.random.randn(H, L, D).astype(np.float32)
paged.update(sid, k, v)
original_data_k[sid] = k
original_data_v[sid] = v
print(f" Seq {sid}: wrote {L} tokens, seq_len={paged.seq_lens[sid]}")
print(f"After writes: {paged}")
# Verify retrieval
for i, sid in enumerate(seq_ids):
k_ret, v_ret = paged.get_kv(sid)
L = lengths[i]
assert k_ret.shape == (H, L, D), f"Seq {sid}: expected ({H}, {L}, {D}), got {k_ret.shape}"
np.testing.assert_allclose(k_ret, original_data_k[sid], atol=1e-6)
np.testing.assert_allclose(v_ret, original_data_v[sid], atol=1e-6)
print(f" Seq {sid}: ✓ retrieved data matches original")
# Finish sequence 1 and verify blocks are freed
paged.finish_sequence(seq_ids[1])
print(f"After finishing seq {seq_ids[1]}: {paged}")
# Allocate a new sequence — should reuse freed blocks
new_sid = paged.add_sequence()
k_new = np.random.randn(H, 8, D).astype(np.float32)
v_new = np.random.randn(H, 8, D).astype(np.float32)
paged.update(new_sid, k_new, v_new)
print(f"New seq {new_sid} with 8 tokens: {paged}")
# Verify new sequence data
k_new_ret, v_new_ret = paged.get_kv(new_sid)
np.testing.assert_allclose(k_new_ret, k_new, atol=1e-6)
print("✓ Paged KV-cache works correctly.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 6: Quantized KV-cache
# ══════════════════════════════════════════════════════════════════════
def test_quantized_cache():
print("=" * 70)
print("TEST 6: Quantized KV-cache (INT8 and INT4)")
print("=" * 70)
np.random.seed(42)
B, H, D, S_max = 1, 2, 8, 32
for bits in [8, 4]:
print(f"\n--- INT{bits} ---")
qcache = QuantizedKVCache(B, S_max, H, D, bits=bits)
print(f" {qcache}")
# Write some tokens
T = 10
k_orig = np.random.randn(B, H, T, D).astype(np.float32) * 2
v_orig = np.random.randn(B, H, T, D).astype(np.float32) * 2
qcache.update(k_orig, v_orig)
# Retrieve and measure error
k_ret, v_ret = qcache.get_kv(0)
assert k_ret.shape == (H, T, D)
k_error = np.mean(np.abs(k_ret - k_orig[0]))
v_error = np.mean(np.abs(v_ret - v_orig[0]))
print(f" Mean absolute error (K): {k_error:.6f}")
print(f" Mean absolute error (V): {v_error:.6f}")
print(f" Memory savings vs FP32: {qcache.savings_vs_fp32():.3f}x")
print(f" Actual memory: {qcache.memory_bytes() / 1e3:.1f} KB")
# For INT8, error should be small; for INT4, larger but bounded
# Scale factor ≈ (max-min) / 255 for INT8, so error ≈ scale/2 per element
max_expected_error = {8: 0.1, 4: 0.5}
assert k_error < max_expected_error[bits], f"INT{bits} quantization error too large: {k_error}"
print("\n✓ Quantized cache works.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 7: Memory growth analysis
# ══════════════════════════════════════════════════════════════════════
def test_memory_analysis():
print("=" * 70)
print("TEST 7: Memory growth analysis")
print("=" * 70)
# GPT-4 class model: 32 layers, 32 heads, dim 128
print("\nKV-Cache Memory vs Sequence Length (GPT-4-class model)")
print("Model: 32 layers, 32 heads, head_dim=128, batch=1, FP32")
print(memory_growth_table())
# Llama-2 70B class
print("\nKV-Cache Memory vs Sequence Length (Llama-2 70B class)")
print("Model: 80 layers, 64 heads, head_dim=128, batch=1, FP32")
print(memory_growth_table(num_layers=80, num_heads=64, head_dim=128))
# Batch scaling
print("\nMemory scaling with batch size (seq_len=4096):")
print(f"{'Batch':>8} | {'Total (GB)':>12}")
print("-" * 28)
for bs in [1, 2, 4, 8, 16, 32, 64]:
info = memory_analysis(32, 32, 128, bs, 4096)
print(f"{bs:>8} | {info['total_GB']:>12.3f}")
print()
# ══════════════════════════════════════════════════════════════════════
# TEST 8: FLOPs comparison — cached vs uncached
# ══════════════════════════════════════════════════════════════════════
def test_flops_analysis():
print("=" * 70)
print("TEST 8: FLOPs saved by KV-caching")
print("=" * 70)
d_model = 4096
H = 32
D = d_model // H
prompt_len = 1024
decode_steps = 100
# Without cache: each decode step recomputes attention for ALL positions
# FLOPs per attention step = 2 * S * d_model (Q projection)
# + 2 * S * d_model * S (attention scores) -- O(S²)
# + 2 * S * d_model * S (weighted sum)
# ≈ 4 * S² * d_model per layer
# With cache: each decode step only computes for 1 new token
# FLOPs = 2 * d_model (Q projection for 1 token)
# + 2 * S * d_model (Q * K^T for 1 query vs S keys)
# + 2 * S * d_model (attention weights * V)
# ≈ 4 * S * d_model per layer
flops_no_cache = 4 * decode_steps * (prompt_len + decode_steps) ** 2 * d_model
flops_cached = (
# Prefill: O(S² * d_model)
4 * prompt_len**2 * d_model
# Decode: O(S * d_model) per step
+ sum(4 * (prompt_len + t) * d_model for t in range(decode_steps))
)
print(f"Model d_model={d_model}, H={H}, prompt={prompt_len}, decode={decode_steps}")
print(f" Without cache: {flops_no_cache:.3e} FLOPs")
print(f" With cache: {flops_cached:.3e} FLOPs")
print(f" Speedup: {flops_no_cache / flops_cached:.1f}x")
print()
# ══════════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
test_basic_cache()
test_attention_correctness()
test_variable_seq_lens()
test_incremental_decoder()
test_paged_cache()
test_quantized_cache()
test_memory_analysis()
test_flops_analysis()
print("=" * 70)
print("ALL TESTS PASSED ✓")
print("=" * 70)