8e72eef09c
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
414 lines
16 KiB
Python
414 lines
16 KiB
Python
"""
|
|
GPU Execution Mapping for KV-Cache Systems
|
|
|
|
Documents how the KV-cache system maps to GPU hardware:
|
|
- Memory hierarchy (registers, shared mem, global mem, HBM)
|
|
- Kernel design for attention with cache
|
|
- CUDA optimization strategies
|
|
- Tensor Core utilization
|
|
"""
|
|
|
|
import numpy as np
|
|
from typing import Dict, List
|
|
|
|
|
|
# =============================================================================
|
|
# GPU MEMORY HIERARCHY REFERENCE
|
|
# =============================================================================
|
|
|
|
GPU_HIERARCHY = {
|
|
"registers": {
|
|
"size_per_sm": "64 KB",
|
|
"latency": "1 cycle",
|
|
"usage": "Thread-local variables, warp-level computation",
|
|
},
|
|
"shared_memory": {
|
|
"size_per_sm": "166 KB (H100)",
|
|
"latency": "1-3 cycles",
|
|
"usage": "Tiling, cooperative loading, softmax intermediate",
|
|
},
|
|
"l2_cache": {
|
|
"size": "50 MB (H100)",
|
|
"latency": "~20 cycles",
|
|
"usage": "Automatic caching of global memory accesses",
|
|
},
|
|
"hbm": {
|
|
"size": "80 GB (H100)",
|
|
"bandwidth": "3.35 TB/s (H100)",
|
|
"latency": "~300-400 cycles",
|
|
"usage": "Model weights, KV cache, activations",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# KERNEL DESIGN: CACHED ATTENTION
|
|
# =============================================================================
|
|
|
|
def describe_cached_attention_kernel():
|
|
"""
|
|
Describe the CUDA kernel for cached attention.
|
|
|
|
Kernel: cached_attention<<<grid, block>>>(Q, K_cache, V_cache, Out, ...)
|
|
|
|
Thread block organization:
|
|
- Each block handles one (batch, head) pair
|
|
- Threads within a block cooperate on the matmul Q @ K^T
|
|
|
|
Memory access pattern:
|
|
1. Load Q tile into shared memory (small: 1 x head_dim)
|
|
2. Stream K_cache tiles from global memory into shared memory
|
|
3. Compute partial dot products in registers
|
|
4. Accumulate scores in shared memory
|
|
5. Softmax in shared memory
|
|
6. Stream V_cache tiles and compute output
|
|
"""
|
|
description = {
|
|
"kernel_name": "cached_attention",
|
|
"grid": "(batch_size, num_heads, 1)",
|
|
"block": "(BLOCK_X, BLOCK_Y) — e.g., (32, 32) for 1024 threads",
|
|
"shared_memory_usage": {
|
|
"q_tile": "1 x head_dim (e.g., 1 x 128 = 128 floats = 512 bytes fp16)",
|
|
"k_tile": "BLOCK_Y x head_dim (e.g., 32 x 128 = 4096 floats = 8 KB fp16)",
|
|
"v_tile": "BLOCK_Y x head_dim (same as K)",
|
|
"score_tile": "BLOCK_X x BLOCK_Y (e.g., 32 x 32 = 1024 floats = 4 KB fp16)",
|
|
"total_shared_per_block": "~16-20 KB (fits in 166 KB SM)",
|
|
},
|
|
"global_memory_accesses": {
|
|
"read_q": "batch * heads * 1 * head_dim (tiny)",
|
|
"read_k_cache": "batch * heads * seq_len * head_dim (dominant)",
|
|
"read_v_cache": "batch * heads * seq_len * head_dim (dominant)",
|
|
"write_output": "batch * heads * 1 * head_dim (tiny)",
|
|
},
|
|
"optimization_strategies": [
|
|
"1. Coalesced global memory access: threads in a warp access consecutive addresses",
|
|
"2. Tiled GEMM: process K/V in tiles that fit in shared memory",
|
|
"3. Persistent kernels: keep blocks alive until all tiles processed",
|
|
"4. Async copy (H100): use cp.async to overlap memory transfer with computation",
|
|
"5. Tensor Cores: use WMMA or mma.sync for the matmul operations",
|
|
"6. Fusion: fuse softmax with attention score computation",
|
|
],
|
|
}
|
|
return description
|
|
|
|
|
|
# =============================================================================
|
|
# TENSOR CORE UTILIZATION
|
|
# =============================================================================
|
|
|
|
def tensor_core_analysis(head_dim: int = 128, seq_len: int = 4096,
|
|
batch: int = 4, heads: int = 32) -> Dict:
|
|
"""
|
|
Analyze Tensor Core utilization for cached attention.
|
|
|
|
H100 Tensor Core specs (FP16):
|
|
- MMA shape: M x N x K where M,N,K are multiples of 16
|
|
- Peak throughput: ~1,970 TFLOPS (FP16 Tensor Core)
|
|
- Each MMA instruction: 16x16x16 = 4096 FLOPs
|
|
"""
|
|
# Q @ K^T: (batch, heads, 1, head_dim) @ (batch, heads, head_dim, seq_len)
|
|
# FLOPs per (batch, head): 2 * 1 * head_dim * seq_len
|
|
flops_qk = 2 * batch * heads * 1 * head_dim * seq_len
|
|
|
|
# Attn @ V: (batch, heads, 1, seq_len) @ (batch, heads, seq_len, head_dim)
|
|
flops_av = 2 * batch * heads * 1 * seq_len * head_dim
|
|
|
|
total_flops = flops_qk + flops_av
|
|
|
|
# Memory traffic
|
|
elem_bytes = 2 # fp16
|
|
mem_q = batch * heads * 1 * head_dim * elem_bytes
|
|
mem_k = batch * heads * seq_len * head_dim * elem_bytes
|
|
mem_v = batch * heads * seq_len * head_dim * elem_bytes
|
|
mem_out = batch * heads * 1 * head_dim * elem_bytes
|
|
total_mem = mem_q + mem_k + mem_v + mem_out
|
|
|
|
# Arithmetic intensity (FLOPs per byte)
|
|
intensity = total_flops / total_mem
|
|
|
|
# H100 peak
|
|
h100_peak_tflops = 1970 # FP16 Tensor Core
|
|
h100_bandwidth = 3.35e12 # bytes/s
|
|
|
|
# Theoretical time bounds
|
|
compute_bound_s = total_flops / (h100_peak_tflops * 1e12)
|
|
memory_bound_s = total_mem / h100_bandwidth
|
|
|
|
return {
|
|
"flops_qk": f"{flops_qk / 1e9:.2f} GFLOPs",
|
|
"flops_av": f"{flops_av / 1e9:.2f} GFLOPs",
|
|
"total_flops": f"{total_flops / 1e9:.2f} GFLOPs",
|
|
"memory_traffic_mb": f"{total_mem / 1e6:.2f} MB",
|
|
"arithmetic_intensity": f"{intensity:.2f} FLOPs/byte",
|
|
"compute_bound_ms": f"{compute_bound_s * 1000:.4f} ms",
|
|
"memory_bound_ms": f"{memory_bound_s * 1000:.4f} ms",
|
|
"bound": "compute-bound" if compute_bound_s > memory_bound_s else "memory-bound",
|
|
"h100_peak_tflops": h100_peak_tflops,
|
|
"h100_bandwidth_tbps": h100_bandwidth / 1e12,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# GPU EXECUTION PIPELINE
|
|
# =============================================================================
|
|
|
|
def describe_execution_pipeline():
|
|
"""
|
|
Describe the full GPU execution pipeline for one generation step.
|
|
|
|
Step 1: Embedding lookup
|
|
- Input: token_id (batch, 1)
|
|
- Operation: embedding[token_id] -> (batch, 1, dim)
|
|
- GPU: Gathers from embedding table (random access, use shared mem tiling)
|
|
|
|
Step 2: Positional encoding
|
|
- Operation: x += pos_encoding[current_pos]
|
|
- GPU: Simple element-wise add (fully parallel)
|
|
|
|
Step 3: Per-layer forward pass (repeated L times)
|
|
3a. LayerNorm
|
|
- GPU: Parallel reduction for mean/var, then element-wise
|
|
|
|
3b. QKV projection
|
|
- GPU: 3 parallel GEMMs: x @ Wq, x @ Wk, x @ Wv
|
|
- cuBLAS/cutlass: highly optimized for small M (M=1)
|
|
|
|
3c. KV cache update
|
|
- GPU: Simple copy to global memory (coalesced write)
|
|
- cache_k[:, :, write_pos, :] = k[:, :, 0, :]
|
|
|
|
3d. Cached attention
|
|
- GPU: Custom kernel (see describe_cached_attention_kernel)
|
|
- Two GEMMs + softmax, tiled for shared memory
|
|
|
|
3e. Output projection
|
|
- GPU: GEMM: attn_out @ Wo
|
|
|
|
3f. MLP
|
|
- GPU: Two GEMMs with activation fusion
|
|
|
|
3g. Residual add + LayerNorm
|
|
- GPU: Element-wise operations
|
|
|
|
Step 4: LM head
|
|
- GPU: GEMM: x @ W_lm -> logits (batch, vocab_size)
|
|
|
|
Step 5: Sampling
|
|
- GPU: Argmax or top-k sampling kernel
|
|
- Can be done on CPU for small batch sizes
|
|
"""
|
|
return {
|
|
"steps": [
|
|
"1. Embedding lookup (gather)",
|
|
"2. Positional encoding (element-wise add)",
|
|
"3. Per-layer: LayerNorm -> QKV proj -> cache update -> attention -> MLP",
|
|
"4. LM head (GEMM)",
|
|
"5. Sampling (argmax/top-k)",
|
|
],
|
|
"bottleneck": "Cached attention (memory-bound for long sequences)",
|
|
"optimization_opportunities": [
|
|
"Operator fusion: merge LayerNorm + GEMM bias + activation",
|
|
"Batched GEMM: process all layers' small GEMMs together",
|
|
"Pipeline parallelism: overlap layers' computation",
|
|
"FlashAttention-style tiling for the cached attention kernel",
|
|
"Warp-specialized design: some warps load, some compute",
|
|
],
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# FLASH-ATTENTION-STYLE CACHED KERNEL
|
|
# =============================================================================
|
|
|
|
def describe_flash_attention_cached():
|
|
"""
|
|
FlashAttention-style kernel adapted for cached attention.
|
|
|
|
Key insight: instead of materializing the full (1 x seq_len) attention
|
|
matrix, process K/V in tiles and accumulate softmax online.
|
|
|
|
Algorithm (for one batch/head):
|
|
1. Initialize: output = 0, m = -inf, l = 0 (online softmax state)
|
|
2. For each K/V tile (size BLOCK):
|
|
a. Compute S = Q @ K_tile^T (in shared memory)
|
|
b. m_new = max(m, max(S))
|
|
c. l = l * exp(m - m_new) + sum(exp(S - m_new))
|
|
d. output = output * (l_old / l) + sum(exp(S - m_new) * V_tile)
|
|
e. m = m_new
|
|
3. output = output / l
|
|
|
|
This avoids materializing the full attention matrix and reduces
|
|
HBM traffic from O(seq_len * head_dim) to O(seq_len * head_dim / BLOCK).
|
|
"""
|
|
return {
|
|
"name": "FlashAttention-style cached kernel",
|
|
"key_benefit": "O(1) shared memory usage regardless of sequence length",
|
|
"hbm_traffic_reduction": "Reduces from 4 reads to ~2 reads of K/V cache",
|
|
"shared_memory": "Only needs BLOCK x head_dim tiles, not full seq_len",
|
|
"complexity": "More complex kernel but 2-4x faster for long sequences",
|
|
"implementation_notes": [
|
|
"Requires careful numerical stability (online softmax)",
|
|
"Two-pass: forward pass accumulates, backward pass needs recompute",
|
|
"For generation (single query), simpler than full FlashAttention",
|
|
"Can use mma.sync for the tile GEMMs on H100",
|
|
],
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# MULTI-GPU STRATEGIES
|
|
# =============================================================================
|
|
|
|
def describe_multi_gpu():
|
|
"""
|
|
Multi-GPU strategies for large models with KV cache.
|
|
"""
|
|
return {
|
|
"tensor_parallelism": {
|
|
"description": "Split model weights across GPUs (Megatron-LM style)",
|
|
"kv_cache_impact": "Each GPU holds its shard of K/V (split by head_dim)",
|
|
"communication": "AllReduce in MLP, all-to-all in attention",
|
|
"scaling": "Linear with num GPUs (up to num_heads)",
|
|
},
|
|
"pipeline_parallelism": {
|
|
"description": "Split layers across GPUs",
|
|
"kv_cache_impact": "Each GPU holds K/V for its layer shard",
|
|
"communication": "Send activations between stages",
|
|
"challenge": "Bubble idle time; needs micro-batching",
|
|
},
|
|
"sequence_parallelism": {
|
|
"description": "Split sequence across GPUs (for prefill)",
|
|
"kv_cache_impact": "Each GPU holds K/V for its sequence shard",
|
|
"communication": "All-to-all for attention across sequence shards",
|
|
"best_for": "Very long context prefill",
|
|
},
|
|
"expert_parallelism": {
|
|
"description": "For MoE models (Mixtral, Grok)",
|
|
"kv_cache_impact": "KV cache is shared; only MLP experts are sharded",
|
|
"communication": "All-to-all for expert routing",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# PRACTICAL GPU TUNING GUIDE
|
|
# =============================================================================
|
|
|
|
def gpu_tuning_guide():
|
|
"""
|
|
Practical GPU tuning recommendations for KV-cache inference.
|
|
"""
|
|
return {
|
|
"streaming_KV_cache": {
|
|
"problem": "For long sequences, K/V cache reads dominate latency",
|
|
"solution": "Use H100's copy engine (async copy) to stream tiles",
|
|
"detail": "Overlap K/V loading with Q projection computation",
|
|
},
|
|
"small_batch_optimization": {
|
|
"problem": "Single-token generation has tiny GEMMs (M=1)",
|
|
"solution": "Use CUTLASS tiny GEMM kernels or custom kernels",
|
|
"detail": "Standard cuBLAS is not optimized for M=1; use flashinfer or turbotransformers",
|
|
},
|
|
"continuous_batching": {
|
|
"problem": "Variable generation lengths waste compute",
|
|
"solution": "Run sequences at different stages simultaneously",
|
|
"detail": "Some sequences in prefill, others in decode; schedule on GPU",
|
|
},
|
|
"kv_cache_quantization_on_gpu": {
|
|
"problem": "Dequantization adds latency",
|
|
"solution": "Use INT8 Tensor Cores (H100 supports INT8 MMA)",
|
|
"detail": "Keep K/V in INT8, dequantize during the MMA instruction",
|
|
},
|
|
"cuda_graphs": {
|
|
"problem": "Kernel launch overhead for small operations",
|
|
"solution": "Record and replay CUDA graphs",
|
|
"detail": "For fixed-shape generation, graphs eliminate launch overhead",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# PRINT GPU MAPPING REPORT
|
|
# =============================================================================
|
|
|
|
def print_gpu_report():
|
|
"""Print comprehensive GPU execution mapping report."""
|
|
print("=" * 80)
|
|
print("GPU EXECUTION MAPPING FOR KV-CACHE SYSTEM")
|
|
print("=" * 80)
|
|
|
|
# Memory hierarchy
|
|
print("\n--- GPU Memory Hierarchy ---\n")
|
|
for level, info in GPU_HIERARCHY.items():
|
|
print(f" {level:>15}:")
|
|
for k, v in info.items():
|
|
print(f" {k}: {v}")
|
|
|
|
# Kernel design
|
|
print("\n\n--- Cached Attention Kernel Design ---\n")
|
|
kernel = describe_cached_attention_kernel()
|
|
print(f" Kernel: {kernel['kernel_name']}")
|
|
print(f" Grid: {kernel['grid']}")
|
|
print(f" Block: {kernel['block']}")
|
|
print("\n Shared Memory Usage:")
|
|
for k, v in kernel["shared_memory_usage"].items():
|
|
if k != "total_shared_per_block":
|
|
print(f" {k}: {v}")
|
|
print(f" {list(kernel['shared_memory_usage'].keys())[-1]}: "
|
|
f"{list(kernel['shared_memory_usage'].values())[-1]}")
|
|
|
|
print("\n Optimization Strategies:")
|
|
for s in kernel["optimization_strategies"]:
|
|
print(f" {s}")
|
|
|
|
# Tensor core analysis
|
|
print("\n\n--- Tensor Core Utilization (batch=4, heads=32, seq=4096) ---\n")
|
|
tc = tensor_core_analysis(batch=4, heads=32, seq_len=4096)
|
|
for k, v in tc.items():
|
|
print(f" {k}: {v}")
|
|
|
|
# Execution pipeline
|
|
print("\n\n--- Execution Pipeline ---\n")
|
|
pipeline = describe_execution_pipeline()
|
|
for i, step in enumerate(pipeline["steps"], 1):
|
|
print(f" {step}")
|
|
print(f"\n Bottleneck: {pipeline['bottleneck']}")
|
|
print("\n Optimization Opportunities:")
|
|
for opt in pipeline["optimization_opportunities"]:
|
|
print(f" - {opt}")
|
|
|
|
# FlashAttention
|
|
print("\n\n--- FlashAttention-Style Cached Kernel ---\n")
|
|
flash = describe_flash_attention_cached()
|
|
for k, v in flash.items():
|
|
if isinstance(v, list):
|
|
print(f" {k}:")
|
|
for item in v:
|
|
print(f" - {item}")
|
|
else:
|
|
print(f" {k}: {v}")
|
|
|
|
# Multi-GPU
|
|
print("\n\n--- Multi-GPU Strategies ---\n")
|
|
multi = describe_multi_gpu()
|
|
for strategy, info in multi.items():
|
|
print(f" {strategy}:")
|
|
for k, v in info.items():
|
|
print(f" {k}: {v}")
|
|
|
|
# Tuning guide
|
|
print("\n\n--- GPU Tuning Guide ---\n")
|
|
tuning = gpu_tuning_guide()
|
|
for area, info in tuning.items():
|
|
print(f" {area}:")
|
|
print(f" Problem: {info['problem']}")
|
|
print(f" Solution: {info['solution']}")
|
|
print(f" Detail: {info['detail']}")
|
|
print()
|
|
|
|
print("=" * 80)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print_gpu_report()
|