Files
llm_programming_tests/qwen36/kv
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00
..

KV-Cache System for Autoregressive Transformer Inference

Pure NumPy implementation — no frameworks. Demonstrates the complete KV-cache pipeline from data structures through GPU mapping.

Architecture

┌─────────────────────────────────────────────────────────────────┐
│                    INFERENCE PIPELINE                            │
│                                                                 │
│  Prompt ──→ [Prefill] ──→ KV Cache populated ──→ [Generate]   │
│              O(n²) attn       O(1) per token      O(seq) attn   │
│                                                                 │
│  Per generation step:                                           │
│    1. Embed + positional encoding                               │
│    2. For each layer:                                           │
│       a. LayerNorm → QKV projection                             │
│       b. Store K,V in cache (append at write_pos)               │
│       c. Cached attention: Q @ K_cache^T → softmax → @ V_cache │
│       d. Output projection → MLP → residual                     │
│    3. LM head → logits → sample next token                      │
└─────────────────────────────────────────────────────────────────┘

Files

File Purpose
kv_cache.py Core KV-cache data structures (KVCache, BatchedKVCache)
attention.py Attention computation (standard, cached, GQA, masked)
transformer.py Full transformer decoder layer + model with KV-cache integration
optimizations.py Paged attention, quantization, chunked prefill
memory_analysis.py Memory growth formulas, model size comparisons, GPU limits
gpu_mapping.py GPU kernel design, Tensor Core analysis, multi-GPU strategies
demo.py 10 end-to-end demos exercising every component

1. Data Structure Layout

Memory Format

cache_k[batch, num_heads, max_seq_len, head_dim]    # float16
cache_v[batch, num_heads, max_seq_len, head_dim]    # float16
lengths[batch]                                      # int32 (actual seq len per item)
write_pos                                           # int (global write pointer)

Why this layout:

  • batch first → enables batched GEMM on GPU
  • heads second → parallel head computation
  • seq_len third → contiguous scan for Q @ K^T
  • head_dim last → inner product dimension, coalesced access

Per-Token Memory Cost

For a 7B model (32 layers, 32 heads, head_dim=128, fp16):

Per token per layer: 2 × 32 × 128 × 2 bytes = 16 KB
Per token (all layers): 16 KB × 32 = 512 KB
At 32K context: 512 KB × 32,768 = 16 GB

2. Update Logic Per Step

# Each generation step:
pos = cache.write_pos
cache.cache_k[:, :, pos, :] = new_k[:, :, 0, :]   # O(1) write
cache.cache_v[:, :, pos, :] = new_v[:, :, 0, :]   # O(1) write
cache.write_pos += 1

The write is a simple memory copy — no computation needed. The cache grows by exactly 2 × heads × head_dim × elem_bytes per token per layer.

3. Attention Computation Using Cache

# Retrieve all cached K, V
cached_k, cached_v = cache.get_all()  # (batch, heads, seq_so_far, head_dim)

# Q @ K^T: (batch, heads, 1, head_dim) × (batch, heads, head_dim, seq)
scores = einsum("bhqd,bhkd->bhqk", q, cached_k) / sqrt(head_dim)

# Softmax (no mask needed — cache only has past tokens)
attn = softmax(scores, axis=-1)

# Attn @ V: (batch, heads, 1, seq) × (batch, heads, seq, head_dim)
output = einsum("bhqk,bhkd->bhqd", attn, cached_v)

Key insight: During generation, the cache naturally enforces causality — it only contains past tokens, so no explicit mask is needed.

4. Memory Growth Analysis

Linear Growth Formula

KV_cache(bytes) = 2 × batch × layers × heads × seq_len × head_dim × elem_bytes

7B Model (batch=1, fp16)

Context KV Cache Total (params + KV) KV Fraction
256 0.12 GB 7.04 GB 1.8%
4,096 2.00 GB 8.91 GB 22.4%
8,192 4.00 GB 10.91 GB 36.7%
32,768 16.00 GB 22.91 GB 69.8%

Maximum Context by GPU (7B model, batch=1)

GPU Max Context
RTX 4090 (24 GB) 6,690 tokens
A100-40GB 39,458 tokens
A100-80GB / H100-80GB 121,378 tokens

Batch Size Impact

KV cache scales linearly with batch size. At batch=4, the 7B model on an A100-80GB can only handle ~30K context instead of 121K.

5. Optimizations

Optimization 1: Paged Attention (vLLM-style)

Problem: Contiguous allocation wastes memory when sequences have variable lengths. A batch with one 32K sequence and three 100-token sequences still allocates 32K for all.

Solution: Divide memory into fixed-size blocks (pages). Each sequence maintains a page table mapping logical blocks to physical pages.

Physical page pool: (total_pages, heads, block_size, head_dim)
Page table: (batch, max_blocks) → logical → physical mapping

Benefits:

  • Zero memory fragmentation
  • Supports speculative decoding and branching
  • Enables prefix caching (share common prefixes)
  • No need to pre-allocate max_seq_len

Trade-off: Page table indirection adds complexity to the attention kernel (gather from non-contiguous pages).

Optimization 2: Quantization

Problem: fp16 KV cache dominates memory for long contexts.

Solution: Store K/V in int8 with per-channel affine dequantization: x ≈ scale × q + zero

int8 data: 1 byte per element (vs 2 for fp16)
fp16 scales + zeros: shared per channel (not per token)
Net savings: ~50% memory with <1% accuracy loss

Production approach: Shared per-channel scales (not per-position) stored in fp16. The per-position approach in this codebase is for correctness demonstration but has higher overhead.

Optimization 3: Chunked Prefill

Problem: Processing a 32K prompt requires materializing a 32K × 32K attention matrix (4 GB in fp32).

Solution: Process the prompt in chunks of size C. Each chunk attends to all previous tokens + causal within chunk.

Peak memory: O(C × seq_len) instead of O(seq_len²)
For C=512, seq=4096: 8 MB vs 64 MB (8× savings)

Combined: Paged + Quantized

Together these give 2-4× memory reduction, enabling 2-4× longer contexts in the same GPU memory.

6. GPU Execution Mapping

Memory Hierarchy

Level Size Latency Usage
Registers 64 KB/SM 1 cycle Thread-local, warp computation
Shared memory 166 KB/SM (H100) 1-3 cycles Tiling, softmax intermediates
L2 cache 50 MB (H100) ~20 cycles Automatic global memory caching
HBM 80 GB (H100) ~300-400 cycles Model weights, KV cache, activations

Cached Attention Kernel Design

Grid: (batch_size, num_heads, 1)
Block: (32, 32) = 1024 threads

Shared memory per block (~16-20 KB):
  - Q tile: 1 × head_dim (512 bytes fp16)
  - K tile: 32 × head_dim (8 KB fp16)
  - Score tile: 32 × 32 (4 KB fp16)

Optimization strategies:

  1. Coalesced global memory access (warp-level consecutive addresses)
  2. Tiled GEMM with shared memory
  3. Persistent kernels (keep blocks alive until all tiles processed)
  4. Async copy (H100 cp.async) to overlap memory transfer with computation
  5. Tensor Cores (mma.sync) for matmul operations
  6. Fusion: merge softmax with attention score computation

Arithmetic Intensity

For single-token generation (batch=1, heads=32, seq=4096):

  • FLOPs: 0.02 GFLOPs
  • Memory traffic: 16.79 MB
  • Arithmetic intensity: 1.0 FLOPs/byte
  • → Memory-bound (H100 peak: 1,970 TFLOPS, 3.35 TB/s)

The cached attention is fundamentally memory-bound — the bottleneck is reading the KV cache from HBM, not computation. This is why bandwidth-optimized kernels (FlashAttention-style tiling) matter more than compute optimization.

Multi-GPU Strategies

Strategy KV Cache Impact Best For
Tensor parallelism Split K/V by head_dim across GPUs Large models
Pipeline parallelism Each GPU holds its layer shard's K/V Very large models
Sequence parallelism Split K/V by sequence dimension Long context prefill
Expert parallelism KV cache shared; only MLP experts sharded MoE models

Running

# Run all 10 demos
python demo.py

# Run memory analysis standalone
python memory_analysis.py

# Run GPU mapping report
python gpu_mapping.py

Key Takeaways

  1. KV cache grows linearly with sequence length and batch size — it's the dominant memory cost for long-context inference.

  2. Generation is memory-bound — reading the KV cache from HBM dominates latency, not the attention computation itself.

  3. Paged attention eliminates fragmentation — critical for serving variable-length sequences efficiently.

  4. Quantization gives ~50% savings with minimal accuracy loss when using shared per-channel scales.

  5. FlashAttention-style tiling reduces HBM traffic by processing K/V in tiles that fit in shared memory, cutting memory bandwidth requirements by 2-4×.