8e72eef09c
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
516 lines
18 KiB
Python
516 lines
18 KiB
Python
"""
|
|
End-to-End KV-Cache Demo
|
|
|
|
Demonstrates:
|
|
1. Building a small transformer with KV-cache
|
|
2. Prefill phase (prompt processing)
|
|
3. Incremental generation (one token at a time)
|
|
4. Variable-length batching
|
|
5. Memory tracking
|
|
6. Optimization comparisons
|
|
"""
|
|
|
|
import numpy as np
|
|
import sys
|
|
import os
|
|
|
|
# Ensure we can import from the project
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
|
|
from kv_cache import KVCache, CacheConfig, BatchedKVCache
|
|
from attention import (
|
|
scaled_dot_product_attention,
|
|
cached_attention,
|
|
build_causal_mask,
|
|
softmax_stable,
|
|
)
|
|
from transformer import TransformerDecoder, TransformerDecoderLayer
|
|
from optimizations import (
|
|
PagedKVCache, PageConfig,
|
|
QuantizedKVCache,
|
|
ChunkedPrefill,
|
|
compare_strategies,
|
|
)
|
|
from memory_analysis import (
|
|
ModelSpec, compute_model_memory, compute_kv_cache_memory,
|
|
find_max_context, compare_model_sizes,
|
|
)
|
|
from gpu_mapping import tensor_core_analysis, print_gpu_report
|
|
|
|
|
|
def demo_basic_kv_cache():
|
|
"""Demo 1: Basic KV cache operations."""
|
|
print("=" * 70)
|
|
print("DEMO 1: Basic KV Cache Operations")
|
|
print("=" * 70)
|
|
|
|
config = CacheConfig(
|
|
batch_size=2,
|
|
num_heads=4,
|
|
head_dim=16,
|
|
max_seq_len=64,
|
|
dtype=np.float32,
|
|
)
|
|
cache = KVCache(config)
|
|
|
|
print(f"\nCache shape: {cache.cache_k.shape}")
|
|
print(f" (batch={config.batch_size}, heads={config.num_heads}, "
|
|
f"max_seq={config.max_seq_len}, head_dim={config.head_dim})")
|
|
print(f"Allocated: {cache.memory_allocated_bytes:,} bytes")
|
|
|
|
# Simulate generating tokens one at a time
|
|
np.random.seed(42)
|
|
for step in range(10):
|
|
# Simulate new K and V from the model
|
|
k_new = np.random.randn(2, 4, 1, 16).astype(np.float32) * 0.01
|
|
v_new = np.random.randn(2, 4, 1, 16).astype(np.float32) * 0.01
|
|
|
|
cache.update(k_new, v_new)
|
|
|
|
print(f"\nAfter 10 steps:")
|
|
print(f" Write position: {cache.write_pos}")
|
|
print(f" Sequence lengths: {cache.lengths}")
|
|
print(f" Memory used: {cache.memory_used_bytes:,} bytes")
|
|
|
|
# Retrieve cached data
|
|
k_cached, v_cached = cache.get_all()
|
|
print(f" Cached K shape: {k_cached.shape}")
|
|
print(f" Cached V shape: {v_cached.shape}")
|
|
|
|
# Verify data integrity
|
|
assert k_cached.shape == (2, 4, 10, 16)
|
|
assert v_cached.shape == (2, 4, 10, 16)
|
|
print("\n ✓ Data integrity verified")
|
|
|
|
|
|
def demo_cached_attention():
|
|
"""Demo 2: Cached attention computation."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 2: Cached Attention Computation")
|
|
print("=" * 70)
|
|
|
|
batch, heads, head_dim = 2, 4, 16
|
|
seq_len = 8
|
|
scale = 1.0 / np.sqrt(head_dim)
|
|
|
|
np.random.seed(123)
|
|
|
|
# Build a cache with some history
|
|
config = CacheConfig(batch_size=batch, num_heads=heads,
|
|
head_dim=head_dim, max_seq_len=64)
|
|
cache = KVCache(config)
|
|
|
|
# Fill cache with random K, V
|
|
for i in range(seq_len):
|
|
k = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
|
|
v = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
|
|
cache.update(k, v)
|
|
|
|
# Current query (new token)
|
|
q = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.01
|
|
|
|
# Cached attention
|
|
output = cached_attention(q, cache, scale)
|
|
print(f"\nQuery shape: {q.shape}")
|
|
print(f"Cached K shape: {cache.cache_k.shape} (used: {cache.write_pos} tokens)")
|
|
print(f"Output shape: {output.shape}")
|
|
|
|
# Verify against manual computation
|
|
k_all, v_all = cache.get_all()
|
|
scores = np.einsum("bhqd,bhkd->bhqk", q, k_all) * scale
|
|
attn = softmax_stable(scores, axis=-1)
|
|
manual_output = np.einsum("bhqk,bhkd->bhqd", attn, v_all)
|
|
|
|
diff = np.max(np.abs(output - manual_output))
|
|
print(f"Max difference from manual: {diff:.2e}")
|
|
assert diff < 1e-5, f"Attention mismatch: {diff}"
|
|
print(" ✓ Cached attention matches manual computation")
|
|
|
|
# Show attention weights for one batch/head
|
|
print(f"\nAttention weights (batch=0, head=0):")
|
|
print(f" {attn[0, 0, 0, :].round(3)}")
|
|
print(f" Sum: {attn[0, 0, 0, :].sum():.4f} (should be ~1.0)")
|
|
|
|
|
|
def demo_full_transformer():
|
|
"""Demo 3: Full transformer with KV-cache."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 3: Full Transformer with KV-Cache")
|
|
print("=" * 70)
|
|
|
|
# Small model for demo
|
|
model = TransformerDecoder(
|
|
num_layers=2,
|
|
dim=64,
|
|
num_heads=4,
|
|
mlp_hidden=128,
|
|
vocab_size=1000,
|
|
max_seq_len=128,
|
|
batch_size=2,
|
|
dtype=np.float32,
|
|
seed=42,
|
|
)
|
|
|
|
# Create a prompt (padded to same length)
|
|
prompt = np.array([[10, 20, 30, 40, 50],
|
|
[15, 25, 35, 45, 0]], dtype=np.int32) # 0 = pad
|
|
|
|
lengths = np.array([5, 4], dtype=np.int32)
|
|
|
|
print(f"\nPrompt tokens: {prompt.shape}")
|
|
print(f" Sequence 0: {prompt[0]} (length={lengths[0]})")
|
|
print(f" Sequence 1: {prompt[1]} (length={lengths[1]})")
|
|
|
|
# Prefill
|
|
hidden = model.prefill(prompt, lengths=lengths)
|
|
print(f"\nAfter prefill:")
|
|
print(f" Hidden shape: {hidden.shape}")
|
|
print(f" Cache write position: {model.cache.caches[0].write_pos}")
|
|
|
|
# Generate tokens
|
|
print(f"\nGenerating 5 tokens...")
|
|
generated = model.generate(prompt, num_tokens=5, temperature=0.8, top_k=50,
|
|
lengths=lengths)
|
|
|
|
for i, tokens in enumerate(generated):
|
|
print(f" Step {i+1}: {tokens}")
|
|
|
|
# Memory report
|
|
report = model.memory_report()
|
|
print(f"\nMemory Report:")
|
|
for k, v in report.items():
|
|
if isinstance(v, float):
|
|
print(f" {k}: {v:.4f}")
|
|
else:
|
|
print(f" {k}: {v}")
|
|
|
|
|
|
def demo_variable_length_batching():
|
|
"""Demo 4: Variable-length batching."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 4: Variable-Length Batching")
|
|
print("=" * 70)
|
|
|
|
batch_size = 4
|
|
config = CacheConfig(
|
|
batch_size=batch_size,
|
|
num_heads=4,
|
|
head_dim=16,
|
|
max_seq_len=32,
|
|
dtype=np.float32,
|
|
)
|
|
cache = KVCache(config)
|
|
|
|
np.random.seed(99)
|
|
|
|
# Simulate sequences of different lengths
|
|
# Seq 0: 8 tokens, Seq 1: 5 tokens, Seq 2: 10 tokens, Seq 3: 3 tokens
|
|
seq_lengths = [8, 5, 10, 3]
|
|
max_len = max(seq_lengths)
|
|
|
|
print("\nSimulating variable-length batch:")
|
|
# Each batch item has its own cache (simplified: use separate caches)
|
|
per_seq_caches = [KVCache(CacheConfig(
|
|
batch_size=1, num_heads=4, head_dim=16,
|
|
max_seq_len=max_len, dtype=np.float32
|
|
)) for _ in range(batch_size)]
|
|
|
|
for b, length in enumerate(seq_lengths):
|
|
for t in range(length):
|
|
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
|
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
|
per_seq_caches[b].update(k, v)
|
|
|
|
# Query for each sequence at its current position
|
|
scale = 1.0 / np.sqrt(16)
|
|
for b in range(batch_size):
|
|
q = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
|
k_cached, v_cached = per_seq_caches[b].get_all()
|
|
|
|
# Attention for this batch item
|
|
scores = np.einsum("bhqd,bhkd->bhqk", q, k_cached) * scale
|
|
attn = softmax_stable(scores, axis=-1)
|
|
|
|
# Show which positions are attended to
|
|
print(f"\n Sequence {b} (length={seq_lengths[b]}):")
|
|
print(f" Attention: {attn[0, 0, 0, :].round(3)}")
|
|
|
|
|
|
def demo_paged_attention():
|
|
"""Demo 5: Paged attention."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 5: Paged Attention (vLLM-style)")
|
|
print("=" * 70)
|
|
|
|
config = PageConfig(
|
|
block_size=4,
|
|
num_pages=16,
|
|
batch_size=2,
|
|
num_heads=4,
|
|
head_dim=16,
|
|
dtype=np.float32,
|
|
)
|
|
paged = PagedKVCache(config)
|
|
|
|
print(f"\nPage config:")
|
|
print(f" Block size: {config.block_size} tokens")
|
|
print(f" Pages per sequence: {config.num_pages}")
|
|
print(f" Max tokens per sequence: {config.num_pages * config.block_size}")
|
|
print(f" Allocated: {paged.memory_allocated_bytes:,} bytes")
|
|
|
|
np.random.seed(77)
|
|
|
|
# Fill sequence 0 with 12 tokens (3 blocks)
|
|
print(f"\nFilling sequence 0 with 12 tokens...")
|
|
for t in range(12):
|
|
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
|
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
|
block_idx = t // config.block_size
|
|
offset = t % config.block_size
|
|
paged.append_token(0, k, v, block_idx, offset)
|
|
|
|
print(f" Blocks allocated: {paged.num_blocks[0]}")
|
|
print(f" Page table: {paged.page_tables[0, :paged.num_blocks[0]]}")
|
|
|
|
# Fill sequence 1 with 8 tokens (2 blocks)
|
|
print(f"\nFilling sequence 1 with 8 tokens...")
|
|
for t in range(8):
|
|
k = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
|
v = np.random.randn(1, 4, 1, 16).astype(np.float32) * 0.01
|
|
block_idx = t // config.block_size
|
|
offset = t % config.block_size
|
|
paged.append_token(1, k, v, block_idx, offset)
|
|
|
|
print(f" Blocks allocated: {paged.num_blocks[1]}")
|
|
print(f" Page table: {paged.page_tables[1, :paged.num_blocks[1]]}")
|
|
|
|
# Retrieve and verify
|
|
k0, v0 = paged.get_sequence_contiguous(0, num_tokens=12)
|
|
k1, v1 = paged.get_sequence_contiguous(1, num_tokens=8)
|
|
print(f"\n Seq 0 K shape: {k0.shape}")
|
|
print(f" Seq 1 K shape: {k1.shape}")
|
|
|
|
print(f"\n Memory used: {paged.memory_used_bytes:,} bytes")
|
|
print(f" Utilization: {paged.memory_utilization():.1%}")
|
|
|
|
|
|
def demo_quantized_cache():
|
|
"""Demo 6: Quantized KV cache."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 6: Quantized KV Cache (int8)")
|
|
print("=" * 70)
|
|
|
|
batch, heads, head_dim, max_seq = 2, 4, 16, 32
|
|
cache = QuantizedKVCache(batch, heads, head_dim, max_seq, dtype=np.float32)
|
|
|
|
np.random.seed(55)
|
|
|
|
# Fill with random data
|
|
for t in range(10):
|
|
k = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.1
|
|
v = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 0.1
|
|
cache.update(k, v)
|
|
|
|
# Retrieve and compare
|
|
k_deq, v_deq = cache.get()
|
|
print(f"\nQuantized cache (10 tokens):")
|
|
print(f" Dequantized K shape: {k_deq.shape}")
|
|
print(f" Dequantized V shape: {v_deq.shape}")
|
|
|
|
# Compare with original (we need to re-quantize to compare)
|
|
# The quantization error depends on the data distribution
|
|
print(f" Memory savings vs fp32: {cache.memory_savings_vs_fp32:.1%}")
|
|
print(f" Memory savings vs fp16: {cache.memory_savings_vs_fp16:.1%} (per-pos scales overhead)")
|
|
|
|
# Show quantization error for one position
|
|
# Use larger values for better int8 quantization fidelity
|
|
k_orig = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 1.0
|
|
v_orig = np.random.randn(batch, heads, 1, head_dim).astype(np.float32) * 1.0
|
|
cache.update(k_orig, v_orig)
|
|
k_deq_single, _ = cache.get(start=10, end=11)
|
|
|
|
# k_deq_single: (batch, heads, 1, head_dim), k_orig: (batch, heads, 1, head_dim)
|
|
print(f" k_orig shape: {k_orig.shape}, k_deq shape: {k_deq_single.shape}")
|
|
error = np.max(np.abs(k_orig - k_deq_single))
|
|
rel_error = error / (np.max(np.abs(k_orig)) + 1e-8)
|
|
print(f" Max absolute error (one token): {error:.6f}")
|
|
print(f" Max relative error: {rel_error:.4f}")
|
|
print(f" → Per-position quantization has high overhead; production uses")
|
|
print(f" shared per-channel scales for ~50% memory savings with <1% error")
|
|
|
|
|
|
def demo_chunked_prefill():
|
|
"""Demo 7: Chunked prefill."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 7: Chunked Prefill")
|
|
print("=" * 70)
|
|
|
|
chunker = ChunkedPrefill(chunk_size=4)
|
|
|
|
batch, heads, seq, head_dim = 1, 4, 12, 16
|
|
scale = 1.0 / np.sqrt(head_dim)
|
|
|
|
np.random.seed(33)
|
|
q = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
|
|
k = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
|
|
v = np.random.randn(batch, heads, seq, head_dim).astype(np.float32) * 0.01
|
|
|
|
# Chunked attention
|
|
output_chunked = chunker.compute_attention_chunked(q, k, v, scale)
|
|
|
|
# Full attention (for comparison)
|
|
from attention import scaled_dot_product_attention, build_causal_mask
|
|
causal = build_causal_mask(seq, dtype=np.float32)
|
|
output_full = scaled_dot_product_attention(
|
|
q, k, v, scale, mask=causal[None, None, :, :]
|
|
)
|
|
|
|
diff = np.max(np.abs(output_chunked - output_full))
|
|
print(f"\nChunk size: {chunker.chunk_size}")
|
|
print(f"Sequence length: {seq}")
|
|
print(f"Chunks: {(seq + chunker.chunk_size - 1) // chunker.chunk_size}")
|
|
print(f"Max difference from full attention: {diff:.2e}")
|
|
assert diff < 1e-5, f"Chunked attention mismatch: {diff}"
|
|
print(" ✓ Chunked attention matches full attention")
|
|
|
|
# Memory comparison
|
|
mem = ChunkedPrefill.peak_memory_comparison(seq_len=4096, chunk_size=512)
|
|
print(f"\nMemory comparison (seq=4096, chunk=512):")
|
|
print(f" Full attention matrix: {mem['full_attention_mb']:.0f} MB")
|
|
print(f" Chunked peak: {mem['chunked_peak_attention_mb']:.0f} MB")
|
|
print(f" Savings: {mem['savings_ratio']:.1f}x")
|
|
|
|
|
|
def demo_optimization_comparison():
|
|
"""Demo 8: Optimization strategy comparison."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 8: Optimization Strategy Comparison")
|
|
print("=" * 70)
|
|
|
|
results = compare_strategies(
|
|
batch_size=4, num_heads=32, head_dim=128,
|
|
max_seq_len=4096, num_layers=32
|
|
)
|
|
|
|
print(f"\nConfiguration: batch=4, heads=32, head_dim=128, "
|
|
f"seq=4096, layers=32\n")
|
|
|
|
header = f"{'Strategy':<25} {'Per Layer(MB)':>14} {'Total(GB)':>10} {'Notes':<25}"
|
|
print(header)
|
|
print("-" * len(header))
|
|
|
|
for name, data in results.items():
|
|
notes = ""
|
|
if "savings_vs_fp16" in data:
|
|
notes = f"{data['savings_vs_fp16']:.0%} savings"
|
|
elif "overhead_vs_naive" in data:
|
|
notes = f"{data['overhead_vs_naive']:.3f}x overhead"
|
|
|
|
print(f"{name:<25} {data['per_layer_mb']:>14.1f} {data['total_mb']/1024:>10.2f} "
|
|
f"{notes:<25}")
|
|
|
|
|
|
def demo_memory_analysis():
|
|
"""Demo 9: Memory growth analysis."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 9: Memory Growth Analysis")
|
|
print("=" * 70)
|
|
|
|
# Compare model sizes
|
|
comparisons = compare_model_sizes()
|
|
|
|
print("\nModel Size Comparison (fp16):\n")
|
|
header = f"{'Model':<20} {'Params(GB)':>10} {'KV@1K':>8} {'KV@8K':>8} {'KV@32K':>8} {'MaxCtx(H100)':>12}"
|
|
print(header)
|
|
print("-" * len(header))
|
|
for name, data in comparisons.items():
|
|
print(f"{name:<20} {data['params_gb']:>10.1f} {data['kv_1k_gb']:>8.2f} "
|
|
f"{data['kv_8k_gb']:>8.2f} {data['kv_32k_gb']:>8.2f} "
|
|
f"{data['max_context_H100']:>12,}")
|
|
|
|
# Growth for 7B model
|
|
spec = ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128)
|
|
model_mem = compute_model_memory(spec, np.float16)
|
|
|
|
print(f"\n\n7B Model Memory Growth (batch=1, fp16):\n")
|
|
print(f" Model params: {model_mem['total_params_gb']:.1f} GB")
|
|
print()
|
|
|
|
seq_lens = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
|
print(f" {'Seq Len':>8} {'KV(GB)':>8} {'Total(GB)':>10} {'KV%':>6}")
|
|
print(f" {'-'*40}")
|
|
for sl in seq_lens:
|
|
kv = compute_kv_cache_memory(1, sl, spec, np.float16)
|
|
total = kv["total_gb"] + model_mem["total_params_gb"]
|
|
pct = kv["total_gb"] / total * 100
|
|
print(f" {sl:>8,} {kv['total_gb']:>8.2f} {total:>10.2f} {pct:>5.1f}%")
|
|
|
|
# GPU limits
|
|
print(f"\n\nMax Context by GPU (7B model, batch=1):\n")
|
|
gpus = {"RTX 4090": 24, "A100-40GB": 40, "A100-80GB": 80, "H100-80GB": 80}
|
|
for gpu, mem in gpus.items():
|
|
ctx = find_max_context(spec, mem, batch_size=1)
|
|
print(f" {gpu:<15}: {ctx:>8,} tokens")
|
|
|
|
|
|
def demo_gpu_tensor_cores():
|
|
"""Demo 10: GPU Tensor Core analysis."""
|
|
print("\n" + "=" * 70)
|
|
print("DEMO 10: GPU Tensor Core Analysis")
|
|
print("=" * 70)
|
|
|
|
configs = [
|
|
{"batch": 1, "heads": 32, "seq": 1024, "label": "Short context"},
|
|
{"batch": 1, "heads": 32, "seq": 8192, "label": "Long context"},
|
|
{"batch": 4, "heads": 32, "seq": 4096, "label": "Batched"},
|
|
]
|
|
|
|
for cfg in configs:
|
|
tc = tensor_core_analysis(
|
|
batch=cfg["batch"], heads=cfg["heads"], seq_len=cfg["seq"]
|
|
)
|
|
print(f"\n {cfg['label']} (batch={cfg['batch']}, seq={cfg['seq']}):")
|
|
print(f" Total FLOPs: {tc['total_flops']}")
|
|
print(f" Memory traffic: {tc['memory_traffic_mb']}")
|
|
print(f" Arithmetic intensity: {tc['arithmetic_intensity']}")
|
|
print(f" Compute bound: {tc['compute_bound_ms']}")
|
|
print(f" Memory bound: {tc['memory_bound_ms']}")
|
|
print(f" → {tc['bound']}")
|
|
|
|
|
|
def main():
|
|
"""Run all demos."""
|
|
print("\n" + "█" * 70)
|
|
print(" KV-CACHE SYSTEM FOR AUTOREGRESSIVE TRANSFORMER INFERENCE")
|
|
print(" Pure NumPy Implementation — No Frameworks")
|
|
print("█" * 70)
|
|
|
|
demos = [
|
|
("Basic KV Cache", demo_basic_kv_cache),
|
|
("Cached Attention", demo_cached_attention),
|
|
("Full Transformer", demo_full_transformer),
|
|
("Variable-Length Batching", demo_variable_length_batching),
|
|
("Paged Attention", demo_paged_attention),
|
|
("Quantized Cache", demo_quantized_cache),
|
|
("Chunked Prefill", demo_chunked_prefill),
|
|
("Optimization Comparison", demo_optimization_comparison),
|
|
("Memory Analysis", demo_memory_analysis),
|
|
("GPU Tensor Cores", demo_gpu_tensor_cores),
|
|
]
|
|
|
|
for name, func in demos:
|
|
try:
|
|
func()
|
|
except Exception as e:
|
|
print(f"\n ✗ {name} failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print("\n" + "█" * 70)
|
|
print(" ALL DEMOS COMPLETE")
|
|
print("█" * 70 + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|