Files
llm_programming_tests/glm5/kv/test_kv_cache.py
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

430 lines
18 KiB
Python

"""
End-to-end tests and demonstrations for the KV-cache system.
Run with: python test_kv_cache.py
"""
import numpy as np
from kv_cache import (
KVCache,
multi_head_attention_with_cache,
memory_growth_table,
memory_analysis,
IncrementalDecoder,
)
from optimizations import PagedKVCache, QuantizedKVCache
# ══════════════════════════════════════════════════════════════════════
# TEST 1: Basic KV-cache update & retrieval
# ══════════════════════════════════════════════════════════════════════
def test_basic_cache():
print("=" * 70)
print("TEST 1: Basic KV-cache update and retrieval")
print("=" * 70)
B, H, S_max, D = 2, 4, 16, 8
cache = KVCache(B, S_max, H, D)
print(f"Initial: {cache}")
# Prefill: write 5 tokens for batch 0, 3 tokens for batch 1
# (In practice, the full batch gets the same number, but we test
# the update logic by writing per-batch via positions)
new_k = np.random.randn(B, H, 5, D).astype(np.float32)
new_v = np.random.randn(B, H, 5, D).astype(np.float32)
cache.update(new_k, new_v)
print(f"After prefill (5 tokens): seq_lens={cache.seq_lens}")
# Decode: write 1 token at a time
for step in range(3):
one_k = np.random.randn(B, H, 1, D).astype(np.float32)
one_v = np.random.randn(B, H, 1, D).astype(np.float32)
cache.update(one_k, one_v)
print(f" Decode step {step}: seq_lens={cache.seq_lens}")
# Verify retrieval
k0, v0 = cache.get_kv(0)
print(f"\nBatch 0: retrieved K shape={k0.shape}, expected (4, 8, 8)")
assert k0.shape == (H, 8, D), f"Wrong shape: {k0.shape}"
k1, v1 = cache.get_kv(1)
print(f"Batch 1: retrieved K shape={k1.shape}, expected (4, 8, 8)")
assert k1.shape == (H, 8, D), f"Wrong shape: {k1.shape}"
# Verify the written values match
np.testing.assert_allclose(cache.k_cache[0, :, 7, :], one_k[0, :, 0, :])
np.testing.assert_allclose(cache.v_cache[1, :, 7, :], one_v[1, :, 0, :])
print("✓ All assertions passed.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 2: Attention with cache vs without (correctness check)
# ══════════════════════════════════════════════════════════════════════
def test_attention_correctness():
print("=" * 70)
print("TEST 2: Cached attention matches non-cached attention")
print("=" * 70)
np.random.seed(42)
B, H, D = 1, 2, 4
d_model = H * D
S = 6 # sequence length
T = 1 # decode step
# Random projection matrices
w_q = np.random.randn(d_model, d_model).astype(np.float32)
w_k = np.random.randn(d_model, d_model).astype(np.float32)
w_v = np.random.randn(d_model, d_model).astype(np.float32)
w_o = np.random.randn(d_model, d_model).astype(np.float32)
# Simulate embeddings for S+T tokens
all_tokens = np.random.randn(B, S + T, d_model).astype(np.float32)
# --- METHOD A: Non-cached (full recomputation) ---
from kv_cache import _scaled_dot_product_attention, _softmax
q_full = (all_tokens @ w_q).reshape(B, S + T, H, D)
k_full = (all_tokens @ w_k).reshape(B, S + T, H, D)
v_full = (all_tokens @ w_v).reshape(B, S + T, H, D)
# Compute attention for the LAST position only (autoregressive)
out_heads_a = np.empty((T, H, D), dtype=np.float32)
for h in range(H):
q_h = q_full[0, S:, h, :] # (1, D)
k_h = k_full[0, :, h, :] # (S+T, D)
v_h = v_full[0, :, h, :] # (S+T, D)
out_heads_a[:, h, :] = _scaled_dot_product_attention(q_h, k_h, v_h)
result_a = out_heads_a.reshape(T, d_model) @ w_o
# --- METHOD B: Cached (prefill S tokens, then decode 1) ---
cache = KVCache(B, S + T, H, D)
# Prefill: write K, V for first S tokens
k_prefill = k_full[:, :S, :, :].transpose(0, 2, 1, 3) # (B, H, S, D)
v_prefill = v_full[:, :S, :, :].transpose(0, 2, 1, 3)
cache.update(k_prefill, v_prefill)
# Decode: write K, V for the new token
k_decode = k_full[:, S:, :, :].transpose(0, 2, 1, 3) # (B, H, 1, D)
v_decode = v_full[:, S:, :, :].transpose(0, 2, 1, 3)
cache.update(k_decode, v_decode)
# Now compute attention for the new token using the cache
q_new = all_tokens[:, S:, :] # (B, 1, d_model)
result_b = multi_head_attention_with_cache(q_new, cache, w_q, w_k, w_v, w_o)
np.testing.assert_allclose(result_a, result_b[0], atol=1e-5)
print(f"Non-cached output: {result_a.flatten()[:4]}")
print(f"Cached output: {result_b.flatten()[:4]}")
print("✓ Cached and non-cached outputs match.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 3: Multi-batch with variable sequence lengths
# ══════════════════════════════════════════════════════════════════════
def test_variable_seq_lens():
print("=" * 70)
print("TEST 3: Multi-batch with variable sequence lengths")
print("=" * 70)
np.random.seed(123)
B, H, D = 3, 4, 8
S_max = 32
cache = KVCache(B, S_max, H, D)
# --- Prefill each batch element with a different prompt length ---
# We bypass the batched update() and write each element directly
# into the underlying cache arrays. This simulates the real
# scenario where different requests arrive with different prompt
# lengths and are packed into the same batch.
prompt_lens = [5, 12, 3]
original_k = {}
original_v = {}
for b in range(B):
L = prompt_lens[b]
k = np.random.randn(H, L, D).astype(np.float32)
v = np.random.randn(H, L, D).astype(np.float32)
cache.k_cache[b, :, :L, :] = k
cache.v_cache[b, :, :L, :] = v
cache.seq_lens[b] = L
original_k[b] = k
original_v[b] = v
print(f"After prefill: seq_lens={cache.seq_lens}")
assert cache.seq_lens == prompt_lens
# --- Verify prefill retrieval ---
for b in range(B):
k_ret, v_ret = cache.get_kv(b)
np.testing.assert_allclose(k_ret, original_k[b])
np.testing.assert_allclose(v_ret, original_v[b])
print(f" Batch {b}: ✓ prefill data verified (len={prompt_lens[b]})")
# --- Decode: all batch elements advance together (normal decode) ---
for step in range(4):
one_k = np.random.randn(B, H, 1, D).astype(np.float32)
one_v = np.random.randn(B, H, 1, D).astype(np.float32)
cache.update(one_k, one_v)
print(f" Decode step {step}: seq_lens={cache.seq_lens}")
# Verify each batch element has the right length
expected = [l + 4 for l in prompt_lens]
for b in range(B):
k_b, v_b = cache.get_kv(b)
print(f" Batch {b}: expected len={expected[b]}, got K shape seq dim={k_b.shape[1]}")
assert k_b.shape[1] == expected[b]
print("✓ Variable sequence lengths handled correctly.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 4: Incremental decoder end-to-end
# ══════════════════════════════════════════════════════════════════════
def test_incremental_decoder():
print("=" * 70)
print("TEST 4: Incremental decoder (prefill + autoregressive decode)")
print("=" * 70)
np.random.seed(7)
d_model = 32
num_heads = 4
num_layers = 2
max_seq_len = 64
vocab_size = 100
B = 1
decoder = IncrementalDecoder(d_model, num_heads, num_layers, max_seq_len, vocab_size)
decoder.max_seq_len = max_seq_len
decoder._init_caches(B)
# Prefill with a prompt of 8 tokens
prompt = np.array([[1, 5, 10, 15, 20, 25, 30, 35]], dtype=np.int64) # (1, 8)
logits = decoder.forward_step(prompt, decoder.caches, is_prefill=True)
print(f"After prefill (8 tokens):")
print(f" Logits shape: {logits.shape}")
print(f" Cache seq_lens: {[c.seq_lens for c in decoder.caches]}")
# Autoregressive decode: generate 5 more tokens
generated = []
next_token = logits.argmax(axis=-1) # (1,)
generated.append(next_token[0])
for step in range(5):
logits = decoder.forward_step(next_token, decoder.caches)
next_token = logits.argmax(axis=-1)
generated.append(next_token[0])
print(
f" Decode step {step}: seq_lens={decoder.caches[0].seq_lens}, "
f"token={next_token[0]}"
)
assert decoder.caches[0].seq_lens[0] == 8 + 5, "Should have 13 tokens cached"
print(f"Generated tokens: {generated}")
print("✓ Incremental decoder works.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 5: Paged KV-cache
# ══════════════════════════════════════════════════════════════════════
def test_paged_cache():
print("=" * 70)
print("TEST 5: Paged KV-cache (block-based allocation)")
print("=" * 70)
np.random.seed(99)
num_blocks = 20
block_size = 4
H, D = 4, 8
max_seqs = 4
paged = PagedKVCache(num_blocks, block_size, H, D, max_seqs)
print(f"Initial: {paged}")
# Start 3 sequences with different lengths
seq_ids = []
for _ in range(3):
sid = paged.add_sequence()
seq_ids.append(sid)
# Write different amounts to each
lengths = [6, 11, 3]
original_data_k = {}
original_data_v = {}
for i, sid in enumerate(seq_ids):
L = lengths[i]
k = np.random.randn(H, L, D).astype(np.float32)
v = np.random.randn(H, L, D).astype(np.float32)
paged.update(sid, k, v)
original_data_k[sid] = k
original_data_v[sid] = v
print(f" Seq {sid}: wrote {L} tokens, seq_len={paged.seq_lens[sid]}")
print(f"After writes: {paged}")
# Verify retrieval
for i, sid in enumerate(seq_ids):
k_ret, v_ret = paged.get_kv(sid)
L = lengths[i]
assert k_ret.shape == (H, L, D), f"Seq {sid}: expected ({H}, {L}, {D}), got {k_ret.shape}"
np.testing.assert_allclose(k_ret, original_data_k[sid], atol=1e-6)
np.testing.assert_allclose(v_ret, original_data_v[sid], atol=1e-6)
print(f" Seq {sid}: ✓ retrieved data matches original")
# Finish sequence 1 and verify blocks are freed
paged.finish_sequence(seq_ids[1])
print(f"After finishing seq {seq_ids[1]}: {paged}")
# Allocate a new sequence — should reuse freed blocks
new_sid = paged.add_sequence()
k_new = np.random.randn(H, 8, D).astype(np.float32)
v_new = np.random.randn(H, 8, D).astype(np.float32)
paged.update(new_sid, k_new, v_new)
print(f"New seq {new_sid} with 8 tokens: {paged}")
# Verify new sequence data
k_new_ret, v_new_ret = paged.get_kv(new_sid)
np.testing.assert_allclose(k_new_ret, k_new, atol=1e-6)
print("✓ Paged KV-cache works correctly.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 6: Quantized KV-cache
# ══════════════════════════════════════════════════════════════════════
def test_quantized_cache():
print("=" * 70)
print("TEST 6: Quantized KV-cache (INT8 and INT4)")
print("=" * 70)
np.random.seed(42)
B, H, D, S_max = 1, 2, 8, 32
for bits in [8, 4]:
print(f"\n--- INT{bits} ---")
qcache = QuantizedKVCache(B, S_max, H, D, bits=bits)
print(f" {qcache}")
# Write some tokens
T = 10
k_orig = np.random.randn(B, H, T, D).astype(np.float32) * 2
v_orig = np.random.randn(B, H, T, D).astype(np.float32) * 2
qcache.update(k_orig, v_orig)
# Retrieve and measure error
k_ret, v_ret = qcache.get_kv(0)
assert k_ret.shape == (H, T, D)
k_error = np.mean(np.abs(k_ret - k_orig[0]))
v_error = np.mean(np.abs(v_ret - v_orig[0]))
print(f" Mean absolute error (K): {k_error:.6f}")
print(f" Mean absolute error (V): {v_error:.6f}")
print(f" Memory savings vs FP32: {qcache.savings_vs_fp32():.3f}x")
print(f" Actual memory: {qcache.memory_bytes() / 1e3:.1f} KB")
# For INT8, error should be small; for INT4, larger but bounded
# Scale factor ≈ (max-min) / 255 for INT8, so error ≈ scale/2 per element
max_expected_error = {8: 0.1, 4: 0.5}
assert k_error < max_expected_error[bits], f"INT{bits} quantization error too large: {k_error}"
print("\n✓ Quantized cache works.\n")
# ══════════════════════════════════════════════════════════════════════
# TEST 7: Memory growth analysis
# ══════════════════════════════════════════════════════════════════════
def test_memory_analysis():
print("=" * 70)
print("TEST 7: Memory growth analysis")
print("=" * 70)
# GPT-4 class model: 32 layers, 32 heads, dim 128
print("\nKV-Cache Memory vs Sequence Length (GPT-4-class model)")
print("Model: 32 layers, 32 heads, head_dim=128, batch=1, FP32")
print(memory_growth_table())
# Llama-2 70B class
print("\nKV-Cache Memory vs Sequence Length (Llama-2 70B class)")
print("Model: 80 layers, 64 heads, head_dim=128, batch=1, FP32")
print(memory_growth_table(num_layers=80, num_heads=64, head_dim=128))
# Batch scaling
print("\nMemory scaling with batch size (seq_len=4096):")
print(f"{'Batch':>8} | {'Total (GB)':>12}")
print("-" * 28)
for bs in [1, 2, 4, 8, 16, 32, 64]:
info = memory_analysis(32, 32, 128, bs, 4096)
print(f"{bs:>8} | {info['total_GB']:>12.3f}")
print()
# ══════════════════════════════════════════════════════════════════════
# TEST 8: FLOPs comparison — cached vs uncached
# ══════════════════════════════════════════════════════════════════════
def test_flops_analysis():
print("=" * 70)
print("TEST 8: FLOPs saved by KV-caching")
print("=" * 70)
d_model = 4096
H = 32
D = d_model // H
prompt_len = 1024
decode_steps = 100
# Without cache: each decode step recomputes attention for ALL positions
# FLOPs per attention step = 2 * S * d_model (Q projection)
# + 2 * S * d_model * S (attention scores) -- O(S²)
# + 2 * S * d_model * S (weighted sum)
# ≈ 4 * S² * d_model per layer
# With cache: each decode step only computes for 1 new token
# FLOPs = 2 * d_model (Q projection for 1 token)
# + 2 * S * d_model (Q * K^T for 1 query vs S keys)
# + 2 * S * d_model (attention weights * V)
# ≈ 4 * S * d_model per layer
flops_no_cache = 4 * decode_steps * (prompt_len + decode_steps) ** 2 * d_model
flops_cached = (
# Prefill: O(S² * d_model)
4 * prompt_len**2 * d_model
# Decode: O(S * d_model) per step
+ sum(4 * (prompt_len + t) * d_model for t in range(decode_steps))
)
print(f"Model d_model={d_model}, H={H}, prompt={prompt_len}, decode={decode_steps}")
print(f" Without cache: {flops_no_cache:.3e} FLOPs")
print(f" With cache: {flops_cached:.3e} FLOPs")
print(f" Speedup: {flops_no_cache / flops_cached:.1f}x")
print()
# ══════════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
test_basic_cache()
test_attention_correctness()
test_variable_seq_lens()
test_incremental_decoder()
test_paged_cache()
test_quantized_cache()
test_memory_analysis()
test_flops_analysis()
print("=" * 70)
print("ALL TESTS PASSED ✓")
print("=" * 70)