Files
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

516 lines
18 KiB
Python

"""
End-to-End KV-Cache Demo
Demonstrates:
1. Building a small transformer with KV-cache
2. Prefill phase (prompt processing)
3. Incremental generation (one token at a time)
4. Variable-length batching
5. Memory tracking
6. Optimization comparisons
"""
import numpy as np
import sys
import os
# Ensure we can import from the project
sys.path.insert(0, os.path.dirname(__file__))
from kv_cache import KVCache, CacheConfig, BatchedKVCache
from attention import (
scaled_dot_product_attention,
cached_attention,
build_causal_mask,
softmax_stable,
)
from transformer import TransformerDecoder, TransformerDecoderLayer
from optimizations import (
PagedKVCache, PageConfig,
QuantizedKVCache,
ChunkedPrefill,
compare_strategies,
)
from memory_analysis import (
ModelSpec, compute_model_memory, compute_kv_cache_memory,
find_max_context, compare_model_sizes,
)
from gpu_mapping import tensor_core_analysis, print_gpu_report
def demo_basic_kv_cache():
"""Demo 1: Basic KV cache operations."""
print("=" * 70)
print("DEMO 1: Basic KV Cache Operations")
print("=" * 70)
config = CacheConfig(
batch_size=2,
num_heads=4,
head_dim=16,
max_seq_len=64,
dtype=np.float32,
)
cache = KVCache(config)
print(f"\nCache shape: {cache.cache_k.shape}")
print(f" (batch={config.batch_size}, heads={config.num_heads}, "
f"max_seq={config.max_seq_len}, head_dim={config.head_dim})")
print(f"Allocated: {cache.memory_allocated_bytes:,} bytes")
# Simulate generating tokens one at a time
np.random.seed(42)
for step in range(10):
# Simulate new K and V from the model
k_new = np.random.randn(2, 4, 1, 16).astype(np.float32) * 0.01
v_new = np.random.randn(2, 4, 1, 16).astype(np.float32) * 0.01
cache.update(k_new, v_new)
print(f"\nAfter 10 steps:")
print(f" Write position: {cache.write_pos}")
print(f" Sequence lengths: {cache.lengths}")
print(f" Memory used: {cache.memory_used_bytes:,} bytes")
# Retrieve cached data
k_cached, v_cached = cache.get_all()
print(f" Cached K shape: {k_cached.shape}")
print(f" Cached V shape: {v_cached.shape}")
# Verify data integrity
assert k_cached.shape == (2, 4, 10, 16)
assert v_cached.shape == (2, 4, 10, 16)
print("\n ✓ Data integrity verified")
def demo_cached_attention():
"""Demo 2: Cached attention computation."""
print("\n" + "=" * 70)
print("DEMO 2: Cached Attention Computation")
print("=" * 70)
batch, heads, head_dim = 2, 4, 16
seq_len = 8
scale = 1.0 / np.sqrt(head_dim)
np.random.seed(123)
# Build a cache with some history
config = CacheConfig(batch_size=batch, num_heads=heads,
head_dim=head_dim, max_seq_len=64)
cache = KVCache(config)
# Fill cache with random K, V
for i in range(seq_len):
k = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
v = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
cache.update(k, v)
# Current query (new token)
q = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
# Cached attention
output = cached_attention(q, cache, scale)
print(f"\nQuery shape: {q.shape}")
print(f"Cached K shape: {cache.cache_k.shape} (used: {cache.write_pos} tokens)")
print(f"Output shape: {output.shape}")
# Verify against manual computation
k_all, v_all = cache.get_all()
scores = np.einsum("bhqd,bhkd->bhqk", q, k_all) * scale
attn = softmax_stable(scores, axis=-1)
manual_output = np.einsum("bhqk,bhkd->bhqd", attn, v_all)
diff = np.max(np.abs(output - manual_output))
print(f"Max difference from manual: {diff:.2e}")
assert diff < 1e-5, f"Attention mismatch: {diff}"
print(" ✓ Cached attention matches manual computation")
# Show attention weights for one batch/head
print(f"\nAttention weights (batch=0, head=0):")
print(f" {attn[0, 0, 0, :].round(3)}")
print(f" Sum: {attn[0, 0, 0, :].sum():.4f} (should be ~1.0)")
def demo_full_transformer():
"""Demo 3: Full transformer with KV-cache."""
print("\n" + "=" * 70)
print("DEMO 3: Full Transformer with KV-Cache")
print("=" * 70)
# Small model for demo
model = TransformerDecoder(
num_layers=2,
dim=64,
num_heads=4,
mlp_hidden=128,
vocab_size=1000,
max_seq_len=128,
batch_size=2,
dtype=np.float32,
seed=42,
)
# Create a prompt (padded to same length)
prompt = np.array([[10, 20, 30, 40, 50],
[15, 25, 35, 45, 0]], dtype=np.int32) # 0 = pad
lengths = np.array([5, 4], dtype=np.int32)
print(f"\nPrompt tokens: {prompt.shape}")
print(f" Sequence 0: {prompt[0]} (length={lengths[0]})")
print(f" Sequence 1: {prompt[1]} (length={lengths[1]})")
# Prefill
hidden = model.prefill(prompt, lengths=lengths)
print(f"\nAfter prefill:")
print(f" Hidden shape: {hidden.shape}")
print(f" Cache write position: {model.cache.caches[0].write_pos}")
# Generate tokens
print(f"\nGenerating 5 tokens...")
generated = model.generate(prompt, num_tokens=5, temperature=0.8, top_k=50,
lengths=lengths)
for i, tokens in enumerate(generated):
print(f" Step {i+1}: {tokens}")
# Memory report
report = model.memory_report()
print(f"\nMemory Report:")
for k, v in report.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
else:
print(f" {k}: {v}")
def demo_variable_length_batching():
"""Demo 4: Variable-length batching."""
print("\n" + "=" * 70)
print("DEMO 4: Variable-Length Batching")
print("=" * 70)
batch_size = 4
config = CacheConfig(
batch_size=batch_size,
num_heads=4,
head_dim=16,
max_seq_len=32,
dtype=np.float32,
)
cache = KVCache(config)
np.random.seed(99)
# Simulate sequences of different lengths
# Seq 0: 8 tokens, Seq 1: 5 tokens, Seq 2: 10 tokens, Seq 3: 3 tokens
seq_lengths = [8, 5, 10, 3]
max_len = max(seq_lengths)
print("\nSimulating variable-length batch:")
# Each batch item has its own cache (simplified: use separate caches)
per_seq_caches = [KVCache(CacheConfig(
batch_size=1, num_heads=4, head_dim=16,
max_seq_len=max_len, dtype=np.float32
)) for _ in range(batch_size)]
for b, length in enumerate(seq_lengths):
for t in range(length):
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
per_seq_caches[b].update(k, v)
# Query for each sequence at its current position
scale = 1.0 / np.sqrt(16)
for b in range(batch_size):
q = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
k_cached, v_cached = per_seq_caches[b].get_all()
# Attention for this batch item
scores = np.einsum("bhqd,bhkd->bhqk", q, k_cached) * scale
attn = softmax_stable(scores, axis=-1)
# Show which positions are attended to
print(f"\n Sequence {b} (length={seq_lengths[b]}):")
print(f" Attention: {attn[0, 0, 0, :].round(3)}")
def demo_paged_attention():
"""Demo 5: Paged attention."""
print("\n" + "=" * 70)
print("DEMO 5: Paged Attention (vLLM-style)")
print("=" * 70)
config = PageConfig(
block_size=4,
num_pages=16,
batch_size=2,
num_heads=4,
head_dim=16,
dtype=np.float32,
)
paged = PagedKVCache(config)
print(f"\nPage config:")
print(f" Block size: {config.block_size} tokens")
print(f" Pages per sequence: {config.num_pages}")
print(f" Max tokens per sequence: {config.num_pages * config.block_size}")
print(f" Allocated: {paged.memory_allocated_bytes:,} bytes")
np.random.seed(77)
# Fill sequence 0 with 12 tokens (3 blocks)
print(f"\nFilling sequence 0 with 12 tokens...")
for t in range(12):
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
block_idx = t // config.block_size
offset = t % config.block_size
paged.append_token(0, k, v, block_idx, offset)
print(f" Blocks allocated: {paged.num_blocks[0]}")
print(f" Page table: {paged.page_tables[0, :paged.num_blocks[0]]}")
# Fill sequence 1 with 8 tokens (2 blocks)
print(f"\nFilling sequence 1 with 8 tokens...")
for t in range(8):
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
block_idx = t // config.block_size
offset = t % config.block_size
paged.append_token(1, k, v, block_idx, offset)
print(f" Blocks allocated: {paged.num_blocks[1]}")
print(f" Page table: {paged.page_tables[1, :paged.num_blocks[1]]}")
# Retrieve and verify
k0, v0 = paged.get_sequence_contiguous(0, num_tokens=12)
k1, v1 = paged.get_sequence_contiguous(1, num_tokens=8)
print(f"\n Seq 0 K shape: {k0.shape}")
print(f" Seq 1 K shape: {k1.shape}")
print(f"\n Memory used: {paged.memory_used_bytes:,} bytes")
print(f" Utilization: {paged.memory_utilization():.1%}")
def demo_quantized_cache():
"""Demo 6: Quantized KV cache."""
print("\n" + "=" * 70)
print("DEMO 6: Quantized KV Cache (int8)")
print("=" * 70)
batch, heads, head_dim, max_seq = 2, 4, 16, 32
cache = QuantizedKVCache(batch, heads, head_dim, max_seq, dtype=np.float32)
np.random.seed(55)
# Fill with random data
for t in range(10):
k = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.1
v = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.1
cache.update(k, v)
# Retrieve and compare
k_deq, v_deq = cache.get()
print(f"\nQuantized cache (10 tokens):")
print(f" Dequantized K shape: {k_deq.shape}")
print(f" Dequantized V shape: {v_deq.shape}")
# Compare with original (we need to re-quantize to compare)
# The quantization error depends on the data distribution
print(f" Memory savings vs fp32: {cache.memory_savings_vs_fp32:.1%}")
print(f" Memory savings vs fp16: {cache.memory_savings_vs_fp16:.1%} (per-pos scales overhead)")
# Show quantization error for one position
# Use larger values for better int8 quantization fidelity
k_orig = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 1.0
v_orig = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 1.0
cache.update(k_orig, v_orig)
k_deq_single, _ = cache.get(start=10, end=11)
# k_deq_single: (batch, heads, 1, head_dim), k_orig: (batch, heads, 1, head_dim)
print(f" k_orig shape: {k_orig.shape}, k_deq shape: {k_deq_single.shape}")
error = np.max(np.abs(k_orig - k_deq_single))
rel_error = error / (np.max(np.abs(k_orig)) + 1e-8)
print(f" Max absolute error (one token): {error:.6f}")
print(f" Max relative error: {rel_error:.4f}")
print(f" → Per-position quantization has high overhead; production uses")
print(f" shared per-channel scales for ~50% memory savings with <1% error")
def demo_chunked_prefill():
"""Demo 7: Chunked prefill."""
print("\n" + "=" * 70)
print("DEMO 7: Chunked Prefill")
print("=" * 70)
chunker = ChunkedPrefill(chunk_size=4)
batch, heads, seq, head_dim = 1, 4, 12, 16
scale = 1.0 / np.sqrt(head_dim)
np.random.seed(33)
q = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
k = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
v = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
# Chunked attention
output_chunked = chunker.compute_attention_chunked(q, k, v, scale)
# Full attention (for comparison)
from attention import scaled_dot_product_attention, build_causal_mask
causal = build_causal_mask(seq, dtype=np.float32)
output_full = scaled_dot_product_attention(
q, k, v, scale, mask=causal[None, None, :, :]
)
diff = np.max(np.abs(output_chunked - output_full))
print(f"\nChunk size: {chunker.chunk_size}")
print(f"Sequence length: {seq}")
print(f"Chunks: {(seq + chunker.chunk_size - 1) // chunker.chunk_size}")
print(f"Max difference from full attention: {diff:.2e}")
assert diff < 1e-5, f"Chunked attention mismatch: {diff}"
print(" ✓ Chunked attention matches full attention")
# Memory comparison
mem = ChunkedPrefill.peak_memory_comparison(seq_len=4096, chunk_size=512)
print(f"\nMemory comparison (seq=4096, chunk=512):")
print(f" Full attention matrix: {mem['full_attention_mb']:.0f} MB")
print(f" Chunked peak: {mem['chunked_peak_attention_mb']:.0f} MB")
print(f" Savings: {mem['savings_ratio']:.1f}x")
def demo_optimization_comparison():
"""Demo 8: Optimization strategy comparison."""
print("\n" + "=" * 70)
print("DEMO 8: Optimization Strategy Comparison")
print("=" * 70)
results = compare_strategies(
batch_size=4, num_heads=32, head_dim=128,
max_seq_len=4096, num_layers=32
)
print(f"\nConfiguration: batch=4, heads=32, head_dim=128, "
f"seq=4096, layers=32\n")
header = f"{'Strategy':<25} {'Per Layer(MB)':>14} {'Total(GB)':>10} {'Notes':<25}"
print(header)
print("-" * len(header))
for name, data in results.items():
notes = ""
if "savings_vs_fp16" in data:
notes = f"{data['savings_vs_fp16']:.0%} savings"
elif "overhead_vs_naive" in data:
notes = f"{data['overhead_vs_naive']:.3f}x overhead"
print(f"{name:<25} {data['per_layer_mb']:>14.1f} {data['total_mb']/1024:>10.2f} "
f"{notes:<25}")
def demo_memory_analysis():
"""Demo 9: Memory growth analysis."""
print("\n" + "=" * 70)
print("DEMO 9: Memory Growth Analysis")
print("=" * 70)
# Compare model sizes
comparisons = compare_model_sizes()
print("\nModel Size Comparison (fp16):\n")
header = f"{'Model':<20} {'Params(GB)':>10} {'KV@1K':>8} {'KV@8K':>8} {'KV@32K':>8} {'MaxCtx(H100)':>12}"
print(header)
print("-" * len(header))
for name, data in comparisons.items():
print(f"{name:<20} {data['params_gb']:>10.1f} {data['kv_1k_gb']:>8.2f} "
f"{data['kv_8k_gb']:>8.2f} {data['kv_32k_gb']:>8.2f} "
f"{data['max_context_H100']:>12,}")
# Growth for 7B model
spec = ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128)
model_mem = compute_model_memory(spec, np.float16)
print(f"\n\n7B Model Memory Growth (batch=1, fp16):\n")
print(f" Model params: {model_mem['total_params_gb']:.1f} GB")
print()
seq_lens = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
print(f" {'Seq Len':>8} {'KV(GB)':>8} {'Total(GB)':>10} {'KV%':>6}")
print(f" {'-'*40}")
for sl in seq_lens:
kv = compute_kv_cache_memory(1, sl, spec, np.float16)
total = kv["total_gb"] + model_mem["total_params_gb"]
pct = kv["total_gb"] / total * 100
print(f" {sl:>8,} {kv['total_gb']:>8.2f} {total:>10.2f} {pct:>5.1f}%")
# GPU limits
print(f"\n\nMax Context by GPU (7B model, batch=1):\n")
gpus = {"RTX 4090": 24, "A100-40GB": 40, "A100-80GB": 80, "H100-80GB": 80}
for gpu, mem in gpus.items():
ctx = find_max_context(spec, mem, batch_size=1)
print(f" {gpu:<15}: {ctx:>8,} tokens")
def demo_gpu_tensor_cores():
"""Demo 10: GPU Tensor Core analysis."""
print("\n" + "=" * 70)
print("DEMO 10: GPU Tensor Core Analysis")
print("=" * 70)
configs = [
{"batch": 1, "heads": 32, "seq": 1024, "label": "Short context"},
{"batch": 1, "heads": 32, "seq": 8192, "label": "Long context"},
{"batch": 4, "heads": 32, "seq": 4096, "label": "Batched"},
]
for cfg in configs:
tc = tensor_core_analysis(
batch=cfg["batch"], heads=cfg["heads"], seq_len=cfg["seq"]
)
print(f"\n {cfg['label']} (batch={cfg['batch']}, seq={cfg['seq']}):")
print(f" Total FLOPs: {tc['total_flops']}")
print(f" Memory traffic: {tc['memory_traffic_mb']}")
print(f" Arithmetic intensity: {tc['arithmetic_intensity']}")
print(f" Compute bound: {tc['compute_bound_ms']}")
print(f" Memory bound: {tc['memory_bound_ms']}")
print(f"{tc['bound']}")
def main():
"""Run all demos."""
print("\n" + "" * 70)
print(" KV-CACHE SYSTEM FOR AUTOREGRESSIVE TRANSFORMER INFERENCE")
print(" Pure NumPy Implementation — No Frameworks")
print("" * 70)
demos = [
("Basic KV Cache", demo_basic_kv_cache),
("Cached Attention", demo_cached_attention),
("Full Transformer", demo_full_transformer),
("Variable-Length Batching", demo_variable_length_batching),
("Paged Attention", demo_paged_attention),
("Quantized Cache", demo_quantized_cache),
("Chunked Prefill", demo_chunked_prefill),
("Optimization Comparison", demo_optimization_comparison),
("Memory Analysis", demo_memory_analysis),
("GPU Tensor Cores", demo_gpu_tensor_cores),
]
for name, func in demos:
try:
func()
except Exception as e:
print(f"\n{name} failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "" * 70)
print(" ALL DEMOS COMPLETE")
print("" * 70 + "\n")
if __name__ == "__main__":
main()