8e72eef09c
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
590 lines
22 KiB
Python
590 lines
22 KiB
Python
"""
|
|
KV-Cache Optimizations
|
|
|
|
Implements three major optimization strategies:
|
|
1. Paged Attention — non-contiguous memory allocation (inspired by vLLM)
|
|
2. Quantization — reduced precision for cached K/V
|
|
3. Chunked Prefill — processing long prompts in chunks to limit peak memory
|
|
"""
|
|
|
|
import numpy as np
|
|
from typing import Optional, Tuple, List, Dict
|
|
from dataclasses import dataclass, field
|
|
from kv_cache import CacheConfig
|
|
|
|
|
|
# =============================================================================
|
|
# 1. PAGED ATTENTION
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class PageConfig:
|
|
"""Configuration for paged KV cache."""
|
|
block_size: int = 16 # tokens per block
|
|
num_pages: int = 256 # total pages per sequence
|
|
batch_size: int = 4
|
|
num_heads: int = 32
|
|
head_dim: int = 128
|
|
dtype: np.dtype = np.float16
|
|
|
|
|
|
class PagedKVCache:
|
|
"""
|
|
Paged KV Cache — inspired by vLLM's PagedAttention.
|
|
|
|
Instead of a contiguous [batch, heads, max_seq, head_dim] buffer,
|
|
memory is divided into fixed-size blocks (pages). Each sequence
|
|
maintains a page table mapping logical block indices to physical pages.
|
|
|
|
Benefits:
|
|
- Zero memory fragmentation: blocks are allocated on demand
|
|
- Supports speculative decoding and branching
|
|
- Enables sharing of common prefixes (prefix caching)
|
|
- No need to pre-allocate max_seq_len
|
|
|
|
Memory layout:
|
|
physical_pages: (num_pages, batch_size, num_heads, block_size, head_dim) [for K]
|
|
physical_pages_v: same shape [for V]
|
|
page_tables: (batch_size, max_blocks) — maps logical block -> physical page index
|
|
"""
|
|
|
|
def __init__(self, config: PageConfig):
|
|
self.config = config
|
|
self.batch_size = config.batch_size
|
|
self.num_heads = config.num_heads
|
|
self.head_dim = config.head_dim
|
|
self.block_size = config.block_size
|
|
self.num_pages = config.num_pages
|
|
self.dtype = config.dtype
|
|
|
|
# Physical page pool (shared across all sequences)
|
|
# Each page holds: (num_heads, block_size, head_dim)
|
|
page_shape = (config.num_pages * config.batch_size,
|
|
config.num_heads, config.block_size, config.head_dim)
|
|
self.physical_pages_k = np.zeros(page_shape, dtype=self.dtype)
|
|
self.physical_pages_v = np.zeros(page_shape, dtype=self.dtype)
|
|
|
|
# Page table per sequence: logical_block_idx -> physical_page_idx
|
|
max_blocks = config.num_pages
|
|
self.page_tables = np.full(
|
|
(config.batch_size, max_blocks), -1, dtype=np.int32
|
|
)
|
|
|
|
# Number of allocated blocks per sequence
|
|
self.num_blocks = np.zeros(config.batch_size, dtype=np.int32)
|
|
|
|
# Free page pool (global, shared)
|
|
total_pages = config.num_pages * config.batch_size
|
|
self.free_list = np.arange(total_pages, dtype=np.int32)
|
|
self.free_ptr = 0 # index into free_list
|
|
|
|
def _alloc_page(self) -> int:
|
|
"""Allocate one physical page from the free pool."""
|
|
if self.free_ptr >= len(self.free_list):
|
|
raise MemoryError("Paged KV cache out of memory")
|
|
page_idx = self.free_list[self.free_ptr]
|
|
self.free_ptr += 1
|
|
return page_idx
|
|
|
|
def _free_page(self, page_idx: int):
|
|
"""Return a physical page to the free pool."""
|
|
self.free_list[self.free_ptr - 1] = page_idx
|
|
self.free_ptr -= 1
|
|
|
|
def reset(self):
|
|
"""Reset cache for a new generation."""
|
|
self.physical_pages_k[...] = 0
|
|
self.physical_pages_v[...] = 0
|
|
self.page_tables[...] = -1
|
|
self.num_blocks[...] = 0
|
|
self.free_ptr = 0
|
|
|
|
def append_token(self, batch_idx: int, keys: np.ndarray,
|
|
values: np.ndarray, logical_block: int,
|
|
offset_in_block: int):
|
|
"""
|
|
Append one token to a specific logical block.
|
|
|
|
Args:
|
|
batch_idx: batch item index
|
|
keys: (1, num_heads, 1, head_dim)
|
|
values: (1, num_heads, 1, head_dim)
|
|
logical_block: which logical block to write to
|
|
offset_in_block: position within the block (0..block_size-1)
|
|
"""
|
|
# Check if physical page is allocated for this logical block
|
|
phys_page = self.page_tables[batch_idx, logical_block]
|
|
|
|
if phys_page == -1:
|
|
# Allocate new physical page
|
|
phys_page = self._alloc_page()
|
|
self.page_tables[batch_idx, logical_block] = phys_page
|
|
if logical_block + 1 > self.num_blocks[batch_idx]:
|
|
self.num_blocks[batch_idx] = logical_block + 1
|
|
|
|
# Write to physical page
|
|
self.physical_pages_k[phys_page, :, offset_in_block, :] = keys[0, :, 0, :]
|
|
self.physical_pages_v[phys_page, :, offset_in_block, :] = values[0, :, 0, :]
|
|
|
|
def get_sequence(self, batch_idx: int,
|
|
start_block: int = 0,
|
|
end_block: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Retrieve K and V for a sequence, gathering from physical pages.
|
|
|
|
Returns:
|
|
k: (num_heads, total_tokens, head_dim)
|
|
v: (num_heads, total_tokens, head_dim)
|
|
"""
|
|
if end_block is None:
|
|
end_block = self.num_blocks[batch_idx]
|
|
|
|
blocks = end_block - start_block
|
|
total_tokens = blocks * self.block_size
|
|
|
|
k_out = np.zeros(
|
|
(self.num_heads, total_tokens, self.head_dim), dtype=self.dtype
|
|
)
|
|
v_out = np.zeros(
|
|
(self.num_heads, total_tokens, self.head_dim), dtype=self.dtype
|
|
)
|
|
|
|
for i in range(start_block, end_block):
|
|
phys_page = self.page_tables[batch_idx, i]
|
|
if phys_page == -1:
|
|
break
|
|
block_idx = i - start_block
|
|
token_start = block_idx * self.block_size
|
|
token_end = token_start + self.block_size
|
|
k_out[:, token_start:token_end, :] = self.physical_pages_k[phys_page]
|
|
v_out[:, token_start:token_end, :] = self.physical_pages_v[phys_page]
|
|
|
|
return k_out, v_out
|
|
|
|
def get_sequence_contiguous(self, batch_idx: int,
|
|
num_tokens: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Get K, V as contiguous arrays for attention computation.
|
|
|
|
Returns:
|
|
k: (1, num_heads, num_tokens, head_dim)
|
|
v: (1, num_heads, num_tokens, head_dim)
|
|
"""
|
|
if num_tokens is None:
|
|
num_tokens = self.num_blocks[batch_idx] * self.block_size
|
|
|
|
k, v = self.get_sequence(batch_idx)
|
|
# k: (num_heads, num_tokens, head_dim) -> (1, num_heads, num_tokens, head_dim)
|
|
return k[None, ...], v[None, ...]
|
|
|
|
@property
|
|
def memory_allocated_bytes(self) -> int:
|
|
elem_bytes = np.dtype(self.dtype).itemsize
|
|
total_pages = self.num_pages * self.batch_size
|
|
page_bytes = self.num_heads * self.block_size * self.head_dim * elem_bytes
|
|
return 2 * total_pages * page_bytes # K + V
|
|
|
|
@property
|
|
def memory_used_bytes(self) -> int:
|
|
"""Bytes actually used (allocated blocks only)."""
|
|
elem_bytes = np.dtype(self.dtype).itemsize
|
|
total_used_blocks = np.sum(self.num_blocks)
|
|
page_bytes = self.num_heads * self.block_size * self.head_dim * elem_bytes
|
|
return 2 * total_used_blocks * page_bytes
|
|
|
|
def memory_utilization(self) -> float:
|
|
"""Fraction of allocated memory actually used."""
|
|
alloc = self.memory_allocated_bytes
|
|
if alloc == 0:
|
|
return 0.0
|
|
return self.memory_used_bytes / alloc
|
|
|
|
|
|
# =============================================================================
|
|
# 2. QUANTIZED KV CACHE
|
|
# =============================================================================
|
|
|
|
class QuantizedKVCache:
|
|
"""
|
|
Quantized KV Cache — stores K and V in reduced precision.
|
|
|
|
Strategy: per-channel (per-head-dim) int8 quantization.
|
|
- Each head-dimension channel has its own scale and zero-point
|
|
- Dequantize on-the-fly during attention computation
|
|
|
|
Memory savings: float16 (16-bit) -> int8 (8-bit) = 2x reduction
|
|
Plus metadata overhead: 2 scales per channel (K and V) in float16
|
|
|
|
For head_dim=128:
|
|
- Original: 128 * 16 = 2048 bits per token per head
|
|
- Quantized: 128 * 8 + 2 * 128 * 16 = 1024 + 4096 = 5120 bits
|
|
- But scales are shared across all tokens, so per-token: 128 * 8 = 1024 bits
|
|
- Net savings: ~50%
|
|
"""
|
|
|
|
def __init__(self, batch_size: int, num_heads: int, head_dim: int,
|
|
max_seq_len: int, dtype=np.float16):
|
|
self.batch_size = batch_size
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.max_seq_len = max_seq_len
|
|
self.dtype = dtype
|
|
self.write_pos = 0
|
|
|
|
# Quantized storage: int8
|
|
shape = (batch_size, num_heads, max_seq_len, head_dim)
|
|
self.cache_k_int8 = np.zeros(shape, dtype=np.int8)
|
|
self.cache_v_int8 = np.zeros(shape, dtype=np.int8)
|
|
|
|
# Per-channel scales and zero-points per position
|
|
scale_shape = (batch_size, num_heads, max_seq_len, head_dim)
|
|
self.k_scales = np.ones(scale_shape, dtype=dtype)
|
|
self.k_zeros = np.zeros(scale_shape, dtype=dtype)
|
|
self.v_scales = np.ones(scale_shape, dtype=dtype)
|
|
self.v_zeros = np.zeros(scale_shape, dtype=dtype)
|
|
|
|
def reset(self):
|
|
self.cache_k_int8[...] = 0
|
|
self.cache_v_int8[...] = 0
|
|
self.k_scales[...] = 1.0
|
|
self.k_zeros[...] = 0.0
|
|
self.v_scales[...] = 1.0
|
|
self.v_zeros[...] = 0.0
|
|
self.write_pos = 0
|
|
|
|
def _quantize(self, x: np.ndarray, axis: int = -1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
"""
|
|
Quantize to int8 with per-channel affine transform: x ≈ scale * q + zero.
|
|
|
|
Returns quantized values, scales, and zero-points.
|
|
"""
|
|
x_f = x.astype(np.float32)
|
|
# Per-channel min/max
|
|
x_min = np.min(x_f, axis=axis, keepdims=True)
|
|
x_max = np.max(x_f, axis=axis, keepdims=True)
|
|
|
|
# Avoid division by zero
|
|
x_range = x_max - x_min
|
|
x_range = np.where(x_range < 1e-6, 1.0, x_range)
|
|
|
|
# Scale: map [-128, 127] to [x_min, x_max]
|
|
scale = x_range / 255.0
|
|
zero = x_min # zero-point
|
|
|
|
# Quantize
|
|
x_centered = x_f - zero
|
|
x_quant = np.round(x_centered / scale).astype(np.int8)
|
|
x_quant = np.clip(x_quant, -128, 127)
|
|
|
|
return x_quant, scale.astype(self.dtype), zero.astype(self.dtype)
|
|
|
|
def _dequantize(self, x_int8: np.ndarray, scale: np.ndarray,
|
|
zero: np.ndarray) -> np.ndarray:
|
|
"""Dequantize int8 back to float: x = scale * q + zero."""
|
|
return (x_int8.astype(np.float32) * scale + zero).astype(self.dtype)
|
|
|
|
def update(self, keys: np.ndarray, values: np.ndarray,
|
|
seqlen_offset: int = None):
|
|
"""
|
|
Quantize and store K, V.
|
|
|
|
Args:
|
|
keys: (batch, heads, 1, head_dim)
|
|
values: (batch, heads, 1, head_dim)
|
|
"""
|
|
if seqlen_offset is None:
|
|
seqlen_offset = self.write_pos
|
|
|
|
pos = seqlen_offset
|
|
|
|
# Quantize K
|
|
k_q, k_s, k_z = self._quantize(keys, axis=-1)
|
|
self.cache_k_int8[:, :, pos, :] = k_q[:, :, 0, :]
|
|
self.k_scales[:, :, pos:pos+1, :] = k_s
|
|
self.k_zeros[:, :, pos:pos+1, :] = k_z
|
|
|
|
# Quantize V
|
|
v_q, v_s, v_z = self._quantize(values, axis=-1)
|
|
self.cache_v_int8[:, :, pos, :] = v_q[:, :, 0, :]
|
|
self.v_scales[:, :, pos:pos+1, :] = v_s
|
|
self.v_zeros[:, :, pos:pos+1, :] = v_z
|
|
|
|
self.write_pos = pos + 1
|
|
|
|
def get(self, start: int = 0, end: int = None) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""Get dequantized K, V."""
|
|
if end is None:
|
|
end = self.write_pos
|
|
|
|
k_int = self.cache_k_int8[:, :, start:end, :]
|
|
v_int = self.cache_v_int8[:, :, start:end, :]
|
|
|
|
# Dequantize using scales and zero-points from each position
|
|
k_deq = self._dequantize(k_int, self.k_scales[:, :, start:end, :],
|
|
self.k_zeros[:, :, start:end, :])
|
|
v_deq = self._dequantize(v_int, self.v_scales[:, :, start:end, :],
|
|
self.v_zeros[:, :, start:end, :])
|
|
|
|
return k_deq, v_deq
|
|
|
|
@property
|
|
def memory_allocated_bytes(self) -> int:
|
|
"""Total allocated memory including quantization metadata.
|
|
|
|
Includes: int8 K + int8 V + fp scales (K+V) + fp zero-points (K+V)
|
|
"""
|
|
elem_int8 = np.dtype(np.int8).itemsize
|
|
elem_fp = np.dtype(self.dtype).itemsize
|
|
n = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
|
|
k_v_bytes = 2 * n * elem_int8 # int8 K + V
|
|
meta_bytes = 4 * n * elem_fp # scales + zeros for K and V
|
|
return k_v_bytes + meta_bytes
|
|
|
|
@property
|
|
def memory_savings_vs_fp16(self) -> float:
|
|
"""Fraction of memory saved vs. full fp16 cache.
|
|
|
|
Note: with per-position scales in fp32, this may be negative.
|
|
For real savings, use fp16 scales or shared (per-channel) scales.
|
|
"""
|
|
elem_fp16 = np.dtype(np.float16).itemsize
|
|
fp16_bytes = 2 * self.batch_size * self.num_heads * self.max_seq_len * self.head_dim * elem_fp16
|
|
return 1.0 - self.memory_allocated_bytes / fp16_bytes
|
|
|
|
@property
|
|
def memory_savings_vs_fp32(self) -> float:
|
|
"""Fraction of memory saved vs. full fp32 cache."""
|
|
elem_fp32 = np.dtype(np.float32).itemsize
|
|
fp32_bytes = 2 * self.batch_size * self.num_heads * self.max_seq_len * self.head_dim * elem_fp32
|
|
return 1.0 - self.memory_allocated_bytes / fp32_bytes
|
|
|
|
|
|
# =============================================================================
|
|
# 3. CHUNKED PREFILL
|
|
# =============================================================================
|
|
|
|
class ChunkedPrefill:
|
|
"""
|
|
Chunked Prefill — process long prompts in chunks to limit peak memory.
|
|
|
|
During prefill with very long prompts (e.g., 32K tokens), computing
|
|
full attention O(n²) requires materializing a (n, n) attention matrix,
|
|
which can exceed GPU memory.
|
|
|
|
Chunked prefill processes the prompt in chunks of size C:
|
|
- Chunk 0: tokens [0, C) — full causal attention within chunk
|
|
- Chunk 1: tokens [C, 2C) — attend to all previous tokens + causal within chunk
|
|
- ...
|
|
|
|
Each chunk's attention is O(C * (i*C + C)) = O(i*C²), but the peak
|
|
memory for the attention matrix is O(C²) instead of O(n²).
|
|
|
|
The KV cache is updated incrementally after each chunk.
|
|
"""
|
|
|
|
def __init__(self, chunk_size: int = 512):
|
|
self.chunk_size = chunk_size
|
|
|
|
def compute_attention_chunked(
|
|
self,
|
|
q_all: np.ndarray,
|
|
k_all: np.ndarray,
|
|
v_all: np.ndarray,
|
|
scale: float,
|
|
dtype=np.float32,
|
|
) -> np.ndarray:
|
|
"""
|
|
Compute causal attention in chunks.
|
|
|
|
Args:
|
|
q_all: (batch, heads, seq, head_dim)
|
|
k_all: (batch, heads, seq, head_dim)
|
|
v_all: (batch, heads, seq, head_dim)
|
|
scale: 1 / sqrt(head_dim)
|
|
|
|
Returns:
|
|
output: (batch, heads, seq, head_dim)
|
|
"""
|
|
batch, heads, seq, head_dim = q_all.shape
|
|
output = np.zeros((batch, heads, seq, head_dim), dtype=dtype)
|
|
|
|
num_chunks = (seq + self.chunk_size - 1) // self.chunk_size
|
|
|
|
for chunk_idx in range(num_chunks):
|
|
start = chunk_idx * self.chunk_size
|
|
end = min(start + self.chunk_size, seq)
|
|
chunk_len = end - start
|
|
|
|
# Current chunk's Q
|
|
q_chunk = q_all[:, :, start:end, :] # (batch, heads, chunk_len, head_dim)
|
|
|
|
# Keys and values up to current position (causal)
|
|
k_prefix = k_all[:, :, :end, :] # (batch, heads, end, head_dim)
|
|
v_prefix = v_all[:, :, :end, :]
|
|
|
|
q_f = q_chunk.astype(dtype)
|
|
k_f = k_prefix.astype(dtype)
|
|
v_f = v_prefix.astype(dtype)
|
|
|
|
# Q @ K^T: (batch, heads, chunk_len, end)
|
|
scores = np.einsum("bhqd,bhkd->bhqk", q_f, k_f) * scale
|
|
|
|
# Causal mask: query at position p can only attend to keys at position <= p
|
|
# Query positions (absolute): start..end-1
|
|
# Key positions (absolute): 0..end-1
|
|
q_positions = np.arange(start, end) # (chunk_len,)
|
|
k_positions = np.arange(end) # (end,)
|
|
# Allowed: q_pos >= k_pos (causal)
|
|
causal_mask = (q_positions[:, None] >= k_positions[None, :]).astype(dtype)
|
|
# (chunk_len, end)
|
|
causal_mask = np.where(causal_mask, 0.0, -np.inf)
|
|
|
|
scores = scores + causal_mask[None, None, :, :]
|
|
|
|
# Softmax
|
|
attn_weights = self._softmax_stable(scores, axis=-1)
|
|
|
|
# Attn @ V
|
|
chunk_output = np.einsum("bhqk,bhkd->bhqd", attn_weights, v_f)
|
|
output[:, :, start:end, :] = chunk_output
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def _softmax_stable(x: np.ndarray, axis: int = -1) -> np.ndarray:
|
|
x_max = np.max(x, axis=axis, keepdims=True)
|
|
exp_x = np.exp(x - x_max)
|
|
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
|
|
|
@staticmethod
|
|
def peak_memory_comparison(seq_len: int, chunk_size: int,
|
|
head_dim: int = 128) -> dict:
|
|
"""
|
|
Compare peak memory usage between full and chunked prefill.
|
|
|
|
The dominant memory is the attention score matrix.
|
|
"""
|
|
# Full prefill: attention matrix is (seq_len, seq_len) in float32
|
|
full_attention_bytes = seq_len * seq_len * 4 # float32
|
|
|
|
# Chunked prefill: attention matrix is (chunk_size, seq_len) at most
|
|
# The last chunk sees all previous tokens
|
|
max_chunk_attention = chunk_size * seq_len * 4
|
|
|
|
return {
|
|
"seq_len": seq_len,
|
|
"chunk_size": chunk_size,
|
|
"full_attention_mb": full_attention_bytes / (1024 * 1024),
|
|
"chunked_peak_attention_mb": max_chunk_attention / (1024 * 1024),
|
|
"savings_ratio": full_attention_bytes / max(chunk_chunk_attention := chunk_size * seq_len * 4, 1),
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# 4. HYBRID: PAGED + QUANTIZED
|
|
# =============================================================================
|
|
|
|
class HybridKVCache:
|
|
"""
|
|
Combines paged attention with quantization for maximum memory efficiency.
|
|
|
|
- Paged allocation eliminates fragmentation
|
|
- Quantization reduces per-token storage by ~50%
|
|
- Together: can handle 2-4x longer contexts in the same memory
|
|
"""
|
|
|
|
def __init__(self, page_config: PageConfig):
|
|
self.page_config = page_config
|
|
self.paged = PagedKVCache(page_config)
|
|
self.quantized = QuantizedKVCache(
|
|
batch_size=page_config.batch_size,
|
|
num_heads=page_config.num_heads,
|
|
head_dim=page_config.head_dim,
|
|
max_seq_len=page_config.num_pages * page_config.block_size,
|
|
dtype=page_config.dtype,
|
|
)
|
|
|
|
def reset(self):
|
|
self.paged.reset()
|
|
self.quantized.reset()
|
|
|
|
@property
|
|
def total_memory_saved(self) -> float:
|
|
"""Combined memory savings vs. naive contiguous fp16 cache."""
|
|
return self.quantized.memory_savings_vs_fp16
|
|
|
|
|
|
# =============================================================================
|
|
# COMPARISON ANALYSIS
|
|
# =============================================================================
|
|
|
|
def compare_strategies(batch_size: int = 4, num_heads: int = 32,
|
|
head_dim: int = 128, max_seq_len: int = 4096,
|
|
num_layers: int = 32) -> Dict[str, dict]:
|
|
"""
|
|
Compare memory usage across different KV-cache strategies.
|
|
"""
|
|
elem_fp16 = 2 # bytes per float16 element
|
|
elem_fp32 = 4
|
|
elem_int8 = 1
|
|
|
|
base_tokens = batch_size * num_heads * max_seq_len * head_dim
|
|
base_bytes_per_layer = 2 * base_tokens * elem_fp16 # K + V
|
|
|
|
results = {}
|
|
|
|
# 1. Naive contiguous fp16
|
|
results["naive_fp16"] = {
|
|
"description": "Contiguous fp16 cache",
|
|
"per_layer_mb": base_bytes_per_layer / (1024 * 1024),
|
|
"total_mb": base_bytes_per_layer * num_layers / (1024 * 1024),
|
|
"per_token_per_layer_bytes": 2 * num_heads * head_dim * elem_fp16,
|
|
}
|
|
|
|
# 2. Contiguous fp32
|
|
base_bytes_fp32 = 2 * base_tokens * elem_fp32
|
|
results["naive_fp32"] = {
|
|
"description": "Contiguous fp32 cache",
|
|
"per_layer_mb": base_bytes_fp32 / (1024 * 1024),
|
|
"total_mb": base_bytes_fp32 * num_layers / (1024 * 1024),
|
|
"per_token_per_layer_bytes": 2 * num_heads * head_dim * elem_fp32,
|
|
}
|
|
|
|
# 3. Quantized int8 (with fp16 scales)
|
|
# Per-token: int8 data + shared fp16 scales per channel
|
|
quant_data = base_tokens * elem_int8 * 2 # K + V int8
|
|
quant_scales = batch_size * num_heads * head_dim * elem_fp16 * 2 # shared scales
|
|
quant_total = quant_data + quant_scales
|
|
results["quantized_int8"] = {
|
|
"description": "Int8 quantized with fp16 scales",
|
|
"per_layer_mb": quant_total / (1024 * 1024),
|
|
"total_mb": quant_total * num_layers / (1024 * 1024),
|
|
"savings_vs_fp16": 1.0 - quant_total / base_bytes_per_layer,
|
|
}
|
|
|
|
# 4. Paged (no fragmentation waste)
|
|
block_size = 16
|
|
blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
|
# Paged has slight overhead from block alignment
|
|
padded_tokens = batch_size * blocks_per_seq * block_size * num_heads * head_dim
|
|
paged_bytes = 2 * padded_tokens * elem_fp16
|
|
results["paged"] = {
|
|
"description": "Paged attention (block_size=16)",
|
|
"per_layer_mb": paged_bytes / (1024 * 1024),
|
|
"total_mb": paged_bytes * num_layers / (1024 * 1024),
|
|
"overhead_vs_naive": paged_bytes / base_bytes_per_layer,
|
|
}
|
|
|
|
# 5. Paged + Quantized
|
|
paged_quant_data = padded_tokens * elem_int8 * 2
|
|
paged_quant_scales = batch_size * num_heads * head_dim * elem_fp16 * 2
|
|
paged_quant_total = paged_quant_data + paged_quant_scales
|
|
results["paged_quantized"] = {
|
|
"description": "Paged + int8 quantized",
|
|
"per_layer_mb": paged_quant_total / (1024 * 1024),
|
|
"total_mb": paged_quant_total * num_layers / (1024 * 1024),
|
|
"savings_vs_fp16": 1.0 - paged_quant_total / base_bytes_per_layer,
|
|
}
|
|
|
|
return results
|