Files
llm_programming_tests/minimax-m2.7/kv/kv_cache.py
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

1720 lines
82 KiB
Python

"""
Efficient KV-Cache System for Autoregressive Transformer Inference
This module implements a complete KV-cache system from scratch, including:
- Data structures for cached keys and values
- Multi-head attention with cached KV pairs
- Batching support with variable sequence lengths
- Memory analysis and optimization strategies
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Dict
from enum import Enum
import math
# =============================================================================
# Part 1: Core Data Structures and Memory Layout
# =============================================================================
class MemoryFormat(Enum):
"""
Memory layout options for KV-cache storage.
Different layouts have different trade-offs for memory contiguity,
access patterns, and cache efficiency.
"""
# [batch, heads, seq_len, dim] - Standard format
# Good for prefill, less efficient for decode attention
BHSD = "batch_heads_seq_dim"
# [batch, seq_len, heads, dim] - PyTorch default for attention
BSHD = "batch_seq_heads_dim"
# [batch, seq_len, heads, dim] with physical memory pages
# Optimized for paged attention
PAGED = "paged"
# [heads, batch, seq_len, dim] - Good for sequence parallelism
HBSD = "heads_batch_seq_dim"
@dataclass
class CacheConfig:
"""Configuration for KV-cache behavior."""
max_batch_size: int = 32
max_seq_len: int = 4096
num_heads: int = 32
head_dim: int = 128
num_layers: int = 32
dtype_bytes: int = 2 # FP16
# Memory layout
memory_format: MemoryFormat = MemoryFormat.BSHD
# Block size for paged attention (must be power of 2)
block_size: int = 16
# Pre-allocation strategy
pre_allocate: bool = True
def memory_per_layer(self) -> int:
"""Bytes needed per layer for full sequence."""
return 2 * self.num_heads * self.max_seq_len * self.head_dim * self.dtype_bytes
def total_memory(self) -> int:
"""Total bytes for all layers."""
return 2 * self.num_layers * self.memory_per_layer()
class KVCacheBlock:
"""
A single block in the KV-cache.
Memory Layout (BSHD format):
┌─────────────────────────────────────────────────┐
│ k_cache: [max_seq_len, num_heads, head_dim] │
│ v_cache: [max_seq_len, num_heads, head_dim] │
└─────────────────────────────────────────────────┘
Each block stores keys and values for a contiguous sequence segment.
"""
def __init__(self, config: CacheConfig, layer_idx: int):
self.config = config
self.layer_idx = layer_idx
self.block_size = config.block_size
# Physical memory allocation
self.k_data = np.zeros(
(self.block_size, config.num_heads, config.head_dim),
dtype=np.float16
)
self.v_data = np.zeros(
(self.block_size, config.num_heads, config.head_dim),
dtype=np.float16
)
# Logical state
self.seq_offset: int = 0 # Starting position in full sequence
self.seq_length: int = 0 # Current length (0 to block_size)
self.is_full: bool = False
# Access tracking for LRU/replacement
self.last_access_step: int = 0
self.access_count: int = 0
@property
def physical_size(self) -> Tuple[int, int, int]:
"""Return the shape of stored tensors."""
return self.k_data.shape
def write(self, k_chunk: np.ndarray, v_chunk: np.ndarray, position: int) -> None:
"""
Write a chunk of keys and values to this block.
Args:
k_chunk: [chunk_len, num_heads, head_dim]
v_chunk: [chunk_len, num_heads, head_dim]
position: Position in block to write to (0-indexed)
"""
chunk_len = k_chunk.shape[0]
assert chunk_len <= self.block_size - position, "Chunk exceeds block capacity"
self.k_data[position:position + chunk_len] = k_chunk
self.v_data[position:position + chunk_len] = v_chunk
self.seq_length = max(self.seq_length, position + chunk_len)
def read(self, start: int, length: int) -> Tuple[np.ndarray, np.ndarray]:
"""Read a chunk from this block."""
return (
self.k_data[start:start + length].copy(),
self.v_data[start:start + length].copy()
)
def clear(self) -> None:
"""Reset the block."""
self.k_data.fill(0)
self.v_data.fill(0)
self.seq_length = 0
self.is_full = False
class PagedKVCache:
"""
Paged KV-Cache with block-wise management.
Memory Layout:
┌────────────────────────────────────────────────────────────┐
│ Virtual Sequence │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ Block 0 │ │ Block 1 │ │ Block 2 │ │ Block 3 │ ... │
│ │ (16 tok) │ │ (16 tok) │ │ (16 tok) │ │ (16 tok) │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
└────────────────────────────────────────────────────────────┘
│ Physical Memory │
│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ GPU │ │ GPU │ │ GPU │ │ CPU │ ... │
│ │ Block │ │ Block │ │ Block │ │ Block │ │
│ └────────┘ └────────┘ └────────┘ └────────┘ │
└────────────────────────────────────────────────────────────┘
Benefits:
- No wasted memory from max_seq_len pre-allocation
- Enables KV-cache offloading to CPU/disk
- Supports heterogeneous sequence lengths in batch
"""
def __init__(self, config: CacheConfig):
self.config = config
self.blocks: List[KVCacheBlock] = []
self.block_allocator: Dict[int, List[int]] = {} # batch_idx -> list of block indices
# Pre-allocate some blocks
if config.pre_allocate:
self._preallocate_blocks(num_blocks=config.max_batch_size * 2)
def _preallocate_blocks(self, num_blocks: int) -> None:
"""Pre-allocate blocks for efficiency."""
for _ in range(num_blocks):
# We don't know layer_idx yet; use -1 as placeholder
block = KVCacheBlock(self.config, layer_idx=-1)
self.blocks.append(block)
def allocate_block(self, layer_idx: int) -> int:
"""Allocate a new block, return its index."""
# Find an unused block or create new one
for idx, block in enumerate(self.blocks):
if block.seq_length == 0 and block.layer_idx == -1:
block.layer_idx = layer_idx
return idx
# No free block, create new one
block = KVCacheBlock(self.config, layer_idx=layer_idx)
self.blocks.append(block)
return len(self.blocks) - 1
def append(self, layer_idx: int, batch_idx: int,
k_new: np.ndarray, v_new: np.ndarray) -> None:
"""
Append new key/value to the cache for a specific batch element.
Args:
layer_idx: Which layer this cache belongs to
batch_idx: Which element in the batch
k_new: [1, num_heads, head_dim] - new key for single token
v_new: [1, num_heads, head_dim] - new value for single token
"""
if batch_idx not in self.block_allocator:
self.block_allocator[batch_idx] = []
block_indices = self.block_allocator[batch_idx]
# Find the last block or allocate new one
if block_indices:
last_block_idx = block_indices[-1]
last_block = self.blocks[last_block_idx]
if not last_block.is_full:
# Write to existing block
pos = last_block.seq_length
last_block.write(k_new, v_new, pos)
return
# Need new block
new_block_idx = self.allocate_block(layer_idx)
self.blocks[new_block_idx].write(k_new, v_new, position=0)
block_indices.append(new_block_idx)
def get_all_cached(self, batch_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Retrieve all cached keys and values for a batch element.
Returns:
k_all: [seq_len, num_heads, head_dim]
v_all: [seq_len, num_heads, head_dim]
"""
if batch_idx not in self.block_allocator:
return np.zeros((0, self.config.num_heads, self.config.head_dim)), \
np.zeros((0, self.config.num_heads, self.config.head_dim))
# Calculate total sequence length
total_len = 0
for block_idx in self.block_allocator[batch_idx]:
total_len += self.blocks[block_idx].seq_length
# Gather all blocks
k_concat = []
v_concat = []
for block_idx in self.block_allocator[batch_idx]:
block = self.blocks[block_idx]
k_concat.append(block.k_data[:block.seq_length])
v_concat.append(block.v_data[:block.seq_length])
if k_concat:
return np.concatenate(k_concat, axis=0), np.concatenate(v_concat, axis=0)
return np.zeros((0, self.config.num_heads, self.config.head_dim)), \
np.zeros((0, self.config.num_heads, self.config.head_dim))
class FlatKVCache:
"""
Flat pre-allocated KV-cache for maximum performance.
Memory Layout:
┌──────────────────────────────────────────────────────────────────┐
│ Shape: [num_layers, max_batch, max_seq_len, 2, num_heads, head_dim] │
│ │
│ ┌─ Layer 0 ─────────────────────────────────────────────────┐ │
│ │ ┌─ Batch 0 ──────────────────────────────────────────┐ │ │
│ │ │ K: [max_seq_len, num_heads, head_dim] │ │ │
│ │ │ V: [max_seq_len, num_heads, head_dim] │ │ │
│ │ └──────────────────────────────────────────────────────┘ │ │
│ │ ┌─ Batch 1 ──────────────────────────────────────────┐ │ │
│ │ │ ... │ │ │
│ │ └──────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ ┌─ Layer 1 ─────────────────────────────────────────────────┐ │
│ │ ... │ │
│ └─────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────────┘
Use this when:
- Sequence lengths are known or bounded
- Memory is sufficient
- Maximum throughput is required
"""
def __init__(self, config: CacheConfig):
self.config = config
# Flat allocation: [layers, batch, seq, 2, heads, dim]
shape = (
config.num_layers,
config.max_batch_size,
config.max_seq_len,
2, # K and V
config.num_heads,
config.head_dim
)
self.kv_data = np.zeros(shape, dtype=np.float16)
self.seq_lengths = np.zeros((config.num_layers, config.max_batch_size), dtype=np.int32)
def update(self, layer_idx: int, batch_idx: int,
seq_pos: int, k_new: np.ndarray, v_new: np.ndarray) -> None:
"""
Update cache with new token at specific position.
Args:
layer_idx: Which layer
batch_idx: Which batch element
seq_pos: Position in sequence (0-indexed)
k_new: [num_heads, head_dim]
v_new: [num_heads, head_dim]
"""
self.kv_data[layer_idx, batch_idx, seq_pos, 0] = k_new
self.kv_data[layer_idx, batch_idx, seq_pos, 1] = v_new
self.seq_lengths[layer_idx, batch_idx] = seq_pos + 1
def get_kv_slice(self, layer_idx: int, batch_idx: int,
start: int, end: int) -> Tuple[np.ndarray, np.ndarray]:
"""Get cached KV for a position range."""
return (
self.kv_data[layer_idx, batch_idx, start:end, 0].copy(),
self.kv_data[layer_idx, batch_idx, start:end, 1].copy()
)
def get_full_kv(self, layer_idx: int, batch_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""Get all cached KV for a batch element."""
seq_len = self.seq_lengths[layer_idx, batch_idx]
return (
self.kv_data[layer_idx, batch_idx, :seq_len, 0].copy(),
self.kv_data[layer_idx, batch_idx, :seq_len, 1].copy()
)
# =============================================================================
# Part 2: Multi-Head Attention with KV-Cache
# =============================================================================
class AttentionPattern(Enum):
"""Attention computation patterns."""
FULL = "full" # All tokens attend to all (prefill)
CAUSAL = "causal" # Causal masking (autoregressive)
SLIDING_WINDOW = "window" # Local window attention
class MultiHeadAttention:
"""
Multi-Head Attention with integrated KV-cache support.
Architecture:
┌──────────────────────────────────────────────────────────────────────┐
│ Multi-Head Attention │
│ │
│ Input: [batch, seq_len, d_model] │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Q = x @ W_q K = x @ W_k V = x @ W_v (Linear proj) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
│ │ Q_heads │ │ K + CacheK │ │ V + CacheV │ │
│ │ [B,H,L,D] │ │ [B,H,L',D] │ │ [B,H,L',D] │ │
│ └────────────┘ └────────────┘ └────────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Attention Score │ │
│ │ S = Q @ K^T / sqrt(d_k) + Mask │ │
│ │ │ │
│ │ [B, H, L, L'] where L' includes cached tokens │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Softmax │ │
│ │ A = softmax(S, dim=-1) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Output │ │
│ │ O = A @ V │ │
│ │ [B, H, L, D] │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ out = O @ W_o (Output projection) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────────────┘
"""
def __init__(
self,
d_model: int,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
bias: bool = True,
attention_pattern: AttentionPattern = AttentionPattern.CAUSAL
):
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = 1.0 / math.sqrt(head_dim)
self.pattern = attention_pattern
# Initialize weights (simplified - normally from trained model)
self.W_q = np.random.randn(d_model, num_heads * head_dim).astype(np.float32) * 0.02
self.W_k = np.random.randn(d_model, num_heads * head_dim).astype(np.float32) * 0.02
self.W_v = np.random.randn(d_model, num_heads * head_dim).astype(np.float32) * 0.02
self.W_o = np.random.randn(num_heads * head_dim, d_model).astype(np.float32) * 0.02
# Bias
if bias:
self.b_q = np.zeros(num_heads * head_dim, dtype=np.float32)
self.b_k = np.zeros(num_heads * head_dim, dtype=np.float32)
self.b_v = np.zeros(num_heads * head_dim, dtype=np.float32)
self.b_o = np.zeros(d_model, dtype=np.float32)
else:
self.b_q = self.b_k = self.b_v = self.b_o = None
def _project(self, x: np.ndarray, W: np.ndarray, b: Optional[np.ndarray]) -> np.ndarray:
"""Project input to Q/K/V space."""
# x: [batch, seq, d_model]
result = np.matmul(x, W)
if b is not None:
result += b
# Reshape to [batch, seq, heads, dim]
batch, seq, _ = result.shape
result = result.reshape(batch, seq, self.num_heads, self.head_dim)
# Transpose to [batch, heads, seq, dim] for efficient attention
return result.transpose(0, 2, 1, 3)
def _create_causal_mask(self, seq_len: int, cache_len: int) -> np.ndarray:
"""Create causal mask for autoregressive generation."""
total_len = seq_len + cache_len
# [1, 1, seq, total_len] - broadcastable
mask = np.triu(np.ones((seq_len, total_len), dtype=np.float32), k=1 - seq_len)
return mask[np.newaxis, np.newaxis]
def _apply_attention(
self,
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
mask: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Core attention computation: O = softmax(QK^T / sqrt(d_k))V
Args:
Q: [batch, heads, q_seq, dim]
K: [batch, heads, k_seq, dim]
V: [batch, heads, k_seq, dim]
mask: Optional attention mask
Returns:
output: [batch, heads, q_seq, dim]
"""
# Compute attention scores
# Q @ K^T: [batch, heads, q_seq, k_seq]
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) * self.scale
# Apply mask
if mask is not None:
scores = scores + mask
# Softmax
# Subtract max for numerical stability
scores_max = scores.max(axis=-1, keepdims=True)
scores = np.exp(scores - scores_max)
scores_sum = scores.sum(axis=-1, keepdims=True)
attn_weights = scores / (scores_sum + 1e-8)
# Apply to values
output = np.matmul(attn_weights, V)
return output
def forward(
self,
x: np.ndarray,
kv_cache: Optional[Dict[int, Tuple[np.ndarray, np.ndarray]]] = None,
layer_idx: int = 0,
use_cache: bool = True
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, np.ndarray]]]:
"""
Forward pass with optional KV-cache.
Args:
x: [batch, seq_len, d_model] - input tensor
kv_cache: Dict mapping layer_idx -> (k_cache, v_cache) from previous layers
layer_idx: Current layer index
use_cache: Whether to read from/write to cache
Returns:
output: [batch, seq_len, d_model]
kv_output: (k_current, v_current) for caching
"""
batch, q_seq, _ = x.shape
# Project to Q, K, V
Q = self._project(x, self.W_q, self.b_q) # [B, H, Q_seq, D]
K = self._project(x, self.W_k, self.b_k) # [B, H, Q_seq, D]
V = self._project(x, self.W_v, self.b_v) # [B, H, Q_seq, D]
# Combine with cached K, V if available
if use_cache and kv_cache and layer_idx in kv_cache:
k_cache, v_cache = kv_cache[layer_idx]
# k_cache: [B, H, cache_seq, D]
K_full = np.concatenate([k_cache, K], axis=2)
V_full = np.concatenate([v_cache, V], axis=2)
else:
K_full = K
V_full = V
# Create mask
cache_len = 0
if use_cache and kv_cache and layer_idx in kv_cache:
cache_len = kv_cache[layer_idx][0].shape[2]
mask = None
if self.pattern == AttentionPattern.CAUSAL:
mask = self._create_causal_mask(q_seq, cache_len)
# Broadcast mask to batch and heads
mask = np.broadcast_to(mask, (batch, self.num_heads, q_seq, q_seq + cache_len))
# Compute attention
attn_output = self._apply_attention(Q, K_full, V_full, mask)
# Reshape and project output
# [B, H, Q_seq, D] -> [B, Q_seq, H, D] -> [B, Q_seq, H*D]
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch, q_seq, -1)
output = np.matmul(attn_output, self.W_o)
if self.b_o is not None:
output += self.b_o
# Return new K, V for caching
kv_output = (K, V) if use_cache else None
return output, kv_output
# =============================================================================
# Part 3: Batching with Variable Sequence Lengths
# =============================================================================
@dataclass
class BatchElement:
"""
Represents a single element in a batch.
For autoregressive decoding, each element has:
- Its own position in the sequence
- Its own KV-cache
- Potentially different sequence history
"""
batch_idx: int
prompt_len: int # Length of initial prompt (prefill)
current_len: int # Current sequence length
generated_tokens: List[int] = field(default_factory=list)
logits: Optional[np.ndarray] = None
finished: bool = False
kv_cache_refs: Dict[int, int] = field(default_factory=dict) # layer -> block indices
class BatchedInferenceEngine:
"""
Manages batched inference with variable-length sequences.
Key challenges:
1. Different sequences have different lengths at each step
2. We only compute attention for valid (non-padding) positions
3. KV-caches must be indexed per-sequence
4. Finished sequences should stop consuming compute
Optimization strategies:
- Padding to longest sequence in batch
- Packing multiple sequences into one "pseudo-sequence"
- Dynamic batch scheduling based on length
"""
def __init__(self, config: CacheConfig, model: 'TransformerBlockStack'):
self.config = config
self.model = model
self.batch: List[BatchElement] = []
self.step: int = 0
# Initialize KV-caches
if config.pre_allocate:
self.kv_cache = FlatKVCache(config)
else:
self.kv_cache = PagedKVCache(config)
def add_to_batch(self, batch_idx: int, prompt_len: int) -> BatchElement:
"""Add an element to the batch."""
elem = BatchElement(
batch_idx=batch_idx,
prompt_len=prompt_len,
current_len=prompt_len
)
self.batch.append(elem)
return elem
def remove_from_batch(self, batch_idx: int) -> None:
"""Remove finished element from batch."""
self.batch = [e for e in self.batch if e.batch_idx != batch_idx]
def get_active_batch_size(self) -> int:
"""Get number of active (non-finished) elements."""
return sum(1 for e in self.batch if not e.finished)
def get_max_seq_len_in_batch(self) -> int:
"""Get maximum sequence length among active elements."""
return max(e.current_len for e in self.batch if not e.finished)
def _create_packed_input(
self,
input_tokens: List[np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Pack multiple sequences into one for efficient computation.
Uses position encoding to prevent attention between sequences.
Returns:
packed_input: [1, total_packed_len, d_model]
position_ids: [1, total_packed_len] - absolute positions
"""
# In practice, this would concatenate embeddings with position info
# For now, simplified version
max_len = max(t.shape[0] for t in input_tokens)
batch_size = len(input_tokens)
# Pad sequences
padded = []
for tokens in input_tokens:
if tokens.shape[0] < max_len:
pad = np.zeros((max_len - tokens.shape[0], tokens.shape[1]))
tokens = np.concatenate([tokens, pad], axis=0)
padded.append(tokens)
return np.stack(padded, axis=0), np.arange(max_len)[np.newaxis]
def step_inference(
self,
input_embeddings: np.ndarray, # [batch, 1, d_model]
use_kv_cache: bool = True
) -> List[np.ndarray]:
"""
Run one step of batched inference.
Args:
input_embeddings: [active_batch, 1, d_model] for new tokens
use_kv_cache: Whether to use KV-caching
Returns:
List of output logits per batch element
"""
batch_size = len(self.batch)
outputs = []
for i, elem in enumerate(self.batch):
if elem.finished:
outputs.append(np.zeros((self.config.max_seq_len, 1)))
continue
# Run model forward
output = self.model.forward(
input_embeddings[i:i+1],
use_kv_cache=use_kv_cache,
batch_idx=elem.batch_idx
)
elem.logits = output
elem.current_len += 1
outputs.append(output)
self.step += 1
return outputs
# =============================================================================
# Part 4: Complete Transformer Layer Integration
# =============================================================================
class TransformerBlock:
"""
Single transformer block with KV-cache support.
Architecture:
┌─────────────────────────────────────────────────────────────┐
│ Input x │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Pre-Norm: x + LayerNorm(x) │ │
│ └──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Multi-Head Self-Attention │ │
│ │ + KV-Cache (on decode) │ │
│ └──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Residual: + │ │
│ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Pre-Norm: + LayerNorm(+) │ │
│ └──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ FFN (SwiGLU or FFN2) │ │
│ └──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Residual: + │ │
│ └─────────────┘ │
│ │ │
│ ▼ │
│ Output │
└─────────────────────────────────────────────────────────────┘
"""
def __init__(
self,
d_model: int,
num_heads: int,
head_dim: int,
hidden_dim: int,
layer_idx: int
):
self.layer_idx = layer_idx
# Attention
self.attention = MultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
head_dim=head_dim
)
# Layer norms (simplified RMSNorm)
self.norm1_weight = np.ones(d_model, dtype=np.float32)
self.norm2_weight = np.ones(d_model, dtype=np.float32)
# FFN
self.W_up = np.random.randn(d_model, hidden_dim).astype(np.float32) * 0.02
self.W_down = np.random.randn(hidden_dim, d_model).astype(np.float32) * 0.02
def forward(
self,
x: np.ndarray,
kv_cache: Optional[Dict] = None,
use_cache: bool = True
) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
Forward pass through transformer block.
Returns:
output: Same shape as input
kv_output: (K, V) tensors for caching
"""
# Pre-norm for attention
x_normed = x * self.norm1_weight / (np.linalg.norm(x, axis=-1, keepdims=True) + 1e-8)
# Attention with cache
attn_out, kv_output = self.attention.forward(
x_normed, kv_cache, self.layer_idx, use_cache
)
# Residual
x = x + attn_out
# Pre-norm for FFN
x_normed = x * self.norm2_weight / (np.linalg.norm(x, axis=-1, keepdims=True) + 1e-8)
# FFN
ffn_out = np.matmul(np.matmul(x_normed, self.W_up), self.W_down)
# Residual
x = x + ffn_out
return x, kv_output
class TransformerBlockStack:
"""
Stack of transformer blocks with shared KV-cache.
Manages:
- Layer-by-layer forward pass
- KV-cache sharing across layers
- Efficient batched inference
"""
def __init__(
self,
num_layers: int,
d_model: int,
num_heads: int,
head_dim: int,
hidden_dim: int
):
self.num_layers = num_layers
self.d_model = d_model
# Create layers
self.layers: List[TransformerBlock] = []
for i in range(num_layers):
self.layers.append(
TransformerBlock(
d_model=d_model,
num_heads=num_heads,
head_dim=head_dim,
hidden_dim=hidden_dim,
layer_idx=i
)
)
# KV-cache store
self.kv_cache: Dict[int, Tuple[np.ndarray, np.ndarray]] = {}
# Final layer norm
self.final_norm = np.ones(d_model, dtype=np.float32)
def forward(
self,
x: np.ndarray,
use_kv_cache: bool = True,
batch_idx: int = 0
) -> np.ndarray:
"""
Forward pass through all layers.
Args:
x: [batch, seq, d_model]
use_kv_cache: Whether to read/write KV cache
batch_idx: Which batch element to cache for
Returns:
output: [batch, seq, d_model]
"""
for layer_idx, layer in enumerate(self.layers):
# Prepare KV cache for this layer
layer_cache = None
if use_kv_cache and layer_idx in self.kv_cache:
layer_cache = {layer_idx: self.kv_cache[layer_idx]}
# Forward through layer
x, kv_output = layer.forward(x, layer_cache, use_kv_cache)
# Store KV for caching
if use_kv_cache and kv_output is not None:
self.kv_cache[layer_idx] = kv_output
# Final norm
x = x * self.final_norm / (np.linalg.norm(x, axis=-1, keepdims=True) + 1e-8)
return x
def clear_cache(self) -> None:
"""Clear all cached KV pairs."""
self.kv_cache = {}
# =============================================================================
# Part 5: Autoregressive Generation Loop
# =============================================================================
class KVCacheAwareGenerator:
"""
Autoregressive generator with KV-cache optimization.
Generation flow:
┌────────────────────────────────────────────────────────────────────┐
│ Prefill Phase │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ Prompt: "The cat" │ │
│ │ Process full sequence with attention │ │
│ │ Cache all K, V for each layer │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ Decode Phase │ │
│ │ │ │
│ │ Step 1: │ │
│ │ Input: [token_id=456] (single token) │ │
│ │ Cache: [K_cache, V_cache] from prompt │ │
│ │ Compute: Q(new) @ concat(K_cache, K_new) │ │
│ │ Output: next token logits │ │
│ │ Update: Append K_new, V_new to cache │ │
│ │ │ │
│ │ Step 2: │ │
│ │ Input: [new_token] │ │
│ │ Cache: Includes tokens from Step 1 │ │
│ │ Compute: ... │ │
│ │ ... repeat until EOS or max_len │ │
│ └──────────────────────────────────────────────────────────────┘ │
└────────────────────────────────────────────────────────────────────┘
"""
def __init__(self, config: CacheConfig):
self.config = config
# Create model
self.model = TransformerBlockStack(
num_layers=config.num_layers,
d_model=config.num_heads * config.head_dim,
num_heads=config.num_heads,
head_dim=config.head_dim,
hidden_dim=4 * config.num_heads * config.head_dim
)
# Token embeddings (simplified)
self.vocab_size = 32000
self.embed_dim = config.num_heads * config.head_dim
self.token_embedding = np.random.randn(self.vocab_size, self.embed_dim).astype(np.float32) * 0.02
# Output projection
self.lm_head = np.random.randn(self.embed_dim, self.vocab_size).astype(np.float32) * 0.02
# Inference engine
self.engine = BatchedInferenceEngine(config, self.model)
def _embed_tokens(self, token_ids: np.ndarray) -> np.ndarray:
"""Convert token IDs to embeddings."""
# token_ids: [batch, seq_len]
# output: [batch, seq_len, embed_dim]
return self.token_embedding[token_ids]
def _sample_token(self, logits: np.ndarray, temperature: float = 1.0) -> int:
"""Sample next token from logits."""
if temperature == 0:
return np.argmax(logits)
# Apply temperature
logits = logits / temperature
# Softmax
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / np.sum(exp_logits)
# Sample
return np.random.choice(len(logits), p=probs)
def prefill(self, prompt_tokens: np.ndarray) -> None:
"""
Prefill phase: process full prompt and populate KV cache.
Args:
prompt_tokens: [batch, prompt_len] token IDs
"""
# Embed
embeddings = self._embed_tokens(prompt_tokens)
# Forward with cache population
self.model.forward(embeddings, use_kv_cache=True)
def decode_step(self, token_ids: np.ndarray) -> np.ndarray:
"""
Single decode step with cached attention.
Args:
token_ids: [batch, 1] - single token IDs
Returns:
next_token_logits: [batch, vocab_size]
"""
# Embed new tokens
embeddings = self._embed_tokens(token_ids)
# Forward with cache (only computes Q for new token)
hidden = self.model.forward(embeddings, use_kv_cache=True)
# Project to vocabulary
logits = np.matmul(hidden, self.lm_head)
return logits.squeeze(1) # [batch, vocab_size]
def generate(
self,
prompt_tokens: np.ndarray,
max_new_tokens: int = 100,
temperature: float = 1.0,
eos_token: int = 2
) -> List[List[int]]:
"""
Generate tokens autoregressively with KV-cache.
Args:
prompt_tokens: [batch, prompt_len]
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (0 = greedy)
eos_token: End-of-sequence token ID
Returns:
List of generated sequences (excluding prompts)
"""
batch_size = prompt_tokens.shape[0]
# Prefill phase
self.prefill(prompt_tokens)
# Decode phase
generated = [[] for _ in range(batch_size)]
current_tokens = prompt_tokens[:, -1:] # Last token
for step in range(max_new_tokens):
# Get logits
logits = self.decode_step(current_tokens)
# Sample tokens
next_tokens = [self._sample_token(logits[i], temperature)
for i in range(batch_size)]
# Store generated
for i in range(batch_size):
generated[i].append(next_tokens[i])
# Update for next step
current_tokens = np.array(next_tokens)[:, np.newaxis]
# Check for EOS
if all(t == eos_token for t in next_tokens):
break
# Clear cache after generation
self.model.clear_cache()
return generated
# =============================================================================
# Part 6: Memory Analysis
# =============================================================================
class MemoryAnalyzer:
"""
Analyzes memory growth patterns in KV-cache.
Memory Growth Model:
For a transformer with L layers, H heads, and D head dimension:
Key memory formulas:
1. Single Sequence Memory:
M_seq = 2 * L * S * H * D * bytes_per_element
where:
- 2 = K and V tensors
- S = sequence length
- L = number of layers
- H = number of attention heads
- D = head dimension
2. Per-token Memory Addition:
M_token = 2 * L * H * D * bytes_per_element
Each new token adds constant memory (O(1) per token)
3. Attention Complexity Growth:
- Prefill: O(S^2) for full attention
- Decode step: O(S) for cached attention
- This is where KV-cache provides exponential speedup
4. Memory for Batch with Variable Lengths:
M_batch = sum(M_seq_i for i in batch) + M_padding
5. Long Sequence Analysis:
At S=4096, L=32, H=32, D=128, FP16:
M_seq = 2 * 32 * 4096 * 32 * 128 * 2 bytes
≈ 2 GB per sequence!
Without cache: each decode step recomputes attention over 4096 tokens
With cache: each decode step only computes Q for new token
"""
def __init__(self, config: CacheConfig):
self.config = config
def memory_per_layer(self, seq_len: int) -> int:
"""Calculate memory for one layer at given sequence length."""
return 2 * seq_len * self.config.num_heads * self.config.head_dim * self.config.dtype_bytes
def memory_full_sequence(self) -> int:
"""Calculate memory for full max sequence."""
return self.config.memory_per_layer() * self.config.num_layers
def memory_growth_rate(self) -> int:
"""Memory added per new token."""
return self.memory_per_layer(seq_len=1)
def estimate_latency(self, seq_len: int, is_decode: bool = False) -> float:
"""
Estimate latency based on sequence length.
Args:
seq_len: Current sequence length
is_decode: True if this is a decode step (vs prefill)
Returns:
Estimated relative latency
"""
if is_decode:
# Decode: O(S) per step due to cache
return seq_len
else:
# Prefill: O(S^2) for full attention
return seq_len * seq_len
def print_memory_analysis(self) -> None:
"""Print comprehensive memory analysis."""
print("=" * 70)
print("KV-CACHE MEMORY ANALYSIS")
print("=" * 70)
config = self.config
print(f"\nConfiguration:")
print(f" - Num layers: {config.num_layers}")
print(f" - Num heads: {config.num_heads}")
print(f" - Head dim: {config.head_dim}")
print(f" - Max seq len: {config.max_seq_len}")
print(f" - Batch size: {config.max_batch_size}")
print(f" - Dtype: FP{config.dtype_bytes * 8}")
print(f"\nMemory per layer at max seq:")
for seq_len in [512, 1024, 2048, 4096, 8192, 16384]:
mem = self.memory_per_layer(seq_len)
print(f" seq_len={seq_len:>6}: {mem / 1024**2:>8.2f} MB")
print(f"\nMemory per token (growth rate):")
per_token = self.memory_growth_rate()
print(f" {per_token / 1024**2:.4f} MB per token")
print(f"\nTotal memory for full configuration:")
total = self.memory_full_sequence()
print(f" Single sequence: {total / 1024**3:.2f} GB")
print(f" Full batch ({config.max_batch_size}): {total * config.max_batch_size / 1024**3:.2f} GB")
print(f"\nLatency comparison (relative units):")
print(f" {'Seq Len':>10} | {'Prefill':>10} | {'Decode (cached)':>15} | {'Speedup':>10}")
print(f" {'-'*10}-+-{'-'*10}-+-{'-'*15}-+-{'-'*10}")
for seq_len in [128, 512, 1024, 2048, 4096]:
prefill = self.estimate_latency(seq_len, is_decode=False)
decode = self.estimate_latency(seq_len, is_decode=True)
speedup = prefill / decode if decode > 0 else 0
print(f" {seq_len:>10} | {prefill:>10.0f} | {decode:>15.0f} | {speedup:>10.1f}x")
print("\n" + "=" * 70)
# =============================================================================
# Part 7: Optimization Strategies
# =============================================================================
class OptimizationStrategies:
"""
Optimization strategies for KV-cache.
This module documents and implements key optimizations:
1. PAGED ATTENTION
─────────────────────────────────────────────────────────────────────
Problem: Pre-allocating max_seq_len for each sequence wastes memory
when actual sequences are much shorter.
Solution: Split KV-cache into fixed-size blocks (pages).
┌──────────────────────────────────────────────────────────────┐
│ Instead of: │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ [max_seq_len=4096, ...] (most unused!) │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ Use blocks: │
│ ┌────┐ ┌────┐ ┌────┐ ┌────┐ │
│ │ 16 │ │ 16 │ │ 16 │ │ 16 │ ... (allocate on demand) │
│ └────┘ └────┘ └────┘ └────┘ │
└──────────────────────────────────────────────────────────────┘
Benefits:
- 50-80% memory reduction for variable-length batches
- Enables efficient KV offloading
- Supports speculative decoding with variable-length candidates
Implementation: See PagedKVCache class above
─────────────────────────────────────────────────────────────────────
2. CHUNKED ATTENTION / FLASH ATTENTION
─────────────────────────────────────────────────────────────────────
Problem: Attention score matrix S = Q @ K^T grows as O(S^2) in memory.
For long sequences, this doesn't fit in SRAM/registers.
Solution: Process attention in chunks that fit in fast memory.
┌──────────────────────────────────────────────────────────────┐
│ Chunked attention algorithm: │
│ │
│ for chunk_q in Q_chunks: │
│ for chunk_k in K_chunks: │
│ S_chunk = chunk_q @ chunk_k.T │
│ P_chunk = softmax(S_chunk) │
│ O_chunk = P_chunk @ chunk_v │
│ accumulate(O_chunk) │
│ │
│ Key insight: Only need O(chunk_size^2) intermediate memory │
└──────────────────────────────────────────────────────────────┘
Benefits:
- Enables processing of arbitrarily long sequences
- ~2-4x speedup from better memory locality
- Flash Attention: IO-aware chunking for GPU memory hierarchy
─────────────────────────────────────────────────────────────────────
3. KV CACHE COMPRESSION
─────────────────────────────────────────────────────────────────────
Problem: Storing full precision KV-cache for all tokens is expensive.
Solutions:
a) Key-Value Quantization:
- Store K,V in INT8/FP8 instead of FP16/BF16
- Use per-channel or per-token scaling
- Typical compression: 2x (FP16→INT8)
b) Sparse KV-cache:
- Store only "important" tokens (high attention score)
- Use locality-sensitive hashing for attention grouping
- Mythomedia: keep only top-k per head per token
c) Token Selection:
- Heavy Head Hypothesis: Some attention heads are key
- H2O: Keep tokens with highest attention flow
- StreamingLLM: Keep recent + "sink" tokens
d) Temporal Compression:
- Differential coding: store delta from previous
- Run-length encoding for repeated patterns
─────────────────────────────────────────────────────────────────────
4. PARALLEL EXECUTION STRATEGIES
─────────────────────────────────────────────────────────────────────
a) Tensor Parallelism:
- Split heads across GPUs: Q, K, V distributed
- All-reduce for attention output
b) Pipeline Parallelism:
- Different layers on different devices
- Micro-batching to reduce pipeline bubbles
- KV-cache passed between stages
c) Sequence Parallelism:
- Split sequence across attention heads
- Ring attention for long sequences
- Each GPU handles subset of heads
─────────────────────────────────────────────────────────────────────
5. SPECULATIVE DECODING
─────────────────────────────────────────────────────────────────────
Problem: Autoregressive decoding is sequential (N steps for N tokens).
Solution: Draft multiple tokens speculatively, verify in parallel.
┌──────────────────────────────────────────────────────────────┐
│ Draft model (smaller): │
│ Generates K candidate tokens in parallel │
│ │
│ Target model (larger): │
│ Verifies all K tokens in ONE forward pass │
│ Accept/reject based on logits │
│ │
│ Speedup: ~2-4x when drafts are good │
│ Key: KV-cache is reused for verification │
└──────────────────────────────────────────────────────────────┘
"""
@staticmethod
def demonstrate_paged_attention() -> None:
"""Demonstrate memory savings from paged attention."""
print("\n" + "=" * 70)
print("PAGED ATTENTION DEMONSTRATION")
print("=" * 70)
# Example batch with variable lengths
batch_sequences = [127, 256, 512, 1024, 2048, 4096]
block_size = 16
num_heads = 32
head_dim = 128
print(f"\nBlock size: {block_size} tokens")
print(f"Heads: {num_heads}, Head dim: {head_dim}")
print(f"\n{'Sequence':>10} | {'Tokens':>6} | {'Blocks':>6} | {'Paged Mem':>12} | {'Flat Mem':>12} | {'Savings':>10}")
print("-" * 70)
for seq_len in batch_sequences:
# Paged: ceiling division for blocks
num_blocks = (seq_len + block_size - 1) // block_size
paged_mem = num_blocks * block_size * num_heads * head_dim * 2 * 2 # bytes (FP16)
flat_mem = seq_len * num_heads * head_dim * 2 * 2
savings = (1 - paged_mem / flat_mem) * 100 if flat_mem > 0 else 0
print(f"{seq_len:>10} | {seq_len:>6} | {num_blocks:>6} | "
f"{paged_mem/1024**2:>10.2f}MB | {flat_mem/1024**2:>10.2f}MB | {savings:>9.1f}%")
# Total comparison
total_tokens = sum(batch_sequences)
total_paged = sum((s + block_size - 1) // block_size * block_size for s in batch_sequences)
total_flat = sum(batch_sequences)
print(f"\n{'TOTAL':>10} | {total_tokens:>6} | {'':<6} | "
f"{total_paged*num_heads*head_dim*2*2/1024**2:>10.2f}MB | "
f"{total_flat*num_heads*head_dim*2*2/1024**2:>10.2f}MB | "
f"{(1-total_paged/total_flat)*100:>9.1f}%")
@staticmethod
def demonstrate_quantization() -> None:
"""Demonstrate KV-cache quantization savings."""
print("\n" + "=" * 70)
print("KV-CACHE QUANTIZATION DEMONSTRATION")
print("=" * 70)
config = CacheConfig(
max_batch_size=32,
max_seq_len=4096,
num_heads=32,
head_dim=128,
num_layers=32,
dtype_bytes=2 # FP16
)
print(f"\nSequence length: {config.max_seq_len}")
print(f"Layers: {config.num_layers}, Heads: {config.num_heads}, Head dim: {config.head_dim}")
print(f"\n{'Format':>15} | {'Bytes/Element':>15} | {'Total Memory':>15} | {'Compression':>12}")
print("-" * 70)
formats = [
("FP32", 4),
("FP16", 2),
("BF16", 2),
("INT8", 1),
("INT4", 0.5),
("INT2", 0.25),
]
base_mem = 2 * config.num_layers * config.max_seq_len * config.num_heads * config.head_dim * 2 # FP16 base
for name, bytes_per_elem in formats:
mem = 2 * config.num_layers * config.max_seq_len * config.num_heads * config.head_dim * bytes_per_elem
ratio = base_mem / mem if mem > 0 else 0
print(f"{name:>15} | {bytes_per_elem:>15.2f} | {mem/1024**3:>13.2f} GB | {ratio:>10.1f}x")
@staticmethod
def demonstrate_chunked_attention() -> None:
"""Demonstrate chunked attention computation."""
print("\n" + "=" * 70)
print("CHUNKED ATTENTION (FLASH ATTENTION STYLE)")
print("=" * 70)
# Simulate chunked attention
seq_len = 4096
head_dim = 128
chunk_size = 256
print(f"\nSequence length: {seq_len}")
print(f"Head dimension: {head_dim}")
print(f"Chunk size: {chunk_size}")
print(f"\nMemory comparison:")
# Full attention: S x S matrix
full_attention_mem = seq_len * seq_len * 4 # 32-bit floats for scores
print(f" Full attention: {full_attention_mem/1024**2:.2f} MB for attention scores")
# Chunked attention: only chunk_size^2 per chunk
chunk_attention_mem = chunk_size * chunk_size * 4
print(f" Chunked attention: {chunk_attention_mem/1024**2:.2f} MB per chunk")
print(f" Memory reduction: {full_attention_mem/chunk_attention_mem:.0f}x")
print(f"\n Number of chunks: {seq_len // chunk_size}")
print(f" Processing: chunks processed sequentially, results accumulated")
print(f" Note: Total compute is same, but peak memory is reduced")
# =============================================================================
# Part 8: GPU Execution Mapping
# =============================================================================
class GPUExecutionMapper:
"""
Maps KV-cache operations to GPU execution model.
GPU Memory Hierarchy:
┌─────────────────────────────────────────────────────────────────┐
│ HBM (High Bandwidth Memory) │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ KV-cache (main storage) │ │
│ │ - Flat format for current active sequences │ │
│ │ - Offloaded to CPU/disk for finished sequences │ │
│ │ Capacity: 80GB H100, 24-80GB A100 │ │
│ └──────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
│ │ │
│ │ ( PCIe / NVLink ) │
│ ▼ │
┌─────────────────────────────────────────────────────────────────┐
│ SM (Streaming Multiprocessor) │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Shared Memory / L1 Cache │ │
│ │ - Chunk of KV for current computation │ │
│ │ - Tile size: 16-64KB per SM │ │
│ │ Capacity: 192KB A100, 228KB H100 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Registers │ │
│ │ - Q, K, V tiles for current warp │ │
│ │ - Per-thread: 255 registers A100, 65536 H100 │ │
│ └──────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
CUDA Kernel Design for KV-Cache Attention:
1. KERNEL: kvcache_update
─────────────────────────────────────────────────────────────────
Purpose: Store new K, V to global memory
Thread grid: (batch * num_heads, head_dim / warp_size)
Each thread writes one element
__global__ void kvcache_update(
const half* k_new, // [batch, heads, 1, dim]
const half* v_new, // [batch, heads, 1, dim]
half* k_cache, // [batch, heads, max_seq, dim]
half* v_cache,
int* seq_lengths, // Current lengths per batch
int batch_idx,
int layer_idx
) {
int h = blockIdx.x % num_heads;
int d = threadIdx.x;
int pos = atomicAdd(&seq_lengths[layer_idx, batch_idx], 1);
k_cache[batch_idx, h, pos, d] = k_new[batch_idx, h, 0, d];
v_cache[batch_idx, h, pos, d] = v_new[batch_idx, h, 0, d];
}
2. KERNEL: attention_with_cache
─────────────────────────────────────────────────────────────────
Purpose: Compute attention using cached K, V
Thread grid: (batch, num_heads, num_chunks)
Uses shared memory for tile storage
Key optimizations:
- Load K, V tiles from global to shared memory
- Compute QK^T in shared memory
- Apply causal mask inline
- Softmax with warp-level reduction
- Store results to output tensor
3. KERNEL: attention_batch_update
─────────────────────────────────────────────────────────────────
Purpose: Prefill phase - process multiple tokens at once
Different from decode: processes S tokens instead of 1
No causal masking needed (can compute full attention)
GPU Memory Access Patterns:
┌──────────────────────────────────────────────────────────────────┐
│ KV-Cache Access Pattern During Decode │
│ │
│ Q (new): k_cache[b, h, pos, :] → strided access, pos = S-1 │
│ K (cached): k_cache[b, h, 0:S, :] → strided, S reads │
│ V (cached): v_cache[b, h, 0:S, :] → strided, S reads │
│ │
│ For batch B, heads H, dim D: │
│ - Q load: B * H * D elements (1 position) │
│ - K load: B * H * S * D elements (full cache) │
│ - V load: B * H * S * D elements (full cache) │
│ │
│ Memory bandwidth: O(B * H * S * D) per decode step │
│ This is the bottleneck! │
└──────────────────────────────────────────────────────────────────┘
Optimization Strategies for GPU:
1. Memory Coalescing:
- Organize cache as [batch, heads, seq, dim] for coalesced access
- Each thread accesses consecutive elements
2. Tensor Core Usage:
- Use WMMA (Warp Matrix Multiply Accumulate) for attention
- Gemm operations: Q @ K^T, attention @ V
3. Asynchronous Execution:
- Overlap KV-cache updates with next forward pass
- Use CUDA streams for independent operations
4. Persistent Kernels:
- Keep kernel resident for low-latency decode
- Reduces kernel launch overhead
"""
@staticmethod
def print_gpu_execution_analysis() -> None:
"""Print detailed GPU execution analysis."""
print("\n" + "=" * 70)
print("GPU EXECUTION MAPPING FOR KV-CACHE")
print("=" * 70)
# GPU specifications
gpus = {
"A100": {"memory_gb": 80, "bandwidth_gbps": 2039, "sm_count": 108},
"H100": {"memory_gb": 80, "bandwidth_gbps": 3350, "sm_count": 132},
"L40S": {"memory_gb": 48, "bandwidth_gbps": 900, "sm_count": 84},
}
config = CacheConfig(
max_batch_size=32,
max_seq_len=4096,
num_heads=32,
head_dim=128,
num_layers=32
)
tokens_per_layer = config.max_seq_len * config.num_heads * config.head_dim
bytes_per_layer = tokens_per_layer * config.dtype_bytes * 2 # K and V
total_bytes = bytes_per_layer * config.num_layers
print(f"\nKV-cache size (full config): {total_bytes / 1024**3:.2f} GB")
print(f"\nGPU Specifications:")
for gpu_name, specs in gpus.items():
cache_fit = "YES" if total_bytes < specs["memory_gb"] * 1024**3 else "NO"
bw_time = total_bytes / (specs["bandwidth_gbps"] * 1e9) * 1000
print(f"\n {gpu_name}:")
print(f" Memory: {specs['memory_gb']} GB")
print(f" Bandwidth: {specs['bandwidth_gbps']} GB/s")
print(f" SMs: {specs['sm_count']}")
print(f" KV-cache fits: {cache_fit}")
print(f" Time to read full cache: {bw_time:.2f} ms")
print("\n" + "-" * 70)
print("EXECUTION PIPELINE")
print("-" * 70)
print("""
┌──────────────────────────────────────────────────────────────────────┐
│ DECODE STEP EXECUTION │
│ │
│ Step N: │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 1. Load Q (current token) from hidden states │ │
│ │ ← Compute previous layer │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 2. Load KV cache from HBM │ │
│ │ [batch * heads * seq * dim] ← main memory │ │
│ │ Latency: ~1-2 μs per access (H100) │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 3. Compute Q @ K^T (attention scores) │ │
│ │ Tensor cores: 256x128x128 GEMM per batch element │ │
│ │ Output: [batch, heads, 1, seq] │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 4. Softmax with causal mask │ │
│ │ Warp-level reduction │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 5. Compute softmax(QK^T) @ V │ │
│ │ Second GEMM: [batch, heads, 1, dim] │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 6. Update KV cache (write new K, V) │ │
│ │ → HBM async write │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 7. Output projection + FFN │ │
│ │ Standard transformer layers │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Step N+1 ready │
└──────────────────────────────────────────────────────────────────────┘
""")
print("-" * 70)
print("CRITICAL BOTTLENECKS AND SOLUTIONS")
print("-" * 70)
bottlenecks = [
("Memory Bandwidth",
"Full KV cache read every decode step",
"Paged attention + async copy, KV compression"),
("Attention O(S²)",
"Score matrix grows with sequence",
"Flash Attention chunking, sparse attention"),
("Synchronization",
"Wait for KV update before next step",
"Lookahead decoding, speculative execution"),
("Memory Capacity",
"Can't cache all sequences indefinitely",
"KV offloading, selective caching"),
]
for name, problem, solution in bottlenecks:
print(f"\n{name}:")
print(f" Problem: {problem}")
print(f" Solution: {solution}")
# =============================================================================
# Part 9: Demonstration and Testing
# =============================================================================
def run_demo():
"""Run comprehensive demonstration of KV-cache system."""
print("=" * 70)
print("KV-CACHE SYSTEM DEMONSTRATION")
print("=" * 70)
# Configuration
config = CacheConfig(
max_batch_size=8,
max_seq_len=512, # Reduced for demo
num_heads=4, # Reduced for demo
head_dim=32, # Reduced for demo
num_layers=2, # Reduced for demo
dtype_bytes=2
)
# Memory analysis
analyzer = MemoryAnalyzer(config)
analyzer.print_memory_analysis()
# Optimizations
OptimizationStrategies.demonstrate_paged_attention()
OptimizationStrategies.demonstrate_quantization()
OptimizationStrategies.demonstrate_chunked_attention()
# GPU execution
GPUExecutionMapper.print_gpu_execution_analysis()
# Functional test
print("\n" + "=" * 70)
print("FUNCTIONAL TEST")
print("=" * 70)
# Create small model for testing
test_config = CacheConfig(
max_batch_size=4,
max_seq_len=64,
num_heads=2,
head_dim=16,
num_layers=2,
dtype_bytes=2
)
generator = KVCacheAwareGenerator(test_config)
# Simple test
prompt = np.array([[1, 2, 3, 4, 5]]) # Batch of 1, seq of 5
print(f"\nInput prompt: shape {prompt.shape}")
# Prefill
generator.prefill(prompt)
print("Prefill complete - KV cache populated")
# Decode steps
print("\nDecode steps:")
for i in range(3):
next_token = np.array([[prompt[0, -1] + i + 1]]) # Simulated next token
logits = generator.decode_step(next_token)
print(f" Step {i+1}: output shape {logits.shape}, max logit at position {np.argmax(logits)}")
print("\nTest completed successfully!")
print("=" * 70)
if __name__ == "__main__":
run_demo()