feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
This commit is contained in:
@@ -0,0 +1,395 @@
|
||||
# KV-Cache for Autoregressive Transformer Inference
|
||||
|
||||
A complete, framework-free implementation of KV-caching for autoregressive
|
||||
transformer inference, built from scratch in Python/NumPy.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Transformer Layer │
|
||||
│ │
|
||||
│ Token IDs ──► Embedding ──► Q,K,V Projections │
|
||||
│ │ │
|
||||
│ ┌─────────────┼──────────────┐ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ Q_new ──► K_new, V_new ──► Cache Write │
|
||||
│ │ │ │
|
||||
│ │ ┌─────────────────────┘ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌──────────────┐ │
|
||||
│ │ Attention │ Q_new × (K_cached + K_new) │
|
||||
│ │ Computation │ ──────────────────────────► │
|
||||
│ │ (read-only) │ weights × (V_cached + V_new) │
|
||||
│ └──────┬───────┘ │
|
||||
│ ▼ │
|
||||
│ Output Projection ──► LayerNorm ──► next layer │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Data Structure Layout
|
||||
|
||||
### Memory Format
|
||||
|
||||
Each layer maintains two pre-allocated tensors:
|
||||
|
||||
```
|
||||
keys: (B, H, S_max, D) float32
|
||||
values: (B, H, S_max, D) float32
|
||||
```
|
||||
|
||||
| Symbol | Meaning | Example (GPT-4 class) |
|
||||
|--------|--------------------------------|----------------------|
|
||||
| B | Batch size | 1–64 |
|
||||
| H | Number of attention heads | 32 |
|
||||
| S_max | Maximum sequence length | 8192–131072 |
|
||||
| D | Head dimension (d_model / H) | 128 |
|
||||
|
||||
**Why BHSD layout?**
|
||||
|
||||
The dimensions are ordered so that the sequence axis (S) is stride-D
|
||||
contiguous. This means:
|
||||
|
||||
1. **Append is a simple slice copy** — `cache[b, :, pos, :] = new_kv`
|
||||
writes D×H floats to a contiguous region.
|
||||
2. **Attention matmul is efficient** — the inner `Q @ K^T` reads K along
|
||||
the S dimension, which is stride-D contiguous.
|
||||
3. **GPU-friendly** — maps directly to a CUDA tensor with no transposition
|
||||
needed between the write and read paths.
|
||||
|
||||
### Auxiliary State
|
||||
|
||||
```
|
||||
seq_lens: int[B] — valid prefix length per batch element
|
||||
```
|
||||
|
||||
Positions `[..., :seq_lens[b], :]` contain valid data. Everything beyond
|
||||
is garbage and must be masked out during attention.
|
||||
|
||||
## Update Logic Per Step
|
||||
|
||||
### Prefill Phase (processing the full prompt)
|
||||
|
||||
```
|
||||
Input: prompt tokens of length S
|
||||
Output: cache filled with S key-value pairs
|
||||
|
||||
for each layer:
|
||||
Q, K, V = project(prompt_embeddings) # (B, S, d_model) → 3× (B, S, d_model)
|
||||
K = reshape(K, (B, H, S, D)) # split into heads
|
||||
V = reshape(V, (B, H, S, D))
|
||||
cache.write(positions=[0, 1, ..., S-1], K, V) # bulk write
|
||||
|
||||
# Self-attention within the prompt (causal mask)
|
||||
attn_output = attention(Q, cache.read()) # O(S²) — one-time cost
|
||||
```
|
||||
|
||||
### Decode Phase (one token at a time)
|
||||
|
||||
```
|
||||
Input: single new token
|
||||
Output: logits for next token prediction
|
||||
|
||||
for each layer:
|
||||
q_new, k_new, v_new = project(token_embedding) # each (B, 1, d_model)
|
||||
k_new = reshape(k_new, (B, H, 1, D))
|
||||
v_new = reshape(v_new, (B, H, 1, D))
|
||||
|
||||
# ── CACHE UPDATE: O(H·D) — write 1 token ──
|
||||
cache[pos] = (k_new, v_new) # 2 × H × D floats
|
||||
|
||||
# ── ATTENTION: O(S·H·D) — query vs ALL cached keys ──
|
||||
K_all, V_all = cache.read() # (B, H, S+1, D)
|
||||
scores = q_new @ K_all.T / √D # (B, H, 1, S+1)
|
||||
weights = softmax(scores)
|
||||
output = weights @ V_all # (B, H, 1, D)
|
||||
```
|
||||
|
||||
**Key insight**: Without caching, each decode step would require O(S²) work
|
||||
(recomputing attention for all S previous tokens). With caching, it's only
|
||||
O(S) — the new query attends against the cached keys/values.
|
||||
|
||||
## Attention Computation Using Cached Keys/Values
|
||||
|
||||
```
|
||||
┌───────────┐ ┌───────────────────────────────────┐
|
||||
│ Q_new │ │ Cached K (all past tokens) │
|
||||
│ (1, D) │ × │ (S_valid, D) │
|
||||
│ │ │ │
|
||||
│ │ │ [k₀] [k₁] [k₂] ... [k_{S-1}] │
|
||||
└─────┬─────┘ └───────────────────────────────────┘
|
||||
│ │
|
||||
▼ ▼
|
||||
┌────────────────────────────────────┐
|
||||
│ scores = Q · K^T / √D │ → (1, S_valid)
|
||||
│ weights = softmax(scores) │ → (1, S_valid)
|
||||
│ output = weights · V │ → (1, D)
|
||||
└────────────────────────────────────┘
|
||||
```
|
||||
|
||||
This is performed independently for each head H and batch element B.
|
||||
|
||||
## Memory Growth Analysis
|
||||
|
||||
### Linear Growth
|
||||
|
||||
The cache grows **linearly** with sequence length:
|
||||
|
||||
```
|
||||
Memory per layer = 2 × B × H × S × D × sizeof(dtype)
|
||||
= 2 × B × d_model × S × sizeof(dtype)
|
||||
```
|
||||
|
||||
For a GPT-4-class model (32 layers, d_model=4096, FP32):
|
||||
|
||||
| Seq Length | Per Layer (MB) | Total (MB) | Total (GB) |
|
||||
|-----------|---------------|-----------|-----------|
|
||||
| 128 | 0.67 | 21.47 | 0.021 |
|
||||
| 1,024 | 5.37 | 171.79 | 0.172 |
|
||||
| 4,096 | 21.47 | 687.19 | 0.687 |
|
||||
| 16,384 | 85.89 | 2,748.77 | 2.749 |
|
||||
| 65,536 | 343.59 | 10,995.08 | 10.995 |
|
||||
| 131,072 | 687.19 | 21,990.16 | 21.990 |
|
||||
|
||||
**Observation**: At 128K context with batch=1, you need **~22 GB** just for
|
||||
the KV cache — before accounting for model weights, activations, or
|
||||
gradients.
|
||||
|
||||
### FLOPs Savings
|
||||
|
||||
| Scenario | Without Cache | With Cache | Speedup |
|
||||
|----------|--------------|-----------|---------|
|
||||
| 1024 prompt + 100 decode | 4.2e14 | 2.0e12 | ~200× |
|
||||
|
||||
The speedup grows quadratically with sequence length.
|
||||
|
||||
## Optimizations
|
||||
|
||||
### 1. Paged Attention (Virtual Memory for KV Cache)
|
||||
|
||||
**Problem**: Pre-allocating `(B, H, S_max, D)` wastes memory for short
|
||||
sequences and causes fragmentation when sequences finish at different
|
||||
times.
|
||||
|
||||
**Solution**: Divide the cache into fixed-size blocks (pages):
|
||||
|
||||
```
|
||||
Physical Memory:
|
||||
┌────────┬────────┬────────┬────────┬────────┬────────┐
|
||||
│ Block 0│ Block 1│ Block 2│ Block 3│ Block 4│ ... │
|
||||
│(H,B,D) │(H,B,D) │(H,B,D) │(H,B,D) │(H,B,D) │ │
|
||||
└────────┴────────┴────────┴────────┴────────┴────────┘
|
||||
|
||||
Page Tables:
|
||||
Seq 0: [0] → [3] → [1] (3 blocks = 3 × BLOCK_SIZE tokens)
|
||||
Seq 1: [2] → [4] (2 blocks = 2 × BLOCK_SIZE tokens)
|
||||
Seq 2: [5] (1 block)
|
||||
Free: [6, 7, 8, ...]
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Memory allocated only as needed (no S_max pre-allocation)
|
||||
- Finished sequences free blocks immediately → higher throughput
|
||||
- No external fragmentation
|
||||
- Enables sharing of KV blocks across sequences (e.g., prefix caching)
|
||||
|
||||
**Implementation**: See `PagedKVCache` in `optimizations.py`.
|
||||
|
||||
### 2. Chunked Prefill
|
||||
|
||||
**Problem**: Processing a 32K-token prompt requires a 32K×32K attention
|
||||
matrix (1 billion floats = 4 GB) just for the prefill.
|
||||
|
||||
**Solution**: Split the prompt into chunks of C tokens:
|
||||
|
||||
```
|
||||
Prompt: [t₀, t₁, t₂, ..., t_{S-1}] (S = 32K)
|
||||
|
||||
Chunk 0: [t₀..t_{C-1}] → cache write → attention vs cache (0..C)
|
||||
Chunk 1: [t_C..t_{2C-1}] → cache write → attention vs cache (0..2C)
|
||||
Chunk 2: [t_{2C}..t_{3C-1}] → cache write → attention vs cache (0..3C)
|
||||
...
|
||||
```
|
||||
|
||||
Peak attention memory: O(C × S) instead of O(S²).
|
||||
|
||||
**Benefits**:
|
||||
- Bounded peak memory regardless of prompt length
|
||||
- Can interleave prefill chunks with decode steps from other sequences
|
||||
- Better GPU utilization (uniform work items)
|
||||
|
||||
### 3. Cache Quantization (INT8 / INT4)
|
||||
|
||||
**Problem**: 22 GB for a 128K context is unsustainable.
|
||||
|
||||
**Solution**: Quantize cached K/V to lower precision:
|
||||
|
||||
| Precision | Bytes/Element | Memory Savings | Typical Quality Loss |
|
||||
|-----------|-------------|---------------|---------------------|
|
||||
| FP32 | 4 | 1× (baseline) | 0% |
|
||||
| FP16 | 2 | 2× | <0.1% |
|
||||
| INT8 | 1 | 4× | <0.5% |
|
||||
| INT4 | 0.5 | 8× | 1-3% |
|
||||
|
||||
Quantization is per-token: `scale[b,h,t] = max(|K[b,h,t,:]|) / (2^bits - 1)`.
|
||||
|
||||
```
|
||||
Storage:
|
||||
k_quant: uint8 (B, H, S, D) or packed uint8 (B, H, S, D/2) for INT4
|
||||
k_scale: float32 (B, H, S) — one scalar per token per head
|
||||
|
||||
Dequantize during attention:
|
||||
K_float = k_quant * k_scale — in registers before matmul
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- 4-8× memory reduction → longer contexts or larger batches
|
||||
- Minimal quality loss for most tasks
|
||||
- Hardware support on modern GPUs (FP8 on Hopper, INT8 on Ampere)
|
||||
|
||||
## GPU Execution Mapping
|
||||
|
||||
### Memory Hierarchy
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────┐
|
||||
│ HBM (High Bandwidth Memory) │
|
||||
│ ┌──────────────────────────────────────┐ │
|
||||
│ │ KV Cache: (B, H, S, D) per layer │ │
|
||||
│ │ ~10-70 GB for long contexts │ │
|
||||
│ └──────────────────────────────────────┘ │
|
||||
│ ┌──────────────────────────────────────┐ │
|
||||
│ │ Model Weights │ │
|
||||
│ └──────────────────────────────────────┘ │
|
||||
└──────────────────────┬───────────────────────┘
|
||||
│ ~2-3 TB/s bandwidth
|
||||
▼
|
||||
┌──────────────────────────────────────────────┐
|
||||
│ Shared Memory (per SM) │
|
||||
│ ┌──────────────────────────────────────┐ │
|
||||
│ │ Q tile: (block_B, H, tile_S, D) │ │
|
||||
│ │ K tile: (block_B, H, tile_S, D) │ │
|
||||
│ │ V tile: (block_B, H, tile_S, D) │ │
|
||||
│ │ Score tile: (block_B, H, tile_S²) │ │
|
||||
│ └──────────────────────────────────────┘ │
|
||||
│ ~48-164 KB per SM │
|
||||
└──────────────────────┬───────────────────────┘
|
||||
│ ~19 TB/s bandwidth
|
||||
▼
|
||||
┌──────────────────────────────────────────────┐
|
||||
│ Registers (per thread block) │
|
||||
│ accumulator for QK^T, softmax, etc. │
|
||||
│ ~255 registers/thread │
|
||||
└──────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Kernel Mapping
|
||||
|
||||
| Operation | CPU (this impl) | GPU Kernel |
|
||||
|-----------|----------------|------------|
|
||||
| Cache write | `cache[b,:,pos,:] = new_kv` | `cudaMemcpyAsync` or block-level scatter |
|
||||
| Q×K^T | `q @ k.T` | Batched GEMM (cuBLAS) or FlashAttention |
|
||||
| Softmax | `_softmax(scores)` | Online softmax (FlashAttention) |
|
||||
| Weights×V | `weights @ v` | GEMM (part of FlashAttention fused kernel) |
|
||||
| Quantize | `_quantize_token()` | Block-reduce + scale + convert |
|
||||
|
||||
### FlashAttention Integration
|
||||
|
||||
The attention computation in this codebase performs the naive:
|
||||
|
||||
```
|
||||
S = Q × K^T # materialize full (S_q, S_kv) matrix
|
||||
A = softmax(S) # another (S_q, S_kv) matrix
|
||||
O = A × V # output
|
||||
```
|
||||
|
||||
On GPU, **FlashAttention** fuses these three operations:
|
||||
|
||||
```
|
||||
for each tile of Q:
|
||||
init: O = 0, m = -∞, l = 0
|
||||
for each tile of K, V:
|
||||
S_tile = Q_tile × K_tile^T # in SRAM
|
||||
m_new = max(m, max(S_tile))
|
||||
P_tile = exp(S_tile - m_new) # in SRAM
|
||||
l_new = l + sum(P_tile)
|
||||
O = O * (l/l_new) + P_tile × V_tile # accumulate
|
||||
m, l = m_new, l_new
|
||||
O = O / l
|
||||
```
|
||||
|
||||
This keeps the O(S²) attention matrix entirely in SRAM, avoiding
|
||||
HBM reads/writes. The KV cache is read tile-by-tile from HBM.
|
||||
|
||||
### Paged Attention on GPU
|
||||
|
||||
The `PagedKVCache` page table translates to a GPU indirection:
|
||||
|
||||
```cuda
|
||||
// CUDA pseudocode for paged attention
|
||||
__global__ void paged_attention(
|
||||
float* Q, // (B, H, 1, D) — new query
|
||||
float* K_pool, // (num_blocks, H, BLOCK_SIZE, D)
|
||||
float* V_pool,
|
||||
int* page_table, // (B, max_pages_per_seq)
|
||||
int* seq_lens, // (B,)
|
||||
float* output // (B, H, 1, D)
|
||||
) {
|
||||
int b = blockIdx.y;
|
||||
int h = blockIdx.x;
|
||||
int S = seq_lens[b];
|
||||
|
||||
// Load query into registers
|
||||
float q[D];
|
||||
load_query(q, Q, b, h);
|
||||
|
||||
// Iterate over pages
|
||||
float score[S_MAX_LOCAL];
|
||||
for (int page = 0; page < ceil(S / BLOCK_SIZE); page++) {
|
||||
int phys_block = page_table[b * max_pages + page];
|
||||
// Gather K/V from scattered physical blocks
|
||||
for (int i = 0; i < BLOCK_SIZE; i++) {
|
||||
float k = K_pool[phys_block * H * BLOCK_SIZE * D
|
||||
+ h * BLOCK_SIZE * D + i * D + d];
|
||||
score[page * BLOCK_SIZE + i] = dot(q, k) / sqrt(D);
|
||||
}
|
||||
}
|
||||
// ... softmax, multiply by V, write output
|
||||
}
|
||||
```
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
kv/
|
||||
├── README.md ← you are here
|
||||
├── kv_cache.py ← core data structures + attention
|
||||
├── optimizations.py ← paged attention, chunked prefill, quantization
|
||||
└── test_kv_cache.py ← comprehensive test suite
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
python test_kv_cache.py
|
||||
```
|
||||
|
||||
All tests run without any external dependencies beyond NumPy.
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
1. **Pre-allocation**: The base `KVCache` pre-allocates to `S_max` to
|
||||
avoid GPU memory allocation during inference (malloc is expensive).
|
||||
The `PagedKVCache` trades this for on-demand block allocation.
|
||||
|
||||
2. **No cross-contamination**: Each batch element maintains its own
|
||||
valid prefix via `seq_lens`. Attention never attends to garbage
|
||||
positions from other sequences.
|
||||
|
||||
3. **Separation of concerns**: Cache update (write) and attention
|
||||
(read) are decoupled. The caller controls when each happens,
|
||||
enabling chunked prefill and prefix sharing.
|
||||
|
||||
4. **Quantization at cache boundary**: K/V are computed in FP32,
|
||||
quantized on write, dequantized on read. This keeps the attention
|
||||
computation unchanged while reducing memory.
|
||||
Reference in New Issue
Block a user