Files
llm_programming_tests/qwen36/kv/optimizations.py
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

590 lines
22 KiB
Python

"""
KV-Cache Optimizations
Implements three major optimization strategies:
1. Paged Attention — non-contiguous memory allocation (inspired by vLLM)
2. Quantization — reduced precision for cached K/V
3. Chunked Prefill — processing long prompts in chunks to limit peak memory
"""
import numpy as np
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass, field
from kv_cache import CacheConfig
# =============================================================================
# 1. PAGED ATTENTION
# =============================================================================
@dataclass
class PageConfig:
"""Configuration for paged KV cache."""
block_size: int = 16 # tokens per block
num_pages: int = 256 # total pages per sequence
batch_size: int = 4
num_heads: int = 32
head_dim: int = 128
dtype: np.dtype = np.float16
class PagedKVCache:
"""
Paged KV Cache — inspired by vLLM's PagedAttention.
Instead of a contiguous [batch, heads, max_seq, head_dim] buffer,
memory is divided into fixed-size blocks (pages). Each sequence
maintains a page table mapping logical block indices to physical pages.
Benefits:
- Zero memory fragmentation: blocks are allocated on demand
- Supports speculative decoding and branching
- Enables sharing of common prefixes (prefix caching)
- No need to pre-allocate max_seq_len
Memory layout:
physical_pages: (num_pages, batch_size, num_heads, block_size, head_dim) [for K]
physical_pages_v: same shape [for V]
page_tables: (batch_size, max_blocks) — maps logical block -> physical page index
"""
def __init__(self, config: PageConfig):
self.config = config
self.batch_size = config.batch_size
self.num_heads = config.num_heads
self.head_dim = config.head_dim
self.block_size = config.block_size
self.num_pages = config.num_pages
self.dtype = config.dtype
# Physical page pool (shared across all sequences)
# Each page holds: (num_heads, block_size, head_dim)
page_shape = (config.num_pages * config.batch_size,
config.num_heads, config.block_size, config.head_dim)
self.physical_pages_k = np.zeros(page_shape, dtype=self.dtype)
self.physical_pages_v = np.zeros(page_shape, dtype=self.dtype)
# Page table per sequence: logical_block_idx -> physical_page_idx
max_blocks = config.num_pages
self.page_tables = np.full(
(config.batch_size, max_blocks), -1, dtype=np.int32
)
# Number of allocated blocks per sequence
self.num_blocks = np.zeros(config.batch_size, dtype=np.int32)
# Free page pool (global, shared)
total_pages = config.num_pages * config.batch_size
self.free_list = np.arange(total_pages, dtype=np.int32)
self.free_ptr = 0 # index into free_list
def _alloc_page(self) -> int:
"""Allocate one physical page from the free pool."""
if self.free_ptr >= len(self.free_list):
raise MemoryError("Paged KV cache out of memory")
page_idx = self.free_list[self.free_ptr]
self.free_ptr += 1
return page_idx
def _free_page(self, page_idx: int):
"""Return a physical page to the free pool."""
self.free_list[self.free_ptr - 1] = page_idx
self.free_ptr -= 1
def reset(self):
"""Reset cache for a new generation."""
self.physical_pages_k[...] = 0
self.physical_pages_v[...] = 0
self.page_tables[...] = -1
self.num_blocks[...] = 0
self.free_ptr = 0
def append_token(self, batch_idx: int, keys: np.ndarray,
values: np.ndarray, logical_block: int,
offset_in_block: int):
"""
Append one token to a specific logical block.
Args:
batch_idx: batch item index
keys: (1, num_heads, 1, head_dim)
values: (1, num_heads, 1, head_dim)
logical_block: which logical block to write to
offset_in_block: position within the block (0..block_size-1)
"""
# Check if physical page is allocated for this logical block
phys_page = self.page_tables[batch_idx, logical_block]
if phys_page == -1:
# Allocate new physical page
phys_page = self._alloc_page()
self.page_tables[batch_idx, logical_block] = phys_page
if logical_block + 1 > self.num_blocks[batch_idx]:
self.num_blocks[batch_idx] = logical_block + 1
# Write to physical page
self.physical_pages_k[phys_page, :, offset_in_block, :] = keys[0, :, 0, :]
self.physical_pages_v[phys_page, :, offset_in_block, :] = values[0, :, 0, :]
def get_sequence(self, batch_idx: int,
start_block: int = 0,
end_block: int = None) -> Tuple[np.ndarray, np.ndarray]:
"""
Retrieve K and V for a sequence, gathering from physical pages.
Returns:
k: (num_heads, total_tokens, head_dim)
v: (num_heads, total_tokens, head_dim)
"""
if end_block is None:
end_block = self.num_blocks[batch_idx]
blocks = end_block - start_block
total_tokens = blocks * self.block_size
k_out = np.zeros(
(self.num_heads, total_tokens, self.head_dim), dtype=self.dtype
)
v_out = np.zeros(
(self.num_heads, total_tokens, self.head_dim), dtype=self.dtype
)
for i in range(start_block, end_block):
phys_page = self.page_tables[batch_idx, i]
if phys_page == -1:
break
block_idx = i - start_block
token_start = block_idx * self.block_size
token_end = token_start + self.block_size
k_out[:, token_start:token_end, :] = self.physical_pages_k[phys_page]
v_out[:, token_start:token_end, :] = self.physical_pages_v[phys_page]
return k_out, v_out
def get_sequence_contiguous(self, batch_idx: int,
num_tokens: int = None) -> Tuple[np.ndarray, np.ndarray]:
"""
Get K, V as contiguous arrays for attention computation.
Returns:
k: (1, num_heads, num_tokens, head_dim)
v: (1, num_heads, num_tokens, head_dim)
"""
if num_tokens is None:
num_tokens = self.num_blocks[batch_idx] * self.block_size
k, v = self.get_sequence(batch_idx)
# k: (num_heads, num_tokens, head_dim) -> (1, num_heads, num_tokens, head_dim)
return k[None, ...], v[None, ...]
@property
def memory_allocated_bytes(self) -> int:
elem_bytes = np.dtype(self.dtype).itemsize
total_pages = self.num_pages * self.batch_size
page_bytes = self.num_heads * self.block_size * self.head_dim * elem_bytes
return 2 * total_pages * page_bytes # K + V
@property
def memory_used_bytes(self) -> int:
"""Bytes actually used (allocated blocks only)."""
elem_bytes = np.dtype(self.dtype).itemsize
total_used_blocks = np.sum(self.num_blocks)
page_bytes = self.num_heads * self.block_size * self.head_dim * elem_bytes
return 2 * total_used_blocks * page_bytes
def memory_utilization(self) -> float:
"""Fraction of allocated memory actually used."""
alloc = self.memory_allocated_bytes
if alloc == 0:
return 0.0
return self.memory_used_bytes / alloc
# =============================================================================
# 2. QUANTIZED KV CACHE
# =============================================================================
class QuantizedKVCache:
"""
Quantized KV Cache — stores K and V in reduced precision.
Strategy: per-channel (per-head-dim) int8 quantization.
- Each head-dimension channel has its own scale and zero-point
- Dequantize on-the-fly during attention computation
Memory savings: float16 (16-bit) -> int8 (8-bit) = 2x reduction
Plus metadata overhead: 2 scales per channel (K and V) in float16
For head_dim=128:
- Original: 128 * 16 = 2048 bits per token per head
- Quantized: 128 * 8 + 2 * 128 * 16 = 1024 + 4096 = 5120 bits
- But scales are shared across all tokens, so per-token: 128 * 8 = 1024 bits
- Net savings: ~50%
"""
def __init__(self, batch_size: int, num_heads: int, head_dim: int,
max_seq_len: int, dtype=np.float16):
self.batch_size = batch_size
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.dtype = dtype
self.write_pos = 0
# Quantized storage: int8
shape = (batch_size, num_heads, max_seq_len, head_dim)
self.cache_k_int8 = np.zeros(shape, dtype=np.int8)
self.cache_v_int8 = np.zeros(shape, dtype=np.int8)
# Per-channel scales and zero-points per position
scale_shape = (batch_size, num_heads, max_seq_len, head_dim)
self.k_scales = np.ones(scale_shape, dtype=dtype)
self.k_zeros = np.zeros(scale_shape, dtype=dtype)
self.v_scales = np.ones(scale_shape, dtype=dtype)
self.v_zeros = np.zeros(scale_shape, dtype=dtype)
def reset(self):
self.cache_k_int8[...] = 0
self.cache_v_int8[...] = 0
self.k_scales[...] = 1.0
self.k_zeros[...] = 0.0
self.v_scales[...] = 1.0
self.v_zeros[...] = 0.0
self.write_pos = 0
def _quantize(self, x: np.ndarray, axis: int = -1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Quantize to int8 with per-channel affine transform: x ≈ scale * q + zero.
Returns quantized values, scales, and zero-points.
"""
x_f = x.astype(np.float32)
# Per-channel min/max
x_min = np.min(x_f, axis=axis, keepdims=True)
x_max = np.max(x_f, axis=axis, keepdims=True)
# Avoid division by zero
x_range = x_max - x_min
x_range = np.where(x_range < 1e-6, 1.0, x_range)
# Scale: map [-128, 127] to [x_min, x_max]
scale = x_range / 255.0
zero = x_min # zero-point
# Quantize
x_centered = x_f - zero
x_quant = np.round(x_centered / scale).astype(np.int8)
x_quant = np.clip(x_quant, -128, 127)
return x_quant, scale.astype(self.dtype), zero.astype(self.dtype)
def _dequantize(self, x_int8: np.ndarray, scale: np.ndarray,
zero: np.ndarray) -> np.ndarray:
"""Dequantize int8 back to float: x = scale * q + zero."""
return (x_int8.astype(np.float32) * scale + zero).astype(self.dtype)
def update(self, keys: np.ndarray, values: np.ndarray,
seqlen_offset: int = None):
"""
Quantize and store K, V.
Args:
keys: (batch, heads, 1, head_dim)
values: (batch, heads, 1, head_dim)
"""
if seqlen_offset is None:
seqlen_offset = self.write_pos
pos = seqlen_offset
# Quantize K
k_q, k_s, k_z = self._quantize(keys, axis=-1)
self.cache_k_int8[:, :, pos, :] = k_q[:, :, 0, :]
self.k_scales[:, :, pos:pos+1, :] = k_s
self.k_zeros[:, :, pos:pos+1, :] = k_z
# Quantize V
v_q, v_s, v_z = self._quantize(values, axis=-1)
self.cache_v_int8[:, :, pos, :] = v_q[:, :, 0, :]
self.v_scales[:, :, pos:pos+1, :] = v_s
self.v_zeros[:, :, pos:pos+1, :] = v_z
self.write_pos = pos + 1
def get(self, start: int = 0, end: int = None) -> Tuple[np.ndarray, np.ndarray]:
"""Get dequantized K, V."""
if end is None:
end = self.write_pos
k_int = self.cache_k_int8[:, :, start:end, :]
v_int = self.cache_v_int8[:, :, start:end, :]
# Dequantize using scales and zero-points from each position
k_deq = self._dequantize(k_int, self.k_scales[:, :, start:end, :],
self.k_zeros[:, :, start:end, :])
v_deq = self._dequantize(v_int, self.v_scales[:, :, start:end, :],
self.v_zeros[:, :, start:end, :])
return k_deq, v_deq
@property
def memory_allocated_bytes(self) -> int:
"""Total allocated memory including quantization metadata.
Includes: int8 K + int8 V + fp scales (K+V) + fp zero-points (K+V)
"""
elem_int8 = np.dtype(np.int8).itemsize
elem_fp = np.dtype(self.dtype).itemsize
n = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
k_v_bytes = 2 * n * elem_int8 # int8 K + V
meta_bytes = 4 * n * elem_fp # scales + zeros for K and V
return k_v_bytes + meta_bytes
@property
def memory_savings_vs_fp16(self) -> float:
"""Fraction of memory saved vs. full fp16 cache.
Note: with per-position scales in fp32, this may be negative.
For real savings, use fp16 scales or shared (per-channel) scales.
"""
elem_fp16 = np.dtype(np.float16).itemsize
fp16_bytes = 2 * self.batch_size * self.num_heads * self.max_seq_len * self.head_dim * elem_fp16
return 1.0 - self.memory_allocated_bytes / fp16_bytes
@property
def memory_savings_vs_fp32(self) -> float:
"""Fraction of memory saved vs. full fp32 cache."""
elem_fp32 = np.dtype(np.float32).itemsize
fp32_bytes = 2 * self.batch_size * self.num_heads * self.max_seq_len * self.head_dim * elem_fp32
return 1.0 - self.memory_allocated_bytes / fp32_bytes
# =============================================================================
# 3. CHUNKED PREFILL
# =============================================================================
class ChunkedPrefill:
"""
Chunked Prefill — process long prompts in chunks to limit peak memory.
During prefill with very long prompts (e.g., 32K tokens), computing
full attention O(n²) requires materializing a (n, n) attention matrix,
which can exceed GPU memory.
Chunked prefill processes the prompt in chunks of size C:
- Chunk 0: tokens [0, C) — full causal attention within chunk
- Chunk 1: tokens [C, 2C) — attend to all previous tokens + causal within chunk
- ...
Each chunk's attention is O(C * (i*C + C)) = O(i*C²), but the peak
memory for the attention matrix is O(C²) instead of O(n²).
The KV cache is updated incrementally after each chunk.
"""
def __init__(self, chunk_size: int = 512):
self.chunk_size = chunk_size
def compute_attention_chunked(
self,
q_all: np.ndarray,
k_all: np.ndarray,
v_all: np.ndarray,
scale: float,
dtype=np.float32,
) -> np.ndarray:
"""
Compute causal attention in chunks.
Args:
q_all: (batch, heads, seq, head_dim)
k_all: (batch, heads, seq, head_dim)
v_all: (batch, heads, seq, head_dim)
scale: 1 / sqrt(head_dim)
Returns:
output: (batch, heads, seq, head_dim)
"""
batch, heads, seq, head_dim = q_all.shape
output = np.zeros((batch, heads, seq, head_dim), dtype=dtype)
num_chunks = (seq + self.chunk_size - 1) // self.chunk_size
for chunk_idx in range(num_chunks):
start = chunk_idx * self.chunk_size
end = min(start + self.chunk_size, seq)
chunk_len = end - start
# Current chunk's Q
q_chunk = q_all[:, :, start:end, :] # (batch, heads, chunk_len, head_dim)
# Keys and values up to current position (causal)
k_prefix = k_all[:, :, :end, :] # (batch, heads, end, head_dim)
v_prefix = v_all[:, :, :end, :]
q_f = q_chunk.astype(dtype)
k_f = k_prefix.astype(dtype)
v_f = v_prefix.astype(dtype)
# Q @ K^T: (batch, heads, chunk_len, end)
scores = np.einsum("bhqd,bhkd->bhqk", q_f, k_f) * scale
# Causal mask: query at position p can only attend to keys at position <= p
# Query positions (absolute): start..end-1
# Key positions (absolute): 0..end-1
q_positions = np.arange(start, end) # (chunk_len,)
k_positions = np.arange(end) # (end,)
# Allowed: q_pos >= k_pos (causal)
causal_mask = (q_positions[:, None] >= k_positions[None, :]).astype(dtype)
# (chunk_len, end)
causal_mask = np.where(causal_mask, 0.0, -np.inf)
scores = scores + causal_mask[None, None, :, :]
# Softmax
attn_weights = self._softmax_stable(scores, axis=-1)
# Attn @ V
chunk_output = np.einsum("bhqk,bhkd->bhqd", attn_weights, v_f)
output[:, :, start:end, :] = chunk_output
return output
@staticmethod
def _softmax_stable(x: np.ndarray, axis: int = -1) -> np.ndarray:
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
@staticmethod
def peak_memory_comparison(seq_len: int, chunk_size: int,
head_dim: int = 128) -> dict:
"""
Compare peak memory usage between full and chunked prefill.
The dominant memory is the attention score matrix.
"""
# Full prefill: attention matrix is (seq_len, seq_len) in float32
full_attention_bytes = seq_len * seq_len * 4 # float32
# Chunked prefill: attention matrix is (chunk_size, seq_len) at most
# The last chunk sees all previous tokens
max_chunk_attention = chunk_size * seq_len * 4
return {
"seq_len": seq_len,
"chunk_size": chunk_size,
"full_attention_mb": full_attention_bytes / (1024 * 1024),
"chunked_peak_attention_mb": max_chunk_attention / (1024 * 1024),
"savings_ratio": full_attention_bytes / max(chunk_chunk_attention := chunk_size * seq_len * 4, 1),
}
# =============================================================================
# 4. HYBRID: PAGED + QUANTIZED
# =============================================================================
class HybridKVCache:
"""
Combines paged attention with quantization for maximum memory efficiency.
- Paged allocation eliminates fragmentation
- Quantization reduces per-token storage by ~50%
- Together: can handle 2-4x longer contexts in the same memory
"""
def __init__(self, page_config: PageConfig):
self.page_config = page_config
self.paged = PagedKVCache(page_config)
self.quantized = QuantizedKVCache(
batch_size=page_config.batch_size,
num_heads=page_config.num_heads,
head_dim=page_config.head_dim,
max_seq_len=page_config.num_pages * page_config.block_size,
dtype=page_config.dtype,
)
def reset(self):
self.paged.reset()
self.quantized.reset()
@property
def total_memory_saved(self) -> float:
"""Combined memory savings vs. naive contiguous fp16 cache."""
return self.quantized.memory_savings_vs_fp16
# =============================================================================
# COMPARISON ANALYSIS
# =============================================================================
def compare_strategies(batch_size: int = 4, num_heads: int = 32,
head_dim: int = 128, max_seq_len: int = 4096,
num_layers: int = 32) -> Dict[str, dict]:
"""
Compare memory usage across different KV-cache strategies.
"""
elem_fp16 = 2 # bytes per float16 element
elem_fp32 = 4
elem_int8 = 1
base_tokens = batch_size * num_heads * max_seq_len * head_dim
base_bytes_per_layer = 2 * base_tokens * elem_fp16 # K + V
results = {}
# 1. Naive contiguous fp16
results["naive_fp16"] = {
"description": "Contiguous fp16 cache",
"per_layer_mb": base_bytes_per_layer / (1024 * 1024),
"total_mb": base_bytes_per_layer * num_layers / (1024 * 1024),
"per_token_per_layer_bytes": 2 * num_heads * head_dim * elem_fp16,
}
# 2. Contiguous fp32
base_bytes_fp32 = 2 * base_tokens * elem_fp32
results["naive_fp32"] = {
"description": "Contiguous fp32 cache",
"per_layer_mb": base_bytes_fp32 / (1024 * 1024),
"total_mb": base_bytes_fp32 * num_layers / (1024 * 1024),
"per_token_per_layer_bytes": 2 * num_heads * head_dim * elem_fp32,
}
# 3. Quantized int8 (with fp16 scales)
# Per-token: int8 data + shared fp16 scales per channel
quant_data = base_tokens * elem_int8 * 2 # K + V int8
quant_scales = batch_size * num_heads * head_dim * elem_fp16 * 2 # shared scales
quant_total = quant_data + quant_scales
results["quantized_int8"] = {
"description": "Int8 quantized with fp16 scales",
"per_layer_mb": quant_total / (1024 * 1024),
"total_mb": quant_total * num_layers / (1024 * 1024),
"savings_vs_fp16": 1.0 - quant_total / base_bytes_per_layer,
}
# 4. Paged (no fragmentation waste)
block_size = 16
blocks_per_seq = (max_seq_len + block_size - 1) // block_size
# Paged has slight overhead from block alignment
padded_tokens = batch_size * blocks_per_seq * block_size * num_heads * head_dim
paged_bytes = 2 * padded_tokens * elem_fp16
results["paged"] = {
"description": "Paged attention (block_size=16)",
"per_layer_mb": paged_bytes / (1024 * 1024),
"total_mb": paged_bytes * num_layers / (1024 * 1024),
"overhead_vs_naive": paged_bytes / base_bytes_per_layer,
}
# 5. Paged + Quantized
paged_quant_data = padded_tokens * elem_int8 * 2
paged_quant_scales = batch_size * num_heads * head_dim * elem_fp16 * 2
paged_quant_total = paged_quant_data + paged_quant_scales
results["paged_quantized"] = {
"description": "Paged + int8 quantized",
"per_layer_mb": paged_quant_total / (1024 * 1024),
"total_mb": paged_quant_total * num_layers / (1024 * 1024),
"savings_vs_fp16": 1.0 - paged_quant_total / base_bytes_per_layer,
}
return results