8e72eef09c
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
268 lines
9.1 KiB
Python
268 lines
9.1 KiB
Python
"""
|
|
Memory Growth Analysis for KV-Cache Systems
|
|
|
|
Analyzes how memory consumption scales with:
|
|
- Sequence length
|
|
- Batch size
|
|
- Number of heads
|
|
- Model dimension
|
|
- Number of layers
|
|
|
|
Provides formulas, visualizations, and practical limits.
|
|
"""
|
|
|
|
import numpy as np
|
|
from typing import Dict, List, Tuple
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class ModelSpec:
|
|
"""Specification of a transformer model."""
|
|
num_layers: int
|
|
dim: int
|
|
num_heads: int
|
|
head_dim: int
|
|
vocab_size: int = 32000
|
|
mlp_hidden_mult: float = 4.0 / 3 # GPT-style
|
|
|
|
|
|
def compute_model_memory(spec: ModelSpec, dtype=np.float16) -> Dict[str, float]:
|
|
"""
|
|
Compute total model parameter memory.
|
|
|
|
Per layer:
|
|
- Wq, Wk, Wv: 3 * dim * dim
|
|
- Wo: dim * dim
|
|
- MLP fc1: dim * hidden
|
|
- MLP fc2: hidden * dim
|
|
- LayerNorm: 2 * dim (weight + bias)
|
|
- Embedding: vocab_size * dim (shared with LM head)
|
|
|
|
Total per layer (excluding shared embedding):
|
|
4 * dim² + 2 * dim * hidden + 2 * dim
|
|
"""
|
|
elem = np.dtype(dtype).itemsize
|
|
hidden = int(spec.dim * spec.mlp_hidden_mult)
|
|
|
|
per_layer = (
|
|
4 * spec.dim * spec.dim + # Wq, Wk, Wv, Wo
|
|
2 * spec.dim * hidden + # MLP fc1, fc2
|
|
2 * spec.dim # LayerNorm params
|
|
) * elem
|
|
|
|
embedding = spec.vocab_size * spec.dim * elem
|
|
|
|
return {
|
|
"per_layer_bytes": per_layer,
|
|
"per_layer_mb": per_layer / (1024 * 1024),
|
|
"embedding_mb": embedding / (1024 * 1024),
|
|
"total_params_mb": (per_layer * spec.num_layers + embedding) / (1024 * 1024),
|
|
"total_params_gb": (per_layer * spec.num_layers + embedding) / (1024 ** 3),
|
|
}
|
|
|
|
|
|
def compute_kv_cache_memory(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
spec: ModelSpec,
|
|
dtype=np.float16,
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute KV cache memory for a given batch and sequence length.
|
|
|
|
Per layer: 2 * batch * heads * seq * head_dim * elem_bytes
|
|
(factor of 2 for K and V)
|
|
"""
|
|
elem = np.dtype(dtype).itemsize
|
|
per_layer = 2 * batch_size * spec.num_heads * seq_len * spec.head_dim * elem
|
|
total = per_layer * spec.num_layers
|
|
|
|
return {
|
|
"per_layer_bytes": per_layer,
|
|
"per_layer_mb": per_layer / (1024 * 1024),
|
|
"total_bytes": total,
|
|
"total_mb": total / (1024 * 1024),
|
|
"total_gb": total / (1024 ** 3),
|
|
"per_token_per_layer_bytes": 2 * spec.num_heads * spec.head_dim * elem,
|
|
"growth_rate_mb_per_token": (
|
|
2 * batch_size * spec.num_heads * spec.head_dim * elem * spec.num_layers
|
|
) / (1024 * 1024),
|
|
}
|
|
|
|
|
|
def analyze_memory_growth(spec: ModelSpec, batch_sizes: List[int] = None,
|
|
seq_lengths: List[int] = None,
|
|
dtype=np.float16) -> Dict:
|
|
"""
|
|
Comprehensive memory growth analysis.
|
|
|
|
Returns analysis for various batch sizes and sequence lengths.
|
|
"""
|
|
if batch_sizes is None:
|
|
batch_sizes = [1, 2, 4, 8, 16, 32]
|
|
if seq_lengths is None:
|
|
seq_lengths = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
|
|
|
model_mem = compute_model_memory(spec, dtype)
|
|
|
|
results = {
|
|
"model": model_mem,
|
|
"spec": {
|
|
"num_layers": spec.num_layers,
|
|
"dim": spec.dim,
|
|
"num_heads": spec.num_heads,
|
|
"head_dim": spec.head_dim,
|
|
"dtype": str(dtype),
|
|
},
|
|
"kv_cache": {},
|
|
}
|
|
|
|
for bs in batch_sizes:
|
|
for sl in seq_lengths:
|
|
kv = compute_kv_cache_memory(bs, sl, spec, dtype)
|
|
key = f"bs{bs}_sl{sl}"
|
|
results["kv_cache"][key] = {
|
|
"batch_size": bs,
|
|
"seq_len": sl,
|
|
"kv_cache_gb": kv["total_gb"],
|
|
"total_system_gb": kv["total_gb"] + model_mem["total_params_gb"],
|
|
"kv_fraction": kv["total_gb"] / (kv["total_gb"] + model_mem["total_params_gb"]),
|
|
}
|
|
|
|
return results
|
|
|
|
|
|
def find_max_context(spec: ModelSpec, gpu_memory_gb: float = 80,
|
|
batch_size: int = 1, dtype=np.float16) -> int:
|
|
"""
|
|
Find the maximum context length that fits in GPU memory.
|
|
|
|
GPU memory = model_params + kv_cache + activation_overhead
|
|
|
|
We estimate activation overhead as ~2x model params (conservative).
|
|
"""
|
|
model_mem = compute_model_memory(spec, dtype)
|
|
model_gb = model_mem["total_params_gb"]
|
|
|
|
# Reserve for activations and other overhead (~2x model params)
|
|
activation_gb = model_gb * 2
|
|
|
|
# Remaining for KV cache
|
|
kv_budget_gb = gpu_memory_gb - model_gb - activation_gb
|
|
if kv_budget_gb <= 0:
|
|
return 0
|
|
|
|
elem = np.dtype(dtype).itemsize
|
|
bytes_per_token = (2 * batch_size * spec.num_heads * spec.head_dim * elem *
|
|
spec.num_layers)
|
|
|
|
max_tokens = int(kv_budget_gb * (1024 ** 3) / bytes_per_token)
|
|
return max_tokens
|
|
|
|
|
|
def compare_model_sizes() -> Dict[str, dict]:
|
|
"""
|
|
Analyze memory for several well-known model sizes.
|
|
"""
|
|
models = {
|
|
"Llama-2-7B": ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128),
|
|
"Llama-2-13B": ModelSpec(num_layers=40, dim=5120, num_heads=40, head_dim=128),
|
|
"Llama-2-70B": ModelSpec(num_layers=80, dim=8192, num_heads=64, head_dim=128),
|
|
"Llama-3-8B": ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128),
|
|
"Mistral-7B": ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128),
|
|
"GPT-4-class": ModelSpec(num_layers=100, dim=12288, num_heads=96, head_dim=128),
|
|
}
|
|
|
|
results = {}
|
|
for name, spec in models.items():
|
|
model_mem = compute_model_memory(spec, np.float16)
|
|
|
|
# KV cache for batch=1, various lengths
|
|
kv_1k = compute_kv_cache_memory(1, 1024, spec, np.float16)
|
|
kv_8k = compute_kv_cache_memory(1, 8192, spec, np.float16)
|
|
kv_32k = compute_kv_cache_memory(1, 32768, spec, np.float16)
|
|
|
|
results[name] = {
|
|
"params_gb": model_mem["total_params_gb"],
|
|
"kv_1k_gb": kv_1k["total_gb"],
|
|
"kv_8k_gb": kv_8k["total_gb"],
|
|
"kv_32k_gb": kv_32k["total_gb"],
|
|
"max_context_H100": find_max_context(spec, gpu_memory_gb=80, batch_size=1),
|
|
"max_context_A100_40": find_max_context(spec, gpu_memory_gb=40, batch_size=1),
|
|
"max_context_A100_80": find_max_context(spec, gpu_memory_gb=80, batch_size=1),
|
|
}
|
|
|
|
return results
|
|
|
|
|
|
def print_analysis():
|
|
"""Print a comprehensive memory analysis report."""
|
|
print("=" * 80)
|
|
print("KV-CACHE MEMORY GROWTH ANALYSIS")
|
|
print("=" * 80)
|
|
|
|
# Model size comparison
|
|
print("\n--- Model Size Comparison (fp16) ---\n")
|
|
comparisons = compare_model_sizes()
|
|
header = f"{'Model':<20} {'Params(GB)':>10} {'KV@1K':>10} {'KV@8K':>10} {'KV@32K':>10} {'MaxCtx(H100)':>12}"
|
|
print(header)
|
|
print("-" * len(header))
|
|
for name, data in comparisons.items():
|
|
print(f"{name:<20} {data['params_gb']:>10.1f} {data['kv_1k_gb']:>10.2f} "
|
|
f"{data['kv_8k_gb']:>10.2f} {data['kv_32k_gb']:>10.2f} "
|
|
f"{data['max_context_H100']:>12,d}")
|
|
|
|
# Growth analysis for a 7B model
|
|
print("\n\n--- Detailed Growth: 7B Model (batch=1, fp16) ---\n")
|
|
spec_7b = ModelSpec(num_layers=32, dim=4096, num_heads=32, head_dim=128)
|
|
model_mem = compute_model_memory(spec_7b, np.float16)
|
|
|
|
seq_lens = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
|
print(f"{'Seq Len':>10} {'KV Cache(GB)':>14} {'Total(GB)':>12} {'KV Fraction':>12}")
|
|
print("-" * 52)
|
|
for sl in seq_lens:
|
|
kv = compute_kv_cache_memory(1, sl, spec_7b, np.float16)
|
|
total = kv["total_gb"] + model_mem["total_params_gb"]
|
|
frac = kv["total_gb"] / total
|
|
print(f"{sl:>10,} {kv['total_gb']:>14.2f} {total:>12.2f} {frac:>12.1%}")
|
|
|
|
# Batch size impact
|
|
print("\n\n--- Batch Size Impact (seq_len=4096, fp16) ---\n")
|
|
batch_sizes = [1, 2, 4, 8, 16, 32]
|
|
print(f"{'Batch':>6} {'KV Cache(GB)':>14} {'Growth/Token(MB)':>18}")
|
|
print("-" * 40)
|
|
for bs in batch_sizes:
|
|
kv = compute_kv_cache_memory(bs, 4096, spec_7b, np.float16)
|
|
print(f"{bs:>6} {kv['total_gb']:>14.2f} {kv['growth_rate_mb_per_token']:>18.4f}")
|
|
|
|
# Per-token cost
|
|
print("\n\n--- Per-Token Memory Cost ---\n")
|
|
kv_one = compute_kv_cache_memory(1, 1, spec_7b, np.float16)
|
|
per_token = kv_one["total_bytes"]
|
|
print(f" Per token (all layers): {per_token:,} bytes = {per_token/1024:.1f} KB")
|
|
print(f" Per token per layer: {kv_one['per_token_per_layer_bytes']:,} bytes")
|
|
print(f" At 32K context: {per_token * 32768 / (1024**3):.2f} GB")
|
|
|
|
# GPU memory limits
|
|
print("\n\n--- Maximum Context Lengths by GPU ---\n")
|
|
gpus = {
|
|
"RTX 4090": 24,
|
|
"A100-40GB": 40,
|
|
"A100-80GB": 80,
|
|
"H100-80GB": 80,
|
|
"H100-96GB (SXM)": 96,
|
|
}
|
|
print(f"{'GPU':<20} {'Max Context (bs=1)':>20} {'Max Context (bs=4)':>20}")
|
|
print("-" * 62)
|
|
for gpu, mem in gpus.items():
|
|
ctx_1 = find_max_context(spec_7b, mem, batch_size=1)
|
|
ctx_4 = find_max_context(spec_7b, mem, batch_size=4)
|
|
print(f"{gpu:<20} {ctx_1:>20,} {ctx_4:>20,}")
|
|
|
|
print("\n" + "=" * 80)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print_analysis()
|