8e72eef09c
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
430 lines
18 KiB
Python
430 lines
18 KiB
Python
"""
|
|
End-to-end tests and demonstrations for the KV-cache system.
|
|
|
|
Run with: python test_kv_cache.py
|
|
"""
|
|
|
|
import numpy as np
|
|
from kv_cache import (
|
|
KVCache,
|
|
multi_head_attention_with_cache,
|
|
memory_growth_table,
|
|
memory_analysis,
|
|
IncrementalDecoder,
|
|
)
|
|
from optimizations import PagedKVCache, QuantizedKVCache
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 1: Basic KV-cache update & retrieval
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_basic_cache():
|
|
print("=" * 70)
|
|
print("TEST 1: Basic KV-cache update and retrieval")
|
|
print("=" * 70)
|
|
|
|
B, H, S_max, D = 2, 4, 16, 8
|
|
cache = KVCache(B, S_max, H, D)
|
|
print(f"Initial: {cache}")
|
|
|
|
# Prefill: write 5 tokens for batch 0, 3 tokens for batch 1
|
|
# (In practice, the full batch gets the same number, but we test
|
|
# the update logic by writing per-batch via positions)
|
|
new_k = np.random.randn(B, H, 5, D).astype(np.float32)
|
|
new_v = np.random.randn(B, H, 5, D).astype(np.float32)
|
|
cache.update(new_k, new_v)
|
|
print(f"After prefill (5 tokens): seq_lens={cache.seq_lens}")
|
|
|
|
# Decode: write 1 token at a time
|
|
for step in range(3):
|
|
one_k = np.random.randn(B, H, 1, D).astype(np.float32)
|
|
one_v = np.random.randn(B, H, 1, D).astype(np.float32)
|
|
cache.update(one_k, one_v)
|
|
print(f" Decode step {step}: seq_lens={cache.seq_lens}")
|
|
|
|
# Verify retrieval
|
|
k0, v0 = cache.get_kv(0)
|
|
print(f"\nBatch 0: retrieved K shape={k0.shape}, expected (4, 8, 8)")
|
|
assert k0.shape == (H, 8, D), f"Wrong shape: {k0.shape}"
|
|
|
|
k1, v1 = cache.get_kv(1)
|
|
print(f"Batch 1: retrieved K shape={k1.shape}, expected (4, 8, 8)")
|
|
assert k1.shape == (H, 8, D), f"Wrong shape: {k1.shape}"
|
|
|
|
# Verify the written values match
|
|
np.testing.assert_allclose(cache.k_cache[0, :, 7, :], one_k[0, :, 0, :])
|
|
np.testing.assert_allclose(cache.v_cache[1, :, 7, :], one_v[1, :, 0, :])
|
|
print("✓ All assertions passed.\n")
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 2: Attention with cache vs without (correctness check)
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_attention_correctness():
|
|
print("=" * 70)
|
|
print("TEST 2: Cached attention matches non-cached attention")
|
|
print("=" * 70)
|
|
|
|
np.random.seed(42)
|
|
B, H, D = 1, 2, 4
|
|
d_model = H * D
|
|
S = 6 # sequence length
|
|
T = 1 # decode step
|
|
|
|
# Random projection matrices
|
|
w_q = np.random.randn(d_model, d_model).astype(np.float32)
|
|
w_k = np.random.randn(d_model, d_model).astype(np.float32)
|
|
w_v = np.random.randn(d_model, d_model).astype(np.float32)
|
|
w_o = np.random.randn(d_model, d_model).astype(np.float32)
|
|
|
|
# Simulate embeddings for S+T tokens
|
|
all_tokens = np.random.randn(B, S + T, d_model).astype(np.float32)
|
|
|
|
# --- METHOD A: Non-cached (full recomputation) ---
|
|
from kv_cache import _scaled_dot_product_attention, _softmax
|
|
|
|
q_full = (all_tokens @ w_q).reshape(B, S + T, H, D)
|
|
k_full = (all_tokens @ w_k).reshape(B, S + T, H, D)
|
|
v_full = (all_tokens @ w_v).reshape(B, S + T, H, D)
|
|
|
|
# Compute attention for the LAST position only (autoregressive)
|
|
out_heads_a = np.empty((T, H, D), dtype=np.float32)
|
|
for h in range(H):
|
|
q_h = q_full[0, S:, h, :] # (1, D)
|
|
k_h = k_full[0, :, h, :] # (S+T, D)
|
|
v_h = v_full[0, :, h, :] # (S+T, D)
|
|
out_heads_a[:, h, :] = _scaled_dot_product_attention(q_h, k_h, v_h)
|
|
result_a = out_heads_a.reshape(T, d_model) @ w_o
|
|
|
|
# --- METHOD B: Cached (prefill S tokens, then decode 1) ---
|
|
cache = KVCache(B, S + T, H, D)
|
|
|
|
# Prefill: write K, V for first S tokens
|
|
k_prefill = k_full[:, :S, :, :].transpose(0, 2, 1, 3) # (B, H, S, D)
|
|
v_prefill = v_full[:, :S, :, :].transpose(0, 2, 1, 3)
|
|
cache.update(k_prefill, v_prefill)
|
|
|
|
# Decode: write K, V for the new token
|
|
k_decode = k_full[:, S:, :, :].transpose(0, 2, 1, 3) # (B, H, 1, D)
|
|
v_decode = v_full[:, S:, :, :].transpose(0, 2, 1, 3)
|
|
cache.update(k_decode, v_decode)
|
|
|
|
# Now compute attention for the new token using the cache
|
|
q_new = all_tokens[:, S:, :] # (B, 1, d_model)
|
|
result_b = multi_head_attention_with_cache(q_new, cache, w_q, w_k, w_v, w_o)
|
|
|
|
np.testing.assert_allclose(result_a, result_b[0], atol=1e-5)
|
|
print(f"Non-cached output: {result_a.flatten()[:4]}")
|
|
print(f"Cached output: {result_b.flatten()[:4]}")
|
|
print("✓ Cached and non-cached outputs match.\n")
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 3: Multi-batch with variable sequence lengths
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_variable_seq_lens():
|
|
print("=" * 70)
|
|
print("TEST 3: Multi-batch with variable sequence lengths")
|
|
print("=" * 70)
|
|
|
|
np.random.seed(123)
|
|
B, H, D = 3, 4, 8
|
|
S_max = 32
|
|
|
|
cache = KVCache(B, S_max, H, D)
|
|
|
|
# --- Prefill each batch element with a different prompt length ---
|
|
# We bypass the batched update() and write each element directly
|
|
# into the underlying cache arrays. This simulates the real
|
|
# scenario where different requests arrive with different prompt
|
|
# lengths and are packed into the same batch.
|
|
prompt_lens = [5, 12, 3]
|
|
|
|
original_k = {}
|
|
original_v = {}
|
|
|
|
for b in range(B):
|
|
L = prompt_lens[b]
|
|
k = np.random.randn(H, L, D).astype(np.float32)
|
|
v = np.random.randn(H, L, D).astype(np.float32)
|
|
cache.k_cache[b, :, :L, :] = k
|
|
cache.v_cache[b, :, :L, :] = v
|
|
cache.seq_lens[b] = L
|
|
original_k[b] = k
|
|
original_v[b] = v
|
|
|
|
print(f"After prefill: seq_lens={cache.seq_lens}")
|
|
assert cache.seq_lens == prompt_lens
|
|
|
|
# --- Verify prefill retrieval ---
|
|
for b in range(B):
|
|
k_ret, v_ret = cache.get_kv(b)
|
|
np.testing.assert_allclose(k_ret, original_k[b])
|
|
np.testing.assert_allclose(v_ret, original_v[b])
|
|
print(f" Batch {b}: ✓ prefill data verified (len={prompt_lens[b]})")
|
|
|
|
# --- Decode: all batch elements advance together (normal decode) ---
|
|
for step in range(4):
|
|
one_k = np.random.randn(B, H, 1, D).astype(np.float32)
|
|
one_v = np.random.randn(B, H, 1, D).astype(np.float32)
|
|
cache.update(one_k, one_v)
|
|
print(f" Decode step {step}: seq_lens={cache.seq_lens}")
|
|
|
|
# Verify each batch element has the right length
|
|
expected = [l + 4 for l in prompt_lens]
|
|
for b in range(B):
|
|
k_b, v_b = cache.get_kv(b)
|
|
print(f" Batch {b}: expected len={expected[b]}, got K shape seq dim={k_b.shape[1]}")
|
|
assert k_b.shape[1] == expected[b]
|
|
|
|
print("✓ Variable sequence lengths handled correctly.\n")
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 4: Incremental decoder end-to-end
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_incremental_decoder():
|
|
print("=" * 70)
|
|
print("TEST 4: Incremental decoder (prefill + autoregressive decode)")
|
|
print("=" * 70)
|
|
|
|
np.random.seed(7)
|
|
d_model = 32
|
|
num_heads = 4
|
|
num_layers = 2
|
|
max_seq_len = 64
|
|
vocab_size = 100
|
|
B = 1
|
|
|
|
decoder = IncrementalDecoder(d_model, num_heads, num_layers, max_seq_len, vocab_size)
|
|
decoder.max_seq_len = max_seq_len
|
|
decoder._init_caches(B)
|
|
|
|
# Prefill with a prompt of 8 tokens
|
|
prompt = np.array([[1, 5, 10, 15, 20, 25, 30, 35]], dtype=np.int64) # (1, 8)
|
|
logits = decoder.forward_step(prompt, decoder.caches, is_prefill=True)
|
|
print(f"After prefill (8 tokens):")
|
|
print(f" Logits shape: {logits.shape}")
|
|
print(f" Cache seq_lens: {[c.seq_lens for c in decoder.caches]}")
|
|
|
|
# Autoregressive decode: generate 5 more tokens
|
|
generated = []
|
|
next_token = logits.argmax(axis=-1) # (1,)
|
|
generated.append(next_token[0])
|
|
|
|
for step in range(5):
|
|
logits = decoder.forward_step(next_token, decoder.caches)
|
|
next_token = logits.argmax(axis=-1)
|
|
generated.append(next_token[0])
|
|
print(
|
|
f" Decode step {step}: seq_lens={decoder.caches[0].seq_lens}, "
|
|
f"token={next_token[0]}"
|
|
)
|
|
|
|
assert decoder.caches[0].seq_lens[0] == 8 + 5, "Should have 13 tokens cached"
|
|
print(f"Generated tokens: {generated}")
|
|
print("✓ Incremental decoder works.\n")
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 5: Paged KV-cache
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_paged_cache():
|
|
print("=" * 70)
|
|
print("TEST 5: Paged KV-cache (block-based allocation)")
|
|
print("=" * 70)
|
|
|
|
np.random.seed(99)
|
|
num_blocks = 20
|
|
block_size = 4
|
|
H, D = 4, 8
|
|
max_seqs = 4
|
|
|
|
paged = PagedKVCache(num_blocks, block_size, H, D, max_seqs)
|
|
print(f"Initial: {paged}")
|
|
|
|
# Start 3 sequences with different lengths
|
|
seq_ids = []
|
|
for _ in range(3):
|
|
sid = paged.add_sequence()
|
|
seq_ids.append(sid)
|
|
|
|
# Write different amounts to each
|
|
lengths = [6, 11, 3]
|
|
original_data_k = {}
|
|
original_data_v = {}
|
|
|
|
for i, sid in enumerate(seq_ids):
|
|
L = lengths[i]
|
|
k = np.random.randn(H, L, D).astype(np.float32)
|
|
v = np.random.randn(H, L, D).astype(np.float32)
|
|
paged.update(sid, k, v)
|
|
original_data_k[sid] = k
|
|
original_data_v[sid] = v
|
|
print(f" Seq {sid}: wrote {L} tokens, seq_len={paged.seq_lens[sid]}")
|
|
|
|
print(f"After writes: {paged}")
|
|
|
|
# Verify retrieval
|
|
for i, sid in enumerate(seq_ids):
|
|
k_ret, v_ret = paged.get_kv(sid)
|
|
L = lengths[i]
|
|
assert k_ret.shape == (H, L, D), f"Seq {sid}: expected ({H}, {L}, {D}), got {k_ret.shape}"
|
|
np.testing.assert_allclose(k_ret, original_data_k[sid], atol=1e-6)
|
|
np.testing.assert_allclose(v_ret, original_data_v[sid], atol=1e-6)
|
|
print(f" Seq {sid}: ✓ retrieved data matches original")
|
|
|
|
# Finish sequence 1 and verify blocks are freed
|
|
paged.finish_sequence(seq_ids[1])
|
|
print(f"After finishing seq {seq_ids[1]}: {paged}")
|
|
|
|
# Allocate a new sequence — should reuse freed blocks
|
|
new_sid = paged.add_sequence()
|
|
k_new = np.random.randn(H, 8, D).astype(np.float32)
|
|
v_new = np.random.randn(H, 8, D).astype(np.float32)
|
|
paged.update(new_sid, k_new, v_new)
|
|
print(f"New seq {new_sid} with 8 tokens: {paged}")
|
|
|
|
# Verify new sequence data
|
|
k_new_ret, v_new_ret = paged.get_kv(new_sid)
|
|
np.testing.assert_allclose(k_new_ret, k_new, atol=1e-6)
|
|
print("✓ Paged KV-cache works correctly.\n")
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 6: Quantized KV-cache
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_quantized_cache():
|
|
print("=" * 70)
|
|
print("TEST 6: Quantized KV-cache (INT8 and INT4)")
|
|
print("=" * 70)
|
|
|
|
np.random.seed(42)
|
|
B, H, D, S_max = 1, 2, 8, 32
|
|
|
|
for bits in [8, 4]:
|
|
print(f"\n--- INT{bits} ---")
|
|
qcache = QuantizedKVCache(B, S_max, H, D, bits=bits)
|
|
print(f" {qcache}")
|
|
|
|
# Write some tokens
|
|
T = 10
|
|
k_orig = np.random.randn(B, H, T, D).astype(np.float32) * 2
|
|
v_orig = np.random.randn(B, H, T, D).astype(np.float32) * 2
|
|
qcache.update(k_orig, v_orig)
|
|
|
|
# Retrieve and measure error
|
|
k_ret, v_ret = qcache.get_kv(0)
|
|
assert k_ret.shape == (H, T, D)
|
|
|
|
k_error = np.mean(np.abs(k_ret - k_orig[0]))
|
|
v_error = np.mean(np.abs(v_ret - v_orig[0]))
|
|
print(f" Mean absolute error (K): {k_error:.6f}")
|
|
print(f" Mean absolute error (V): {v_error:.6f}")
|
|
print(f" Memory savings vs FP32: {qcache.savings_vs_fp32():.3f}x")
|
|
print(f" Actual memory: {qcache.memory_bytes() / 1e3:.1f} KB")
|
|
|
|
# For INT8, error should be small; for INT4, larger but bounded
|
|
# Scale factor ≈ (max-min) / 255 for INT8, so error ≈ scale/2 per element
|
|
max_expected_error = {8: 0.1, 4: 0.5}
|
|
assert k_error < max_expected_error[bits], f"INT{bits} quantization error too large: {k_error}"
|
|
|
|
print("\n✓ Quantized cache works.\n")
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 7: Memory growth analysis
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_memory_analysis():
|
|
print("=" * 70)
|
|
print("TEST 7: Memory growth analysis")
|
|
print("=" * 70)
|
|
|
|
# GPT-4 class model: 32 layers, 32 heads, dim 128
|
|
print("\nKV-Cache Memory vs Sequence Length (GPT-4-class model)")
|
|
print("Model: 32 layers, 32 heads, head_dim=128, batch=1, FP32")
|
|
print(memory_growth_table())
|
|
|
|
# Llama-2 70B class
|
|
print("\nKV-Cache Memory vs Sequence Length (Llama-2 70B class)")
|
|
print("Model: 80 layers, 64 heads, head_dim=128, batch=1, FP32")
|
|
print(memory_growth_table(num_layers=80, num_heads=64, head_dim=128))
|
|
|
|
# Batch scaling
|
|
print("\nMemory scaling with batch size (seq_len=4096):")
|
|
print(f"{'Batch':>8} | {'Total (GB)':>12}")
|
|
print("-" * 28)
|
|
for bs in [1, 2, 4, 8, 16, 32, 64]:
|
|
info = memory_analysis(32, 32, 128, bs, 4096)
|
|
print(f"{bs:>8} | {info['total_GB']:>12.3f}")
|
|
|
|
print()
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# TEST 8: FLOPs comparison — cached vs uncached
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
def test_flops_analysis():
|
|
print("=" * 70)
|
|
print("TEST 8: FLOPs saved by KV-caching")
|
|
print("=" * 70)
|
|
|
|
d_model = 4096
|
|
H = 32
|
|
D = d_model // H
|
|
prompt_len = 1024
|
|
decode_steps = 100
|
|
|
|
# Without cache: each decode step recomputes attention for ALL positions
|
|
# FLOPs per attention step = 2 * S * d_model (Q projection)
|
|
# + 2 * S * d_model * S (attention scores) -- O(S²)
|
|
# + 2 * S * d_model * S (weighted sum)
|
|
# ≈ 4 * S² * d_model per layer
|
|
|
|
# With cache: each decode step only computes for 1 new token
|
|
# FLOPs = 2 * d_model (Q projection for 1 token)
|
|
# + 2 * S * d_model (Q * K^T for 1 query vs S keys)
|
|
# + 2 * S * d_model (attention weights * V)
|
|
# ≈ 4 * S * d_model per layer
|
|
|
|
flops_no_cache = 4 * decode_steps * (prompt_len + decode_steps) ** 2 * d_model
|
|
flops_cached = (
|
|
# Prefill: O(S² * d_model)
|
|
4 * prompt_len**2 * d_model
|
|
# Decode: O(S * d_model) per step
|
|
+ sum(4 * (prompt_len + t) * d_model for t in range(decode_steps))
|
|
)
|
|
|
|
print(f"Model d_model={d_model}, H={H}, prompt={prompt_len}, decode={decode_steps}")
|
|
print(f" Without cache: {flops_no_cache:.3e} FLOPs")
|
|
print(f" With cache: {flops_cached:.3e} FLOPs")
|
|
print(f" Speedup: {flops_no_cache / flops_cached:.1f}x")
|
|
print()
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
# MAIN
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
if __name__ == "__main__":
|
|
test_basic_cache()
|
|
test_attention_correctness()
|
|
test_variable_seq_lens()
|
|
test_incremental_decoder()
|
|
test_paged_cache()
|
|
test_quantized_cache()
|
|
test_memory_analysis()
|
|
test_flops_analysis()
|
|
|
|
print("=" * 70)
|
|
print("ALL TESTS PASSED ✓")
|
|
print("=" * 70)
|