- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
17 KiB
PROMPT: Implement an efficient KV-cache system for autoregressive transformer inference from scratch.
Requirements:
- Support incremental decoding (one token at a time).
- Avoid recomputing attention for past tokens.
- Handle:
- multi-head attention
- batching with variable sequence lengths
- Provide:
- data structure layout (memory format)
- update logic per step
- attention computation using cached keys/values
Additionally:
- Analyze memory growth over long sequences.
- Propose at least two optimizations (e.g., paged attention, chunking, compression).
- Explain how this would map to GPU execution.
Do not use any frameworks.
MODEL:
Here's a thinking process:
- Understand User Request:
- Goal: Implement an efficient KV-cache system for autoregressive transformer inference from scratch.
- Requirements:
- Support incremental decoding (one token at a time)
- Avoid recomputing attention for past tokens
- Handle multi-head attention
- Handle batching with variable sequence lengths
- Provide: data structure layout, update logic per step, attention computation using cached K/V
- Additional:
- Analyze memory growth over long sequences
- Propose at least two optimizations (e.g., paged attention, chunking, compression)
- Explain GPU execution mapping
- Constraint: Do not use any frameworks (pure Python/CUDA-like pseudocode or conceptual implementation, but I'll stick to clear, framework-agnostic code/pseudocode with numpy-like operations for clarity, while emphasizing the algorithmic structure).
- Deconstruct the Problem:
- KV-Cache Basics: In autoregressive decoding, each new token attends to all previous tokens. Recomputing Q, K, V for past tokens is wasteful. We cache K and V tensors for each layer, each head, and each sequence position.
- Data Structure: Needs to store K and V for each layer, each head, each batch item, up to max sequence length. Format:
[batch_size, num_heads, seq_len, head_dim]or[batch_size, seq_len, num_heads, head_dim]. I'll choose[batch_size, num_heads, seq_len, head_dim]as it's common and efficient for attention computation. - Batching with Variable Lengths: Need a way to handle different sequence lengths per batch item. Common approaches: padding + attention mask, or packed sequences with offsets. I'll use padding + mask for simplicity, but mention packed sequences as an optimization.
- Incremental Decoding: At each step, compute Q for the new token, retrieve cached K/V, compute attention, cache new K/V.
- Attention Computation: Standard scaled dot-product attention with mask to handle variable lengths and prevent attending to future tokens (though in decoding, it's naturally causal).
- Memory Growth: KV cache grows linearly with sequence length. Formula:
2 * num_layers * batch_size * num_heads * seq_len * head_dim * dtype_size. - Optimizations: PagedAttention (vLLM), quantization/compression, chunked/continuous batching, sliding window, eviction policies.
- GPU Mapping: Memory coalescing, shared memory for attention, kernel fusion, async memory transfers, handling variable lengths via warp-level primitives or block-level tiling.
- Design Data Structure & Layout:
- Let's define:
B: batch sizeL: max sequence lengthH: number of attention headsD: head dimensionN: number of transformer layers
- KV Cache per layer:
K_cache[B, H, L, D],V_cache[B, H, L, D] - To handle variable lengths efficiently, I'll add a
seq_lens[B]array and anattention_mask[B, 1, 1, L](or[B, L]broadcasted). - Memory layout: Contiguous in memory for GPU efficiency.
[B, H, L, D]allows efficient tiling in attention kernels.
- Update Logic per Step:
- Input:
x_newof shape[B, 1, hidden_dim] - Compute
Q, K, Vfor the new token:Q = x_new @ W_Q,K = x_new @ W_K,V = x_new @ W_V - Reshape to
[B, H, 1, D] - Update cache:
K_cache[:, :, seq_pos:seq_pos+1, :] = K, same for V - Compute attention:
scores = Q @ K_cache.transpose(..., 2, 3) / sqrt(D) - Apply mask:
scores += mask(where mask is-inffor invalid positions) attn_weights = softmax(scores)output = attn_weights @ V_cache- Reshape and combine heads, apply output projection.
- Attention Computation with Cache:
- Need to handle the mask carefully for variable lengths.
- Mask shape:
[B, 1, 1, L]wheremask[b, 0, 0, j] = 0ifj < seq_lens[b]else-inf - Actually, during decoding, we only attend to positions
0toseq_lens[b]-1. The new token is atseq_lens[b], so after update,seq_lens[b] += 1. - I'll write clear pseudocode/numpy-like code that's framework-agnostic.
- Memory Growth Analysis:
- Formula:
Memory = 2 * N * B * H * L * D * sizeof(dtype) - Example: Llama-2-7B: N=32, H=32, D=128, B=1, L=2048, dtype=float16 (2 bytes)
2 * 32 * 1 * 32 * 2048 * 128 * 2 ≈ 8.5 GB- Grows linearly with L and B. Becomes bottleneck for long contexts or large batches.
- Optimizations:
- PagedAttention (vLLM): Divide KV cache into fixed-size blocks (e.g., 32 tokens). Use a page table to map logical sequence positions to physical blocks. Enables memory compaction, supports branching, reduces fragmentation.
- Quantization/Compression: Store KV cache in INT8/INT4 instead of FP16/BF16. Reduces memory by 2x-4x. Can be done per-channel or per-token.
- Sliding Window Attention: Only cache last W tokens. Older tokens are evicted. Good for local dependencies.
- Chunked/Continuous Batching: Interleave prefill and decode steps, dynamically schedule tokens to maximize GPU utilization.
- I'll detail PagedAttention and Quantization as requested.
- GPU Execution Mapping:
- Memory access patterns:
[B, H, L, D]layout allows coalesced reads when processing heads in parallel. - Attention kernel: Typically uses tiling (e.g., 128x128 blocks) with shared memory for Q, K, V tiles.
- Mask application: Warp-level primitives or predicated execution.
- Cache update: Simple strided copy, highly parallel.
- Variable lengths: Use block-level tiling where each block handles a subset of sequences, or use a unified attention kernel that handles padding via masks. Modern GPUs use warp-level matrix instructions (WMMA/TMA) for efficiency.
- I'll explain how the layout maps to GPU memory hierarchy, kernel design, and synchronization.
- Structure the Response:
- Introduction
- Data Structure Layout
- Update Logic & Attention Computation (with code/pseudocode)
- Memory Growth Analysis
- Optimizations (PagedAttention, Quantization)
- GPU Execution Mapping
- Conclusion
- Ensure no frameworks are used, pure algorithmic/mathematical description with clear pseudocode.
I'll write it carefully, checking each requirement.
- Incremental decoding: Yes, step-by-step update.
- Avoid recomputing: Cache K/V, only compute Q for new token.
- Multi-head: Handled via H dimension.
- Batching with variable lengths: Handled via seq_lens and attention mask.
- Data structure layout: Explicitly defined.
- Update logic: Step-by-step.
- Attention computation: With mask and cache.
- Memory growth: Formula + example.
- Optimizations: Two detailed.
- GPU mapping: Memory hierarchy, kernel design, coalescing, tiling.
All requirements met. I'll draft it now. Self-Correction/Refinement during thought:
- The attention mask for variable lengths during decoding: Actually, during decoding, each sequence has its own current length. The mask should be
[B, 1, 1, L]where valid positions are0toseq_lens[b]-1. After computing attention, we updateseq_lens[b] += 1. - I should clarify that the cache is pre-allocated to max length, but only valid up to
seq_lens[b]. - For GPU mapping, I'll mention TMA (Tensor Memory Accelerator) for Hopper, but keep it general enough.
- I'll use clear mathematical notation and pseudocode that's framework-agnostic.
- I'll make sure the code doesn't import numpy/torch, just uses array-like operations conceptually.
Structure:
- Data Structure & Memory Layout
- Incremental Decoding & Update Logic
- Attention Computation with Cache
- Memory Growth Analysis
- Optimizations (PagedAttention, Quantization)
- GPU Execution Mapping
- Summary
Let's write it out carefully. I'll ensure the pseudocode is clear, the math is precise, and the explanations are thorough. No frameworks, pure algorithmic description. All constraints satisfied.✅ Proceed. Output matches the response. Self-Correction/Verification during drafting:
- Check batch variable length handling: I'll use a
seq_lensarray and a causal+padding mask. The mask will be applied during attention. This is standard. - Check multi-head: Explicitly handled via
Hdimension and head-wise operations. - Check incremental: Only new token's Q,K,V computed, cache updated at
seq_pos. - Memory growth: Linear with L and B. Formula correct.
- Optimizations: PagedAttention and Quantization explained with mechanics.
- GPU mapping: Coalesced access, tiling, shared memory, warp-level primitives, async copies. All good. I'll generate. Output generation. [Done] Self-Correction/Note during thought: I should make sure the pseudocode doesn't accidentally imply framework usage. I'll use explicit array indexing and mathematical operations. I'll also clarify that in practice, this would be implemented in C/CUDA, but the algorithm is framework-agnostic. I'll structure it clearly. All requirements are explicitly addressed. Ready. Output matches the final response. [Proceeds]
Here is a complete, framework-agnostic design for an efficient KV-cache system tailored for autoregressive transformer inference.
1. Data Structure & Memory Layout
Core Dimensions
B: batch sizeL: maximum context lengthN: number of transformer layersH: number of attention headsD: head dimension (hidden_dim / H)dtype: numerical precision (e.g., FP16/BF16)
Cache Layout
For each layer l ∈ [0, N-1], we maintain two contiguous tensors:
K_cache[l][b, h, p, d] ∈ ℝ^(B × H × L × D)
V_cache[l][b, h, p, d] ∈ ℝ^(B × H × L × D)
b: batch indexh: head indexp: sequence position (0 to L-1)d: head dimension index
Memory Format Choice: [B, H, L, D] is preferred over [B, L, H, D] because:
- Attention computation iterates over heads independently, making
Hthe second dimension cache-friendly for head-parallel kernels. - The innermost dimension
Daligns with SIMD/vector registers, enabling coalesced memory access duringQ @ K^Tandattn @ V.
Auxiliary State
seq_lens[B]: current valid length per sequencepos_ptr[B]: next insertion index (equalsseq_lens[b]before update)mask[B, 1, 1, L]: attention mask (0 for valid,-∞for padding/future)
2. Incremental Decoding & Update Logic
At decoding step t, we receive a single new token per sequence: x_new ∈ ℝ^(B × 1 × hidden_dim).
Step-by-Step Update (per layer l)
# 1. Project new token to Q, K, V
Q = x_new @ W_Q[l] # (B, 1, hidden_dim)
K = x_new @ W_K[l] # (B, 1, hidden_dim)
V = x_new @ W_V[l] # (B, 1, hidden_dim)
# 2. Reshape to multi-head format
Q = reshape(Q, (B, 1, H, D)).transpose(0, 2, 1, 3) # (B, H, 1, D)
K = reshape(K, (B, 1, H, D)).transpose(0, 2, 1, 3) # (B, H, 1, D)
V = reshape(V, (B, 1, H, D)).transpose(0, 2, 1, 3) # (B, H, 1, D)
# 3. Write new K, V into cache at current position
for b in range(B):
p = pos_ptr[b]
K_cache[l][b, :, p, :] = K[b, :, 0, :]
V_cache[l][b, :, p, :] = V[b, :, 0, :]
pos_ptr[b] += 1
# 4. Compute attention using cached K, V
scores = matmul(Q, K_cache[l].transpose(0, 1, 3, 2)) / sqrt(D) # (B, H, 1, L)
scores = scores + mask[b] # apply padding/causal mask
attn_weights = softmax(scores, dim=-1)
output = matmul(attn_weights, V_cache[l]) # (B, H, 1, D)
# 5. Reshape and project back
output = output.transpose(0, 2, 1, 3).reshape(B, 1, hidden_dim)
output = output @ W_O[l]
Key Properties:
- Only
Q, K, Vfor the new token are computed. PastK, Vare reused. pos_ptrensures each sequence writes to its correct offset, naturally handling variable lengths.- The mask prevents attending to padded positions and enforces causality.
3. Attention Computation with Cache
The attention operation is:
Attention(Q, K_cache, V_cache, mask) = softmax( (Q @ K_cache^T) / √D + mask ) @ V_cache
Mask Construction for Variable Lengths:
mask[b, 0, 0, p] = 0.0 if p < seq_lens[b] else -INF
During decoding, seq_lens[b] increases by 1 each step. The mask is typically precomputed or updated incrementally using a simple comparison kernel.
Numerical Stability:
- Apply
max(scores, dim=-1, keepdim=True)subtraction before softmax to prevent overflow. - Use FP32 accumulation for
scoresandattn_weightseven if cache is FP16.
4. Memory Growth Analysis
KV-cache memory scales linearly with sequence length and batch size:
Memory = 2 × N × B × H × L × D × sizeof(dtype)
Example (Llama-2-7B, FP16):
N=32, H=32, D=128, B=1, L=2048, sizeof(FP16)=2Memory = 2 × 32 × 1 × 32 × 2048 × 128 × 2 ≈ 8.5 GB
Implications:
- Doubling context length doubles cache memory.
- Batch size
Bmultiplies memory linearly, limiting throughput for long sequences. - Cache dominates GPU memory for
L > 1024, often exceeding model weights.
5. Optimizations
A. PagedAttention (Memory Virtualization)
Problem: Contiguous allocation wastes memory due to fragmentation and unused padding. Branching (e.g., speculative decoding) requires copying entire caches.
Solution:
- Divide KV cache into fixed-size blocks (e.g., 32 tokens/block).
- Maintain a page table per sequence mapping logical positions to physical block IDs.
- Allocate blocks on-demand from a global pool.
- Supports memory compaction, zero-copy branching, and eliminates padding waste.
Memory Savings: Up to 30-50% reduction in practice, especially with variable-length batches and speculative decoding.
B. KV Cache Quantization
Problem: FP16/BF16 cache is memory-heavy. Attention scores are robust to low-precision K/V.
Solution:
- Quantize
K_cacheandV_cacheto INT8 or INT4 per-channel or per-token. - Store scale/zero-point metadata per head or per block.
- Dequantize on-the-fly during attention computation using fused kernels.
Trade-offs:
- INT8: ~2× memory reduction, <0.5% perplexity drop.
- INT4: ~4× reduction, requires careful calibration or learned scales.
- Compatible with PagedAttention (quantized blocks).
6. GPU Execution Mapping
Memory Hierarchy & Access Patterns
- Global Memory:
[B, H, L, D]layout enables coalesced reads when threads in a warp process consecutiveDelements. - Shared Memory: Attention kernels tile
Q(1×D) andK/V(L×D) into shared memory. Typical tile: 128×128 or 64×128. - Registers: Per-thread accumulation of
Q @ K^Tandattn @ Vfragments.
Kernel Design
-
Cache Update Kernel:
- Strided copy from register/shared memory to global memory at
pos_ptr[b]. - Highly parallel, latency-bound. Can be fused with projection kernel.
- Strided copy from register/shared memory to global memory at
-
Attention Kernel:
- Each thread block handles one
(b, h)pair. - Uses warp-level matrix instructions (WMMA/TMA on Hopper) for
Q @ K^T. - Mask applied via predicated execution or warp-level ballot.
- Softmax computed in two passes (max then exp/sum) to avoid synchronization.
attn @ Vcomputed in shared memory with double-buffering for latency hiding.
- Each thread block handles one
-
Variable Length Handling:
- Blocks process sequences with similar lengths together (batch scheduling).
- Warp-level primitives skip invalid positions using
seq_lens[b]. - Modern GPUs use Tensor Memory Accelerator (TMA) for async, coalesced cache fetches.
Synchronization & Throughput
- No cross-sequence synchronization needed; each
(b, h)is independent. - Pipeline parallelism: While layer
lcomputes attention, layerl+1can prefetch next token's projections. - Memory bandwidth is the bottleneck; quantization and PagedAttention directly improve effective bandwidth utilization.
Summary
| Component | Design Choice | Rationale |
|---|---|---|
| Layout | [B, H, L, D] |
Coalesced access, head-parallelism, register-friendly |
| Update | Incremental write at pos_ptr |
O(1) per token, no recomputation |
| Attention | Masked softmax(QK^T/√D)V |
Causal + variable-length safe |
| Memory Growth | Linear in L × B |
Inevitable without compression/eviction |
| Optimizations | PagedAttention, Quantization | Fragmentation reduction, 2-4× memory savings |
| GPU Mapping | Tiled shared memory, TMA, warp-level masking | Maximizes occupancy, hides latency, avoids divergence |
This design forms the foundation of modern inference engines (vLLM, TensorRT-LLM, SGLang). Implementation in C/CUDA would fuse projection, cache update, and attention into a single kernel per layer, leveraging async memory copies and register tiling to achieve near-theoretical bandwidth utilization.