8e72eef09c
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
194 lines
6.6 KiB
Python
194 lines
6.6 KiB
Python
"""
|
|
KV-Cache Data Structures for Autoregressive Transformer Inference
|
|
|
|
Core memory layout:
|
|
cache_k[batch, head, seq_len, head_dim]
|
|
cache_v[batch, head, seq_len, head_dim]
|
|
|
|
This layout enables O(1) append per token and contiguous memory access
|
|
during attention computation (Q @ K^T scans along seq_len).
|
|
"""
|
|
|
|
import numpy as np
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
@dataclass
|
|
class CacheConfig:
|
|
"""Configuration for a single layer's KV cache."""
|
|
batch_size: int
|
|
num_heads: int
|
|
head_dim: int
|
|
max_seq_len: int
|
|
dtype: np.dtype = np.float16
|
|
|
|
@property
|
|
def cache_bytes_per_layer(self) -> int:
|
|
"""Bytes for one layer's K + V cache."""
|
|
elem_bytes = np.dtype(self.dtype).itemsize
|
|
one_side = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
|
|
return 2 * one_side * elem_bytes # K + V
|
|
|
|
@property
|
|
def cache_bytes_per_layer_per_token(self) -> int:
|
|
"""Bytes consumed per generated token per layer."""
|
|
elem_bytes = np.dtype(self.dtype).itemsize
|
|
return 2 * self.num_heads * self.head_dim * elem_bytes
|
|
|
|
|
|
class KVCache:
|
|
"""
|
|
Standard contiguous KV cache for one transformer layer.
|
|
|
|
Memory layout (row-major / C-contiguous):
|
|
cache_k: (batch, num_heads, max_seq_len, head_dim)
|
|
cache_v: (batch, num_heads, max_seq_len, head_dim)
|
|
|
|
Why this layout:
|
|
- batch first: enables batched GEMM on GPU
|
|
- head second: allows parallel head computation
|
|
- seq_len third: contiguous scan for Q @ K^T
|
|
- head_dim last: inner product dimension
|
|
|
|
The cache is pre-allocated to max_seq_len. A `lengths` array tracks
|
|
actual sequence lengths per batch item (for variable-length batching).
|
|
"""
|
|
|
|
def __init__(self, config: CacheConfig):
|
|
self.config = config
|
|
self.batch_size = config.batch_size
|
|
self.num_heads = config.num_heads
|
|
self.head_dim = config.head_dim
|
|
self.max_seq_len = config.max_seq_len
|
|
self.dtype = config.dtype
|
|
|
|
# Pre-allocate full buffers (zero-initialized)
|
|
shape = (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim)
|
|
self.cache_k = np.zeros(shape, dtype=self.dtype)
|
|
self.cache_v = np.zeros(shape, dtype=self.dtype)
|
|
|
|
# Per-batch-item current sequence length
|
|
self.lengths = np.zeros(self.batch_size, dtype=np.int32)
|
|
|
|
# Write pointer: next position to write into
|
|
self.write_pos = 0
|
|
|
|
def reset(self):
|
|
"""Clear the cache for a new generation."""
|
|
self.cache_k[...] = 0
|
|
self.cache_v[...] = 0
|
|
self.lengths[...] = 0
|
|
self.write_pos = 0
|
|
|
|
def update(self, keys: np.ndarray, values: np.ndarray,
|
|
seqlen_offset: int = None) -> None:
|
|
"""
|
|
Append newly computed K and V to the cache.
|
|
|
|
Args:
|
|
keys: (batch, num_heads, 1, head_dim) — current step's K
|
|
values: (batch, num_heads, 1, head_dim) — current step's V
|
|
seqlen_offset: optional explicit write position (defaults to self.write_pos)
|
|
|
|
The write position advances by 1 each call during generation.
|
|
For the initial prompt, seqlen_offset=0 and we write all prompt tokens.
|
|
"""
|
|
if seqlen_offset is None:
|
|
seqlen_offset = self.write_pos
|
|
|
|
pos = seqlen_offset
|
|
self.cache_k[:, :, pos, :] = keys[:, :, 0, :]
|
|
self.cache_v[:, :, pos, :] = values[:, :, 0, :]
|
|
|
|
# Update per-batch-item lengths
|
|
for b in range(self.batch_size):
|
|
self.lengths[b] = pos + 1
|
|
|
|
self.write_pos = pos + 1
|
|
|
|
def get(self, start: int = 0, end: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Retrieve cached K and V slices.
|
|
|
|
Returns:
|
|
k: (batch, num_heads, end-start, head_dim)
|
|
v: (batch, num_heads, end-start, head_dim)
|
|
"""
|
|
if end is None:
|
|
end = self.write_pos
|
|
return (
|
|
self.cache_k[:, :, start:end, :],
|
|
self.cache_v[:, :, start:end, :],
|
|
)
|
|
|
|
def get_all(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""Get all cached tokens so far (up to write_pos)."""
|
|
return self.get(0, self.write_pos)
|
|
|
|
@property
|
|
def memory_used_bytes(self) -> int:
|
|
"""Actual bytes used (based on write_pos, not max allocation)."""
|
|
elem_bytes = np.dtype(self.dtype).itemsize
|
|
tokens = self.write_pos
|
|
return 2 * self.batch_size * self.num_heads * tokens * self.head_dim * elem_bytes
|
|
|
|
@property
|
|
def memory_allocated_bytes(self) -> int:
|
|
"""Total pre-allocated bytes."""
|
|
return self.config.cache_bytes_per_layer
|
|
|
|
|
|
class BatchedKVCache:
|
|
"""
|
|
Manages KV caches across all layers of a transformer.
|
|
|
|
In a real model with L layers, we need L separate KV caches.
|
|
This class coordinates them and handles variable-length batching.
|
|
"""
|
|
|
|
def __init__(self, num_layers: int, config: CacheConfig):
|
|
self.num_layers = num_layers
|
|
self.config = config
|
|
self.caches = [KVCache(config) for _ in range(num_layers)]
|
|
|
|
def reset(self):
|
|
for cache in self.caches:
|
|
cache.reset()
|
|
|
|
def update(self, layer_idx: int, keys: np.ndarray, values: np.ndarray,
|
|
seqlen_offset: int = None):
|
|
self.caches[layer_idx].update(keys, values, seqlen_offset)
|
|
|
|
def get(self, layer_idx: int, start: int = 0, end: int = None):
|
|
return self.caches[layer_idx].get(start, end)
|
|
|
|
@property
|
|
def total_memory_allocated_bytes(self) -> int:
|
|
return sum(c.memory_allocated_bytes for c in self.caches)
|
|
|
|
@property
|
|
def total_memory_used_bytes(self) -> int:
|
|
return sum(c.memory_used_bytes for c in self.caches)
|
|
|
|
def memory_report(self) -> dict:
|
|
"""Detailed memory breakdown."""
|
|
elem_bytes = self.config.dtype.itemsize
|
|
tokens = self.caches[0].write_pos if self.caches else 0
|
|
per_layer = self.config.cache_bytes_per_layer
|
|
per_token_per_layer = self.config.cache_bytes_per_layer_per_token
|
|
|
|
return {
|
|
"num_layers": self.num_layers,
|
|
"batch_size": self.config.batch_size,
|
|
"num_heads": self.config.num_heads,
|
|
"head_dim": self.config.head_dim,
|
|
"max_seq_len": self.config.max_seq_len,
|
|
"dtype": str(self.config.dtype),
|
|
"tokens_generated": tokens,
|
|
"per_layer_allocated_mb": per_layer / (1024 * 1024),
|
|
"total_allocated_mb": self.total_memory_allocated_bytes / (1024 * 1024),
|
|
"total_used_mb": self.total_memory_used_bytes / (1024 * 1024),
|
|
"growth_per_token_mb": (per_token_per_layer * self.num_layers) / (1024 * 1024),
|
|
}
|