feat: add model comparisons and sanitize session files

- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
This commit is contained in:
2026-04-23 11:16:01 +02:00
commit 8e72eef09c
62 changed files with 18469 additions and 0 deletions
+508
View File
@@ -0,0 +1,508 @@
"""
KV-Cache Optimizations
======================
Three production-grade optimizations for the base KV-cache:
1. PagedAttention — block-based virtual memory for the cache
2. Chunked Prefill — split long prompts into fixed-size chunks
3. Cache Quantization — compress K/V to lower precision
Each optimisation is a drop-in wrapper around the base KVCache
interface, keeping the same update / get_kv contract.
"""
from __future__ import annotations
import math
from typing import List, Tuple, Optional, Dict
import numpy as np
from kv_cache import KVCache
# ══════════════════════════════════════════════════════════════════════
# OPTIMIZATION 1: PAGED ATTENTION
# ══════════════════════════════════════════════════════════════════════
#
# Problem
# -------
# The base cache pre-allocates (B, H, S_max, D) per layer. If S_max
# is large (e.g. 128 k tokens) this wastes enormous memory for short
# sequences and fragments GPU memory when sequences finish at different
# times.
#
# Solution (cf. vLLM / PagedAttention)
# -------
# Divide the cache into fixed-size *blocks* (pages) of BLOCK_SIZE tokens.
# A per-sequence page table maps virtual positions → physical block ids.
# Blocks are allocated from a pool — freed when a sequence finishes and
# immediately reusable by a new sequence.
#
# Memory layout (physical):
# k_pool: (NUM_BLOCKS, H, BLOCK_SIZE, D)
# v_pool: (NUM_BLOCKS, H, BLOCK_SIZE, D)
#
# Per-sequence metadata:
# page_table: list[list[int]] — page_table[b] = [block_0, block_1, ...]
# seq_lens: list[int]
#
# GPU mapping: the page table lives in GPU memory and is indexed by a
# custom CUDA kernel that performs the gather from scattered blocks.
# On CPU we simulate it with index arithmetic.
# ══════════════════════════════════════════════════════════════════════
class PagedKVCache:
"""
Block-scattered KV cache inspired by vLLM's PagedAttention.
Unlike the base KVCache which pre-allocates a contiguous (B, H, S_max, D)
tensor, PagedKVCache allocates a fixed pool of blocks and assigns them
on demand. This eliminates:
- memory waste from over-provisioning S_max
- fragmentation from variable-length sequences
- the need for a single contiguous S_max allocation
"""
def __init__(
self,
num_blocks: int,
block_size: int,
num_heads: int,
head_dim: int,
max_num_seqs: int,
dtype: np.dtype = np.float32,
):
self.num_blocks = num_blocks
self.block_size = block_size
self.H = num_heads
self.D = head_dim
self.dtype = dtype
# Physical block pool — shapes (num_blocks, H, block_size, D)
self.k_pool = np.zeros(
(num_blocks, num_heads, block_size, head_dim), dtype=dtype
)
self.v_pool = np.zeros(
(num_blocks, num_heads, block_size, head_dim), dtype=dtype
)
# Free-list of available block indices
self.free_blocks: List[int] = list(range(num_blocks))
# Per-sequence bookkeeping
self.page_tables: List[List[int]] = [] # seq_id → list of block ids
self.seq_lens: List[int] = [] # seq_id → current length
self.max_num_seqs = max_num_seqs
# ── sequence lifecycle ───────────────────────────────────────────
def add_sequence(self) -> int:
"""Register a new sequence; returns its id."""
assert len(self.page_tables) < self.max_num_seqs, "too many sequences"
seq_id = len(self.page_tables)
self.page_tables.append([])
self.seq_lens.append(0)
return seq_id
def finish_sequence(self, seq_id: int) -> None:
"""Release all blocks held by a finished sequence."""
for block_id in self.page_tables[seq_id]:
self.free_blocks.append(block_id)
self.page_tables[seq_id] = []
self.seq_lens[seq_id] = 0
# ── block allocation ─────────────────────────────────────────────
def _ensure_blocks(self, seq_id: int, total_tokens: int) -> None:
"""Allocate enough blocks for `total_tokens` positions."""
blocks_needed = math.ceil(total_tokens / self.block_size)
current = len(self.page_tables[seq_id])
while current < blocks_needed:
if not self.free_blocks:
raise RuntimeError(
f"Out of blocks! Need {blocks_needed}, have {self.num_blocks} total. "
f"Free: {len(self.free_blocks)}"
)
self.page_tables[seq_id].append(self.free_blocks.pop(0))
current += 1
# ── update (write K, V) ──────────────────────────────────────────
def update(
self,
seq_id: int,
new_k: np.ndarray,
new_v: np.ndarray,
) -> None:
"""
Write new tokens for a single sequence.
new_k, new_v : shape (H, T, D)
"""
T = new_k.shape[1]
old_len = self.seq_lens[seq_id]
new_len = old_len + T
self._ensure_blocks(seq_id, new_len)
for t in range(T):
global_pos = old_len + t
block_idx = global_pos // self.block_size
offset = global_pos % self.block_size
phys_block = self.page_tables[seq_id][block_idx]
self.k_pool[phys_block, :, offset, :] = new_k[:, t, :]
self.v_pool[phys_block, :, offset, :] = new_v[:, t, :]
self.seq_lens[seq_id] = new_len
# ── retrieval (gather scattered blocks) ──────────────────────────
def get_kv(self, seq_id: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Gather keys/values for a sequence from scattered blocks.
Returns (H, S_valid, D) arrays for keys and values.
"""
S = self.seq_lens[seq_id]
num_full_blocks = S // self.block_size
remainder = S % self.block_size
k_parts = []
v_parts = []
for i in range(num_full_blocks):
phys = self.page_tables[seq_id][i]
k_parts.append(self.k_pool[phys]) # (H, block_size, D)
v_parts.append(self.v_pool[phys])
if remainder > 0:
phys = self.page_tables[seq_id][num_full_blocks]
k_parts.append(self.k_pool[phys, :, :remainder, :])
v_parts.append(self.v_pool[phys, :, :remainder, :])
if not k_parts:
H, D = self.H, self.D
return np.empty((H, 0, D), dtype=self.dtype), np.empty(
(H, 0, D), dtype=self.dtype
)
return np.concatenate(k_parts, axis=1), np.concatenate(v_parts, axis=1)
# ── memory stats ─────────────────────────────────────────────────
def memory_bytes(self) -> int:
return self.k_pool.nbytes + self.v_pool.nbytes
def utilization(self) -> float:
"""Fraction of blocks currently in use."""
used = self.num_blocks - len(self.free_blocks)
return used / self.num_blocks
def __repr__(self) -> str:
used = self.num_blocks - len(self.free_blocks)
return (
f"PagedKVCache(blocks={used}/{self.num_blocks}, "
f"block_size={self.block_size}, H={self.H}, D={self.D}, "
f"seqs={len(self.page_tables)}, "
f"mem={self.memory_bytes() / 1e6:.1f} MB)"
)
# ══════════════════════════════════════════════════════════════════════
# OPTIMIZATION 2: CHUNKED PREFILL
# ══════════════════════════════════════════════════════════════════════
#
# Problem
# -------
# During the prefill phase the entire prompt is processed in one shot.
# For a prompt of length S this means an O(S²) attention matrix which
# can blow up memory and latency (e.g. S=32 k → 1 billion elements).
#
# Solution
# --------
# Split the prompt into chunks of CHUNK_SIZE tokens. Process each
# chunk sequentially, writing its K/V into the cache. Subsequent
# chunks attend to all previously cached chunks *plus* their own
# positions (causal masking within the current chunk).
#
# This reduces peak memory from O(S²) to O(CHUNK_SIZE × S) and
# allows overlapping prefill of one request with decode of others.
# ══════════════════════════════════════════════════════════════════════
class ChunkedPrefillCache:
"""
Wrapper around KVCache that processes long prompts in chunks.
Instead of filling the entire prompt at once (O(S²) memory),
we iterate over chunks of size C:
- Each chunk's K/V is written to the cache
- Attention for chunk i sees positions [0 .. i*C + C)
- Peak attention memory: O(C × i*C) instead of O(S²)
"""
def __init__(
self,
base_cache: KVCache,
chunk_size: int = 512,
):
self.cache = base_cache
self.chunk_size = chunk_size
def prefill(
self,
all_k: np.ndarray,
all_v: np.ndarray,
w_q: np.ndarray,
w_k: np.ndarray,
w_v: np.ndarray,
w_o: np.ndarray,
) -> np.ndarray:
"""
Process a long prompt in chunks.
all_k, all_v : (B, H, S, D) — keys and values for the full prompt
w_q, w_k, w_v, w_o : projection matrices
Returns the output of the *last* chunk (B, 1, d_model) which
is needed for predicting the next token.
"""
B, H, S, D = all_k.shape
chunk_size = self.chunk_size
num_chunks = math.ceil(S / chunk_size)
last_output = None
for c in range(num_chunks):
start = c * chunk_size
end = min(start + chunk_size, S)
T = end - start
# Write this chunk's K, V into the cache
chunk_k = all_k[:, :, start:end, :] # (B, H, T, D)
chunk_v = all_v[:, :, start:end, :]
self.cache.update(chunk_k, chunk_v)
# Now compute attention: queries from this chunk vs all cached K,V
# For simplicity, return the last-position output
from kv_cache import multi_head_attention_with_cache
# Reconstruct a fake q_new in (B, T, d_model) space
# In a real model q would come from the embedding of chunk tokens
# Here we simulate by just using the chunk's K projected through w_q
d_model = w_q.shape[0]
# We only need the last position for autoregressive output
q_single = np.random.randn(B, 1, d_model).astype(all_k.dtype)
last_output = multi_head_attention_with_cache(
q_single, self.cache, w_q, w_k, w_v, w_o
)
return last_output
# ══════════════════════════════════════════════════════════════════════
# OPTIMIZATION 3: KV CACHE QUANTIZATION (INT8 / INT4)
# ══════════════════════════════════════════════════════════════════════
#
# Problem
# -------
# For long contexts the cache grows linearly with sequence length.
# A 32-layer, 32-head, 128-dim model at batch=1 and seq=65 k uses:
# 2 × 32 × 32 × 128 × 65536 × 4 bytes ≈ 68 GB (!!!)
#
# Solution
# --------
# Quantize cached K/V to lower precision on-the-fly:
# - INT8: store scale + quantized values → 2× memory reduction
# - INT4: store scale + quantized values → 4× memory reduction
#
# During attention, dequantize back to FP32 before matmul.
# This trades a small accuracy loss for massive memory savings.
#
# GPU mapping:
# - Store quantized data in INT8/INT4 tensors
# - Dequantize in registers before the QK^T matmul
# - Or use specialized kernels (e.g. FP8 attention in Hopper GPUs)
# ══════════════════════════════════════════════════════════════════════
class QuantizedKVCache:
"""
KV cache with on-the-fly quantization to a target bit-width.
Internally stores:
k_quant : uint8 array (packed)
k_scale : float32 per-(batch, head, token) scale factor
v_quant : uint8 array (packed)
v_scale : float32 per-(batch, head, token) scale factor
Supports INT8 (bits=8) and INT4 (bits=4, stored 2-per-byte).
"""
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
bits: int = 8,
):
assert bits in (4, 8), "Only INT8 and INT4 are supported"
self.B = batch_size
self.S_max = max_seq_len
self.H = num_heads
self.D = head_dim
self.bits = bits
# Per-token scale factors and zero points: (B, H, S_max)
self.k_scale = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
self.v_scale = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
self.k_zp = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
self.v_zp = np.zeros((batch_size, num_heads, max_seq_len), dtype=np.float32)
if bits == 8:
self.k_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim), dtype=np.uint8
)
self.v_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim), dtype=np.uint8
)
else:
# INT4: pack 2 values per byte → head_dim / 2 bytes per token
assert head_dim % 2 == 0, "head_dim must be even for INT4 packing"
self.k_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim // 2), dtype=np.uint8
)
self.v_quant = np.zeros(
(batch_size, num_heads, max_seq_len, head_dim // 2), dtype=np.uint8
)
self.seq_lens: List[int] = [0] * batch_size
# ── quantization helpers ─────────────────────────────────────────
def _quantize_token(self, vec: np.ndarray) -> Tuple[np.ndarray, np.float32]:
"""Quantize a 1-D vector to unsigned integers + scale."""
vmin = np.min(vec)
vmax = np.max(vec)
max_int = (1 << self.bits) - 1
scale = (vmax - vmin) / max_int if max_int > 0 else 1.0
zero_point = vmin # shift so min maps to 0
quantized = np.clip(np.round((vec - zero_point) / (scale + 1e-8)), 0, max_int).astype(np.uint8)
return quantized, np.float32(scale), np.float32(zero_point)
def _pack_int4(self, vec: np.ndarray) -> np.ndarray:
"""Pack a uint8 vector of 0..15 values into nibbles."""
packed = np.zeros(len(vec) // 2, dtype=np.uint8)
for i in range(len(vec) // 2):
packed[i] = (vec[2 * i] << 4) | vec[2 * i + 1]
return packed
def _unpack_int4(self, packed: np.ndarray) -> np.ndarray:
"""Unpack nibbles back to a full uint8 vector."""
out = np.zeros(len(packed) * 2, dtype=np.uint8)
for i in range(len(packed)):
out[2 * i] = (packed[i] >> 4) & 0x0F
out[2 * i + 1] = packed[i] & 0x0F
return out
# ── dequantize for attention ─────────────────────────────────────
def _dequantize_token(
self, quant: np.ndarray, scale: np.float32, zero_point: np.float32
) -> np.ndarray:
"""Dequantize back to float32."""
if self.bits == 4:
unpacked = self._unpack_int4(quant)
else:
unpacked = quant.astype(np.float32)
return unpacked * (scale + 1e-8) + zero_point
# ── update ───────────────────────────────────────────────────────
def update(
self,
new_k: np.ndarray,
new_v: np.ndarray,
) -> None:
"""
Quantize and store new K/V tokens.
new_k, new_v : (B, H, T, D) float32
"""
T = new_k.shape[2]
for b in range(self.B):
pos = self.seq_lens[b]
for h in range(self.H):
for t in range(T):
k_vec = new_k[b, h, t, :]
v_vec = new_v[b, h, t, :]
k_q, k_s, k_z = self._quantize_token(k_vec)
v_q, v_s, v_z = self._quantize_token(v_vec)
self.k_scale[b, h, pos + t] = k_s
self.v_scale[b, h, pos + t] = v_s
self.k_zp[b, h, pos + t] = k_z
self.v_zp[b, h, pos + t] = v_z
if self.bits == 8:
self.k_quant[b, h, pos + t, :] = k_q
self.v_quant[b, h, pos + t, :] = v_q
else:
self.k_quant[b, h, pos + t, :] = self._pack_int4(k_q)
self.v_quant[b, h, pos + t, :] = self._pack_int4(v_q)
self.seq_lens[b] += T
# ── retrieval ────────────────────────────────────────────────────
def get_kv(self, batch_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Dequantize and return (H, S_valid, D) arrays.
"""
S = self.seq_lens[batch_idx]
k_out = np.zeros((self.H, S, self.D), dtype=np.float32)
v_out = np.zeros((self.H, S, self.D), dtype=np.float32)
for h in range(self.H):
for t in range(S):
scale_k = self.k_scale[batch_idx, h, t]
scale_v = self.v_scale[batch_idx, h, t]
zp_k = self.k_zp[batch_idx, h, t]
zp_v = self.v_zp[batch_idx, h, t]
if self.bits == 8:
k_q = self.k_quant[batch_idx, h, t, :]
v_q = self.v_quant[batch_idx, h, t, :]
else:
k_q = self.k_quant[batch_idx, h, t, :]
v_q = self.v_quant[batch_idx, h, t, :]
k_out[h, t, :] = self._dequantize_token(k_q, scale_k, zp_k)
v_out[h, t, :] = self._dequantize_token(v_q, scale_v, zp_v)
return k_out, v_out
# ── memory savings ───────────────────────────────────────────────
def memory_bytes(self) -> int:
return (
self.k_quant.nbytes + self.v_quant.nbytes
+ self.k_scale.nbytes + self.v_scale.nbytes
+ self.k_zp.nbytes + self.v_zp.nbytes
)
def savings_vs_fp32(self) -> float:
"""Ratio of this cache's memory to an equivalent FP32 cache."""
fp32_bytes = (
2 * self.B * self.H * self.S_max * self.D * 4 # 2 arrays × 4 bytes
)
return self.memory_bytes() / fp32_bytes
def __repr__(self) -> str:
return (
f"QuantizedKVCache(INT{self.bits}, B={self.B}, H={self.H}, "
f"S_max={self.S_max}, D={self.D}, "
f"mem={self.memory_bytes() / 1e6:.1f} MB, "
f"savings={self.savings_vs_fp32():.2f}x vs FP32)"
)