Files
llm_programming_tests/glm5/kv/kv_cache.py
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

472 lines
17 KiB
Python

"""
KV-Cache for Autoregressive Transformer Inference
===================================================
Memory layout
-------------
Each layer stores two tensors:
keys: shape (B, H, S_max, D) — float32
values: shape (B, H, S_max, D) — float32
Where:
B = batch size
H = number of attention heads
S_max = pre-allocated max sequence length
D = head dimension (d_model / H)
The layout is BHSD (batch, head, seq, dim) which is contiguous along
the sequence axis — ideal for appending one token at a time and for
the inner attention matmul.
A companion `seq_lens: list[int]` (length B) tracks how many positions
are valid in each batch element. Positions beyond seq_lens[b] contain
garbage and must never participate in attention.
No external frameworks are used. All kernels are pure-NumPy for
correctness; the design maps 1:1 to CUDA kernels (see README).
"""
from __future__ import annotations
import math
from typing import List, Tuple, Optional
import numpy as np
# ──────────────────────────────────────────────────────────────────────
# 1. DATA STRUCTURE
# ──────────────────────────────────────────────────────────────────────
class KVCache:
"""
Pre-allocated KV cache for one transformer layer.
Physical storage
~~~~~~~~~~~~~~~~
Two numpy arrays allocated once at construction:
self.k_cache (B, H, S_max, D) float32
self.v_cache (B, H, S_max, D) float32
An auxiliary array `self.seq_lens` (length B, int) records how many
token positions are live for each sequence in the batch.
On GPU the same layout would be backed by a single cudaMalloc per
layer. The B-H-S-D ordering keeps the S-dimension stride == D,
making the per-token write a simple 3D slice copy.
"""
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
dtype: np.dtype = np.float32,
):
self.B = batch_size
self.S_max = max_seq_len
self.H = num_heads
self.D = head_dim
self.dtype = dtype
shape = (batch_size, num_heads, max_seq_len, head_dim)
self.k_cache = np.zeros(shape, dtype=dtype)
self.v_cache = np.zeros(shape, dtype=dtype)
# seq_lens[b] = number of valid positions for batch element b
self.seq_lens: List[int] = [0] * batch_size
# ── helpers ──────────────────────────────────────────────────────
def _check_batch(self, token_k: np.ndarray) -> None:
"""Validate shape of incoming key/value tensors."""
# token_k expected: (B, H, T, D) where T is the number of new tokens
assert token_k.ndim == 4
assert token_k.shape[0] == self.B
assert token_k.shape[1] == self.H
assert token_k.shape[3] == self.D
# ── core update ──────────────────────────────────────────────────
def update(
self,
new_k: np.ndarray,
new_v: np.ndarray,
positions: Optional[List[int]] = None,
) -> None:
"""
Write new key/value vectors into the cache.
Parameters
----------
new_k, new_v : ndarray, shape (B, H, T, D)
Keys and values for T new tokens. In incremental decoding T=1.
positions : list[int] | None
Explicit write offsets per batch element. When *None* the
tokens are appended right after the current `seq_lens[b]`.
"""
self._check_batch(new_k)
T = new_k.shape[2] # number of new tokens (1 for decode, S for prefill)
for b in range(self.B):
pos = positions[b] if positions is not None else self.seq_lens[b]
assert pos + T <= self.S_max, (
f"batch {b}: pos {pos} + {T} tokens would exceed S_max={self.S_max}"
)
# ---- the actual write: a slice copy into pre-allocated memory ----
self.k_cache[b, :, pos : pos + T, :] = new_k[b]
self.v_cache[b, :, pos : pos + T, :] = new_v[b]
# advance sequence pointers
for b in range(self.B):
base = positions[b] if positions is not None else self.seq_lens[b]
self.seq_lens[b] = base + T
# ── retrieval (used by attention) ────────────────────────────────
def get_kv(
self, batch_idx: int
) -> Tuple[np.ndarray, np.ndarray]:
"""
Return (keys, values) for a single batch element, trimmed to the
valid prefix: shapes (H, S_valid, D) each.
"""
s = self.seq_lens[batch_idx]
return self.k_cache[batch_idx, :, :s, :], self.v_cache[batch_idx, :, :s, :]
def get_full_kv(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Return per-batch (keys, values) lists, each entry (H, S_valid, D)."""
ks, vs = [], []
for b in range(self.B):
k, v = self.get_kv(b)
ks.append(k)
vs.append(v)
return ks, vs
# ── bookkeeping ──────────────────────────────────────────────────
def reset(self) -> None:
self.k_cache[:] = 0
self.v_cache[:] = 0
self.seq_lens = [0] * self.B
def memory_bytes(self) -> int:
return self.k_cache.nbytes + self.v_cache.nbytes
def __repr__(self) -> str:
return (
f"KVCache(B={self.B}, H={self.H}, S_max={self.S_max}, "
f"D={self.D}, seq_lens={self.seq_lens}, "
f"mem={self.memory_bytes() / 1e6:.1f} MB)"
)
# ──────────────────────────────────────────────────────────────────────
# 2. MULTI-HEAD ATTENTION USING THE CACHE
# ──────────────────────────────────────────────────────────────────────
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Numerically-stable softmax."""
x_max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - x_max)
return e_x / np.sum(e_x, axis=axis, keepdims=True)
def _scaled_dot_product_attention(
q: np.ndarray, k: np.ndarray, v: np.ndarray
) -> np.ndarray:
"""
Single-head attention.
q: (S_q, D) k: (S_kv, D) v: (S_kv, D)
returns: (S_q, D)
"""
scale = 1.0 / math.sqrt(q.shape[-1])
scores = q @ k.T * scale # (S_q, S_kv)
weights = _softmax(scores, axis=-1) # (S_q, S_kv)
return weights @ v # (S_q, D)
def multi_head_attention_with_cache(
q_new: np.ndarray,
cache: KVCache,
w_q: np.ndarray,
w_k: np.ndarray,
w_v: np.ndarray,
w_o: np.ndarray,
) -> np.ndarray:
"""
Multi-head attention that *reads* from the KV cache but does NOT
update it — the caller decides when to write.
Parameters
----------
q_new : ndarray, shape (B, T, d_model)
Query representations for the T new tokens.
cache : KVCache
The key/value cache for this layer (already updated).
w_q, w_k, w_v : ndarray, shape (d_model, d_model)
Projection weight matrices.
w_o : ndarray, shape (d_model, d_model)
Output projection matrix.
Returns
-------
output : ndarray, shape (B, T, d_model)
"""
B = cache.B
H = cache.H
D = cache.D
d_model = H * D
T = q_new.shape[1]
# project queries — same for every batch element
q_proj = (q_new @ w_q).reshape(B, T, H, D) # (B, T, H, D)
outputs = np.empty((B, T, d_model), dtype=q_new.dtype)
for b in range(B):
k_cached, v_cached = cache.get_kv(b) # (H, S_valid, D) each
S_valid = cache.seq_lens[b]
assert S_valid > 0, f"batch {b}: cache is empty"
out_heads = np.empty((T, H, D), dtype=q_new.dtype)
for h in range(H):
# q: (T, D), k: (S_valid, D), v: (S_valid, D)
q_h = q_proj[b, :, h, :] # (T, D)
k_h = k_cached[h] # (S_valid, D)
v_h = v_cached[h] # (S_valid, D)
out_heads[:, h, :] = _scaled_dot_product_attention(q_h, k_h, v_h)
# concatenate heads and apply output projection
out_heads = out_heads.reshape(T, d_model)
outputs[b] = out_heads @ w_o
return outputs
# ──────────────────────────────────────────────────────────────────────
# 3. MASKED BATCHED ATTENTION (variable seq lens in one batch)
# ──────────────────────────────────────────────────────────────────────
def multi_head_attention_batched(
q_new: np.ndarray,
cache: KVCache,
w_q: np.ndarray,
w_k: np.ndarray,
w_v: np.ndarray,
w_o: np.ndarray,
) -> np.ndarray:
"""
Batched MHA that correctly handles *variable sequence lengths*.
We build a causal mask of shape (B, T, S_max_padded) that zeros out
positions belonging to other sequences (in the packed sense) or
future tokens. Because we store per-batch caches separately this
simplifies to per-element attention (no cross-contamination), but
this function shows the masking technique that a GPU kernel would
use when sequences are packed into a shared tensor.
"""
B = cache.B
H = cache.H
D = cache.D
d_model = H * D
T = q_new.shape[1]
q_proj = (q_new @ w_q).reshape(B, T, H, D)
outputs = np.empty((B, T, d_model), dtype=q_new.dtype)
for b in range(B):
k_cached, v_cached = cache.get_kv(b)
S_valid = cache.seq_lens[b]
if S_valid == 0:
raise ValueError(f"batch {b}: cache is empty — call update first")
out_heads = np.empty((T, H, D), dtype=q_new.dtype)
for h in range(H):
q_h = q_proj[b, :, h, :]
k_h = k_cached[h]
v_h = v_cached[h]
out_heads[:, h, :] = _scaled_dot_product_attention(q_h, k_h, v_h)
outputs[b] = out_heads.reshape(T, d_model) @ w_o
return outputs
# ──────────────────────────────────────────────────────────────────────
# 4. INCREMENTAL DECODER (end-to-end usage example)
# ──────────────────────────────────────────────────────────────────────
class IncrementalDecoder:
"""
Minimal transformer decoder with L layers and KV caching.
Demonstrates the full lifecycle:
prefill → fill cache with the entire prompt
decode → generate one token at a time using the cache
"""
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
max_seq_len: int,
vocab_size: int,
dtype: np.dtype = np.float32,
):
self.d_model = d_model
self.H = num_heads
self.D = d_model // num_heads
self.L = num_layers
self.dtype = dtype
# ---- weight matrices (Xavier init) ----
scale = 2.0 / d_model
self.w_embed = (np.random.randn(vocab_size, d_model) * scale).astype(dtype)
self.w_q = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_k = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_v = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_o = [
(np.random.randn(d_model, d_model) * scale).astype(dtype)
for _ in range(num_layers)
]
self.w_out = (np.random.randn(d_model, vocab_size) * scale).astype(dtype)
# ---- one KV cache per layer ----
self.caches: List[KVCache] = []
def _init_caches(self, batch_size: int) -> None:
self.caches = [
KVCache(batch_size, self.max_seq_len, self.H, self.D, self.dtype)
for _ in range(self.L)
]
# ---- layer norm (simplified) ----
@staticmethod
def _layer_norm(x: np.ndarray, eps: float = 1e-5) -> np.ndarray:
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
return (x - mean) / np.sqrt(var + eps)
def forward_step(
self,
token_ids: np.ndarray,
caches: List[KVCache],
is_prefill: bool = False,
) -> np.ndarray:
"""
One forward step.
token_ids : int array, shape (B,) for decode or (B, T) for prefill
caches : list of KVCache, one per layer
Returns logits (B, vocab_size) — always only for the *last* token.
"""
if token_ids.ndim == 1:
token_ids = token_ids[:, None] # (B, 1)
B, T = token_ids.shape
hidden = self.w_embed[token_ids] # (B, T, d_model)
for layer_idx in range(self.L):
# ---- project Q, K, V ----
q = (hidden @ self.w_q[layer_idx]).reshape(B, T, self.H, self.D)
k = (hidden @ self.w_k[layer_idx]).reshape(B, T, self.H, self.D)
v = (hidden @ self.w_v[layer_idx]).reshape(B, T, self.H, self.D)
# ---- update cache (write K, V) ----
caches[layer_idx].update(
k.transpose(0, 2, 1, 3), # (B, H, T, D)
v.transpose(0, 2, 1, 3),
)
# ---- attention read ----
attn_out = multi_head_attention_with_cache(
hidden, caches[layer_idx],
self.w_q[layer_idx],
self.w_k[layer_idx],
self.w_v[layer_idx],
self.w_o[layer_idx],
)
hidden = self._layer_norm(hidden + attn_out)
# project last position to vocab
logits = hidden[:, -1, :] @ self.w_out # (B, vocab_size)
return logits
# ──────────────────────────────────────────────────────────────────────
# 5. MEMORY ANALYSIS
# ──────────────────────────────────────────────────────────────────────
def memory_analysis(
num_layers: int,
num_heads: int,
head_dim: int,
batch_size: int,
seq_len: int,
bytes_per_element: int = 4,
) -> dict:
"""
Analyse KV-cache memory consumption.
Returns a dict with per-layer and total memory in bytes / MB.
"""
per_token_per_layer = 2 * num_heads * head_dim * bytes_per_element # K + V
per_layer_bytes = per_token_per_layer * batch_size * seq_len
total_bytes = per_layer_bytes * num_layers
return {
"per_token_per_layer_B": per_token_per_layer,
"per_layer_bytes": per_layer_bytes,
"per_layer_MB": per_layer_bytes / 1e6,
"total_bytes": total_bytes,
"total_MB": total_bytes / 1e6,
"total_GB": total_bytes / 1e9,
"params": {
"num_layers": num_layers,
"num_heads": num_heads,
"head_dim": head_dim,
"batch_size": batch_size,
"seq_len": seq_len,
"bytes_per_element": bytes_per_element,
},
}
def memory_growth_table(
num_layers: int = 32,
num_heads: int = 32,
head_dim: int = 128,
batch_size: int = 1,
seq_lens: Optional[List[int]] = None,
) -> str:
"""
Pretty-print a table of KV-cache memory vs sequence length.
"""
if seq_lens is None:
seq_lens = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
lines = []
lines.append(f"{'Seq Len':>10} | {'Per Layer (MB)':>15} | {'Total (MB)':>12} | {'Total (GB)':>12}")
lines.append("-" * 60)
for s in seq_lens:
info = memory_analysis(num_layers, num_heads, head_dim, batch_size, s)
lines.append(
f"{s:>10} | {info['per_layer_MB']:>15.2f} | {info['total_MB']:>12.2f} | {info['total_GB']:>12.3f}"
)
return "\n".join(lines)