Files
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

414 lines
16 KiB
Python

"""
GPU Execution Mapping for KV-Cache Systems
Documents how the KV-cache system maps to GPU hardware:
- Memory hierarchy (registers, shared mem, global mem, HBM)
- Kernel design for attention with cache
- CUDA optimization strategies
- Tensor Core utilization
"""
import numpy as np
from typing import Dict, List
# =============================================================================
# GPU MEMORY HIERARCHY REFERENCE
# =============================================================================
GPU_HIERARCHY = {
"registers": {
"size_per_sm": "64 KB",
"latency": "1 cycle",
"usage": "Thread-local variables, warp-level computation",
},
"shared_memory": {
"size_per_sm": "166 KB (H100)",
"latency": "1-3 cycles",
"usage": "Tiling, cooperative loading, softmax intermediate",
},
"l2_cache": {
"size": "50 MB (H100)",
"latency": "~20 cycles",
"usage": "Automatic caching of global memory accesses",
},
"hbm": {
"size": "80 GB (H100)",
"bandwidth": "3.35 TB/s (H100)",
"latency": "~300-400 cycles",
"usage": "Model weights, KV cache, activations",
},
}
# =============================================================================
# KERNEL DESIGN: CACHED ATTENTION
# =============================================================================
def describe_cached_attention_kernel():
"""
Describe the CUDA kernel for cached attention.
Kernel: cached_attention<<<grid, block>>>(Q, K_cache, V_cache, Out, ...)
Thread block organization:
- Each block handles one (batch, head) pair
- Threads within a block cooperate on the matmul Q @ K^T
Memory access pattern:
1. Load Q tile into shared memory (small: 1 x head_dim)
2. Stream K_cache tiles from global memory into shared memory
3. Compute partial dot products in registers
4. Accumulate scores in shared memory
5. Softmax in shared memory
6. Stream V_cache tiles and compute output
"""
description = {
"kernel_name": "cached_attention",
"grid": "(batch_size, num_heads, 1)",
"block": "(BLOCK_X, BLOCK_Y) — e.g., (32, 32) for 1024 threads",
"shared_memory_usage": {
"q_tile": "1 x head_dim (e.g., 1 x 128 = 128 floats = 512 bytes fp16)",
"k_tile": "BLOCK_Y x head_dim (e.g., 32 x 128 = 4096 floats = 8 KB fp16)",
"v_tile": "BLOCK_Y x head_dim (same as K)",
"score_tile": "BLOCK_X x BLOCK_Y (e.g., 32 x 32 = 1024 floats = 4 KB fp16)",
"total_shared_per_block": "~16-20 KB (fits in 166 KB SM)",
},
"global_memory_accesses": {
"read_q": "batch * heads * 1 * head_dim (tiny)",
"read_k_cache": "batch * heads * seq_len * head_dim (dominant)",
"read_v_cache": "batch * heads * seq_len * head_dim (dominant)",
"write_output": "batch * heads * 1 * head_dim (tiny)",
},
"optimization_strategies": [
"1. Coalesced global memory access: threads in a warp access consecutive addresses",
"2. Tiled GEMM: process K/V in tiles that fit in shared memory",
"3. Persistent kernels: keep blocks alive until all tiles processed",
"4. Async copy (H100): use cp.async to overlap memory transfer with computation",
"5. Tensor Cores: use WMMA or mma.sync for the matmul operations",
"6. Fusion: fuse softmax with attention score computation",
],
}
return description
# =============================================================================
# TENSOR CORE UTILIZATION
# =============================================================================
def tensor_core_analysis(head_dim: int = 128, seq_len: int = 4096,
batch: int = 4, heads: int = 32) -> Dict:
"""
Analyze Tensor Core utilization for cached attention.
H100 Tensor Core specs (FP16):
- MMA shape: M x N x K where M,N,K are multiples of 16
- Peak throughput: ~1,970 TFLOPS (FP16 Tensor Core)
- Each MMA instruction: 16x16x16 = 4096 FLOPs
"""
# Q @ K^T: (batch, heads, 1, head_dim) @ (batch, heads, head_dim, seq_len)
# FLOPs per (batch, head): 2 * 1 * head_dim * seq_len
flops_qk = 2 * batch * heads * 1 * head_dim * seq_len
# Attn @ V: (batch, heads, 1, seq_len) @ (batch, heads, seq_len, head_dim)
flops_av = 2 * batch * heads * 1 * seq_len * head_dim
total_flops = flops_qk + flops_av
# Memory traffic
elem_bytes = 2 # fp16
mem_q = batch * heads * 1 * head_dim * elem_bytes
mem_k = batch * heads * seq_len * head_dim * elem_bytes
mem_v = batch * heads * seq_len * head_dim * elem_bytes
mem_out = batch * heads * 1 * head_dim * elem_bytes
total_mem = mem_q + mem_k + mem_v + mem_out
# Arithmetic intensity (FLOPs per byte)
intensity = total_flops / total_mem
# H100 peak
h100_peak_tflops = 1970 # FP16 Tensor Core
h100_bandwidth = 3.35e12 # bytes/s
# Theoretical time bounds
compute_bound_s = total_flops / (h100_peak_tflops * 1e12)
memory_bound_s = total_mem / h100_bandwidth
return {
"flops_qk": f"{flops_qk / 1e9:.2f} GFLOPs",
"flops_av": f"{flops_av / 1e9:.2f} GFLOPs",
"total_flops": f"{total_flops / 1e9:.2f} GFLOPs",
"memory_traffic_mb": f"{total_mem / 1e6:.2f} MB",
"arithmetic_intensity": f"{intensity:.2f} FLOPs/byte",
"compute_bound_ms": f"{compute_bound_s * 1000:.4f} ms",
"memory_bound_ms": f"{memory_bound_s * 1000:.4f} ms",
"bound": "compute-bound" if compute_bound_s > memory_bound_s else "memory-bound",
"h100_peak_tflops": h100_peak_tflops,
"h100_bandwidth_tbps": h100_bandwidth / 1e12,
}
# =============================================================================
# GPU EXECUTION PIPELINE
# =============================================================================
def describe_execution_pipeline():
"""
Describe the full GPU execution pipeline for one generation step.
Step 1: Embedding lookup
- Input: token_id (batch, 1)
- Operation: embedding[token_id] -> (batch, 1, dim)
- GPU: Gathers from embedding table (random access, use shared mem tiling)
Step 2: Positional encoding
- Operation: x += pos_encoding[current_pos]
- GPU: Simple element-wise add (fully parallel)
Step 3: Per-layer forward pass (repeated L times)
3a. LayerNorm
- GPU: Parallel reduction for mean/var, then element-wise
3b. QKV projection
- GPU: 3 parallel GEMMs: x @ Wq, x @ Wk, x @ Wv
- cuBLAS/cutlass: highly optimized for small M (M=1)
3c. KV cache update
- GPU: Simple copy to global memory (coalesced write)
- cache_k[:, :, write_pos, :] = k[:, :, 0, :]
3d. Cached attention
- GPU: Custom kernel (see describe_cached_attention_kernel)
- Two GEMMs + softmax, tiled for shared memory
3e. Output projection
- GPU: GEMM: attn_out @ Wo
3f. MLP
- GPU: Two GEMMs with activation fusion
3g. Residual add + LayerNorm
- GPU: Element-wise operations
Step 4: LM head
- GPU: GEMM: x @ W_lm -> logits (batch, vocab_size)
Step 5: Sampling
- GPU: Argmax or top-k sampling kernel
- Can be done on CPU for small batch sizes
"""
return {
"steps": [
"1. Embedding lookup (gather)",
"2. Positional encoding (element-wise add)",
"3. Per-layer: LayerNorm -> QKV proj -> cache update -> attention -> MLP",
"4. LM head (GEMM)",
"5. Sampling (argmax/top-k)",
],
"bottleneck": "Cached attention (memory-bound for long sequences)",
"optimization_opportunities": [
"Operator fusion: merge LayerNorm + GEMM bias + activation",
"Batched GEMM: process all layers' small GEMMs together",
"Pipeline parallelism: overlap layers' computation",
"FlashAttention-style tiling for the cached attention kernel",
"Warp-specialized design: some warps load, some compute",
],
}
# =============================================================================
# FLASH-ATTENTION-STYLE CACHED KERNEL
# =============================================================================
def describe_flash_attention_cached():
"""
FlashAttention-style kernel adapted for cached attention.
Key insight: instead of materializing the full (1 x seq_len) attention
matrix, process K/V in tiles and accumulate softmax online.
Algorithm (for one batch/head):
1. Initialize: output = 0, m = -inf, l = 0 (online softmax state)
2. For each K/V tile (size BLOCK):
a. Compute S = Q @ K_tile^T (in shared memory)
b. m_new = max(m, max(S))
c. l = l * exp(m - m_new) + sum(exp(S - m_new))
d. output = output * (l_old / l) + sum(exp(S - m_new) * V_tile)
e. m = m_new
3. output = output / l
This avoids materializing the full attention matrix and reduces
HBM traffic from O(seq_len * head_dim) to O(seq_len * head_dim / BLOCK).
"""
return {
"name": "FlashAttention-style cached kernel",
"key_benefit": "O(1) shared memory usage regardless of sequence length",
"hbm_traffic_reduction": "Reduces from 4 reads to ~2 reads of K/V cache",
"shared_memory": "Only needs BLOCK x head_dim tiles, not full seq_len",
"complexity": "More complex kernel but 2-4x faster for long sequences",
"implementation_notes": [
"Requires careful numerical stability (online softmax)",
"Two-pass: forward pass accumulates, backward pass needs recompute",
"For generation (single query), simpler than full FlashAttention",
"Can use mma.sync for the tile GEMMs on H100",
],
}
# =============================================================================
# MULTI-GPU STRATEGIES
# =============================================================================
def describe_multi_gpu():
"""
Multi-GPU strategies for large models with KV cache.
"""
return {
"tensor_parallelism": {
"description": "Split model weights across GPUs (Megatron-LM style)",
"kv_cache_impact": "Each GPU holds its shard of K/V (split by head_dim)",
"communication": "AllReduce in MLP, all-to-all in attention",
"scaling": "Linear with num GPUs (up to num_heads)",
},
"pipeline_parallelism": {
"description": "Split layers across GPUs",
"kv_cache_impact": "Each GPU holds K/V for its layer shard",
"communication": "Send activations between stages",
"challenge": "Bubble idle time; needs micro-batching",
},
"sequence_parallelism": {
"description": "Split sequence across GPUs (for prefill)",
"kv_cache_impact": "Each GPU holds K/V for its sequence shard",
"communication": "All-to-all for attention across sequence shards",
"best_for": "Very long context prefill",
},
"expert_parallelism": {
"description": "For MoE models (Mixtral, Grok)",
"kv_cache_impact": "KV cache is shared; only MLP experts are sharded",
"communication": "All-to-all for expert routing",
},
}
# =============================================================================
# PRACTICAL GPU TUNING GUIDE
# =============================================================================
def gpu_tuning_guide():
"""
Practical GPU tuning recommendations for KV-cache inference.
"""
return {
"streaming_KV_cache": {
"problem": "For long sequences, K/V cache reads dominate latency",
"solution": "Use H100's copy engine (async copy) to stream tiles",
"detail": "Overlap K/V loading with Q projection computation",
},
"small_batch_optimization": {
"problem": "Single-token generation has tiny GEMMs (M=1)",
"solution": "Use CUTLASS tiny GEMM kernels or custom kernels",
"detail": "Standard cuBLAS is not optimized for M=1; use flashinfer or turbotransformers",
},
"continuous_batching": {
"problem": "Variable generation lengths waste compute",
"solution": "Run sequences at different stages simultaneously",
"detail": "Some sequences in prefill, others in decode; schedule on GPU",
},
"kv_cache_quantization_on_gpu": {
"problem": "Dequantization adds latency",
"solution": "Use INT8 Tensor Cores (H100 supports INT8 MMA)",
"detail": "Keep K/V in INT8, dequantize during the MMA instruction",
},
"cuda_graphs": {
"problem": "Kernel launch overhead for small operations",
"solution": "Record and replay CUDA graphs",
"detail": "For fixed-shape generation, graphs eliminate launch overhead",
},
}
# =============================================================================
# PRINT GPU MAPPING REPORT
# =============================================================================
def print_gpu_report():
"""Print comprehensive GPU execution mapping report."""
print("=" * 80)
print("GPU EXECUTION MAPPING FOR KV-CACHE SYSTEM")
print("=" * 80)
# Memory hierarchy
print("\n--- GPU Memory Hierarchy ---\n")
for level, info in GPU_HIERARCHY.items():
print(f" {level:>15}:")
for k, v in info.items():
print(f" {k}: {v}")
# Kernel design
print("\n\n--- Cached Attention Kernel Design ---\n")
kernel = describe_cached_attention_kernel()
print(f" Kernel: {kernel['kernel_name']}")
print(f" Grid: {kernel['grid']}")
print(f" Block: {kernel['block']}")
print("\n Shared Memory Usage:")
for k, v in kernel["shared_memory_usage"].items():
if k != "total_shared_per_block":
print(f" {k}: {v}")
print(f" {list(kernel['shared_memory_usage'].keys())[-1]}: "
f"{list(kernel['shared_memory_usage'].values())[-1]}")
print("\n Optimization Strategies:")
for s in kernel["optimization_strategies"]:
print(f" {s}")
# Tensor core analysis
print("\n\n--- Tensor Core Utilization (batch=4, heads=32, seq=4096) ---\n")
tc = tensor_core_analysis(batch=4, heads=32, seq_len=4096)
for k, v in tc.items():
print(f" {k}: {v}")
# Execution pipeline
print("\n\n--- Execution Pipeline ---\n")
pipeline = describe_execution_pipeline()
for i, step in enumerate(pipeline["steps"], 1):
print(f" {step}")
print(f"\n Bottleneck: {pipeline['bottleneck']}")
print("\n Optimization Opportunities:")
for opt in pipeline["optimization_opportunities"]:
print(f" - {opt}")
# FlashAttention
print("\n\n--- FlashAttention-Style Cached Kernel ---\n")
flash = describe_flash_attention_cached()
for k, v in flash.items():
if isinstance(v, list):
print(f" {k}:")
for item in v:
print(f" - {item}")
else:
print(f" {k}: {v}")
# Multi-GPU
print("\n\n--- Multi-GPU Strategies ---\n")
multi = describe_multi_gpu()
for strategy, info in multi.items():
print(f" {strategy}:")
for k, v in info.items():
print(f" {k}: {v}")
# Tuning guide
print("\n\n--- GPU Tuning Guide ---\n")
tuning = gpu_tuning_guide()
for area, info in tuning.items():
print(f" {area}:")
print(f" Problem: {info['problem']}")
print(f" Solution: {info['solution']}")
print(f" Detail: {info['detail']}")
print()
print("=" * 80)
if __name__ == "__main__":
print_gpu_report()