feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
This commit is contained in:
@@ -0,0 +1,37 @@
|
||||
All 10 demos pass cleanly. Here's a summary of what was built:
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
kv/
|
||||
├── kv_cache.py # Core KV-cache data structures
|
||||
├── attention.py # Attention computation (standard, cached, GQA, masked)
|
||||
├── transformer.py # Full transformer decoder with KV-cache integration
|
||||
├── optimizations.py # Paged attention, quantization, chunked prefill
|
||||
├── memory_analysis.py # Memory growth formulas, GPU limits
|
||||
├── gpu_mapping.py # GPU kernel design, Tensor Core analysis
|
||||
├── demo.py # 10 end-to-end demos
|
||||
└── README.md # Comprehensive documentation
|
||||
```
|
||||
|
||||
## What Each Demo Validates
|
||||
|
||||
| # | Demo | Validates |
|
||||
|---|------|-----------|
|
||||
| 1 | Basic KV Cache | Data structure layout, O(1) append, memory tracking |
|
||||
| 2 | Cached Attention | Q @ K^T @ V using cache, matches manual computation |
|
||||
| 3 | Full Transformer | Prefill + generation loop, variable-length batching |
|
||||
| 4 | Variable-Length Batching | Per-sequence attention with different lengths |
|
||||
| 5 | Paged Attention | Block allocation, page tables, non-contiguous memory |
|
||||
| 6 | Quantized Cache | int8 quantization/dequantization with affine transform |
|
||||
| 7 | Chunked Prefill | Matches full attention (4.56e-10 diff), 8× memory savings |
|
||||
| 8 | Optimization Comparison | Side-by-side memory costs of all strategies |
|
||||
| 9 | Memory Analysis | Model sizes, growth curves, max context per GPU |
|
||||
| 10 | GPU Tensor Cores | Arithmetic intensity → all configs are memory-bound |
|
||||
|
||||
## Key Findings
|
||||
|
||||
- **KV cache is linear**: `2 × batch × layers × heads × seq × head_dim × elem_bytes`
|
||||
- **7B model on H100-80GB**: max ~121K tokens (batch=1), ~30K (batch=4)
|
||||
- **Generation is memory-bound**: 1.0 FLOPs/byte intensity — bottleneck is HBM bandwidth, not compute
|
||||
- **Paged + quantized**: 2-4× memory reduction, enabling proportionally longer contexts
|
||||
@@ -0,0 +1,19 @@
|
||||
Implement an efficient KV-cache system for autoregressive transformer inference from scratch.
|
||||
|
||||
Requirements:
|
||||
1. Support incremental decoding (one token at a time).
|
||||
2. Avoid recomputing attention for past tokens.
|
||||
3. Handle:
|
||||
- multi-head attention
|
||||
- batching with variable sequence lengths
|
||||
4. Provide:
|
||||
- data structure layout (memory format)
|
||||
- update logic per step
|
||||
- attention computation using cached keys/values
|
||||
|
||||
Additionally:
|
||||
- Analyze memory growth over long sequences.
|
||||
- Propose at least two optimizations (e.g., paged attention, chunking, compression).
|
||||
- Explain how this would map to GPU execution.
|
||||
|
||||
Do not use any frameworks.
|
||||
@@ -0,0 +1,246 @@
|
||||
# KV-Cache System for Autoregressive Transformer Inference
|
||||
|
||||
Pure NumPy implementation — no frameworks. Demonstrates the complete KV-cache pipeline from data structures through GPU mapping.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ INFERENCE PIPELINE │
|
||||
│ │
|
||||
│ Prompt ──→ [Prefill] ──→ KV Cache populated ──→ [Generate] │
|
||||
│ O(n²) attn O(1) per token O(seq) attn │
|
||||
│ │
|
||||
│ Per generation step: │
|
||||
│ 1. Embed + positional encoding │
|
||||
│ 2. For each layer: │
|
||||
│ a. LayerNorm → QKV projection │
|
||||
│ b. Store K,V in cache (append at write_pos) │
|
||||
│ c. Cached attention: Q @ K_cache^T → softmax → @ V_cache │
|
||||
│ d. Output projection → MLP → residual │
|
||||
│ 3. LM head → logits → sample next token │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `kv_cache.py` | Core KV-cache data structures (`KVCache`, `BatchedKVCache`) |
|
||||
| `attention.py` | Attention computation (standard, cached, GQA, masked) |
|
||||
| `transformer.py` | Full transformer decoder layer + model with KV-cache integration |
|
||||
| `optimizations.py` | Paged attention, quantization, chunked prefill |
|
||||
| `memory_analysis.py` | Memory growth formulas, model size comparisons, GPU limits |
|
||||
| `gpu_mapping.py` | GPU kernel design, Tensor Core analysis, multi-GPU strategies |
|
||||
| `demo.py` | 10 end-to-end demos exercising every component |
|
||||
|
||||
## 1. Data Structure Layout
|
||||
|
||||
### Memory Format
|
||||
|
||||
```
|
||||
cache_k[batch, num_heads, max_seq_len, head_dim] # float16
|
||||
cache_v[batch, num_heads, max_seq_len, head_dim] # float16
|
||||
lengths[batch] # int32 (actual seq len per item)
|
||||
write_pos # int (global write pointer)
|
||||
```
|
||||
|
||||
**Why this layout:**
|
||||
- `batch` first → enables batched GEMM on GPU
|
||||
- `heads` second → parallel head computation
|
||||
- `seq_len` third → contiguous scan for Q @ K^T
|
||||
- `head_dim` last → inner product dimension, coalesced access
|
||||
|
||||
### Per-Token Memory Cost
|
||||
|
||||
For a 7B model (32 layers, 32 heads, head_dim=128, fp16):
|
||||
|
||||
```
|
||||
Per token per layer: 2 × 32 × 128 × 2 bytes = 16 KB
|
||||
Per token (all layers): 16 KB × 32 = 512 KB
|
||||
At 32K context: 512 KB × 32,768 = 16 GB
|
||||
```
|
||||
|
||||
## 2. Update Logic Per Step
|
||||
|
||||
```python
|
||||
# Each generation step:
|
||||
pos = cache.write_pos
|
||||
cache.cache_k[:, :, pos, :] = new_k[:, :, 0, :] # O(1) write
|
||||
cache.cache_v[:, :, pos, :] = new_v[:, :, 0, :] # O(1) write
|
||||
cache.write_pos += 1
|
||||
```
|
||||
|
||||
The write is a simple memory copy — no computation needed. The cache grows by exactly `2 × heads × head_dim × elem_bytes` per token per layer.
|
||||
|
||||
## 3. Attention Computation Using Cache
|
||||
|
||||
```python
|
||||
# Retrieve all cached K, V
|
||||
cached_k, cached_v = cache.get_all() # (batch, heads, seq_so_far, head_dim)
|
||||
|
||||
# Q @ K^T: (batch, heads, 1, head_dim) × (batch, heads, head_dim, seq)
|
||||
scores = einsum("bhqd,bhkd->bhqk", q, cached_k) / sqrt(head_dim)
|
||||
|
||||
# Softmax (no mask needed — cache only has past tokens)
|
||||
attn = softmax(scores, axis=-1)
|
||||
|
||||
# Attn @ V: (batch, heads, 1, seq) × (batch, heads, seq, head_dim)
|
||||
output = einsum("bhqk,bhkd->bhqd", attn, cached_v)
|
||||
```
|
||||
|
||||
**Key insight:** During generation, the cache naturally enforces causality — it only contains past tokens, so no explicit mask is needed.
|
||||
|
||||
## 4. Memory Growth Analysis
|
||||
|
||||
### Linear Growth Formula
|
||||
|
||||
```
|
||||
KV_cache(bytes) = 2 × batch × layers × heads × seq_len × head_dim × elem_bytes
|
||||
```
|
||||
|
||||
### 7B Model (batch=1, fp16)
|
||||
|
||||
| Context | KV Cache | Total (params + KV) | KV Fraction |
|
||||
|---------|----------|---------------------|-------------|
|
||||
| 256 | 0.12 GB | 7.04 GB | 1.8% |
|
||||
| 4,096 | 2.00 GB | 8.91 GB | 22.4% |
|
||||
| 8,192 | 4.00 GB | 10.91 GB | 36.7% |
|
||||
| 32,768 | 16.00 GB | 22.91 GB | 69.8% |
|
||||
|
||||
### Maximum Context by GPU (7B model, batch=1)
|
||||
|
||||
| GPU | Max Context |
|
||||
|-----|-------------|
|
||||
| RTX 4090 (24 GB) | 6,690 tokens |
|
||||
| A100-40GB | 39,458 tokens |
|
||||
| A100-80GB / H100-80GB | 121,378 tokens |
|
||||
|
||||
### Batch Size Impact
|
||||
|
||||
KV cache scales linearly with batch size. At batch=4, the 7B model on an A100-80GB can only handle ~30K context instead of 121K.
|
||||
|
||||
## 5. Optimizations
|
||||
|
||||
### Optimization 1: Paged Attention (vLLM-style)
|
||||
|
||||
**Problem:** Contiguous allocation wastes memory when sequences have variable lengths. A batch with one 32K sequence and three 100-token sequences still allocates 32K for all.
|
||||
|
||||
**Solution:** Divide memory into fixed-size blocks (pages). Each sequence maintains a page table mapping logical blocks to physical pages.
|
||||
|
||||
```
|
||||
Physical page pool: (total_pages, heads, block_size, head_dim)
|
||||
Page table: (batch, max_blocks) → logical → physical mapping
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Zero memory fragmentation
|
||||
- Supports speculative decoding and branching
|
||||
- Enables prefix caching (share common prefixes)
|
||||
- No need to pre-allocate max_seq_len
|
||||
|
||||
**Trade-off:** Page table indirection adds complexity to the attention kernel (gather from non-contiguous pages).
|
||||
|
||||
### Optimization 2: Quantization
|
||||
|
||||
**Problem:** fp16 KV cache dominates memory for long contexts.
|
||||
|
||||
**Solution:** Store K/V in int8 with per-channel affine dequantization: `x ≈ scale × q + zero`
|
||||
|
||||
```
|
||||
int8 data: 1 byte per element (vs 2 for fp16)
|
||||
fp16 scales + zeros: shared per channel (not per token)
|
||||
Net savings: ~50% memory with <1% accuracy loss
|
||||
```
|
||||
|
||||
**Production approach:** Shared per-channel scales (not per-position) stored in fp16. The per-position approach in this codebase is for correctness demonstration but has higher overhead.
|
||||
|
||||
### Optimization 3: Chunked Prefill
|
||||
|
||||
**Problem:** Processing a 32K prompt requires materializing a 32K × 32K attention matrix (4 GB in fp32).
|
||||
|
||||
**Solution:** Process the prompt in chunks of size C. Each chunk attends to all previous tokens + causal within chunk.
|
||||
|
||||
```
|
||||
Peak memory: O(C × seq_len) instead of O(seq_len²)
|
||||
For C=512, seq=4096: 8 MB vs 64 MB (8× savings)
|
||||
```
|
||||
|
||||
### Combined: Paged + Quantized
|
||||
|
||||
Together these give 2-4× memory reduction, enabling 2-4× longer contexts in the same GPU memory.
|
||||
|
||||
## 6. GPU Execution Mapping
|
||||
|
||||
### Memory Hierarchy
|
||||
|
||||
| Level | Size | Latency | Usage |
|
||||
|-------|------|---------|-------|
|
||||
| Registers | 64 KB/SM | 1 cycle | Thread-local, warp computation |
|
||||
| Shared memory | 166 KB/SM (H100) | 1-3 cycles | Tiling, softmax intermediates |
|
||||
| L2 cache | 50 MB (H100) | ~20 cycles | Automatic global memory caching |
|
||||
| HBM | 80 GB (H100) | ~300-400 cycles | Model weights, KV cache, activations |
|
||||
|
||||
### Cached Attention Kernel Design
|
||||
|
||||
```
|
||||
Grid: (batch_size, num_heads, 1)
|
||||
Block: (32, 32) = 1024 threads
|
||||
|
||||
Shared memory per block (~16-20 KB):
|
||||
- Q tile: 1 × head_dim (512 bytes fp16)
|
||||
- K tile: 32 × head_dim (8 KB fp16)
|
||||
- Score tile: 32 × 32 (4 KB fp16)
|
||||
```
|
||||
|
||||
**Optimization strategies:**
|
||||
1. Coalesced global memory access (warp-level consecutive addresses)
|
||||
2. Tiled GEMM with shared memory
|
||||
3. Persistent kernels (keep blocks alive until all tiles processed)
|
||||
4. Async copy (H100 `cp.async`) to overlap memory transfer with computation
|
||||
5. Tensor Cores (`mma.sync`) for matmul operations
|
||||
6. Fusion: merge softmax with attention score computation
|
||||
|
||||
### Arithmetic Intensity
|
||||
|
||||
For single-token generation (batch=1, heads=32, seq=4096):
|
||||
- **FLOPs:** 0.02 GFLOPs
|
||||
- **Memory traffic:** 16.79 MB
|
||||
- **Arithmetic intensity:** 1.0 FLOPs/byte
|
||||
- **→ Memory-bound** (H100 peak: 1,970 TFLOPS, 3.35 TB/s)
|
||||
|
||||
The cached attention is fundamentally memory-bound — the bottleneck is reading the KV cache from HBM, not computation. This is why bandwidth-optimized kernels (FlashAttention-style tiling) matter more than compute optimization.
|
||||
|
||||
### Multi-GPU Strategies
|
||||
|
||||
| Strategy | KV Cache Impact | Best For |
|
||||
|----------|----------------|----------|
|
||||
| Tensor parallelism | Split K/V by head_dim across GPUs | Large models |
|
||||
| Pipeline parallelism | Each GPU holds its layer shard's K/V | Very large models |
|
||||
| Sequence parallelism | Split K/V by sequence dimension | Long context prefill |
|
||||
| Expert parallelism | KV cache shared; only MLP experts sharded | MoE models |
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# Run all 10 demos
|
||||
python demo.py
|
||||
|
||||
# Run memory analysis standalone
|
||||
python memory_analysis.py
|
||||
|
||||
# Run GPU mapping report
|
||||
python gpu_mapping.py
|
||||
```
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **KV cache grows linearly** with sequence length and batch size — it's the dominant memory cost for long-context inference.
|
||||
|
||||
2. **Generation is memory-bound** — reading the KV cache from HBM dominates latency, not the attention computation itself.
|
||||
|
||||
3. **Paged attention eliminates fragmentation** — critical for serving variable-length sequences efficiently.
|
||||
|
||||
4. **Quantization gives ~50% savings** with minimal accuracy loss when using shared per-channel scales.
|
||||
|
||||
5. **FlashAttention-style tiling** reduces HBM traffic by processing K/V in tiles that fit in shared memory, cutting memory bandwidth requirements by 2-4×.
|
||||
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Attention Computation with KV-Cache
|
||||
|
||||
Implements:
|
||||
1. Standard scaled dot-product attention (no cache)
|
||||
2. Cached attention for incremental decoding
|
||||
3. Masked attention for variable-length batches
|
||||
4. Multi-query and grouped-query attention variants
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
from kv_cache import KVCache, CacheConfig
|
||||
|
||||
|
||||
def softmax_stable(x: np.ndarray, axis: int = -1) -> np.ndarray:
|
||||
"""Numerically stable softmax."""
|
||||
x_max = np.max(x, axis=axis, keepdims=True)
|
||||
exp_x = np.exp(x - x_max)
|
||||
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
q: np.ndarray,
|
||||
k: np.ndarray,
|
||||
v: np.ndarray,
|
||||
scale: float,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Standard scaled dot-product attention (no caching).
|
||||
|
||||
Args:
|
||||
q: (batch, num_heads, seq_q, head_dim)
|
||||
k: (batch, num_heads, seq_k, head_dim)
|
||||
v: (batch, num_heads, seq_k, head_dim)
|
||||
scale: typically 1 / sqrt(head_dim)
|
||||
mask: (batch, 1, 1, seq_k) or broadcastable — values masked to -inf
|
||||
|
||||
Returns:
|
||||
output: (batch, num_heads, seq_q, head_dim)
|
||||
"""
|
||||
# Q @ K^T: (batch, heads, seq_q, head_dim) @ (batch, heads, head_dim, seq_k)
|
||||
# -> (batch, heads, seq_q, seq_k)
|
||||
scores = np.einsum("bhqd,bhkd->bhqk", q, k) * scale
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask # mask has -inf for masked positions
|
||||
|
||||
attn_weights = softmax_stable(scores, axis=-1)
|
||||
|
||||
# Attn @ V: (batch, heads, seq_q, seq_k) @ (batch, heads, seq_k, head_dim)
|
||||
# -> (batch, heads, seq_q, head_dim)
|
||||
output = np.einsum("bhqk,bhkd->bhqd", attn_weights, v)
|
||||
return output
|
||||
|
||||
|
||||
def build_causal_mask(seq_len: int, dtype=np.float32) -> np.ndarray:
|
||||
"""
|
||||
Build a causal (triangular) mask for a sequence.
|
||||
|
||||
Returns (seq_len, seq_len) where upper triangle is -inf.
|
||||
Position i can attend to positions j where j <= i.
|
||||
"""
|
||||
indices = np.arange(seq_len)
|
||||
# Mask positions where key_pos > query_pos (future positions)
|
||||
mask = np.where(indices[None, :] > indices[:, None], -np.inf, 0.0)
|
||||
return mask.astype(dtype)
|
||||
|
||||
|
||||
def build_variable_length_mask(
|
||||
lengths: np.ndarray,
|
||||
query_len: int,
|
||||
max_key_len: int = None,
|
||||
dtype=np.float32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Build a mask for variable-length batches.
|
||||
|
||||
For each batch item, positions beyond its actual length are masked.
|
||||
Also applies causal masking (only attend to positions <= query position).
|
||||
|
||||
Args:
|
||||
lengths: (batch,) actual sequence lengths per batch item
|
||||
query_len: number of query positions (usually 1 for generation)
|
||||
max_key_len: override for key dimension (defaults to max(lengths))
|
||||
|
||||
Returns:
|
||||
mask: (batch, 1, query_len, max_key_len)
|
||||
"""
|
||||
batch_size = len(lengths)
|
||||
if max_key_len is None:
|
||||
max_key_len = int(np.max(lengths))
|
||||
|
||||
# Key positions: 0 .. max_key_len-1
|
||||
key_positions = np.arange(max_key_len) # (max_key_len,)
|
||||
|
||||
# Query positions: 0 .. query_len-1 (relative to each sequence)
|
||||
query_positions = np.arange(query_len) # (query_len,)
|
||||
|
||||
# Causal: key_pos <= query_pos is allowed (attend to past)
|
||||
causal = (key_positions[None, :] <= query_positions[:, None]).astype(dtype)
|
||||
# (query_len, max_key_len)
|
||||
|
||||
# Length mask: key_pos < length[b] is allowed
|
||||
length_mask = (key_positions[None, None, None, :] < lengths[:, None, None, None]).astype(dtype)
|
||||
# (batch, 1, 1, max_key_len)
|
||||
|
||||
# Combined: both causal and within length
|
||||
# causal: (query_len, max_key_len) -> (1, 1, query_len, max_key_len)
|
||||
combined = causal[None, None, :, :] * length_mask # broadcast
|
||||
# (batch, 1, query_len, max_key_len)
|
||||
|
||||
# Convert 0/1 to 0/-inf
|
||||
mask = np.where(combined > 0, 0.0, -np.inf)
|
||||
return mask.astype(dtype)
|
||||
|
||||
|
||||
def cached_attention(
|
||||
q: np.ndarray,
|
||||
cache: KVCache,
|
||||
scale: float,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Attention using cached K and V.
|
||||
|
||||
During generation, q is (batch, heads, 1, head_dim) — just the current token.
|
||||
The cache holds all previous K and V.
|
||||
|
||||
Steps:
|
||||
1. Retrieve cached K, V from the cache
|
||||
2. Compute Q @ K^T with the full history
|
||||
3. Apply softmax and @ V
|
||||
|
||||
This avoids recomputing K and V for past tokens.
|
||||
|
||||
Args:
|
||||
q: (batch, num_heads, 1, head_dim) — current query
|
||||
cache: KVCache with previously stored K and V
|
||||
scale: 1 / sqrt(head_dim)
|
||||
|
||||
Returns:
|
||||
output: (batch, num_heads, 1, head_dim)
|
||||
"""
|
||||
# Retrieve all cached keys and values
|
||||
cached_k, cached_v = cache.get_all()
|
||||
# (batch, num_heads, seq_so_far, head_dim)
|
||||
|
||||
# Cast to computation dtype for numerical stability
|
||||
q_f = q.astype(dtype)
|
||||
k_f = cached_k.astype(dtype)
|
||||
v_f = cached_v.astype(dtype)
|
||||
|
||||
# Q @ K^T: (batch, heads, 1, head_dim) @ (batch, heads, head_dim, seq)
|
||||
# -> (batch, heads, 1, seq)
|
||||
scores = np.einsum("bhqd,bhkd->bhqk", q_f, k_f) * scale
|
||||
|
||||
# No mask needed during generation (causal is implicit: we only have
|
||||
# past keys, no future keys exist in the cache)
|
||||
attn_weights = softmax_stable(scores, axis=-1)
|
||||
|
||||
# Attn @ V: (batch, heads, 1, seq) @ (batch, heads, seq, head_dim)
|
||||
# -> (batch, heads, 1, head_dim)
|
||||
output = np.einsum("bhqk,bhkd->bhqd", attn_weights, v_f)
|
||||
|
||||
return output.astype(q.dtype)
|
||||
|
||||
|
||||
def cached_attention_with_mask(
|
||||
q: np.ndarray,
|
||||
cache: KVCache,
|
||||
scale: float,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Cached attention with variable-length masking.
|
||||
|
||||
Handles batches where sequences have different lengths (some may have
|
||||
finished generation and are padded).
|
||||
"""
|
||||
cached_k, cached_v = cache.get_all()
|
||||
seq_len = cached_k.shape[2]
|
||||
|
||||
q_f = q.astype(dtype)
|
||||
k_f = cached_k.astype(dtype)
|
||||
v_f = cached_v.astype(dtype)
|
||||
|
||||
scores = np.einsum("bhqd,bhkd->bhqk", q_f, k_f) * scale
|
||||
|
||||
# Build mask if variable lengths
|
||||
if lengths is not None:
|
||||
# During generation, lengths should reflect current cache position
|
||||
# Clamp lengths to not exceed cache size
|
||||
effective_lengths = np.minimum(lengths, seq_len)
|
||||
mask = build_variable_length_mask(effective_lengths, query_len=1,
|
||||
max_key_len=seq_len, dtype=dtype)
|
||||
scores = scores + mask
|
||||
|
||||
attn_weights = softmax_stable(scores, axis=-1)
|
||||
output = np.einsum("bhqk,bhkd->bhqd", attn_weights, v_f)
|
||||
|
||||
return output.astype(q.dtype)
|
||||
|
||||
|
||||
def prompt_attention(
|
||||
q: np.ndarray,
|
||||
k: np.ndarray,
|
||||
v: np.ndarray,
|
||||
cache: KVCache,
|
||||
scale: float,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Process the initial prompt (prefill phase).
|
||||
|
||||
During prefill, we compute Q, K, V for all prompt tokens at once,
|
||||
store K and V in the cache, and compute attention with causal masking.
|
||||
|
||||
Args:
|
||||
q: (batch, heads, prompt_len, head_dim)
|
||||
k: (batch, heads, prompt_len, head_dim)
|
||||
v: (batch, heads, prompt_len, head_dim)
|
||||
cache: KVCache to populate
|
||||
scale: 1 / sqrt(head_dim)
|
||||
|
||||
Returns:
|
||||
output, k, v (k and v are returned for the caller to use)
|
||||
"""
|
||||
batch_size = q.shape[0]
|
||||
prompt_len = q.shape[2]
|
||||
|
||||
# Store all prompt tokens in cache
|
||||
for pos in range(prompt_len):
|
||||
k_slice = k[:, :, pos:pos+1, :] # (batch, heads, 1, head_dim)
|
||||
v_slice = v[:, :, pos:pos+1, :]
|
||||
cache.update(k_slice, v_slice, seqlen_offset=pos)
|
||||
|
||||
# Causal attention over the full prompt
|
||||
q_f = q.astype(dtype)
|
||||
k_f = k.astype(dtype)
|
||||
v_f = v.astype(dtype)
|
||||
|
||||
scores = np.einsum("bhqd,bhkd->bhqk", q_f, k_f) * scale
|
||||
|
||||
# Causal mask
|
||||
causal = build_causal_mask(prompt_len, dtype=dtype)
|
||||
scores = scores + causal[None, None, :, :] # broadcast over batch, heads
|
||||
|
||||
# Variable length mask
|
||||
if lengths is not None:
|
||||
mask = build_variable_length_mask(lengths, query_len=prompt_len, dtype=dtype)
|
||||
scores = scores + mask
|
||||
|
||||
attn_weights = softmax_stable(scores, axis=-1)
|
||||
output = np.einsum("bhqk,bhkd->bhqd", attn_weights, v_f)
|
||||
|
||||
return output.astype(q.dtype), k, v
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cached_attention_gqa(
|
||||
q: np.ndarray,
|
||||
cache_k: np.ndarray,
|
||||
cache_v: np.ndarray,
|
||||
num_query_groups: int,
|
||||
scale: float,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Grouped-query attention with cached K/V.
|
||||
|
||||
In GQA, multiple query heads share one key-value head.
|
||||
q: (batch, num_heads, 1, head_dim)
|
||||
cache_k, cache_v: (batch, num_kv_heads, seq, head_dim)
|
||||
num_query_groups = num_heads / num_kv_heads
|
||||
|
||||
Each group of query heads attends to the same K/V head.
|
||||
"""
|
||||
batch, num_heads, _, head_dim = q.shape
|
||||
num_kv_heads = cache_k.shape[1]
|
||||
assert num_heads % num_kv_heads == 0
|
||||
queries_per_group = num_heads // num_kv_heads
|
||||
|
||||
q_f = q.astype(dtype)
|
||||
k_f = cache_k.astype(dtype)
|
||||
v_f = cache_v.astype(dtype)
|
||||
|
||||
# Expand K and V to match query heads
|
||||
# k_f: (batch, num_kv_heads, 1, seq, head_dim)
|
||||
k_expanded = k_f[:, None, :, :, :]
|
||||
v_expanded = v_f[:, None, :, :, :]
|
||||
|
||||
# q_f: (batch, num_kv_heads, queries_per_group, 1, head_dim)
|
||||
q_reshaped = q_f.reshape(batch, num_kv_heads, queries_per_group, 1, head_dim)
|
||||
|
||||
# Q @ K^T per group
|
||||
# (batch, kv_heads, q_per_group, 1, head_dim) @ (batch, kv_heads, head_dim, seq)
|
||||
scores = np.einsum("bhgqd,bhkd->bhgqk", q_reshaped, k_f) * scale
|
||||
|
||||
attn_weights = softmax_stable(scores, axis=-1)
|
||||
|
||||
# Attn @ V
|
||||
output = np.einsum("bhgqk,bhkd->bhgqd", attn_weights, v_f)
|
||||
|
||||
# Reshape back: (batch, num_heads, 1, head_dim)
|
||||
output = output.reshape(batch, num_heads, 1, head_dim)
|
||||
return output.astype(q.dtype)
|
||||
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
End-to-End KV-Cache Demo
|
||||
|
||||
Demonstrates:
|
||||
1. Building a small transformer with KV-cache
|
||||
2. Prefill phase (prompt processing)
|
||||
3. Incremental generation (one token at a time)
|
||||
4. Variable-length batching
|
||||
5. Memory tracking
|
||||
6. Optimization comparisons
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Ensure we can import from the project
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from kv_cache import KVCache, CacheConfig, BatchedKVCache
|
||||
from attention import (
|
||||
scaled_dot_product_attention,
|
||||
cached_attention,
|
||||
build_causal_mask,
|
||||
softmax_stable,
|
||||
)
|
||||
from transformer import TransformerDecoder, TransformerDecoderLayer
|
||||
from optimizations import (
|
||||
PagedKVCache, PageConfig,
|
||||
QuantizedKVCache,
|
||||
ChunkedPrefill,
|
||||
compare_strategies,
|
||||
)
|
||||
from memory_analysis import (
|
||||
ModelSpec, compute_model_memory, compute_kv_cache_memory,
|
||||
find_max_context, compare_model_sizes,
|
||||
)
|
||||
from gpu_mapping import tensor_core_analysis, print_gpu_report
|
||||
|
||||
|
||||
def demo_basic_kv_cache():
|
||||
"""Demo 1: Basic KV cache operations."""
|
||||
print("=" * 70)
|
||||
print("DEMO 1: Basic KV Cache Operations")
|
||||
print("=" * 70)
|
||||
|
||||
config = CacheConfig(
|
||||
batch_size=2,
|
||||
num_heads=4,
|
||||
head_dim=16,
|
||||
max_seq_len=64,
|
||||
dtype=np.float32,
|
||||
)
|
||||
cache = KVCache(config)
|
||||
|
||||
print(f"\nCache shape: {cache.cache_k.shape}")
|
||||
print(f" (batch={config.batch_size}, heads={config.num_heads}, "
|
||||
f"max_seq={config.max_seq_len}, head_dim={config.head_dim})")
|
||||
print(f"Allocated: {cache.memory_allocated_bytes:,} bytes")
|
||||
|
||||
# Simulate generating tokens one at a time
|
||||
np.random.seed(42)
|
||||
for step in range(10):
|
||||
# Simulate new K and V from the model
|
||||
k_new = np.random.randn(2, 4, 1, 16).astype(np.float32) * 0.01
|
||||
v_new = np.random.randn(2, 4, 1, 16).astype(np.float32) * 0.01
|
||||
|
||||
cache.update(k_new, v_new)
|
||||
|
||||
print(f"\nAfter 10 steps:")
|
||||
print(f" Write position: {cache.write_pos}")
|
||||
print(f" Sequence lengths: {cache.lengths}")
|
||||
print(f" Memory used: {cache.memory_used_bytes:,} bytes")
|
||||
|
||||
# Retrieve cached data
|
||||
k_cached, v_cached = cache.get_all()
|
||||
print(f" Cached K shape: {k_cached.shape}")
|
||||
print(f" Cached V shape: {v_cached.shape}")
|
||||
|
||||
# Verify data integrity
|
||||
assert k_cached.shape == (2, 4, 10, 16)
|
||||
assert v_cached.shape == (2, 4, 10, 16)
|
||||
print("\n ✓ Data integrity verified")
|
||||
|
||||
|
||||
def demo_cached_attention():
|
||||
"""Demo 2: Cached attention computation."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 2: Cached Attention Computation")
|
||||
print("=" * 70)
|
||||
|
||||
batch, heads, head_dim = 2, 4, 16
|
||||
seq_len = 8
|
||||
scale = 1.0 / np.sqrt(head_dim)
|
||||
|
||||
np.random.seed(123)
|
||||
|
||||
# Build a cache with some history
|
||||
config = CacheConfig(batch_size=batch, num_heads=heads,
|
||||
head_dim=head_dim, max_seq_len=64)
|
||||
cache = KVCache(config)
|
||||
|
||||
# Fill cache with random K, V
|
||||
for i in range(seq_len):
|
||||
k = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
|
||||
v = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
|
||||
cache.update(k, v)
|
||||
|
||||
# Current query (new token)
|
||||
q = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
|
||||
|
||||
# Cached attention
|
||||
output = cached_attention(q, cache, scale)
|
||||
print(f"\nQuery shape: {q.shape}")
|
||||
print(f"Cached K shape: {cache.cache_k.shape} (used: {cache.write_pos} tokens)")
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
# Verify against manual computation
|
||||
k_all, v_all = cache.get_all()
|
||||
scores = np.einsum("bhqd,bhkd->bhqk", q, k_all) * scale
|
||||
attn = softmax_stable(scores, axis=-1)
|
||||
manual_output = np.einsum("bhqk,bhkd->bhqd", attn, v_all)
|
||||
|
||||
diff = np.max(np.abs(output - manual_output))
|
||||
print(f"Max difference from manual: {diff:.2e}")
|
||||
assert diff < 1e-5, f"Attention mismatch: {diff}"
|
||||
print(" ✓ Cached attention matches manual computation")
|
||||
|
||||
# Show attention weights for one batch/head
|
||||
print(f"\nAttention weights (batch=0, head=0):")
|
||||
print(f" {attn[0, 0, 0, :].round(3)}")
|
||||
print(f" Sum: {attn[0, 0, 0, :].sum():.4f} (should be ~1.0)")
|
||||
|
||||
|
||||
def demo_full_transformer():
|
||||
"""Demo 3: Full transformer with KV-cache."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 3: Full Transformer with KV-Cache")
|
||||
print("=" * 70)
|
||||
|
||||
# Small model for demo
|
||||
model = TransformerDecoder(
|
||||
num_layers=2,
|
||||
dim=64,
|
||||
num_heads=4,
|
||||
mlp_hidden=128,
|
||||
vocab_size=1000,
|
||||
max_seq_len=128,
|
||||
batch_size=2,
|
||||
dtype=np.float32,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Create a prompt (padded to same length)
|
||||
prompt = np.array([[10, 20, 30, 40, 50],
|
||||
[15, 25, 35, 45, 0]], dtype=np.int32) # 0 = pad
|
||||
|
||||
lengths = np.array([5, 4], dtype=np.int32)
|
||||
|
||||
print(f"\nPrompt tokens: {prompt.shape}")
|
||||
print(f" Sequence 0: {prompt[0]} (length={lengths[0]})")
|
||||
print(f" Sequence 1: {prompt[1]} (length={lengths[1]})")
|
||||
|
||||
# Prefill
|
||||
hidden = model.prefill(prompt, lengths=lengths)
|
||||
print(f"\nAfter prefill:")
|
||||
print(f" Hidden shape: {hidden.shape}")
|
||||
print(f" Cache write position: {model.cache.caches[0].write_pos}")
|
||||
|
||||
# Generate tokens
|
||||
print(f"\nGenerating 5 tokens...")
|
||||
generated = model.generate(prompt, num_tokens=5, temperature=0.8, top_k=50,
|
||||
lengths=lengths)
|
||||
|
||||
for i, tokens in enumerate(generated):
|
||||
print(f" Step {i+1}: {tokens}")
|
||||
|
||||
# Memory report
|
||||
report = model.memory_report()
|
||||
print(f"\nMemory Report:")
|
||||
for k, v in report.items():
|
||||
if isinstance(v, float):
|
||||
print(f" {k}: {v:.4f}")
|
||||
else:
|
||||
print(f" {k}: {v}")
|
||||
|
||||
|
||||
def demo_variable_length_batching():
|
||||
"""Demo 4: Variable-length batching."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 4: Variable-Length Batching")
|
||||
print("=" * 70)
|
||||
|
||||
batch_size = 4
|
||||
config = CacheConfig(
|
||||
batch_size=batch_size,
|
||||
num_heads=4,
|
||||
head_dim=16,
|
||||
max_seq_len=32,
|
||||
dtype=np.float32,
|
||||
)
|
||||
cache = KVCache(config)
|
||||
|
||||
np.random.seed(99)
|
||||
|
||||
# Simulate sequences of different lengths
|
||||
# Seq 0: 8 tokens, Seq 1: 5 tokens, Seq 2: 10 tokens, Seq 3: 3 tokens
|
||||
seq_lengths = [8, 5, 10, 3]
|
||||
max_len = max(seq_lengths)
|
||||
|
||||
print("\nSimulating variable-length batch:")
|
||||
# Each batch item has its own cache (simplified: use separate caches)
|
||||
per_seq_caches = [KVCache(CacheConfig(
|
||||
batch_size=1, num_heads=4, head_dim=16,
|
||||
max_seq_len=max_len, dtype=np.float32
|
||||
)) for _ in range(batch_size)]
|
||||
|
||||
for b, length in enumerate(seq_lengths):
|
||||
for t in range(length):
|
||||
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
||||
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
||||
per_seq_caches[b].update(k, v)
|
||||
|
||||
# Query for each sequence at its current position
|
||||
scale = 1.0 / np.sqrt(16)
|
||||
for b in range(batch_size):
|
||||
q = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
||||
k_cached, v_cached = per_seq_caches[b].get_all()
|
||||
|
||||
# Attention for this batch item
|
||||
scores = np.einsum("bhqd,bhkd->bhqk", q, k_cached) * scale
|
||||
attn = softmax_stable(scores, axis=-1)
|
||||
|
||||
# Show which positions are attended to
|
||||
print(f"\n Sequence {b} (length={seq_lengths[b]}):")
|
||||
print(f" Attention: {attn[0, 0, 0, :].round(3)}")
|
||||
|
||||
|
||||
def demo_paged_attention():
|
||||
"""Demo 5: Paged attention."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 5: Paged Attention (vLLM-style)")
|
||||
print("=" * 70)
|
||||
|
||||
config = PageConfig(
|
||||
block_size=4,
|
||||
num_pages=16,
|
||||
batch_size=2,
|
||||
num_heads=4,
|
||||
head_dim=16,
|
||||
dtype=np.float32,
|
||||
)
|
||||
paged = PagedKVCache(config)
|
||||
|
||||
print(f"\nPage config:")
|
||||
print(f" Block size: {config.block_size} tokens")
|
||||
print(f" Pages per sequence: {config.num_pages}")
|
||||
print(f" Max tokens per sequence: {config.num_pages * config.block_size}")
|
||||
print(f" Allocated: {paged.memory_allocated_bytes:,} bytes")
|
||||
|
||||
np.random.seed(77)
|
||||
|
||||
# Fill sequence 0 with 12 tokens (3 blocks)
|
||||
print(f"\nFilling sequence 0 with 12 tokens...")
|
||||
for t in range(12):
|
||||
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
||||
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
||||
block_idx = t // config.block_size
|
||||
offset = t % config.block_size
|
||||
paged.append_token(0, k, v, block_idx, offset)
|
||||
|
||||
print(f" Blocks allocated: {paged.num_blocks[0]}")
|
||||
print(f" Page table: {paged.page_tables[0, :paged.num_blocks[0]]}")
|
||||
|
||||
# Fill sequence 1 with 8 tokens (2 blocks)
|
||||
print(f"\nFilling sequence 1 with 8 tokens...")
|
||||
for t in range(8):
|
||||
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
||||
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
||||
block_idx = t // config.block_size
|
||||
offset = t % config.block_size
|
||||
paged.append_token(1, k, v, block_idx, offset)
|
||||
|
||||
print(f" Blocks allocated: {paged.num_blocks[1]}")
|
||||
print(f" Page table: {paged.page_tables[1, :paged.num_blocks[1]]}")
|
||||
|
||||
# Retrieve and verify
|
||||
k0, v0 = paged.get_sequence_contiguous(0, num_tokens=12)
|
||||
k1, v1 = paged.get_sequence_contiguous(1, num_tokens=8)
|
||||
print(f"\n Seq 0 K shape: {k0.shape}")
|
||||
print(f" Seq 1 K shape: {k1.shape}")
|
||||
|
||||
print(f"\n Memory used: {paged.memory_used_bytes:,} bytes")
|
||||
print(f" Utilization: {paged.memory_utilization():.1%}")
|
||||
|
||||
|
||||
def demo_quantized_cache():
|
||||
"""Demo 6: Quantized KV cache."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 6: Quantized KV Cache (int8)")
|
||||
print("=" * 70)
|
||||
|
||||
batch, heads, head_dim, max_seq = 2, 4, 16, 32
|
||||
cache = QuantizedKVCache(batch, heads, head_dim, max_seq, dtype=np.float32)
|
||||
|
||||
np.random.seed(55)
|
||||
|
||||
# Fill with random data
|
||||
for t in range(10):
|
||||
k = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.1
|
||||
v = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.1
|
||||
cache.update(k, v)
|
||||
|
||||
# Retrieve and compare
|
||||
k_deq, v_deq = cache.get()
|
||||
print(f"\nQuantized cache (10 tokens):")
|
||||
print(f" Dequantized K shape: {k_deq.shape}")
|
||||
print(f" Dequantized V shape: {v_deq.shape}")
|
||||
|
||||
# Compare with original (we need to re-quantize to compare)
|
||||
# The quantization error depends on the data distribution
|
||||
print(f" Memory savings vs fp32: {cache.memory_savings_vs_fp32:.1%}")
|
||||
print(f" Memory savings vs fp16: {cache.memory_savings_vs_fp16:.1%} (per-pos scales overhead)")
|
||||
|
||||
# Show quantization error for one position
|
||||
# Use larger values for better int8 quantization fidelity
|
||||
k_orig = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 1.0
|
||||
v_orig = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 1.0
|
||||
cache.update(k_orig, v_orig)
|
||||
k_deq_single, _ = cache.get(start=10, end=11)
|
||||
|
||||
# k_deq_single: (batch, heads, 1, head_dim), k_orig: (batch, heads, 1, head_dim)
|
||||
print(f" k_orig shape: {k_orig.shape}, k_deq shape: {k_deq_single.shape}")
|
||||
error = np.max(np.abs(k_orig - k_deq_single))
|
||||
rel_error = error / (np.max(np.abs(k_orig)) + 1e-8)
|
||||
print(f" Max absolute error (one token): {error:.6f}")
|
||||
print(f" Max relative error: {rel_error:.4f}")
|
||||
print(f" → Per-position quantization has high overhead; production uses")
|
||||
print(f" shared per-channel scales for ~50% memory savings with <1% error")
|
||||
|
||||
|
||||
def demo_chunked_prefill():
|
||||
"""Demo 7: Chunked prefill."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 7: Chunked Prefill")
|
||||
print("=" * 70)
|
||||
|
||||
chunker = ChunkedPrefill(chunk_size=4)
|
||||
|
||||
batch, heads, seq, head_dim = 1, 4, 12, 16
|
||||
scale = 1.0 / np.sqrt(head_dim)
|
||||
|
||||
np.random.seed(33)
|
||||
q = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
|
||||
k = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
|
||||
v = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
|
||||
|
||||
# Chunked attention
|
||||
output_chunked = chunker.compute_attention_chunked(q, k, v, scale)
|
||||
|
||||
# Full attention (for comparison)
|
||||
from attention import scaled_dot_product_attention, build_causal_mask
|
||||
causal = build_causal_mask(seq, dtype=np.float32)
|
||||
output_full = scaled_dot_product_attention(
|
||||
q, k, v, scale, mask=causal[None, None, :, :]
|
||||
)
|
||||
|
||||
diff = np.max(np.abs(output_chunked - output_full))
|
||||
print(f"\nChunk size: {chunker.chunk_size}")
|
||||
print(f"Sequence length: {seq}")
|
||||
print(f"Chunks: {(seq + chunker.chunk_size - 1) // chunker.chunk_size}")
|
||||
print(f"Max difference from full attention: {diff:.2e}")
|
||||
assert diff < 1e-5, f"Chunked attention mismatch: {diff}"
|
||||
print(" ✓ Chunked attention matches full attention")
|
||||
|
||||
# Memory comparison
|
||||
mem = ChunkedPrefill.peak_memory_comparison(seq_len=4096, chunk_size=512)
|
||||
print(f"\nMemory comparison (seq=4096, chunk=512):")
|
||||
print(f" Full attention matrix: {mem['full_attention_mb']:.0f} MB")
|
||||
print(f" Chunked peak: {mem['chunked_peak_attention_mb']:.0f} MB")
|
||||
print(f" Savings: {mem['savings_ratio']:.1f}x")
|
||||
|
||||
|
||||
def demo_optimization_comparison():
|
||||
"""Demo 8: Optimization strategy comparison."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 8: Optimization Strategy Comparison")
|
||||
print("=" * 70)
|
||||
|
||||
results = compare_strategies(
|
||||
batch_size=4, num_heads=32, head_dim=128,
|
||||
max_seq_len=4096, num_layers=32
|
||||
)
|
||||
|
||||
print(f"\nConfiguration: batch=4, heads=32, head_dim=128, "
|
||||
f"seq=4096, layers=32\n")
|
||||
|
||||
header = f"{'Strategy':<25} {'Per Layer(MB)':>14} {'Total(GB)':>10} {'Notes':<25}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
for name, data in results.items():
|
||||
notes = ""
|
||||
if "savings_vs_fp16" in data:
|
||||
notes = f"{data['savings_vs_fp16']:.0%} savings"
|
||||
elif "overhead_vs_naive" in data:
|
||||
notes = f"{data['overhead_vs_naive']:.3f}x overhead"
|
||||
|
||||
print(f"{name:<25} {data['per_layer_mb']:>14.1f} {data['total_mb']/1024:>10.2f} "
|
||||
f"{notes:<25}")
|
||||
|
||||
|
||||
def demo_memory_analysis():
|
||||
"""Demo 9: Memory growth analysis."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 9: Memory Growth Analysis")
|
||||
print("=" * 70)
|
||||
|
||||
# Compare model sizes
|
||||
comparisons = compare_model_sizes()
|
||||
|
||||
print("\nModel Size Comparison (fp16):\n")
|
||||
header = f"{'Model':<20} {'Params(GB)':>10} {'KV@1K':>8} {'KV@8K':>8} {'KV@32K':>8} {'MaxCtx(H100)':>12}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
for name, data in comparisons.items():
|
||||
print(f"{name:<20} {data['params_gb']:>10.1f} {data['kv_1k_gb']:>8.2f} "
|
||||
f"{data['kv_8k_gb']:>8.2f} {data['kv_32k_gb']:>8.2f} "
|
||||
f"{data['max_context_H100']:>12,}")
|
||||
|
||||
# Growth for 7B model
|
||||
spec = ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128)
|
||||
model_mem = compute_model_memory(spec, np.float16)
|
||||
|
||||
print(f"\n\n7B Model Memory Growth (batch=1, fp16):\n")
|
||||
print(f" Model params: {model_mem['total_params_gb']:.1f} GB")
|
||||
print()
|
||||
|
||||
seq_lens = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
||||
print(f" {'Seq Len':>8} {'KV(GB)':>8} {'Total(GB)':>10} {'KV%':>6}")
|
||||
print(f" {'-'*40}")
|
||||
for sl in seq_lens:
|
||||
kv = compute_kv_cache_memory(1, sl, spec, np.float16)
|
||||
total = kv["total_gb"] + model_mem["total_params_gb"]
|
||||
pct = kv["total_gb"] / total * 100
|
||||
print(f" {sl:>8,} {kv['total_gb']:>8.2f} {total:>10.2f} {pct:>5.1f}%")
|
||||
|
||||
# GPU limits
|
||||
print(f"\n\nMax Context by GPU (7B model, batch=1):\n")
|
||||
gpus = {"RTX 4090": 24, "A100-40GB": 40, "A100-80GB": 80, "H100-80GB": 80}
|
||||
for gpu, mem in gpus.items():
|
||||
ctx = find_max_context(spec, mem, batch_size=1)
|
||||
print(f" {gpu:<15}: {ctx:>8,} tokens")
|
||||
|
||||
|
||||
def demo_gpu_tensor_cores():
|
||||
"""Demo 10: GPU Tensor Core analysis."""
|
||||
print("\n" + "=" * 70)
|
||||
print("DEMO 10: GPU Tensor Core Analysis")
|
||||
print("=" * 70)
|
||||
|
||||
configs = [
|
||||
{"batch": 1, "heads": 32, "seq": 1024, "label": "Short context"},
|
||||
{"batch": 1, "heads": 32, "seq": 8192, "label": "Long context"},
|
||||
{"batch": 4, "heads": 32, "seq": 4096, "label": "Batched"},
|
||||
]
|
||||
|
||||
for cfg in configs:
|
||||
tc = tensor_core_analysis(
|
||||
batch=cfg["batch"], heads=cfg["heads"], seq_len=cfg["seq"]
|
||||
)
|
||||
print(f"\n {cfg['label']} (batch={cfg['batch']}, seq={cfg['seq']}):")
|
||||
print(f" Total FLOPs: {tc['total_flops']}")
|
||||
print(f" Memory traffic: {tc['memory_traffic_mb']}")
|
||||
print(f" Arithmetic intensity: {tc['arithmetic_intensity']}")
|
||||
print(f" Compute bound: {tc['compute_bound_ms']}")
|
||||
print(f" Memory bound: {tc['memory_bound_ms']}")
|
||||
print(f" → {tc['bound']}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all demos."""
|
||||
print("\n" + "█" * 70)
|
||||
print(" KV-CACHE SYSTEM FOR AUTOREGRESSIVE TRANSFORMER INFERENCE")
|
||||
print(" Pure NumPy Implementation — No Frameworks")
|
||||
print("█" * 70)
|
||||
|
||||
demos = [
|
||||
("Basic KV Cache", demo_basic_kv_cache),
|
||||
("Cached Attention", demo_cached_attention),
|
||||
("Full Transformer", demo_full_transformer),
|
||||
("Variable-Length Batching", demo_variable_length_batching),
|
||||
("Paged Attention", demo_paged_attention),
|
||||
("Quantized Cache", demo_quantized_cache),
|
||||
("Chunked Prefill", demo_chunked_prefill),
|
||||
("Optimization Comparison", demo_optimization_comparison),
|
||||
("Memory Analysis", demo_memory_analysis),
|
||||
("GPU Tensor Cores", demo_gpu_tensor_cores),
|
||||
]
|
||||
|
||||
for name, func in demos:
|
||||
try:
|
||||
func()
|
||||
except Exception as e:
|
||||
print(f"\n ✗ {name} failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "█" * 70)
|
||||
print(" ALL DEMOS COMPLETE")
|
||||
print("█" * 70 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
GPU Execution Mapping for KV-Cache Systems
|
||||
|
||||
Documents how the KV-cache system maps to GPU hardware:
|
||||
- Memory hierarchy (registers, shared mem, global mem, HBM)
|
||||
- Kernel design for attention with cache
|
||||
- CUDA optimization strategies
|
||||
- Tensor Core utilization
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GPU MEMORY HIERARCHY REFERENCE
|
||||
# =============================================================================
|
||||
|
||||
GPU_HIERARCHY = {
|
||||
"registers": {
|
||||
"size_per_sm": "64 KB",
|
||||
"latency": "1 cycle",
|
||||
"usage": "Thread-local variables, warp-level computation",
|
||||
},
|
||||
"shared_memory": {
|
||||
"size_per_sm": "166 KB (H100)",
|
||||
"latency": "1-3 cycles",
|
||||
"usage": "Tiling, cooperative loading, softmax intermediate",
|
||||
},
|
||||
"l2_cache": {
|
||||
"size": "50 MB (H100)",
|
||||
"latency": "~20 cycles",
|
||||
"usage": "Automatic caching of global memory accesses",
|
||||
},
|
||||
"hbm": {
|
||||
"size": "80 GB (H100)",
|
||||
"bandwidth": "3.35 TB/s (H100)",
|
||||
"latency": "~300-400 cycles",
|
||||
"usage": "Model weights, KV cache, activations",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# KERNEL DESIGN: CACHED ATTENTION
|
||||
# =============================================================================
|
||||
|
||||
def describe_cached_attention_kernel():
|
||||
"""
|
||||
Describe the CUDA kernel for cached attention.
|
||||
|
||||
Kernel: cached_attention<<<grid, block>>>(Q, K_cache, V_cache, Out, ...)
|
||||
|
||||
Thread block organization:
|
||||
- Each block handles one (batch, head) pair
|
||||
- Threads within a block cooperate on the matmul Q @ K^T
|
||||
|
||||
Memory access pattern:
|
||||
1. Load Q tile into shared memory (small: 1 x head_dim)
|
||||
2. Stream K_cache tiles from global memory into shared memory
|
||||
3. Compute partial dot products in registers
|
||||
4. Accumulate scores in shared memory
|
||||
5. Softmax in shared memory
|
||||
6. Stream V_cache tiles and compute output
|
||||
"""
|
||||
description = {
|
||||
"kernel_name": "cached_attention",
|
||||
"grid": "(batch_size, num_heads, 1)",
|
||||
"block": "(BLOCK_X, BLOCK_Y) — e.g., (32, 32) for 1024 threads",
|
||||
"shared_memory_usage": {
|
||||
"q_tile": "1 x head_dim (e.g., 1 x 128 = 128 floats = 512 bytes fp16)",
|
||||
"k_tile": "BLOCK_Y x head_dim (e.g., 32 x 128 = 4096 floats = 8 KB fp16)",
|
||||
"v_tile": "BLOCK_Y x head_dim (same as K)",
|
||||
"score_tile": "BLOCK_X x BLOCK_Y (e.g., 32 x 32 = 1024 floats = 4 KB fp16)",
|
||||
"total_shared_per_block": "~16-20 KB (fits in 166 KB SM)",
|
||||
},
|
||||
"global_memory_accesses": {
|
||||
"read_q": "batch * heads * 1 * head_dim (tiny)",
|
||||
"read_k_cache": "batch * heads * seq_len * head_dim (dominant)",
|
||||
"read_v_cache": "batch * heads * seq_len * head_dim (dominant)",
|
||||
"write_output": "batch * heads * 1 * head_dim (tiny)",
|
||||
},
|
||||
"optimization_strategies": [
|
||||
"1. Coalesced global memory access: threads in a warp access consecutive addresses",
|
||||
"2. Tiled GEMM: process K/V in tiles that fit in shared memory",
|
||||
"3. Persistent kernels: keep blocks alive until all tiles processed",
|
||||
"4. Async copy (H100): use cp.async to overlap memory transfer with computation",
|
||||
"5. Tensor Cores: use WMMA or mma.sync for the matmul operations",
|
||||
"6. Fusion: fuse softmax with attention score computation",
|
||||
],
|
||||
}
|
||||
return description
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TENSOR CORE UTILIZATION
|
||||
# =============================================================================
|
||||
|
||||
def tensor_core_analysis(head_dim: int = 128, seq_len: int = 4096,
|
||||
batch: int = 4, heads: int = 32) -> Dict:
|
||||
"""
|
||||
Analyze Tensor Core utilization for cached attention.
|
||||
|
||||
H100 Tensor Core specs (FP16):
|
||||
- MMA shape: M x N x K where M,N,K are multiples of 16
|
||||
- Peak throughput: ~1,970 TFLOPS (FP16 Tensor Core)
|
||||
- Each MMA instruction: 16x16x16 = 4096 FLOPs
|
||||
"""
|
||||
# Q @ K^T: (batch, heads, 1, head_dim) @ (batch, heads, head_dim, seq_len)
|
||||
# FLOPs per (batch, head): 2 * 1 * head_dim * seq_len
|
||||
flops_qk = 2 * batch * heads * 1 * head_dim * seq_len
|
||||
|
||||
# Attn @ V: (batch, heads, 1, seq_len) @ (batch, heads, seq_len, head_dim)
|
||||
flops_av = 2 * batch * heads * 1 * seq_len * head_dim
|
||||
|
||||
total_flops = flops_qk + flops_av
|
||||
|
||||
# Memory traffic
|
||||
elem_bytes = 2 # fp16
|
||||
mem_q = batch * heads * 1 * head_dim * elem_bytes
|
||||
mem_k = batch * heads * seq_len * head_dim * elem_bytes
|
||||
mem_v = batch * heads * seq_len * head_dim * elem_bytes
|
||||
mem_out = batch * heads * 1 * head_dim * elem_bytes
|
||||
total_mem = mem_q + mem_k + mem_v + mem_out
|
||||
|
||||
# Arithmetic intensity (FLOPs per byte)
|
||||
intensity = total_flops / total_mem
|
||||
|
||||
# H100 peak
|
||||
h100_peak_tflops = 1970 # FP16 Tensor Core
|
||||
h100_bandwidth = 3.35e12 # bytes/s
|
||||
|
||||
# Theoretical time bounds
|
||||
compute_bound_s = total_flops / (h100_peak_tflops * 1e12)
|
||||
memory_bound_s = total_mem / h100_bandwidth
|
||||
|
||||
return {
|
||||
"flops_qk": f"{flops_qk / 1e9:.2f} GFLOPs",
|
||||
"flops_av": f"{flops_av / 1e9:.2f} GFLOPs",
|
||||
"total_flops": f"{total_flops / 1e9:.2f} GFLOPs",
|
||||
"memory_traffic_mb": f"{total_mem / 1e6:.2f} MB",
|
||||
"arithmetic_intensity": f"{intensity:.2f} FLOPs/byte",
|
||||
"compute_bound_ms": f"{compute_bound_s * 1000:.4f} ms",
|
||||
"memory_bound_ms": f"{memory_bound_s * 1000:.4f} ms",
|
||||
"bound": "compute-bound" if compute_bound_s > memory_bound_s else "memory-bound",
|
||||
"h100_peak_tflops": h100_peak_tflops,
|
||||
"h100_bandwidth_tbps": h100_bandwidth / 1e12,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GPU EXECUTION PIPELINE
|
||||
# =============================================================================
|
||||
|
||||
def describe_execution_pipeline():
|
||||
"""
|
||||
Describe the full GPU execution pipeline for one generation step.
|
||||
|
||||
Step 1: Embedding lookup
|
||||
- Input: token_id (batch, 1)
|
||||
- Operation: embedding[token_id] -> (batch, 1, dim)
|
||||
- GPU: Gathers from embedding table (random access, use shared mem tiling)
|
||||
|
||||
Step 2: Positional encoding
|
||||
- Operation: x += pos_encoding[current_pos]
|
||||
- GPU: Simple element-wise add (fully parallel)
|
||||
|
||||
Step 3: Per-layer forward pass (repeated L times)
|
||||
3a. LayerNorm
|
||||
- GPU: Parallel reduction for mean/var, then element-wise
|
||||
|
||||
3b. QKV projection
|
||||
- GPU: 3 parallel GEMMs: x @ Wq, x @ Wk, x @ Wv
|
||||
- cuBLAS/cutlass: highly optimized for small M (M=1)
|
||||
|
||||
3c. KV cache update
|
||||
- GPU: Simple copy to global memory (coalesced write)
|
||||
- cache_k[:, :, write_pos, :] = k[:, :, 0, :]
|
||||
|
||||
3d. Cached attention
|
||||
- GPU: Custom kernel (see describe_cached_attention_kernel)
|
||||
- Two GEMMs + softmax, tiled for shared memory
|
||||
|
||||
3e. Output projection
|
||||
- GPU: GEMM: attn_out @ Wo
|
||||
|
||||
3f. MLP
|
||||
- GPU: Two GEMMs with activation fusion
|
||||
|
||||
3g. Residual add + LayerNorm
|
||||
- GPU: Element-wise operations
|
||||
|
||||
Step 4: LM head
|
||||
- GPU: GEMM: x @ W_lm -> logits (batch, vocab_size)
|
||||
|
||||
Step 5: Sampling
|
||||
- GPU: Argmax or top-k sampling kernel
|
||||
- Can be done on CPU for small batch sizes
|
||||
"""
|
||||
return {
|
||||
"steps": [
|
||||
"1. Embedding lookup (gather)",
|
||||
"2. Positional encoding (element-wise add)",
|
||||
"3. Per-layer: LayerNorm -> QKV proj -> cache update -> attention -> MLP",
|
||||
"4. LM head (GEMM)",
|
||||
"5. Sampling (argmax/top-k)",
|
||||
],
|
||||
"bottleneck": "Cached attention (memory-bound for long sequences)",
|
||||
"optimization_opportunities": [
|
||||
"Operator fusion: merge LayerNorm + GEMM bias + activation",
|
||||
"Batched GEMM: process all layers' small GEMMs together",
|
||||
"Pipeline parallelism: overlap layers' computation",
|
||||
"FlashAttention-style tiling for the cached attention kernel",
|
||||
"Warp-specialized design: some warps load, some compute",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FLASH-ATTENTION-STYLE CACHED KERNEL
|
||||
# =============================================================================
|
||||
|
||||
def describe_flash_attention_cached():
|
||||
"""
|
||||
FlashAttention-style kernel adapted for cached attention.
|
||||
|
||||
Key insight: instead of materializing the full (1 x seq_len) attention
|
||||
matrix, process K/V in tiles and accumulate softmax online.
|
||||
|
||||
Algorithm (for one batch/head):
|
||||
1. Initialize: output = 0, m = -inf, l = 0 (online softmax state)
|
||||
2. For each K/V tile (size BLOCK):
|
||||
a. Compute S = Q @ K_tile^T (in shared memory)
|
||||
b. m_new = max(m, max(S))
|
||||
c. l = l * exp(m - m_new) + sum(exp(S - m_new))
|
||||
d. output = output * (l_old / l) + sum(exp(S - m_new) * V_tile)
|
||||
e. m = m_new
|
||||
3. output = output / l
|
||||
|
||||
This avoids materializing the full attention matrix and reduces
|
||||
HBM traffic from O(seq_len * head_dim) to O(seq_len * head_dim / BLOCK).
|
||||
"""
|
||||
return {
|
||||
"name": "FlashAttention-style cached kernel",
|
||||
"key_benefit": "O(1) shared memory usage regardless of sequence length",
|
||||
"hbm_traffic_reduction": "Reduces from 4 reads to ~2 reads of K/V cache",
|
||||
"shared_memory": "Only needs BLOCK x head_dim tiles, not full seq_len",
|
||||
"complexity": "More complex kernel but 2-4x faster for long sequences",
|
||||
"implementation_notes": [
|
||||
"Requires careful numerical stability (online softmax)",
|
||||
"Two-pass: forward pass accumulates, backward pass needs recompute",
|
||||
"For generation (single query), simpler than full FlashAttention",
|
||||
"Can use mma.sync for the tile GEMMs on H100",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MULTI-GPU STRATEGIES
|
||||
# =============================================================================
|
||||
|
||||
def describe_multi_gpu():
|
||||
"""
|
||||
Multi-GPU strategies for large models with KV cache.
|
||||
"""
|
||||
return {
|
||||
"tensor_parallelism": {
|
||||
"description": "Split model weights across GPUs (Megatron-LM style)",
|
||||
"kv_cache_impact": "Each GPU holds its shard of K/V (split by head_dim)",
|
||||
"communication": "AllReduce in MLP, all-to-all in attention",
|
||||
"scaling": "Linear with num GPUs (up to num_heads)",
|
||||
},
|
||||
"pipeline_parallelism": {
|
||||
"description": "Split layers across GPUs",
|
||||
"kv_cache_impact": "Each GPU holds K/V for its layer shard",
|
||||
"communication": "Send activations between stages",
|
||||
"challenge": "Bubble idle time; needs micro-batching",
|
||||
},
|
||||
"sequence_parallelism": {
|
||||
"description": "Split sequence across GPUs (for prefill)",
|
||||
"kv_cache_impact": "Each GPU holds K/V for its sequence shard",
|
||||
"communication": "All-to-all for attention across sequence shards",
|
||||
"best_for": "Very long context prefill",
|
||||
},
|
||||
"expert_parallelism": {
|
||||
"description": "For MoE models (Mixtral, Grok)",
|
||||
"kv_cache_impact": "KV cache is shared; only MLP experts are sharded",
|
||||
"communication": "All-to-all for expert routing",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PRACTICAL GPU TUNING GUIDE
|
||||
# =============================================================================
|
||||
|
||||
def gpu_tuning_guide():
|
||||
"""
|
||||
Practical GPU tuning recommendations for KV-cache inference.
|
||||
"""
|
||||
return {
|
||||
"streaming_KV_cache": {
|
||||
"problem": "For long sequences, K/V cache reads dominate latency",
|
||||
"solution": "Use H100's copy engine (async copy) to stream tiles",
|
||||
"detail": "Overlap K/V loading with Q projection computation",
|
||||
},
|
||||
"small_batch_optimization": {
|
||||
"problem": "Single-token generation has tiny GEMMs (M=1)",
|
||||
"solution": "Use CUTLASS tiny GEMM kernels or custom kernels",
|
||||
"detail": "Standard cuBLAS is not optimized for M=1; use flashinfer or turbotransformers",
|
||||
},
|
||||
"continuous_batching": {
|
||||
"problem": "Variable generation lengths waste compute",
|
||||
"solution": "Run sequences at different stages simultaneously",
|
||||
"detail": "Some sequences in prefill, others in decode; schedule on GPU",
|
||||
},
|
||||
"kv_cache_quantization_on_gpu": {
|
||||
"problem": "Dequantization adds latency",
|
||||
"solution": "Use INT8 Tensor Cores (H100 supports INT8 MMA)",
|
||||
"detail": "Keep K/V in INT8, dequantize during the MMA instruction",
|
||||
},
|
||||
"cuda_graphs": {
|
||||
"problem": "Kernel launch overhead for small operations",
|
||||
"solution": "Record and replay CUDA graphs",
|
||||
"detail": "For fixed-shape generation, graphs eliminate launch overhead",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PRINT GPU MAPPING REPORT
|
||||
# =============================================================================
|
||||
|
||||
def print_gpu_report():
|
||||
"""Print comprehensive GPU execution mapping report."""
|
||||
print("=" * 80)
|
||||
print("GPU EXECUTION MAPPING FOR KV-CACHE SYSTEM")
|
||||
print("=" * 80)
|
||||
|
||||
# Memory hierarchy
|
||||
print("\n--- GPU Memory Hierarchy ---\n")
|
||||
for level, info in GPU_HIERARCHY.items():
|
||||
print(f" {level:>15}:")
|
||||
for k, v in info.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Kernel design
|
||||
print("\n\n--- Cached Attention Kernel Design ---\n")
|
||||
kernel = describe_cached_attention_kernel()
|
||||
print(f" Kernel: {kernel['kernel_name']}")
|
||||
print(f" Grid: {kernel['grid']}")
|
||||
print(f" Block: {kernel['block']}")
|
||||
print("\n Shared Memory Usage:")
|
||||
for k, v in kernel["shared_memory_usage"].items():
|
||||
if k != "total_shared_per_block":
|
||||
print(f" {k}: {v}")
|
||||
print(f" {list(kernel['shared_memory_usage'].keys())[-1]}: "
|
||||
f"{list(kernel['shared_memory_usage'].values())[-1]}")
|
||||
|
||||
print("\n Optimization Strategies:")
|
||||
for s in kernel["optimization_strategies"]:
|
||||
print(f" {s}")
|
||||
|
||||
# Tensor core analysis
|
||||
print("\n\n--- Tensor Core Utilization (batch=4, heads=32, seq=4096) ---\n")
|
||||
tc = tensor_core_analysis(batch=4, heads=32, seq_len=4096)
|
||||
for k, v in tc.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Execution pipeline
|
||||
print("\n\n--- Execution Pipeline ---\n")
|
||||
pipeline = describe_execution_pipeline()
|
||||
for i, step in enumerate(pipeline["steps"], 1):
|
||||
print(f" {step}")
|
||||
print(f"\n Bottleneck: {pipeline['bottleneck']}")
|
||||
print("\n Optimization Opportunities:")
|
||||
for opt in pipeline["optimization_opportunities"]:
|
||||
print(f" - {opt}")
|
||||
|
||||
# FlashAttention
|
||||
print("\n\n--- FlashAttention-Style Cached Kernel ---\n")
|
||||
flash = describe_flash_attention_cached()
|
||||
for k, v in flash.items():
|
||||
if isinstance(v, list):
|
||||
print(f" {k}:")
|
||||
for item in v:
|
||||
print(f" - {item}")
|
||||
else:
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Multi-GPU
|
||||
print("\n\n--- Multi-GPU Strategies ---\n")
|
||||
multi = describe_multi_gpu()
|
||||
for strategy, info in multi.items():
|
||||
print(f" {strategy}:")
|
||||
for k, v in info.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Tuning guide
|
||||
print("\n\n--- GPU Tuning Guide ---\n")
|
||||
tuning = gpu_tuning_guide()
|
||||
for area, info in tuning.items():
|
||||
print(f" {area}:")
|
||||
print(f" Problem: {info['problem']}")
|
||||
print(f" Solution: {info['solution']}")
|
||||
print(f" Detail: {info['detail']}")
|
||||
print()
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_gpu_report()
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
KV-Cache Data Structures for Autoregressive Transformer Inference
|
||||
|
||||
Core memory layout:
|
||||
cache_k[batch, head, seq_len, head_dim]
|
||||
cache_v[batch, head, seq_len, head_dim]
|
||||
|
||||
This layout enables O(1) append per token and contiguous memory access
|
||||
during attention computation (Q @ K^T scans along seq_len).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Configuration for a single layer's KV cache."""
|
||||
batch_size: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
max_seq_len: int
|
||||
dtype: np.dtype = np.float16
|
||||
|
||||
@property
|
||||
def cache_bytes_per_layer(self) -> int:
|
||||
"""Bytes for one layer's K + V cache."""
|
||||
elem_bytes = np.dtype(self.dtype).itemsize
|
||||
one_side = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
|
||||
return 2 * one_side * elem_bytes # K + V
|
||||
|
||||
@property
|
||||
def cache_bytes_per_layer_per_token(self) -> int:
|
||||
"""Bytes consumed per generated token per layer."""
|
||||
elem_bytes = np.dtype(self.dtype).itemsize
|
||||
return 2 * self.num_heads * self.head_dim * elem_bytes
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""
|
||||
Standard contiguous KV cache for one transformer layer.
|
||||
|
||||
Memory layout (row-major / C-contiguous):
|
||||
cache_k: (batch, num_heads, max_seq_len, head_dim)
|
||||
cache_v: (batch, num_heads, max_seq_len, head_dim)
|
||||
|
||||
Why this layout:
|
||||
- batch first: enables batched GEMM on GPU
|
||||
- head second: allows parallel head computation
|
||||
- seq_len third: contiguous scan for Q @ K^T
|
||||
- head_dim last: inner product dimension
|
||||
|
||||
The cache is pre-allocated to max_seq_len. A `lengths` array tracks
|
||||
actual sequence lengths per batch item (for variable-length batching).
|
||||
"""
|
||||
|
||||
def __init__(self, config: CacheConfig):
|
||||
self.config = config
|
||||
self.batch_size = config.batch_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.max_seq_len = config.max_seq_len
|
||||
self.dtype = config.dtype
|
||||
|
||||
# Pre-allocate full buffers (zero-initialized)
|
||||
shape = (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim)
|
||||
self.cache_k = np.zeros(shape, dtype=self.dtype)
|
||||
self.cache_v = np.zeros(shape, dtype=self.dtype)
|
||||
|
||||
# Per-batch-item current sequence length
|
||||
self.lengths = np.zeros(self.batch_size, dtype=np.int32)
|
||||
|
||||
# Write pointer: next position to write into
|
||||
self.write_pos = 0
|
||||
|
||||
def reset(self):
|
||||
"""Clear the cache for a new generation."""
|
||||
self.cache_k[...] = 0
|
||||
self.cache_v[...] = 0
|
||||
self.lengths[...] = 0
|
||||
self.write_pos = 0
|
||||
|
||||
def update(self, keys: np.ndarray, values: np.ndarray,
|
||||
seqlen_offset: int = None) -> None:
|
||||
"""
|
||||
Append newly computed K and V to the cache.
|
||||
|
||||
Args:
|
||||
keys: (batch, num_heads, 1, head_dim) — current step's K
|
||||
values: (batch, num_heads, 1, head_dim) — current step's V
|
||||
seqlen_offset: optional explicit write position (defaults to self.write_pos)
|
||||
|
||||
The write position advances by 1 each call during generation.
|
||||
For the initial prompt, seqlen_offset=0 and we write all prompt tokens.
|
||||
"""
|
||||
if seqlen_offset is None:
|
||||
seqlen_offset = self.write_pos
|
||||
|
||||
pos = seqlen_offset
|
||||
self.cache_k[:, :, pos, :] = keys[:, :, 0, :]
|
||||
self.cache_v[:, :, pos, :] = values[:, :, 0, :]
|
||||
|
||||
# Update per-batch-item lengths
|
||||
for b in range(self.batch_size):
|
||||
self.lengths[b] = pos + 1
|
||||
|
||||
self.write_pos = pos + 1
|
||||
|
||||
def get(self, start: int = 0, end: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Retrieve cached K and V slices.
|
||||
|
||||
Returns:
|
||||
k: (batch, num_heads, end-start, head_dim)
|
||||
v: (batch, num_heads, end-start, head_dim)
|
||||
"""
|
||||
if end is None:
|
||||
end = self.write_pos
|
||||
return (
|
||||
self.cache_k[:, :, start:end, :],
|
||||
self.cache_v[:, :, start:end, :],
|
||||
)
|
||||
|
||||
def get_all(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get all cached tokens so far (up to write_pos)."""
|
||||
return self.get(0, self.write_pos)
|
||||
|
||||
@property
|
||||
def memory_used_bytes(self) -> int:
|
||||
"""Actual bytes used (based on write_pos, not max allocation)."""
|
||||
elem_bytes = np.dtype(self.dtype).itemsize
|
||||
tokens = self.write_pos
|
||||
return 2 * self.batch_size * self.num_heads * tokens * self.head_dim * elem_bytes
|
||||
|
||||
@property
|
||||
def memory_allocated_bytes(self) -> int:
|
||||
"""Total pre-allocated bytes."""
|
||||
return self.config.cache_bytes_per_layer
|
||||
|
||||
|
||||
class BatchedKVCache:
|
||||
"""
|
||||
Manages KV caches across all layers of a transformer.
|
||||
|
||||
In a real model with L layers, we need L separate KV caches.
|
||||
This class coordinates them and handles variable-length batching.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int, config: CacheConfig):
|
||||
self.num_layers = num_layers
|
||||
self.config = config
|
||||
self.caches = [KVCache(config) for _ in range(num_layers)]
|
||||
|
||||
def reset(self):
|
||||
for cache in self.caches:
|
||||
cache.reset()
|
||||
|
||||
def update(self, layer_idx: int, keys: np.ndarray, values: np.ndarray,
|
||||
seqlen_offset: int = None):
|
||||
self.caches[layer_idx].update(keys, values, seqlen_offset)
|
||||
|
||||
def get(self, layer_idx: int, start: int = 0, end: int = None):
|
||||
return self.caches[layer_idx].get(start, end)
|
||||
|
||||
@property
|
||||
def total_memory_allocated_bytes(self) -> int:
|
||||
return sum(c.memory_allocated_bytes for c in self.caches)
|
||||
|
||||
@property
|
||||
def total_memory_used_bytes(self) -> int:
|
||||
return sum(c.memory_used_bytes for c in self.caches)
|
||||
|
||||
def memory_report(self) -> dict:
|
||||
"""Detailed memory breakdown."""
|
||||
elem_bytes = self.config.dtype.itemsize
|
||||
tokens = self.caches[0].write_pos if self.caches else 0
|
||||
per_layer = self.config.cache_bytes_per_layer
|
||||
per_token_per_layer = self.config.cache_bytes_per_layer_per_token
|
||||
|
||||
return {
|
||||
"num_layers": self.num_layers,
|
||||
"batch_size": self.config.batch_size,
|
||||
"num_heads": self.config.num_heads,
|
||||
"head_dim": self.config.head_dim,
|
||||
"max_seq_len": self.config.max_seq_len,
|
||||
"dtype": str(self.config.dtype),
|
||||
"tokens_generated": tokens,
|
||||
"per_layer_allocated_mb": per_layer / (1024 * 1024),
|
||||
"total_allocated_mb": self.total_memory_allocated_bytes / (1024 * 1024),
|
||||
"total_used_mb": self.total_memory_used_bytes / (1024 * 1024),
|
||||
"growth_per_token_mb": (per_token_per_layer * self.num_layers) / (1024 * 1024),
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Memory Growth Analysis for KV-Cache Systems
|
||||
|
||||
Analyzes how memory consumption scales with:
|
||||
- Sequence length
|
||||
- Batch size
|
||||
- Number of heads
|
||||
- Model dimension
|
||||
- Number of layers
|
||||
|
||||
Provides formulas, visualizations, and practical limits.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Dict, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
"""Specification of a transformer model."""
|
||||
num_layers: int
|
||||
dim: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
vocab_size: int = 32000
|
||||
mlp_hidden_mult: float = 4.0 / 3 # GPT-style
|
||||
|
||||
|
||||
def compute_model_memory(spec: ModelSpec, dtype=np.float16) -> Dict[str, float]:
|
||||
"""
|
||||
Compute total model parameter memory.
|
||||
|
||||
Per layer:
|
||||
- Wq, Wk, Wv: 3 * dim * dim
|
||||
- Wo: dim * dim
|
||||
- MLP fc1: dim * hidden
|
||||
- MLP fc2: hidden * dim
|
||||
- LayerNorm: 2 * dim (weight + bias)
|
||||
- Embedding: vocab_size * dim (shared with LM head)
|
||||
|
||||
Total per layer (excluding shared embedding):
|
||||
4 * dim² + 2 * dim * hidden + 2 * dim
|
||||
"""
|
||||
elem = np.dtype(dtype).itemsize
|
||||
hidden = int(spec.dim * spec.mlp_hidden_mult)
|
||||
|
||||
per_layer = (
|
||||
4 * spec.dim * spec.dim + # Wq, Wk, Wv, Wo
|
||||
2 * spec.dim * hidden + # MLP fc1, fc2
|
||||
2 * spec.dim # LayerNorm params
|
||||
) * elem
|
||||
|
||||
embedding = spec.vocab_size * spec.dim * elem
|
||||
|
||||
return {
|
||||
"per_layer_bytes": per_layer,
|
||||
"per_layer_mb": per_layer / (1024 * 1024),
|
||||
"embedding_mb": embedding / (1024 * 1024),
|
||||
"total_params_mb": (per_layer * spec.num_layers + embedding) / (1024 * 1024),
|
||||
"total_params_gb": (per_layer * spec.num_layers + embedding) / (1024 ** 3),
|
||||
}
|
||||
|
||||
|
||||
def compute_kv_cache_memory(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
spec: ModelSpec,
|
||||
dtype=np.float16,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute KV cache memory for a given batch and sequence length.
|
||||
|
||||
Per layer: 2 * batch * heads * seq * head_dim * elem_bytes
|
||||
(factor of 2 for K and V)
|
||||
"""
|
||||
elem = np.dtype(dtype).itemsize
|
||||
per_layer = 2 * batch_size * spec.num_heads * seq_len * spec.head_dim * elem
|
||||
total = per_layer * spec.num_layers
|
||||
|
||||
return {
|
||||
"per_layer_bytes": per_layer,
|
||||
"per_layer_mb": per_layer / (1024 * 1024),
|
||||
"total_bytes": total,
|
||||
"total_mb": total / (1024 * 1024),
|
||||
"total_gb": total / (1024 ** 3),
|
||||
"per_token_per_layer_bytes": 2 * spec.num_heads * spec.head_dim * elem,
|
||||
"growth_rate_mb_per_token": (
|
||||
2 * batch_size * spec.num_heads * spec.head_dim * elem * spec.num_layers
|
||||
) / (1024 * 1024),
|
||||
}
|
||||
|
||||
|
||||
def analyze_memory_growth(spec: ModelSpec, batch_sizes: List[int] = None,
|
||||
seq_lengths: List[int] = None,
|
||||
dtype=np.float16) -> Dict:
|
||||
"""
|
||||
Comprehensive memory growth analysis.
|
||||
|
||||
Returns analysis for various batch sizes and sequence lengths.
|
||||
"""
|
||||
if batch_sizes is None:
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32]
|
||||
if seq_lengths is None:
|
||||
seq_lengths = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
||||
|
||||
model_mem = compute_model_memory(spec, dtype)
|
||||
|
||||
results = {
|
||||
"model": model_mem,
|
||||
"spec": {
|
||||
"num_layers": spec.num_layers,
|
||||
"dim": spec.dim,
|
||||
"num_heads": spec.num_heads,
|
||||
"head_dim": spec.head_dim,
|
||||
"dtype": str(dtype),
|
||||
},
|
||||
"kv_cache": {},
|
||||
}
|
||||
|
||||
for bs in batch_sizes:
|
||||
for sl in seq_lengths:
|
||||
kv = compute_kv_cache_memory(bs, sl, spec, dtype)
|
||||
key = f"bs{bs}_sl{sl}"
|
||||
results["kv_cache"][key] = {
|
||||
"batch_size": bs,
|
||||
"seq_len": sl,
|
||||
"kv_cache_gb": kv["total_gb"],
|
||||
"total_system_gb": kv["total_gb"] + model_mem["total_params_gb"],
|
||||
"kv_fraction": kv["total_gb"] / (kv["total_gb"] + model_mem["total_params_gb"]),
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def find_max_context(spec: ModelSpec, gpu_memory_gb: float = 80,
|
||||
batch_size: int = 1, dtype=np.float16) -> int:
|
||||
"""
|
||||
Find the maximum context length that fits in GPU memory.
|
||||
|
||||
GPU memory = model_params + kv_cache + activation_overhead
|
||||
|
||||
We estimate activation overhead as ~2x model params (conservative).
|
||||
"""
|
||||
model_mem = compute_model_memory(spec, dtype)
|
||||
model_gb = model_mem["total_params_gb"]
|
||||
|
||||
# Reserve for activations and other overhead (~2x model params)
|
||||
activation_gb = model_gb * 2
|
||||
|
||||
# Remaining for KV cache
|
||||
kv_budget_gb = gpu_memory_gb - model_gb - activation_gb
|
||||
if kv_budget_gb <= 0:
|
||||
return 0
|
||||
|
||||
elem = np.dtype(dtype).itemsize
|
||||
bytes_per_token = (2 * batch_size * spec.num_heads * spec.head_dim * elem *
|
||||
spec.num_layers)
|
||||
|
||||
max_tokens = int(kv_budget_gb * (1024 ** 3) / bytes_per_token)
|
||||
return max_tokens
|
||||
|
||||
|
||||
def compare_model_sizes() -> Dict[str, dict]:
|
||||
"""
|
||||
Analyze memory for several well-known model sizes.
|
||||
"""
|
||||
models = {
|
||||
"Llama-2-7B": ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128),
|
||||
"Llama-2-13B": ModelSpec(num_layers=40, dim=5120, num_heads=40, head_dim=128),
|
||||
"Llama-2-70B": ModelSpec(num_layers=80, dim=8192, num_heads=64, head_dim=128),
|
||||
"Llama-3-8B": ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128),
|
||||
"Mistral-7B": ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128),
|
||||
"GPT-4-class": ModelSpec(num_layers=100, dim=12288, num_heads=96, head_dim=128),
|
||||
}
|
||||
|
||||
results = {}
|
||||
for name, spec in models.items():
|
||||
model_mem = compute_model_memory(spec, np.float16)
|
||||
|
||||
# KV cache for batch=1, various lengths
|
||||
kv_1k = compute_kv_cache_memory(1, 1024, spec, np.float16)
|
||||
kv_8k = compute_kv_cache_memory(1, 8192, spec, np.float16)
|
||||
kv_32k = compute_kv_cache_memory(1, 32768, spec, np.float16)
|
||||
|
||||
results[name] = {
|
||||
"params_gb": model_mem["total_params_gb"],
|
||||
"kv_1k_gb": kv_1k["total_gb"],
|
||||
"kv_8k_gb": kv_8k["total_gb"],
|
||||
"kv_32k_gb": kv_32k["total_gb"],
|
||||
"max_context_H100": find_max_context(spec, gpu_memory_gb=80, batch_size=1),
|
||||
"max_context_A100_40": find_max_context(spec, gpu_memory_gb=40, batch_size=1),
|
||||
"max_context_A100_80": find_max_context(spec, gpu_memory_gb=80, batch_size=1),
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_analysis():
|
||||
"""Print a comprehensive memory analysis report."""
|
||||
print("=" * 80)
|
||||
print("KV-CACHE MEMORY GROWTH ANALYSIS")
|
||||
print("=" * 80)
|
||||
|
||||
# Model size comparison
|
||||
print("\n--- Model Size Comparison (fp16) ---\n")
|
||||
comparisons = compare_model_sizes()
|
||||
header = f"{'Model':<20} {'Params(GB)':>10} {'KV@1K':>10} {'KV@8K':>10} {'KV@32K':>10} {'MaxCtx(H100)':>12}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
for name, data in comparisons.items():
|
||||
print(f"{name:<20} {data['params_gb']:>10.1f} {data['kv_1k_gb']:>10.2f} "
|
||||
f"{data['kv_8k_gb']:>10.2f} {data['kv_32k_gb']:>10.2f} "
|
||||
f"{data['max_context_H100']:>12,d}")
|
||||
|
||||
# Growth analysis for a 7B model
|
||||
print("\n\n--- Detailed Growth: 7B Model (batch=1, fp16) ---\n")
|
||||
spec_7b = ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128)
|
||||
model_mem = compute_model_memory(spec_7b, np.float16)
|
||||
|
||||
seq_lens = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
||||
print(f"{'Seq Len':>10} {'KV Cache(GB)':>14} {'Total(GB)':>12} {'KV Fraction':>12}")
|
||||
print("-" * 52)
|
||||
for sl in seq_lens:
|
||||
kv = compute_kv_cache_memory(1, sl, spec_7b, np.float16)
|
||||
total = kv["total_gb"] + model_mem["total_params_gb"]
|
||||
frac = kv["total_gb"] / total
|
||||
print(f"{sl:>10,} {kv['total_gb']:>14.2f} {total:>12.2f} {frac:>12.1%}")
|
||||
|
||||
# Batch size impact
|
||||
print("\n\n--- Batch Size Impact (seq_len=4096, fp16) ---\n")
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32]
|
||||
print(f"{'Batch':>6} {'KV Cache(GB)':>14} {'Growth/Token(MB)':>18}")
|
||||
print("-" * 40)
|
||||
for bs in batch_sizes:
|
||||
kv = compute_kv_cache_memory(bs, 4096, spec_7b, np.float16)
|
||||
print(f"{bs:>6} {kv['total_gb']:>14.2f} {kv['growth_rate_mb_per_token']:>18.4f}")
|
||||
|
||||
# Per-token cost
|
||||
print("\n\n--- Per-Token Memory Cost ---\n")
|
||||
kv_one = compute_kv_cache_memory(1, 1, spec_7b, np.float16)
|
||||
per_token = kv_one["total_bytes"]
|
||||
print(f" Per token (all layers): {per_token:,} bytes = {per_token/1024:.1f} KB")
|
||||
print(f" Per token per layer: {kv_one['per_token_per_layer_bytes']:,} bytes")
|
||||
print(f" At 32K context: {per_token * 32768 / (1024**3):.2f} GB")
|
||||
|
||||
# GPU memory limits
|
||||
print("\n\n--- Maximum Context Lengths by GPU ---\n")
|
||||
gpus = {
|
||||
"RTX 4090": 24,
|
||||
"A100-40GB": 40,
|
||||
"A100-80GB": 80,
|
||||
"H100-80GB": 80,
|
||||
"H100-96GB (SXM)": 96,
|
||||
}
|
||||
print(f"{'GPU':<20} {'Max Context (bs=1)':>20} {'Max Context (bs=4)':>20}")
|
||||
print("-" * 62)
|
||||
for gpu, mem in gpus.items():
|
||||
ctx_1 = find_max_context(spec_7b, mem, batch_size=1)
|
||||
ctx_4 = find_max_context(spec_7b, mem, batch_size=4)
|
||||
print(f"{gpu:<20} {ctx_1:>20,} {ctx_4:>20,}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_analysis()
|
||||
@@ -0,0 +1,589 @@
|
||||
"""
|
||||
KV-Cache Optimizations
|
||||
|
||||
Implements three major optimization strategies:
|
||||
1. Paged Attention — non-contiguous memory allocation (inspired by vLLM)
|
||||
2. Quantization — reduced precision for cached K/V
|
||||
3. Chunked Prefill — processing long prompts in chunks to limit peak memory
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple, List, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from kv_cache import CacheConfig
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. PAGED ATTENTION
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class PageConfig:
|
||||
"""Configuration for paged KV cache."""
|
||||
block_size: int = 16 # tokens per block
|
||||
num_pages: int = 256 # total pages per sequence
|
||||
batch_size: int = 4
|
||||
num_heads: int = 32
|
||||
head_dim: int = 128
|
||||
dtype: np.dtype = np.float16
|
||||
|
||||
|
||||
class PagedKVCache:
|
||||
"""
|
||||
Paged KV Cache — inspired by vLLM's PagedAttention.
|
||||
|
||||
Instead of a contiguous [batch, heads, max_seq, head_dim] buffer,
|
||||
memory is divided into fixed-size blocks (pages). Each sequence
|
||||
maintains a page table mapping logical block indices to physical pages.
|
||||
|
||||
Benefits:
|
||||
- Zero memory fragmentation: blocks are allocated on demand
|
||||
- Supports speculative decoding and branching
|
||||
- Enables sharing of common prefixes (prefix caching)
|
||||
- No need to pre-allocate max_seq_len
|
||||
|
||||
Memory layout:
|
||||
physical_pages: (num_pages, batch_size, num_heads, block_size, head_dim) [for K]
|
||||
physical_pages_v: same shape [for V]
|
||||
page_tables: (batch_size, max_blocks) — maps logical block -> physical page index
|
||||
"""
|
||||
|
||||
def __init__(self, config: PageConfig):
|
||||
self.config = config
|
||||
self.batch_size = config.batch_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.block_size = config.block_size
|
||||
self.num_pages = config.num_pages
|
||||
self.dtype = config.dtype
|
||||
|
||||
# Physical page pool (shared across all sequences)
|
||||
# Each page holds: (num_heads, block_size, head_dim)
|
||||
page_shape = (config.num_pages * config.batch_size,
|
||||
config.num_heads, config.block_size, config.head_dim)
|
||||
self.physical_pages_k = np.zeros(page_shape, dtype=self.dtype)
|
||||
self.physical_pages_v = np.zeros(page_shape, dtype=self.dtype)
|
||||
|
||||
# Page table per sequence: logical_block_idx -> physical_page_idx
|
||||
max_blocks = config.num_pages
|
||||
self.page_tables = np.full(
|
||||
(config.batch_size, max_blocks), -1, dtype=np.int32
|
||||
)
|
||||
|
||||
# Number of allocated blocks per sequence
|
||||
self.num_blocks = np.zeros(config.batch_size, dtype=np.int32)
|
||||
|
||||
# Free page pool (global, shared)
|
||||
total_pages = config.num_pages * config.batch_size
|
||||
self.free_list = np.arange(total_pages, dtype=np.int32)
|
||||
self.free_ptr = 0 # index into free_list
|
||||
|
||||
def _alloc_page(self) -> int:
|
||||
"""Allocate one physical page from the free pool."""
|
||||
if self.free_ptr >= len(self.free_list):
|
||||
raise MemoryError("Paged KV cache out of memory")
|
||||
page_idx = self.free_list[self.free_ptr]
|
||||
self.free_ptr += 1
|
||||
return page_idx
|
||||
|
||||
def _free_page(self, page_idx: int):
|
||||
"""Return a physical page to the free pool."""
|
||||
self.free_list[self.free_ptr - 1] = page_idx
|
||||
self.free_ptr -= 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset cache for a new generation."""
|
||||
self.physical_pages_k[...] = 0
|
||||
self.physical_pages_v[...] = 0
|
||||
self.page_tables[...] = -1
|
||||
self.num_blocks[...] = 0
|
||||
self.free_ptr = 0
|
||||
|
||||
def append_token(self, batch_idx: int, keys: np.ndarray,
|
||||
values: np.ndarray, logical_block: int,
|
||||
offset_in_block: int):
|
||||
"""
|
||||
Append one token to a specific logical block.
|
||||
|
||||
Args:
|
||||
batch_idx: batch item index
|
||||
keys: (1, num_heads, 1, head_dim)
|
||||
values: (1, num_heads, 1, head_dim)
|
||||
logical_block: which logical block to write to
|
||||
offset_in_block: position within the block (0..block_size-1)
|
||||
"""
|
||||
# Check if physical page is allocated for this logical block
|
||||
phys_page = self.page_tables[batch_idx, logical_block]
|
||||
|
||||
if phys_page == -1:
|
||||
# Allocate new physical page
|
||||
phys_page = self._alloc_page()
|
||||
self.page_tables[batch_idx, logical_block] = phys_page
|
||||
if logical_block + 1 > self.num_blocks[batch_idx]:
|
||||
self.num_blocks[batch_idx] = logical_block + 1
|
||||
|
||||
# Write to physical page
|
||||
self.physical_pages_k[phys_page, :, offset_in_block, :] = keys[0, :, 0, :]
|
||||
self.physical_pages_v[phys_page, :, offset_in_block, :] = values[0, :, 0, :]
|
||||
|
||||
def get_sequence(self, batch_idx: int,
|
||||
start_block: int = 0,
|
||||
end_block: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Retrieve K and V for a sequence, gathering from physical pages.
|
||||
|
||||
Returns:
|
||||
k: (num_heads, total_tokens, head_dim)
|
||||
v: (num_heads, total_tokens, head_dim)
|
||||
"""
|
||||
if end_block is None:
|
||||
end_block = self.num_blocks[batch_idx]
|
||||
|
||||
blocks = end_block - start_block
|
||||
total_tokens = blocks * self.block_size
|
||||
|
||||
k_out = np.zeros(
|
||||
(self.num_heads, total_tokens, self.head_dim), dtype=self.dtype
|
||||
)
|
||||
v_out = np.zeros(
|
||||
(self.num_heads, total_tokens, self.head_dim), dtype=self.dtype
|
||||
)
|
||||
|
||||
for i in range(start_block, end_block):
|
||||
phys_page = self.page_tables[batch_idx, i]
|
||||
if phys_page == -1:
|
||||
break
|
||||
block_idx = i - start_block
|
||||
token_start = block_idx * self.block_size
|
||||
token_end = token_start + self.block_size
|
||||
k_out[:, token_start:token_end, :] = self.physical_pages_k[phys_page]
|
||||
v_out[:, token_start:token_end, :] = self.physical_pages_v[phys_page]
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def get_sequence_contiguous(self, batch_idx: int,
|
||||
num_tokens: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Get K, V as contiguous arrays for attention computation.
|
||||
|
||||
Returns:
|
||||
k: (1, num_heads, num_tokens, head_dim)
|
||||
v: (1, num_heads, num_tokens, head_dim)
|
||||
"""
|
||||
if num_tokens is None:
|
||||
num_tokens = self.num_blocks[batch_idx] * self.block_size
|
||||
|
||||
k, v = self.get_sequence(batch_idx)
|
||||
# k: (num_heads, num_tokens, head_dim) -> (1, num_heads, num_tokens, head_dim)
|
||||
return k[None, ...], v[None, ...]
|
||||
|
||||
@property
|
||||
def memory_allocated_bytes(self) -> int:
|
||||
elem_bytes = np.dtype(self.dtype).itemsize
|
||||
total_pages = self.num_pages * self.batch_size
|
||||
page_bytes = self.num_heads * self.block_size * self.head_dim * elem_bytes
|
||||
return 2 * total_pages * page_bytes # K + V
|
||||
|
||||
@property
|
||||
def memory_used_bytes(self) -> int:
|
||||
"""Bytes actually used (allocated blocks only)."""
|
||||
elem_bytes = np.dtype(self.dtype).itemsize
|
||||
total_used_blocks = np.sum(self.num_blocks)
|
||||
page_bytes = self.num_heads * self.block_size * self.head_dim * elem_bytes
|
||||
return 2 * total_used_blocks * page_bytes
|
||||
|
||||
def memory_utilization(self) -> float:
|
||||
"""Fraction of allocated memory actually used."""
|
||||
alloc = self.memory_allocated_bytes
|
||||
if alloc == 0:
|
||||
return 0.0
|
||||
return self.memory_used_bytes / alloc
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. QUANTIZED KV CACHE
|
||||
# =============================================================================
|
||||
|
||||
class QuantizedKVCache:
|
||||
"""
|
||||
Quantized KV Cache — stores K and V in reduced precision.
|
||||
|
||||
Strategy: per-channel (per-head-dim) int8 quantization.
|
||||
- Each head-dimension channel has its own scale and zero-point
|
||||
- Dequantize on-the-fly during attention computation
|
||||
|
||||
Memory savings: float16 (16-bit) -> int8 (8-bit) = 2x reduction
|
||||
Plus metadata overhead: 2 scales per channel (K and V) in float16
|
||||
|
||||
For head_dim=128:
|
||||
- Original: 128 * 16 = 2048 bits per token per head
|
||||
- Quantized: 128 * 8 + 2 * 128 * 16 = 1024 + 4096 = 5120 bits
|
||||
- But scales are shared across all tokens, so per-token: 128 * 8 = 1024 bits
|
||||
- Net savings: ~50%
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int, num_heads: int, head_dim: int,
|
||||
max_seq_len: int, dtype=np.float16):
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.dtype = dtype
|
||||
self.write_pos = 0
|
||||
|
||||
# Quantized storage: int8
|
||||
shape = (batch_size, num_heads, max_seq_len, head_dim)
|
||||
self.cache_k_int8 = np.zeros(shape, dtype=np.int8)
|
||||
self.cache_v_int8 = np.zeros(shape, dtype=np.int8)
|
||||
|
||||
# Per-channel scales and zero-points per position
|
||||
scale_shape = (batch_size, num_heads, max_seq_len, head_dim)
|
||||
self.k_scales = np.ones(scale_shape, dtype=dtype)
|
||||
self.k_zeros = np.zeros(scale_shape, dtype=dtype)
|
||||
self.v_scales = np.ones(scale_shape, dtype=dtype)
|
||||
self.v_zeros = np.zeros(scale_shape, dtype=dtype)
|
||||
|
||||
def reset(self):
|
||||
self.cache_k_int8[...] = 0
|
||||
self.cache_v_int8[...] = 0
|
||||
self.k_scales[...] = 1.0
|
||||
self.k_zeros[...] = 0.0
|
||||
self.v_scales[...] = 1.0
|
||||
self.v_zeros[...] = 0.0
|
||||
self.write_pos = 0
|
||||
|
||||
def _quantize(self, x: np.ndarray, axis: int = -1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Quantize to int8 with per-channel affine transform: x ≈ scale * q + zero.
|
||||
|
||||
Returns quantized values, scales, and zero-points.
|
||||
"""
|
||||
x_f = x.astype(np.float32)
|
||||
# Per-channel min/max
|
||||
x_min = np.min(x_f, axis=axis, keepdims=True)
|
||||
x_max = np.max(x_f, axis=axis, keepdims=True)
|
||||
|
||||
# Avoid division by zero
|
||||
x_range = x_max - x_min
|
||||
x_range = np.where(x_range < 1e-6, 1.0, x_range)
|
||||
|
||||
# Scale: map [-128, 127] to [x_min, x_max]
|
||||
scale = x_range / 255.0
|
||||
zero = x_min # zero-point
|
||||
|
||||
# Quantize
|
||||
x_centered = x_f - zero
|
||||
x_quant = np.round(x_centered / scale).astype(np.int8)
|
||||
x_quant = np.clip(x_quant, -128, 127)
|
||||
|
||||
return x_quant, scale.astype(self.dtype), zero.astype(self.dtype)
|
||||
|
||||
def _dequantize(self, x_int8: np.ndarray, scale: np.ndarray,
|
||||
zero: np.ndarray) -> np.ndarray:
|
||||
"""Dequantize int8 back to float: x = scale * q + zero."""
|
||||
return (x_int8.astype(np.float32) * scale + zero).astype(self.dtype)
|
||||
|
||||
def update(self, keys: np.ndarray, values: np.ndarray,
|
||||
seqlen_offset: int = None):
|
||||
"""
|
||||
Quantize and store K, V.
|
||||
|
||||
Args:
|
||||
keys: (batch, heads, 1, head_dim)
|
||||
values: (batch, heads, 1, head_dim)
|
||||
"""
|
||||
if seqlen_offset is None:
|
||||
seqlen_offset = self.write_pos
|
||||
|
||||
pos = seqlen_offset
|
||||
|
||||
# Quantize K
|
||||
k_q, k_s, k_z = self._quantize(keys, axis=-1)
|
||||
self.cache_k_int8[:, :, pos, :] = k_q[:, :, 0, :]
|
||||
self.k_scales[:, :, pos:pos+1, :] = k_s
|
||||
self.k_zeros[:, :, pos:pos+1, :] = k_z
|
||||
|
||||
# Quantize V
|
||||
v_q, v_s, v_z = self._quantize(values, axis=-1)
|
||||
self.cache_v_int8[:, :, pos, :] = v_q[:, :, 0, :]
|
||||
self.v_scales[:, :, pos:pos+1, :] = v_s
|
||||
self.v_zeros[:, :, pos:pos+1, :] = v_z
|
||||
|
||||
self.write_pos = pos + 1
|
||||
|
||||
def get(self, start: int = 0, end: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get dequantized K, V."""
|
||||
if end is None:
|
||||
end = self.write_pos
|
||||
|
||||
k_int = self.cache_k_int8[:, :, start:end, :]
|
||||
v_int = self.cache_v_int8[:, :, start:end, :]
|
||||
|
||||
# Dequantize using scales and zero-points from each position
|
||||
k_deq = self._dequantize(k_int, self.k_scales[:, :, start:end, :],
|
||||
self.k_zeros[:, :, start:end, :])
|
||||
v_deq = self._dequantize(v_int, self.v_scales[:, :, start:end, :],
|
||||
self.v_zeros[:, :, start:end, :])
|
||||
|
||||
return k_deq, v_deq
|
||||
|
||||
@property
|
||||
def memory_allocated_bytes(self) -> int:
|
||||
"""Total allocated memory including quantization metadata.
|
||||
|
||||
Includes: int8 K + int8 V + fp scales (K+V) + fp zero-points (K+V)
|
||||
"""
|
||||
elem_int8 = np.dtype(np.int8).itemsize
|
||||
elem_fp = np.dtype(self.dtype).itemsize
|
||||
n = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
|
||||
k_v_bytes = 2 * n * elem_int8 # int8 K + V
|
||||
meta_bytes = 4 * n * elem_fp # scales + zeros for K and V
|
||||
return k_v_bytes + meta_bytes
|
||||
|
||||
@property
|
||||
def memory_savings_vs_fp16(self) -> float:
|
||||
"""Fraction of memory saved vs. full fp16 cache.
|
||||
|
||||
Note: with per-position scales in fp32, this may be negative.
|
||||
For real savings, use fp16 scales or shared (per-channel) scales.
|
||||
"""
|
||||
elem_fp16 = np.dtype(np.float16).itemsize
|
||||
fp16_bytes = 2 * self.batch_size * self.num_heads * self.max_seq_len * self.head_dim * elem_fp16
|
||||
return 1.0 - self.memory_allocated_bytes / fp16_bytes
|
||||
|
||||
@property
|
||||
def memory_savings_vs_fp32(self) -> float:
|
||||
"""Fraction of memory saved vs. full fp32 cache."""
|
||||
elem_fp32 = np.dtype(np.float32).itemsize
|
||||
fp32_bytes = 2 * self.batch_size * self.num_heads * self.max_seq_len * self.head_dim * elem_fp32
|
||||
return 1.0 - self.memory_allocated_bytes / fp32_bytes
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. CHUNKED PREFILL
|
||||
# =============================================================================
|
||||
|
||||
class ChunkedPrefill:
|
||||
"""
|
||||
Chunked Prefill — process long prompts in chunks to limit peak memory.
|
||||
|
||||
During prefill with very long prompts (e.g., 32K tokens), computing
|
||||
full attention O(n²) requires materializing a (n, n) attention matrix,
|
||||
which can exceed GPU memory.
|
||||
|
||||
Chunked prefill processes the prompt in chunks of size C:
|
||||
- Chunk 0: tokens [0, C) — full causal attention within chunk
|
||||
- Chunk 1: tokens [C, 2C) — attend to all previous tokens + causal within chunk
|
||||
- ...
|
||||
|
||||
Each chunk's attention is O(C * (i*C + C)) = O(i*C²), but the peak
|
||||
memory for the attention matrix is O(C²) instead of O(n²).
|
||||
|
||||
The KV cache is updated incrementally after each chunk.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: int = 512):
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def compute_attention_chunked(
|
||||
self,
|
||||
q_all: np.ndarray,
|
||||
k_all: np.ndarray,
|
||||
v_all: np.ndarray,
|
||||
scale: float,
|
||||
dtype=np.float32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute causal attention in chunks.
|
||||
|
||||
Args:
|
||||
q_all: (batch, heads, seq, head_dim)
|
||||
k_all: (batch, heads, seq, head_dim)
|
||||
v_all: (batch, heads, seq, head_dim)
|
||||
scale: 1 / sqrt(head_dim)
|
||||
|
||||
Returns:
|
||||
output: (batch, heads, seq, head_dim)
|
||||
"""
|
||||
batch, heads, seq, head_dim = q_all.shape
|
||||
output = np.zeros((batch, heads, seq, head_dim), dtype=dtype)
|
||||
|
||||
num_chunks = (seq + self.chunk_size - 1) // self.chunk_size
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * self.chunk_size
|
||||
end = min(start + self.chunk_size, seq)
|
||||
chunk_len = end - start
|
||||
|
||||
# Current chunk's Q
|
||||
q_chunk = q_all[:, :, start:end, :] # (batch, heads, chunk_len, head_dim)
|
||||
|
||||
# Keys and values up to current position (causal)
|
||||
k_prefix = k_all[:, :, :end, :] # (batch, heads, end, head_dim)
|
||||
v_prefix = v_all[:, :, :end, :]
|
||||
|
||||
q_f = q_chunk.astype(dtype)
|
||||
k_f = k_prefix.astype(dtype)
|
||||
v_f = v_prefix.astype(dtype)
|
||||
|
||||
# Q @ K^T: (batch, heads, chunk_len, end)
|
||||
scores = np.einsum("bhqd,bhkd->bhqk", q_f, k_f) * scale
|
||||
|
||||
# Causal mask: query at position p can only attend to keys at position <= p
|
||||
# Query positions (absolute): start..end-1
|
||||
# Key positions (absolute): 0..end-1
|
||||
q_positions = np.arange(start, end) # (chunk_len,)
|
||||
k_positions = np.arange(end) # (end,)
|
||||
# Allowed: q_pos >= k_pos (causal)
|
||||
causal_mask = (q_positions[:, None] >= k_positions[None, :]).astype(dtype)
|
||||
# (chunk_len, end)
|
||||
causal_mask = np.where(causal_mask, 0.0, -np.inf)
|
||||
|
||||
scores = scores + causal_mask[None, None, :, :]
|
||||
|
||||
# Softmax
|
||||
attn_weights = self._softmax_stable(scores, axis=-1)
|
||||
|
||||
# Attn @ V
|
||||
chunk_output = np.einsum("bhqk,bhkd->bhqd", attn_weights, v_f)
|
||||
output[:, :, start:end, :] = chunk_output
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _softmax_stable(x: np.ndarray, axis: int = -1) -> np.ndarray:
|
||||
x_max = np.max(x, axis=axis, keepdims=True)
|
||||
exp_x = np.exp(x - x_max)
|
||||
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
||||
|
||||
@staticmethod
|
||||
def peak_memory_comparison(seq_len: int, chunk_size: int,
|
||||
head_dim: int = 128) -> dict:
|
||||
"""
|
||||
Compare peak memory usage between full and chunked prefill.
|
||||
|
||||
The dominant memory is the attention score matrix.
|
||||
"""
|
||||
# Full prefill: attention matrix is (seq_len, seq_len) in float32
|
||||
full_attention_bytes = seq_len * seq_len * 4 # float32
|
||||
|
||||
# Chunked prefill: attention matrix is (chunk_size, seq_len) at most
|
||||
# The last chunk sees all previous tokens
|
||||
max_chunk_attention = chunk_size * seq_len * 4
|
||||
|
||||
return {
|
||||
"seq_len": seq_len,
|
||||
"chunk_size": chunk_size,
|
||||
"full_attention_mb": full_attention_bytes / (1024 * 1024),
|
||||
"chunked_peak_attention_mb": max_chunk_attention / (1024 * 1024),
|
||||
"savings_ratio": full_attention_bytes / max(chunk_chunk_attention := chunk_size * seq_len * 4, 1),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. HYBRID: PAGED + QUANTIZED
|
||||
# =============================================================================
|
||||
|
||||
class HybridKVCache:
|
||||
"""
|
||||
Combines paged attention with quantization for maximum memory efficiency.
|
||||
|
||||
- Paged allocation eliminates fragmentation
|
||||
- Quantization reduces per-token storage by ~50%
|
||||
- Together: can handle 2-4x longer contexts in the same memory
|
||||
"""
|
||||
|
||||
def __init__(self, page_config: PageConfig):
|
||||
self.page_config = page_config
|
||||
self.paged = PagedKVCache(page_config)
|
||||
self.quantized = QuantizedKVCache(
|
||||
batch_size=page_config.batch_size,
|
||||
num_heads=page_config.num_heads,
|
||||
head_dim=page_config.head_dim,
|
||||
max_seq_len=page_config.num_pages * page_config.block_size,
|
||||
dtype=page_config.dtype,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.paged.reset()
|
||||
self.quantized.reset()
|
||||
|
||||
@property
|
||||
def total_memory_saved(self) -> float:
|
||||
"""Combined memory savings vs. naive contiguous fp16 cache."""
|
||||
return self.quantized.memory_savings_vs_fp16
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# COMPARISON ANALYSIS
|
||||
# =============================================================================
|
||||
|
||||
def compare_strategies(batch_size: int = 4, num_heads: int = 32,
|
||||
head_dim: int = 128, max_seq_len: int = 4096,
|
||||
num_layers: int = 32) -> Dict[str, dict]:
|
||||
"""
|
||||
Compare memory usage across different KV-cache strategies.
|
||||
"""
|
||||
elem_fp16 = 2 # bytes per float16 element
|
||||
elem_fp32 = 4
|
||||
elem_int8 = 1
|
||||
|
||||
base_tokens = batch_size * num_heads * max_seq_len * head_dim
|
||||
base_bytes_per_layer = 2 * base_tokens * elem_fp16 # K + V
|
||||
|
||||
results = {}
|
||||
|
||||
# 1. Naive contiguous fp16
|
||||
results["naive_fp16"] = {
|
||||
"description": "Contiguous fp16 cache",
|
||||
"per_layer_mb": base_bytes_per_layer / (1024 * 1024),
|
||||
"total_mb": base_bytes_per_layer * num_layers / (1024 * 1024),
|
||||
"per_token_per_layer_bytes": 2 * num_heads * head_dim * elem_fp16,
|
||||
}
|
||||
|
||||
# 2. Contiguous fp32
|
||||
base_bytes_fp32 = 2 * base_tokens * elem_fp32
|
||||
results["naive_fp32"] = {
|
||||
"description": "Contiguous fp32 cache",
|
||||
"per_layer_mb": base_bytes_fp32 / (1024 * 1024),
|
||||
"total_mb": base_bytes_fp32 * num_layers / (1024 * 1024),
|
||||
"per_token_per_layer_bytes": 2 * num_heads * head_dim * elem_fp32,
|
||||
}
|
||||
|
||||
# 3. Quantized int8 (with fp16 scales)
|
||||
# Per-token: int8 data + shared fp16 scales per channel
|
||||
quant_data = base_tokens * elem_int8 * 2 # K + V int8
|
||||
quant_scales = batch_size * num_heads * head_dim * elem_fp16 * 2 # shared scales
|
||||
quant_total = quant_data + quant_scales
|
||||
results["quantized_int8"] = {
|
||||
"description": "Int8 quantized with fp16 scales",
|
||||
"per_layer_mb": quant_total / (1024 * 1024),
|
||||
"total_mb": quant_total * num_layers / (1024 * 1024),
|
||||
"savings_vs_fp16": 1.0 - quant_total / base_bytes_per_layer,
|
||||
}
|
||||
|
||||
# 4. Paged (no fragmentation waste)
|
||||
block_size = 16
|
||||
blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
# Paged has slight overhead from block alignment
|
||||
padded_tokens = batch_size * blocks_per_seq * block_size * num_heads * head_dim
|
||||
paged_bytes = 2 * padded_tokens * elem_fp16
|
||||
results["paged"] = {
|
||||
"description": "Paged attention (block_size=16)",
|
||||
"per_layer_mb": paged_bytes / (1024 * 1024),
|
||||
"total_mb": paged_bytes * num_layers / (1024 * 1024),
|
||||
"overhead_vs_naive": paged_bytes / base_bytes_per_layer,
|
||||
}
|
||||
|
||||
# 5. Paged + Quantized
|
||||
paged_quant_data = padded_tokens * elem_int8 * 2
|
||||
paged_quant_scales = batch_size * num_heads * head_dim * elem_fp16 * 2
|
||||
paged_quant_total = paged_quant_data + paged_quant_scales
|
||||
results["paged_quantized"] = {
|
||||
"description": "Paged + int8 quantized",
|
||||
"per_layer_mb": paged_quant_total / (1024 * 1024),
|
||||
"total_mb": paged_quant_total * num_layers / (1024 * 1024),
|
||||
"savings_vs_fp16": 1.0 - paged_quant_total / base_bytes_per_layer,
|
||||
}
|
||||
|
||||
return results
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
Transformer Layer with KV-Cache Integration
|
||||
|
||||
Implements a complete decoder transformer layer that:
|
||||
- Computes Q, K, V projections
|
||||
- Stores K, V in the cache
|
||||
- Performs cached attention
|
||||
- Applies MLP with residual connections and layer norm
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple, List
|
||||
from kv_cache import KVCache, CacheConfig, BatchedKVCache
|
||||
from attention import (
|
||||
cached_attention,
|
||||
cached_attention_with_mask,
|
||||
prompt_attention,
|
||||
)
|
||||
|
||||
|
||||
class Linear:
|
||||
"""Simple linear layer (no framework)."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int,
|
||||
dtype=np.float32, seed: int = None):
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
# Kaiming initialization
|
||||
scale = np.sqrt(2.0 / in_features)
|
||||
self.weight = np.random.randn(out_features, in_features).astype(dtype) * scale
|
||||
self.bias = np.zeros(out_features, dtype=dtype)
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
return (x @ self.weight.T + self.bias).astype(self.dtype)
|
||||
|
||||
|
||||
class LayerNorm:
|
||||
"""Layer normalization."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-5, dtype=np.float32):
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = np.ones(dim, dtype=dtype)
|
||||
self.bias = np.zeros(dim, dtype=dtype)
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
x_f = x.astype(np.float32)
|
||||
mean = np.mean(x_f, axis=-1, keepdims=True)
|
||||
var = np.var(x_f, axis=-1, keepdims=True)
|
||||
x_norm = (x_f - mean) / np.sqrt(var + self.eps)
|
||||
return (x_norm * self.weight + self.bias).astype(self.dtype)
|
||||
|
||||
|
||||
class MLP:
|
||||
"""Feed-forward network: linear -> activation -> linear."""
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, dtype=np.float32, seed: int = None):
|
||||
self.fc1 = Linear(dim, hidden_dim, dtype=dtype, seed=seed)
|
||||
self.fc2 = Linear(hidden_dim, dim, dtype=dtype, seed=seed + 1 if seed else None)
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
h = self.fc1.forward(x)
|
||||
# GELU approximation
|
||||
h = h * (1 + np.tanh(np.sqrt(2 / np.pi) * (h + 0.044715 * h ** 3))) * 0.5
|
||||
return self.fc2.forward(h)
|
||||
|
||||
|
||||
class TransformerDecoderLayer:
|
||||
"""
|
||||
Single decoder transformer layer with KV-cache support.
|
||||
|
||||
Architecture:
|
||||
x -> LayerNorm -> Self-Attention -> Residual -> LayerNorm -> MLP -> Residual
|
||||
|
||||
Pre-norm variant (used by most modern models).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, mlp_hidden: int,
|
||||
dtype=np.float32, seed: int = None):
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = 1.0 / np.sqrt(self.head_dim)
|
||||
self.dtype = dtype
|
||||
|
||||
# Q, K, V projections
|
||||
self.wq = Linear(dim, dim, dtype=dtype, seed=seed)
|
||||
self.wk = Linear(dim, dim, dtype=dtype, seed=seed + 1 if seed else None)
|
||||
self.wv = Linear(dim, dim, dtype=dtype, seed=seed + 2 if seed else None)
|
||||
|
||||
# Output projection
|
||||
self.wo = Linear(dim, dim, dtype=dtype, seed=seed + 3 if seed else None)
|
||||
|
||||
# Normalizations
|
||||
self.norm1 = LayerNorm(dim, dtype=dtype)
|
||||
self.norm2 = LayerNorm(dim, dtype=dtype)
|
||||
|
||||
# MLP
|
||||
self.mlp = MLP(dim, mlp_hidden, dtype=dtype, seed=seed + 4 if seed else None)
|
||||
|
||||
def _to_heads(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reshape (batch, seq, dim) -> (batch, seq, heads, head_dim)."""
|
||||
batch, seq, _ = x.shape
|
||||
return x.reshape(batch, seq, self.num_heads, self.head_dim)
|
||||
|
||||
def _from_heads(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reshape (batch, seq, heads, head_dim) -> (batch, seq, dim)."""
|
||||
batch, seq, _, _ = x.shape
|
||||
return x.reshape(batch, seq, self.dim)
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
x: np.ndarray,
|
||||
cache: KVCache,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Process the full prompt (prefill phase).
|
||||
|
||||
Args:
|
||||
x: (batch, prompt_len, dim)
|
||||
cache: KVCache to populate with K, V
|
||||
lengths: optional per-batch-item prompt lengths
|
||||
|
||||
Returns:
|
||||
output: (batch, prompt_len, dim)
|
||||
"""
|
||||
batch, seq_len, _ = x.shape
|
||||
|
||||
# Self-attention with residual
|
||||
residual = x
|
||||
x_norm = self.norm1.forward(x)
|
||||
|
||||
# Project to Q, K, V
|
||||
q = self.wq.forward(x_norm) # (batch, seq, dim)
|
||||
k = self.wk.forward(x_norm)
|
||||
v = self.wv.forward(x_norm)
|
||||
|
||||
# Reshape to multi-head
|
||||
q = self._to_heads(q).transpose(0, 2, 1, 3) # (batch, heads, seq, head_dim)
|
||||
k = self._to_heads(k).transpose(0, 2, 1, 3)
|
||||
v = self._to_heads(v).transpose(0, 2, 1, 3)
|
||||
|
||||
# Cached attention (stores K, V in cache)
|
||||
attn_out, _, _ = prompt_attention(
|
||||
q, k, v, cache, self.scale, lengths=lengths
|
||||
)
|
||||
# (batch, heads, seq, head_dim)
|
||||
|
||||
# Reshape and project output
|
||||
attn_out = attn_out.transpose(0, 2, 1, 3) # (batch, seq, heads, head_dim)
|
||||
attn_out = self._from_heads(attn_out) # (batch, seq, dim)
|
||||
attn_out = self.wo.forward(attn_out)
|
||||
|
||||
x = residual + attn_out
|
||||
|
||||
# MLP with residual
|
||||
residual = x
|
||||
x_norm = self.norm2.forward(x)
|
||||
mlp_out = self.mlp.forward(x_norm)
|
||||
x = residual + mlp_out
|
||||
|
||||
return x
|
||||
|
||||
def forward_generate(
|
||||
self,
|
||||
x: np.ndarray,
|
||||
cache: KVCache,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Process one token (generation phase).
|
||||
|
||||
Args:
|
||||
x: (batch, 1, dim) — single token
|
||||
cache: KVCache with previous K, V
|
||||
lengths: optional per-batch-item sequence lengths
|
||||
|
||||
Returns:
|
||||
output: (batch, 1, dim)
|
||||
"""
|
||||
# Self-attention with residual
|
||||
residual = x
|
||||
x_norm = self.norm1.forward(x)
|
||||
|
||||
# Project to Q, K, V
|
||||
q = self.wq.forward(x_norm) # (batch, 1, dim)
|
||||
k = self.wk.forward(x_norm)
|
||||
v = self.wv.forward(x_norm)
|
||||
|
||||
# Reshape to multi-head
|
||||
q = self._to_heads(q).transpose(0, 2, 1, 3) # (batch, heads, 1, head_dim)
|
||||
k = self._to_heads(k).transpose(0, 2, 1, 3)
|
||||
v = self._to_heads(v).transpose(0, 2, 1, 3)
|
||||
|
||||
# Store K, V in cache
|
||||
cache.update(k, v)
|
||||
|
||||
# Cached attention
|
||||
if lengths is not None:
|
||||
attn_out = cached_attention_with_mask(
|
||||
q, cache, self.scale, lengths=lengths
|
||||
)
|
||||
else:
|
||||
attn_out = cached_attention(q, cache, self.scale)
|
||||
# (batch, heads, 1, head_dim)
|
||||
|
||||
# Reshape and project output
|
||||
attn_out = attn_out.transpose(0, 2, 1, 3) # (batch, 1, heads, head_dim)
|
||||
attn_out = self._from_heads(attn_out) # (batch, 1, dim)
|
||||
attn_out = self.wo.forward(attn_out)
|
||||
|
||||
x = residual + attn_out
|
||||
|
||||
# MLP with residual
|
||||
residual = x
|
||||
x_norm = self.norm2.forward(x)
|
||||
mlp_out = self.mlp.forward(x_norm)
|
||||
x = residual + mlp_out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder:
|
||||
"""
|
||||
Full transformer decoder with KV-cache management.
|
||||
|
||||
Orchestrates prefill and generation across all layers.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int, dim: int, num_heads: int,
|
||||
mlp_hidden: int, vocab_size: int, max_seq_len: int,
|
||||
batch_size: int = 1, dtype=np.float32, seed: int = 42):
|
||||
self.num_layers = num_layers
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.vocab_size = vocab_size
|
||||
self.dtype = dtype
|
||||
|
||||
# Embedding
|
||||
self.embedding = np.random.randn(vocab_size, dim).astype(dtype) * 0.02
|
||||
|
||||
# Positional encoding (learnable)
|
||||
self.pos_embedding = np.random.randn(max_seq_len, dim).astype(dtype) * 0.02
|
||||
|
||||
# Layers
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(dim, num_heads, mlp_hidden,
|
||||
dtype=dtype, seed=seed + i * 100)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Final normalization and LM head
|
||||
self.final_norm = LayerNorm(dim, dtype=dtype)
|
||||
self.lm_head_weight = self.embedding.T # weight tying
|
||||
|
||||
# KV cache
|
||||
cache_config = CacheConfig(
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=self.head_dim,
|
||||
max_seq_len=max_seq_len,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.cache = BatchedKVCache(num_layers, cache_config)
|
||||
|
||||
def _add_positional_encoding(self, x: np.ndarray, start_pos: int = 0) -> np.ndarray:
|
||||
"""Add positional encoding to input embeddings."""
|
||||
batch, seq, _ = x.shape
|
||||
pos_enc = self.pos_embedding[start_pos:start_pos + seq]
|
||||
return (x + pos_enc[None, :, :]).astype(self.dtype)
|
||||
|
||||
def prefill(self, token_ids: np.ndarray,
|
||||
lengths: Optional[np.ndarray] = None) -> np.ndarray:
|
||||
"""
|
||||
Process the full prompt.
|
||||
|
||||
Args:
|
||||
token_ids: (batch, prompt_len) integer token IDs
|
||||
lengths: optional (batch,) actual lengths per batch item
|
||||
|
||||
Returns:
|
||||
hidden: (batch, prompt_len, dim) — hidden states after all layers
|
||||
"""
|
||||
batch, prompt_len = token_ids.shape
|
||||
|
||||
# Embed + positional encoding
|
||||
x = self.embedding[token_ids] # (batch, prompt_len, dim)
|
||||
x = self._add_positional_encoding(x, start_pos=0)
|
||||
|
||||
# Through all layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer.forward_prefill(x, self.cache.caches[i], lengths=lengths)
|
||||
|
||||
return x
|
||||
|
||||
def generate_step(
|
||||
self,
|
||||
token_ids: np.ndarray,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate one token.
|
||||
|
||||
Args:
|
||||
token_ids: (batch, 1) — the token to process
|
||||
lengths: optional (batch,) current sequence lengths
|
||||
|
||||
Returns:
|
||||
logits: (batch, vocab_size) — output logits for next token
|
||||
"""
|
||||
batch = token_ids.shape[0]
|
||||
current_pos = self.cache.caches[0].write_pos - 1 # position of this token
|
||||
|
||||
# Embed + positional encoding
|
||||
x = self.embedding[token_ids] # (batch, 1, dim)
|
||||
x = self._add_positional_encoding(x, start_pos=current_pos)
|
||||
|
||||
# Through all layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer.forward_generate(x, self.cache.caches[i], lengths=lengths)
|
||||
|
||||
# Final norm + LM head
|
||||
x = self.final_norm.forward(x) # (batch, 1, dim)
|
||||
logits = x @ self.lm_head_weight # (batch, 1, vocab_size)
|
||||
return logits[:, 0, :] # (batch, vocab_size)
|
||||
|
||||
def generate(self, prompt_ids: np.ndarray, num_tokens: int,
|
||||
temperature: float = 1.0, top_k: int = None,
|
||||
lengths: Optional[np.ndarray] = None) -> List[int]:
|
||||
"""
|
||||
Full generation loop.
|
||||
|
||||
Args:
|
||||
prompt_ids: (batch, prompt_len) prompt token IDs
|
||||
num_tokens: number of tokens to generate
|
||||
temperature: sampling temperature
|
||||
top_k: top-k sampling
|
||||
lengths: optional per-batch-item prompt lengths
|
||||
|
||||
Returns:
|
||||
generated_ids: list of (batch,) token arrays
|
||||
"""
|
||||
# Reset cache
|
||||
self.cache.reset()
|
||||
|
||||
# Prefill
|
||||
self.prefill(prompt_ids, lengths=lengths)
|
||||
|
||||
# Get last token from prefill
|
||||
batch = prompt_ids.shape[0]
|
||||
last_tokens = prompt_ids[:, -1:] # (batch, 1)
|
||||
|
||||
# Track current lengths (start from prompt lengths)
|
||||
if lengths is not None:
|
||||
cur_lengths = lengths.copy()
|
||||
else:
|
||||
cur_lengths = np.full(batch, prompt_ids.shape[1], dtype=np.int32)
|
||||
|
||||
generated = []
|
||||
for step in range(num_tokens):
|
||||
logits = self.generate_step(last_tokens, lengths=cur_lengths)
|
||||
|
||||
# Apply temperature
|
||||
logits = logits / temperature
|
||||
|
||||
# Top-k filtering
|
||||
if top_k is not None:
|
||||
top_k_values = np.sort(logits, axis=-1)[:, -top_k:]
|
||||
mask = logits < top_k_values[:, -1:]
|
||||
logits = np.where(mask, -np.inf, logits)
|
||||
|
||||
# Softmax + sample
|
||||
probs = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
|
||||
probs = probs / np.sum(probs, axis=-1, keepdims=True)
|
||||
|
||||
# Sample
|
||||
sampled = np.array([
|
||||
np.random.choice(len(probs[b]), p=probs[b] / probs[b].sum())
|
||||
for b in range(batch)
|
||||
])
|
||||
|
||||
generated.append(sampled)
|
||||
last_tokens = sampled[:, None] # (batch, 1)
|
||||
|
||||
# Update lengths
|
||||
cur_lengths = cur_lengths + 1
|
||||
|
||||
return generated
|
||||
|
||||
def memory_report(self) -> dict:
|
||||
"""Get memory usage report."""
|
||||
return self.cache.memory_report()
|
||||
Reference in New Issue
Block a user