Files
llm_programming_tests/glm5/kv/README.md
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

396 lines
16 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# KV-Cache for Autoregressive Transformer Inference
A complete, framework-free implementation of KV-caching for autoregressive
transformer inference, built from scratch in Python/NumPy.
## Architecture Overview
```
┌─────────────────────────────────────────────────────────────┐
│ Transformer Layer │
│ │
│ Token IDs ──► Embedding ──► Q,K,V Projections │
│ │ │
│ ┌─────────────┼──────────────┐ │
│ ▼ ▼ ▼ │
│ Q_new ──► K_new, V_new ──► Cache Write │
│ │ │ │
│ │ ┌─────────────────────┘ │
│ ▼ ▼ │
│ ┌──────────────┐ │
│ │ Attention │ Q_new × (K_cached + K_new) │
│ │ Computation │ ──────────────────────────► │
│ │ (read-only) │ weights × (V_cached + V_new) │
│ └──────┬───────┘ │
│ ▼ │
│ Output Projection ──► LayerNorm ──► next layer │
└─────────────────────────────────────────────────────────────┘
```
## Data Structure Layout
### Memory Format
Each layer maintains two pre-allocated tensors:
```
keys: (B, H, S_max, D) float32
values: (B, H, S_max, D) float32
```
| Symbol | Meaning | Example (GPT-4 class) |
|--------|--------------------------------|----------------------|
| B | Batch size | 164 |
| H | Number of attention heads | 32 |
| S_max | Maximum sequence length | 8192131072 |
| D | Head dimension (d_model / H) | 128 |
**Why BHSD layout?**
The dimensions are ordered so that the sequence axis (S) is stride-D
contiguous. This means:
1. **Append is a simple slice copy**`cache[b, :, pos, :] = new_kv`
writes D×H floats to a contiguous region.
2. **Attention matmul is efficient** — the inner `Q @ K^T` reads K along
the S dimension, which is stride-D contiguous.
3. **GPU-friendly** — maps directly to a CUDA tensor with no transposition
needed between the write and read paths.
### Auxiliary State
```
seq_lens: int[B] — valid prefix length per batch element
```
Positions `[..., :seq_lens[b], :]` contain valid data. Everything beyond
is garbage and must be masked out during attention.
## Update Logic Per Step
### Prefill Phase (processing the full prompt)
```
Input: prompt tokens of length S
Output: cache filled with S key-value pairs
for each layer:
Q, K, V = project(prompt_embeddings) # (B, S, d_model) → 3× (B, S, d_model)
K = reshape(K, (B, H, S, D)) # split into heads
V = reshape(V, (B, H, S, D))
cache.write(positions=[0, 1, ..., S-1], K, V) # bulk write
# Self-attention within the prompt (causal mask)
attn_output = attention(Q, cache.read()) # O(S²) — one-time cost
```
### Decode Phase (one token at a time)
```
Input: single new token
Output: logits for next token prediction
for each layer:
q_new, k_new, v_new = project(token_embedding) # each (B, 1, d_model)
k_new = reshape(k_new, (B, H, 1, D))
v_new = reshape(v_new, (B, H, 1, D))
# ── CACHE UPDATE: O(H·D) — write 1 token ──
cache[pos] = (k_new, v_new) # 2 × H × D floats
# ── ATTENTION: O(S·H·D) — query vs ALL cached keys ──
K_all, V_all = cache.read() # (B, H, S+1, D)
scores = q_new @ K_all.T / √D # (B, H, 1, S+1)
weights = softmax(scores)
output = weights @ V_all # (B, H, 1, D)
```
**Key insight**: Without caching, each decode step would require O(S²) work
(recomputing attention for all S previous tokens). With caching, it's only
O(S) — the new query attends against the cached keys/values.
## Attention Computation Using Cached Keys/Values
```
┌───────────┐ ┌───────────────────────────────────┐
│ Q_new │ │ Cached K (all past tokens) │
│ (1, D) │ × │ (S_valid, D) │
│ │ │ │
│ │ │ [k₀] [k₁] [k₂] ... [k_{S-1}] │
└─────┬─────┘ └───────────────────────────────────┘
│ │
▼ ▼
┌────────────────────────────────────┐
│ scores = Q · K^T / √D │ → (1, S_valid)
│ weights = softmax(scores) │ → (1, S_valid)
│ output = weights · V │ → (1, D)
└────────────────────────────────────┘
```
This is performed independently for each head H and batch element B.
## Memory Growth Analysis
### Linear Growth
The cache grows **linearly** with sequence length:
```
Memory per layer = 2 × B × H × S × D × sizeof(dtype)
= 2 × B × d_model × S × sizeof(dtype)
```
For a GPT-4-class model (32 layers, d_model=4096, FP32):
| Seq Length | Per Layer (MB) | Total (MB) | Total (GB) |
|-----------|---------------|-----------|-----------|
| 128 | 0.67 | 21.47 | 0.021 |
| 1,024 | 5.37 | 171.79 | 0.172 |
| 4,096 | 21.47 | 687.19 | 0.687 |
| 16,384 | 85.89 | 2,748.77 | 2.749 |
| 65,536 | 343.59 | 10,995.08 | 10.995 |
| 131,072 | 687.19 | 21,990.16 | 21.990 |
**Observation**: At 128K context with batch=1, you need **~22 GB** just for
the KV cache — before accounting for model weights, activations, or
gradients.
### FLOPs Savings
| Scenario | Without Cache | With Cache | Speedup |
|----------|--------------|-----------|---------|
| 1024 prompt + 100 decode | 4.2e14 | 2.0e12 | ~200× |
The speedup grows quadratically with sequence length.
## Optimizations
### 1. Paged Attention (Virtual Memory for KV Cache)
**Problem**: Pre-allocating `(B, H, S_max, D)` wastes memory for short
sequences and causes fragmentation when sequences finish at different
times.
**Solution**: Divide the cache into fixed-size blocks (pages):
```
Physical Memory:
┌────────┬────────┬────────┬────────┬────────┬────────┐
│ Block 0│ Block 1│ Block 2│ Block 3│ Block 4│ ... │
│(H,B,D) │(H,B,D) │(H,B,D) │(H,B,D) │(H,B,D) │ │
└────────┴────────┴────────┴────────┴────────┴────────┘
Page Tables:
Seq 0: [0] → [3] → [1] (3 blocks = 3 × BLOCK_SIZE tokens)
Seq 1: [2] → [4] (2 blocks = 2 × BLOCK_SIZE tokens)
Seq 2: [5] (1 block)
Free: [6, 7, 8, ...]
```
**Benefits**:
- Memory allocated only as needed (no S_max pre-allocation)
- Finished sequences free blocks immediately → higher throughput
- No external fragmentation
- Enables sharing of KV blocks across sequences (e.g., prefix caching)
**Implementation**: See `PagedKVCache` in `optimizations.py`.
### 2. Chunked Prefill
**Problem**: Processing a 32K-token prompt requires a 32K×32K attention
matrix (1 billion floats = 4 GB) just for the prefill.
**Solution**: Split the prompt into chunks of C tokens:
```
Prompt: [t₀, t₁, t₂, ..., t_{S-1}] (S = 32K)
Chunk 0: [t₀..t_{C-1}] → cache write → attention vs cache (0..C)
Chunk 1: [t_C..t_{2C-1}] → cache write → attention vs cache (0..2C)
Chunk 2: [t_{2C}..t_{3C-1}] → cache write → attention vs cache (0..3C)
...
```
Peak attention memory: O(C × S) instead of O(S²).
**Benefits**:
- Bounded peak memory regardless of prompt length
- Can interleave prefill chunks with decode steps from other sequences
- Better GPU utilization (uniform work items)
### 3. Cache Quantization (INT8 / INT4)
**Problem**: 22 GB for a 128K context is unsustainable.
**Solution**: Quantize cached K/V to lower precision:
| Precision | Bytes/Element | Memory Savings | Typical Quality Loss |
|-----------|-------------|---------------|---------------------|
| FP32 | 4 | 1× (baseline) | 0% |
| FP16 | 2 | 2× | <0.1% |
| INT8 | 1 | 4× | <0.5% |
| INT4 | 0.5 | 8× | 1-3% |
Quantization is per-token: `scale[b,h,t] = max(|K[b,h,t,:]|) / (2^bits - 1)`.
```
Storage:
k_quant: uint8 (B, H, S, D) or packed uint8 (B, H, S, D/2) for INT4
k_scale: float32 (B, H, S) — one scalar per token per head
Dequantize during attention:
K_float = k_quant * k_scale — in registers before matmul
```
**Benefits**:
- 4-8× memory reduction → longer contexts or larger batches
- Minimal quality loss for most tasks
- Hardware support on modern GPUs (FP8 on Hopper, INT8 on Ampere)
## GPU Execution Mapping
### Memory Hierarchy
```
┌──────────────────────────────────────────────┐
│ HBM (High Bandwidth Memory) │
│ ┌──────────────────────────────────────┐ │
│ │ KV Cache: (B, H, S, D) per layer │ │
│ │ ~10-70 GB for long contexts │ │
│ └──────────────────────────────────────┘ │
│ ┌──────────────────────────────────────┐ │
│ │ Model Weights │ │
│ └──────────────────────────────────────┘ │
└──────────────────────┬───────────────────────┘
│ ~2-3 TB/s bandwidth
┌──────────────────────────────────────────────┐
│ Shared Memory (per SM) │
│ ┌──────────────────────────────────────┐ │
│ │ Q tile: (block_B, H, tile_S, D) │ │
│ │ K tile: (block_B, H, tile_S, D) │ │
│ │ V tile: (block_B, H, tile_S, D) │ │
│ │ Score tile: (block_B, H, tile_S²) │ │
│ └──────────────────────────────────────┘ │
│ ~48-164 KB per SM │
└──────────────────────┬───────────────────────┘
│ ~19 TB/s bandwidth
┌──────────────────────────────────────────────┐
│ Registers (per thread block) │
│ accumulator for QK^T, softmax, etc. │
│ ~255 registers/thread │
└──────────────────────────────────────────────┘
```
### Kernel Mapping
| Operation | CPU (this impl) | GPU Kernel |
|-----------|----------------|------------|
| Cache write | `cache[b,:,pos,:] = new_kv` | `cudaMemcpyAsync` or block-level scatter |
| Q×K^T | `q @ k.T` | Batched GEMM (cuBLAS) or FlashAttention |
| Softmax | `_softmax(scores)` | Online softmax (FlashAttention) |
| Weights×V | `weights @ v` | GEMM (part of FlashAttention fused kernel) |
| Quantize | `_quantize_token()` | Block-reduce + scale + convert |
### FlashAttention Integration
The attention computation in this codebase performs the naive:
```
S = Q × K^T # materialize full (S_q, S_kv) matrix
A = softmax(S) # another (S_q, S_kv) matrix
O = A × V # output
```
On GPU, **FlashAttention** fuses these three operations:
```
for each tile of Q:
init: O = 0, m = -∞, l = 0
for each tile of K, V:
S_tile = Q_tile × K_tile^T # in SRAM
m_new = max(m, max(S_tile))
P_tile = exp(S_tile - m_new) # in SRAM
l_new = l + sum(P_tile)
O = O * (l/l_new) + P_tile × V_tile # accumulate
m, l = m_new, l_new
O = O / l
```
This keeps the O(S²) attention matrix entirely in SRAM, avoiding
HBM reads/writes. The KV cache is read tile-by-tile from HBM.
### Paged Attention on GPU
The `PagedKVCache` page table translates to a GPU indirection:
```cuda
// CUDA pseudocode for paged attention
__global__ void paged_attention(
float* Q, // (B, H, 1, D) — new query
float* K_pool, // (num_blocks, H, BLOCK_SIZE, D)
float* V_pool,
int* page_table, // (B, max_pages_per_seq)
int* seq_lens, // (B,)
float* output // (B, H, 1, D)
) {
int b = blockIdx.y;
int h = blockIdx.x;
int S = seq_lens[b];
// Load query into registers
float q[D];
load_query(q, Q, b, h);
// Iterate over pages
float score[S_MAX_LOCAL];
for (int page = 0; page < ceil(S / BLOCK_SIZE); page++) {
int phys_block = page_table[b * max_pages + page];
// Gather K/V from scattered physical blocks
for (int i = 0; i < BLOCK_SIZE; i++) {
float k = K_pool[phys_block * H * BLOCK_SIZE * D
+ h * BLOCK_SIZE * D + i * D + d];
score[page * BLOCK_SIZE + i] = dot(q, k) / sqrt(D);
}
}
// ... softmax, multiply by V, write output
}
```
## File Structure
```
kv/
├── README.md ← you are here
├── kv_cache.py ← core data structures + attention
├── optimizations.py ← paged attention, chunked prefill, quantization
└── test_kv_cache.py ← comprehensive test suite
```
## Running
```bash
python test_kv_cache.py
```
All tests run without any external dependencies beyond NumPy.
## Key Design Decisions
1. **Pre-allocation**: The base `KVCache` pre-allocates to `S_max` to
avoid GPU memory allocation during inference (malloc is expensive).
The `PagedKVCache` trades this for on-demand block allocation.
2. **No cross-contamination**: Each batch element maintains its own
valid prefix via `seq_lens`. Attention never attends to garbage
positions from other sequences.
3. **Separation of concerns**: Cache update (write) and attention
(read) are decoupled. The caller controls when each happens,
enabling chunked prefill and prefix sharing.
4. **Quantization at cache boundary**: K/V are computed in FP32,
quantized on write, dequantized on read. This keeps the attention
computation unchanged while reducing memory.