""" GPU Execution Mapping for KV-Cache Systems Documents how the KV-cache system maps to GPU hardware: - Memory hierarchy (registers, shared mem, global mem, HBM) - Kernel design for attention with cache - CUDA optimization strategies - Tensor Core utilization """ import numpy as np from typing import Dict, List # ============================================================================= # GPU MEMORY HIERARCHY REFERENCE # ============================================================================= GPU_HIERARCHY = { "registers": { "size_per_sm": "64 KB", "latency": "1 cycle", "usage": "Thread-local variables, warp-level computation", }, "shared_memory": { "size_per_sm": "166 KB (H100)", "latency": "1-3 cycles", "usage": "Tiling, cooperative loading, softmax intermediate", }, "l2_cache": { "size": "50 MB (H100)", "latency": "~20 cycles", "usage": "Automatic caching of global memory accesses", }, "hbm": { "size": "80 GB (H100)", "bandwidth": "3.35 TB/s (H100)", "latency": "~300-400 cycles", "usage": "Model weights, KV cache, activations", }, } # ============================================================================= # KERNEL DESIGN: CACHED ATTENTION # ============================================================================= def describe_cached_attention_kernel(): """ Describe the CUDA kernel for cached attention. Kernel: cached_attention<<>>(Q, K_cache, V_cache, Out, ...) Thread block organization: - Each block handles one (batch, head) pair - Threads within a block cooperate on the matmul Q @ K^T Memory access pattern: 1. Load Q tile into shared memory (small: 1 x head_dim) 2. Stream K_cache tiles from global memory into shared memory 3. Compute partial dot products in registers 4. Accumulate scores in shared memory 5. Softmax in shared memory 6. Stream V_cache tiles and compute output """ description = { "kernel_name": "cached_attention", "grid": "(batch_size, num_heads, 1)", "block": "(BLOCK_X, BLOCK_Y) — e.g., (32, 32) for 1024 threads", "shared_memory_usage": { "q_tile": "1 x head_dim (e.g., 1 x 128 = 128 floats = 512 bytes fp16)", "k_tile": "BLOCK_Y x head_dim (e.g., 32 x 128 = 4096 floats = 8 KB fp16)", "v_tile": "BLOCK_Y x head_dim (same as K)", "score_tile": "BLOCK_X x BLOCK_Y (e.g., 32 x 32 = 1024 floats = 4 KB fp16)", "total_shared_per_block": "~16-20 KB (fits in 166 KB SM)", }, "global_memory_accesses": { "read_q": "batch * heads * 1 * head_dim (tiny)", "read_k_cache": "batch * heads * seq_len * head_dim (dominant)", "read_v_cache": "batch * heads * seq_len * head_dim (dominant)", "write_output": "batch * heads * 1 * head_dim (tiny)", }, "optimization_strategies": [ "1. Coalesced global memory access: threads in a warp access consecutive addresses", "2. Tiled GEMM: process K/V in tiles that fit in shared memory", "3. Persistent kernels: keep blocks alive until all tiles processed", "4. Async copy (H100): use cp.async to overlap memory transfer with computation", "5. Tensor Cores: use WMMA or mma.sync for the matmul operations", "6. Fusion: fuse softmax with attention score computation", ], } return description # ============================================================================= # TENSOR CORE UTILIZATION # ============================================================================= def tensor_core_analysis(head_dim: int = 128, seq_len: int = 4096, batch: int = 4, heads: int = 32) -> Dict: """ Analyze Tensor Core utilization for cached attention. H100 Tensor Core specs (FP16): - MMA shape: M x N x K where M,N,K are multiples of 16 - Peak throughput: ~1,970 TFLOPS (FP16 Tensor Core) - Each MMA instruction: 16x16x16 = 4096 FLOPs """ # Q @ K^T: (batch, heads, 1, head_dim) @ (batch, heads, head_dim, seq_len) # FLOPs per (batch, head): 2 * 1 * head_dim * seq_len flops_qk = 2 * batch * heads * 1 * head_dim * seq_len # Attn @ V: (batch, heads, 1, seq_len) @ (batch, heads, seq_len, head_dim) flops_av = 2 * batch * heads * 1 * seq_len * head_dim total_flops = flops_qk + flops_av # Memory traffic elem_bytes = 2 # fp16 mem_q = batch * heads * 1 * head_dim * elem_bytes mem_k = batch * heads * seq_len * head_dim * elem_bytes mem_v = batch * heads * seq_len * head_dim * elem_bytes mem_out = batch * heads * 1 * head_dim * elem_bytes total_mem = mem_q + mem_k + mem_v + mem_out # Arithmetic intensity (FLOPs per byte) intensity = total_flops / total_mem # H100 peak h100_peak_tflops = 1970 # FP16 Tensor Core h100_bandwidth = 3.35e12 # bytes/s # Theoretical time bounds compute_bound_s = total_flops / (h100_peak_tflops * 1e12) memory_bound_s = total_mem / h100_bandwidth return { "flops_qk": f"{flops_qk / 1e9:.2f} GFLOPs", "flops_av": f"{flops_av / 1e9:.2f} GFLOPs", "total_flops": f"{total_flops / 1e9:.2f} GFLOPs", "memory_traffic_mb": f"{total_mem / 1e6:.2f} MB", "arithmetic_intensity": f"{intensity:.2f} FLOPs/byte", "compute_bound_ms": f"{compute_bound_s * 1000:.4f} ms", "memory_bound_ms": f"{memory_bound_s * 1000:.4f} ms", "bound": "compute-bound" if compute_bound_s > memory_bound_s else "memory-bound", "h100_peak_tflops": h100_peak_tflops, "h100_bandwidth_tbps": h100_bandwidth / 1e12, } # ============================================================================= # GPU EXECUTION PIPELINE # ============================================================================= def describe_execution_pipeline(): """ Describe the full GPU execution pipeline for one generation step. Step 1: Embedding lookup - Input: token_id (batch, 1) - Operation: embedding[token_id] -> (batch, 1, dim) - GPU: Gathers from embedding table (random access, use shared mem tiling) Step 2: Positional encoding - Operation: x += pos_encoding[current_pos] - GPU: Simple element-wise add (fully parallel) Step 3: Per-layer forward pass (repeated L times) 3a. LayerNorm - GPU: Parallel reduction for mean/var, then element-wise 3b. QKV projection - GPU: 3 parallel GEMMs: x @ Wq, x @ Wk, x @ Wv - cuBLAS/cutlass: highly optimized for small M (M=1) 3c. KV cache update - GPU: Simple copy to global memory (coalesced write) - cache_k[:, :, write_pos, :] = k[:, :, 0, :] 3d. Cached attention - GPU: Custom kernel (see describe_cached_attention_kernel) - Two GEMMs + softmax, tiled for shared memory 3e. Output projection - GPU: GEMM: attn_out @ Wo 3f. MLP - GPU: Two GEMMs with activation fusion 3g. Residual add + LayerNorm - GPU: Element-wise operations Step 4: LM head - GPU: GEMM: x @ W_lm -> logits (batch, vocab_size) Step 5: Sampling - GPU: Argmax or top-k sampling kernel - Can be done on CPU for small batch sizes """ return { "steps": [ "1. Embedding lookup (gather)", "2. Positional encoding (element-wise add)", "3. Per-layer: LayerNorm -> QKV proj -> cache update -> attention -> MLP", "4. LM head (GEMM)", "5. Sampling (argmax/top-k)", ], "bottleneck": "Cached attention (memory-bound for long sequences)", "optimization_opportunities": [ "Operator fusion: merge LayerNorm + GEMM bias + activation", "Batched GEMM: process all layers' small GEMMs together", "Pipeline parallelism: overlap layers' computation", "FlashAttention-style tiling for the cached attention kernel", "Warp-specialized design: some warps load, some compute", ], } # ============================================================================= # FLASH-ATTENTION-STYLE CACHED KERNEL # ============================================================================= def describe_flash_attention_cached(): """ FlashAttention-style kernel adapted for cached attention. Key insight: instead of materializing the full (1 x seq_len) attention matrix, process K/V in tiles and accumulate softmax online. Algorithm (for one batch/head): 1. Initialize: output = 0, m = -inf, l = 0 (online softmax state) 2. For each K/V tile (size BLOCK): a. Compute S = Q @ K_tile^T (in shared memory) b. m_new = max(m, max(S)) c. l = l * exp(m - m_new) + sum(exp(S - m_new)) d. output = output * (l_old / l) + sum(exp(S - m_new) * V_tile) e. m = m_new 3. output = output / l This avoids materializing the full attention matrix and reduces HBM traffic from O(seq_len * head_dim) to O(seq_len * head_dim / BLOCK). """ return { "name": "FlashAttention-style cached kernel", "key_benefit": "O(1) shared memory usage regardless of sequence length", "hbm_traffic_reduction": "Reduces from 4 reads to ~2 reads of K/V cache", "shared_memory": "Only needs BLOCK x head_dim tiles, not full seq_len", "complexity": "More complex kernel but 2-4x faster for long sequences", "implementation_notes": [ "Requires careful numerical stability (online softmax)", "Two-pass: forward pass accumulates, backward pass needs recompute", "For generation (single query), simpler than full FlashAttention", "Can use mma.sync for the tile GEMMs on H100", ], } # ============================================================================= # MULTI-GPU STRATEGIES # ============================================================================= def describe_multi_gpu(): """ Multi-GPU strategies for large models with KV cache. """ return { "tensor_parallelism": { "description": "Split model weights across GPUs (Megatron-LM style)", "kv_cache_impact": "Each GPU holds its shard of K/V (split by head_dim)", "communication": "AllReduce in MLP, all-to-all in attention", "scaling": "Linear with num GPUs (up to num_heads)", }, "pipeline_parallelism": { "description": "Split layers across GPUs", "kv_cache_impact": "Each GPU holds K/V for its layer shard", "communication": "Send activations between stages", "challenge": "Bubble idle time; needs micro-batching", }, "sequence_parallelism": { "description": "Split sequence across GPUs (for prefill)", "kv_cache_impact": "Each GPU holds K/V for its sequence shard", "communication": "All-to-all for attention across sequence shards", "best_for": "Very long context prefill", }, "expert_parallelism": { "description": "For MoE models (Mixtral, Grok)", "kv_cache_impact": "KV cache is shared; only MLP experts are sharded", "communication": "All-to-all for expert routing", }, } # ============================================================================= # PRACTICAL GPU TUNING GUIDE # ============================================================================= def gpu_tuning_guide(): """ Practical GPU tuning recommendations for KV-cache inference. """ return { "streaming_KV_cache": { "problem": "For long sequences, K/V cache reads dominate latency", "solution": "Use H100's copy engine (async copy) to stream tiles", "detail": "Overlap K/V loading with Q projection computation", }, "small_batch_optimization": { "problem": "Single-token generation has tiny GEMMs (M=1)", "solution": "Use CUTLASS tiny GEMM kernels or custom kernels", "detail": "Standard cuBLAS is not optimized for M=1; use flashinfer or turbotransformers", }, "continuous_batching": { "problem": "Variable generation lengths waste compute", "solution": "Run sequences at different stages simultaneously", "detail": "Some sequences in prefill, others in decode; schedule on GPU", }, "kv_cache_quantization_on_gpu": { "problem": "Dequantization adds latency", "solution": "Use INT8 Tensor Cores (H100 supports INT8 MMA)", "detail": "Keep K/V in INT8, dequantize during the MMA instruction", }, "cuda_graphs": { "problem": "Kernel launch overhead for small operations", "solution": "Record and replay CUDA graphs", "detail": "For fixed-shape generation, graphs eliminate launch overhead", }, } # ============================================================================= # PRINT GPU MAPPING REPORT # ============================================================================= def print_gpu_report(): """Print comprehensive GPU execution mapping report.""" print("=" * 80) print("GPU EXECUTION MAPPING FOR KV-CACHE SYSTEM") print("=" * 80) # Memory hierarchy print("\n--- GPU Memory Hierarchy ---\n") for level, info in GPU_HIERARCHY.items(): print(f" {level:>15}:") for k, v in info.items(): print(f" {k}: {v}") # Kernel design print("\n\n--- Cached Attention Kernel Design ---\n") kernel = describe_cached_attention_kernel() print(f" Kernel: {kernel['kernel_name']}") print(f" Grid: {kernel['grid']}") print(f" Block: {kernel['block']}") print("\n Shared Memory Usage:") for k, v in kernel["shared_memory_usage"].items(): if k != "total_shared_per_block": print(f" {k}: {v}") print(f" {list(kernel['shared_memory_usage'].keys())[-1]}: " f"{list(kernel['shared_memory_usage'].values())[-1]}") print("\n Optimization Strategies:") for s in kernel["optimization_strategies"]: print(f" {s}") # Tensor core analysis print("\n\n--- Tensor Core Utilization (batch=4, heads=32, seq=4096) ---\n") tc = tensor_core_analysis(batch=4, heads=32, seq_len=4096) for k, v in tc.items(): print(f" {k}: {v}") # Execution pipeline print("\n\n--- Execution Pipeline ---\n") pipeline = describe_execution_pipeline() for i, step in enumerate(pipeline["steps"], 1): print(f" {step}") print(f"\n Bottleneck: {pipeline['bottleneck']}") print("\n Optimization Opportunities:") for opt in pipeline["optimization_opportunities"]: print(f" - {opt}") # FlashAttention print("\n\n--- FlashAttention-Style Cached Kernel ---\n") flash = describe_flash_attention_cached() for k, v in flash.items(): if isinstance(v, list): print(f" {k}:") for item in v: print(f" - {item}") else: print(f" {k}: {v}") # Multi-GPU print("\n\n--- Multi-GPU Strategies ---\n") multi = describe_multi_gpu() for strategy, info in multi.items(): print(f" {strategy}:") for k, v in info.items(): print(f" {k}: {v}") # Tuning guide print("\n\n--- GPU Tuning Guide ---\n") tuning = gpu_tuning_guide() for area, info in tuning.items(): print(f" {area}:") print(f" Problem: {info['problem']}") print(f" Solution: {info['solution']}") print(f" Detail: {info['detail']}") print() print("=" * 80) if __name__ == "__main__": print_gpu_report()