feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
This commit is contained in:
@@ -0,0 +1,334 @@
|
||||
PROMPT: Implement an efficient KV-cache system for autoregressive transformer inference from scratch.
|
||||
|
||||
Requirements:
|
||||
1. Support incremental decoding (one token at a time).
|
||||
2. Avoid recomputing attention for past tokens.
|
||||
3. Handle:
|
||||
- multi-head attention
|
||||
- batching with variable sequence lengths
|
||||
4. Provide:
|
||||
- data structure layout (memory format)
|
||||
- update logic per step
|
||||
- attention computation using cached keys/values
|
||||
|
||||
Additionally:
|
||||
- Analyze memory growth over long sequences.
|
||||
- Propose at least two optimizations (e.g., paged attention, chunking, compression).
|
||||
- Explain how this would map to GPU execution.
|
||||
|
||||
Do not use any frameworks.
|
||||
|
||||
|
||||
MODEL:
|
||||
|
||||
Here's a thinking process:
|
||||
|
||||
1. **Understand User Request:**
|
||||
- **Goal:** Implement an efficient KV-cache system for autoregressive transformer inference from scratch.
|
||||
- **Requirements:**
|
||||
- Support incremental decoding (one token at a time)
|
||||
- Avoid recomputing attention for past tokens
|
||||
- Handle multi-head attention
|
||||
- Handle batching with variable sequence lengths
|
||||
- Provide: data structure layout, update logic per step, attention computation using cached K/V
|
||||
- **Additional:**
|
||||
- Analyze memory growth over long sequences
|
||||
- Propose at least two optimizations (e.g., paged attention, chunking, compression)
|
||||
- Explain GPU execution mapping
|
||||
- **Constraint:** Do not use any frameworks (pure Python/CUDA-like pseudocode or conceptual implementation, but I'll stick to clear, framework-agnostic code/pseudocode with numpy-like operations for clarity, while emphasizing the algorithmic structure).
|
||||
|
||||
2. **Deconstruct the Problem:**
|
||||
- **KV-Cache Basics:** In autoregressive decoding, each new token attends to all previous tokens. Recomputing Q, K, V for past tokens is wasteful. We cache K and V tensors for each layer, each head, and each sequence position.
|
||||
- **Data Structure:** Needs to store K and V for each layer, each head, each batch item, up to max sequence length. Format: `[batch_size, num_heads, seq_len, head_dim]` or `[batch_size, seq_len, num_heads, head_dim]`. I'll choose `[batch_size, num_heads, seq_len, head_dim]` as it's common and efficient for attention computation.
|
||||
- **Batching with Variable Lengths:** Need a way to handle different sequence lengths per batch item. Common approaches: padding + attention mask, or packed sequences with offsets. I'll use padding + mask for simplicity, but mention packed sequences as an optimization.
|
||||
- **Incremental Decoding:** At each step, compute Q for the new token, retrieve cached K/V, compute attention, cache new K/V.
|
||||
- **Attention Computation:** Standard scaled dot-product attention with mask to handle variable lengths and prevent attending to future tokens (though in decoding, it's naturally causal).
|
||||
- **Memory Growth:** KV cache grows linearly with sequence length. Formula: `2 * num_layers * batch_size * num_heads * seq_len * head_dim * dtype_size`.
|
||||
- **Optimizations:** PagedAttention (vLLM), quantization/compression, chunked/continuous batching, sliding window, eviction policies.
|
||||
- **GPU Mapping:** Memory coalescing, shared memory for attention, kernel fusion, async memory transfers, handling variable lengths via warp-level primitives or block-level tiling.
|
||||
|
||||
3. **Design Data Structure & Layout:**
|
||||
- Let's define:
|
||||
- `B`: batch size
|
||||
- `L`: max sequence length
|
||||
- `H`: number of attention heads
|
||||
- `D`: head dimension
|
||||
- `N`: number of transformer layers
|
||||
- KV Cache per layer: `K_cache[B, H, L, D]`, `V_cache[B, H, L, D]`
|
||||
- To handle variable lengths efficiently, I'll add a `seq_lens[B]` array and an `attention_mask[B, 1, 1, L]` (or `[B, L]` broadcasted).
|
||||
- Memory layout: Contiguous in memory for GPU efficiency. `[B, H, L, D]` allows efficient tiling in attention kernels.
|
||||
|
||||
4. **Update Logic per Step:**
|
||||
- Input: `x_new` of shape `[B, 1, hidden_dim]`
|
||||
- Compute `Q, K, V` for the new token: `Q = x_new @ W_Q`, `K = x_new @ W_K`, `V = x_new @ W_V`
|
||||
- Reshape to `[B, H, 1, D]`
|
||||
- Update cache: `K_cache[:, :, seq_pos:seq_pos+1, :] = K`, same for V
|
||||
- Compute attention: `scores = Q @ K_cache.transpose(..., 2, 3) / sqrt(D)`
|
||||
- Apply mask: `scores += mask` (where mask is `-inf` for invalid positions)
|
||||
- `attn_weights = softmax(scores)`
|
||||
- `output = attn_weights @ V_cache`
|
||||
- Reshape and combine heads, apply output projection.
|
||||
|
||||
5. **Attention Computation with Cache:**
|
||||
- Need to handle the mask carefully for variable lengths.
|
||||
- Mask shape: `[B, 1, 1, L]` where `mask[b, 0, 0, j] = 0` if `j < seq_lens[b]` else `-inf`
|
||||
- Actually, during decoding, we only attend to positions `0` to `seq_lens[b]-1`. The new token is at `seq_lens[b]`, so after update, `seq_lens[b] += 1`.
|
||||
- I'll write clear pseudocode/numpy-like code that's framework-agnostic.
|
||||
|
||||
6. **Memory Growth Analysis:**
|
||||
- Formula: `Memory = 2 * N * B * H * L * D * sizeof(dtype)`
|
||||
- Example: Llama-2-7B: N=32, H=32, D=128, B=1, L=2048, dtype=float16 (2 bytes)
|
||||
- `2 * 32 * 1 * 32 * 2048 * 128 * 2 ≈ 8.5 GB`
|
||||
- Grows linearly with L and B. Becomes bottleneck for long contexts or large batches.
|
||||
|
||||
7. **Optimizations:**
|
||||
- **PagedAttention (vLLM):** Divide KV cache into fixed-size blocks (e.g., 32 tokens). Use a page table to map logical sequence positions to physical blocks. Enables memory compaction, supports branching, reduces fragmentation.
|
||||
- **Quantization/Compression:** Store KV cache in INT8/INT4 instead of FP16/BF16. Reduces memory by 2x-4x. Can be done per-channel or per-token.
|
||||
- **Sliding Window Attention:** Only cache last W tokens. Older tokens are evicted. Good for local dependencies.
|
||||
- **Chunked/Continuous Batching:** Interleave prefill and decode steps, dynamically schedule tokens to maximize GPU utilization.
|
||||
- I'll detail PagedAttention and Quantization as requested.
|
||||
|
||||
8. **GPU Execution Mapping:**
|
||||
- Memory access patterns: `[B, H, L, D]` layout allows coalesced reads when processing heads in parallel.
|
||||
- Attention kernel: Typically uses tiling (e.g., 128x128 blocks) with shared memory for Q, K, V tiles.
|
||||
- Mask application: Warp-level primitives or predicated execution.
|
||||
- Cache update: Simple strided copy, highly parallel.
|
||||
- Variable lengths: Use block-level tiling where each block handles a subset of sequences, or use a unified attention kernel that handles padding via masks. Modern GPUs use warp-level matrix instructions (WMMA/TMA) for efficiency.
|
||||
- I'll explain how the layout maps to GPU memory hierarchy, kernel design, and synchronization.
|
||||
|
||||
9. **Structure the Response:**
|
||||
- Introduction
|
||||
- Data Structure Layout
|
||||
- Update Logic & Attention Computation (with code/pseudocode)
|
||||
- Memory Growth Analysis
|
||||
- Optimizations (PagedAttention, Quantization)
|
||||
- GPU Execution Mapping
|
||||
- Conclusion
|
||||
- Ensure no frameworks are used, pure algorithmic/mathematical description with clear pseudocode.
|
||||
|
||||
I'll write it carefully, checking each requirement.
|
||||
- Incremental decoding: Yes, step-by-step update.
|
||||
- Avoid recomputing: Cache K/V, only compute Q for new token.
|
||||
- Multi-head: Handled via H dimension.
|
||||
- Batching with variable lengths: Handled via seq_lens and attention mask.
|
||||
- Data structure layout: Explicitly defined.
|
||||
- Update logic: Step-by-step.
|
||||
- Attention computation: With mask and cache.
|
||||
- Memory growth: Formula + example.
|
||||
- Optimizations: Two detailed.
|
||||
- GPU mapping: Memory hierarchy, kernel design, coalescing, tiling.
|
||||
|
||||
All requirements met. I'll draft it now.
|
||||
Self-Correction/Refinement during thought:
|
||||
- The attention mask for variable lengths during decoding: Actually, during decoding, each sequence has its own current length. The mask should be `[B, 1, 1, L]` where valid positions are `0` to `seq_lens[b]-1`. After computing attention, we update `seq_lens[b] += 1`.
|
||||
- I should clarify that the cache is pre-allocated to max length, but only valid up to `seq_lens[b]`.
|
||||
- For GPU mapping, I'll mention TMA (Tensor Memory Accelerator) for Hopper, but keep it general enough.
|
||||
- I'll use clear mathematical notation and pseudocode that's framework-agnostic.
|
||||
- I'll make sure the code doesn't import numpy/torch, just uses array-like operations conceptually.
|
||||
|
||||
Structure:
|
||||
1. Data Structure & Memory Layout
|
||||
2. Incremental Decoding & Update Logic
|
||||
3. Attention Computation with Cache
|
||||
4. Memory Growth Analysis
|
||||
5. Optimizations (PagedAttention, Quantization)
|
||||
6. GPU Execution Mapping
|
||||
7. Summary
|
||||
|
||||
Let's write it out carefully.
|
||||
I'll ensure the pseudocode is clear, the math is precise, and the explanations are thorough. No frameworks, pure algorithmic description. All constraints satisfied.✅
|
||||
Proceed.
|
||||
Output matches the response.
|
||||
Self-Correction/Verification during drafting:
|
||||
- Check batch variable length handling: I'll use a `seq_lens` array and a causal+padding mask. The mask will be applied during attention. This is standard.
|
||||
- Check multi-head: Explicitly handled via `H` dimension and head-wise operations.
|
||||
- Check incremental: Only new token's Q,K,V computed, cache updated at `seq_pos`.
|
||||
- Memory growth: Linear with L and B. Formula correct.
|
||||
- Optimizations: PagedAttention and Quantization explained with mechanics.
|
||||
- GPU mapping: Coalesced access, tiling, shared memory, warp-level primitives, async copies.
|
||||
All good. I'll generate.
|
||||
Output generation.
|
||||
[Done]
|
||||
*Self-Correction/Note during thought*:
|
||||
I should make sure the pseudocode doesn't accidentally imply framework usage. I'll use explicit array indexing and mathematical operations. I'll also clarify that in practice, this would be implemented in C/CUDA, but the algorithm is framework-agnostic. I'll structure it clearly. All requirements are explicitly addressed. Ready.
|
||||
Output matches the final response.
|
||||
[Proceeds]
|
||||
|
||||
Here is a complete, framework-agnostic design for an efficient KV-cache system tailored for autoregressive transformer inference.
|
||||
|
||||
---
|
||||
## 1. Data Structure & Memory Layout
|
||||
|
||||
### Core Dimensions
|
||||
- `B`: batch size
|
||||
- `L`: maximum context length
|
||||
- `N`: number of transformer layers
|
||||
- `H`: number of attention heads
|
||||
- `D`: head dimension (`hidden_dim / H`)
|
||||
- `dtype`: numerical precision (e.g., FP16/BF16)
|
||||
|
||||
### Cache Layout
|
||||
For each layer `l ∈ [0, N-1]`, we maintain two contiguous tensors:
|
||||
```
|
||||
K_cache[l][b, h, p, d] ∈ ℝ^(B × H × L × D)
|
||||
V_cache[l][b, h, p, d] ∈ ℝ^(B × H × L × D)
|
||||
```
|
||||
- `b`: batch index
|
||||
- `h`: head index
|
||||
- `p`: sequence position (0 to L-1)
|
||||
- `d`: head dimension index
|
||||
|
||||
**Memory Format Choice:** `[B, H, L, D]` is preferred over `[B, L, H, D]` because:
|
||||
1. Attention computation iterates over heads independently, making `H` the second dimension cache-friendly for head-parallel kernels.
|
||||
2. The innermost dimension `D` aligns with SIMD/vector registers, enabling coalesced memory access during `Q @ K^T` and `attn @ V`.
|
||||
|
||||
### Auxiliary State
|
||||
- `seq_lens[B]`: current valid length per sequence
|
||||
- `pos_ptr[B]`: next insertion index (equals `seq_lens[b]` before update)
|
||||
- `mask[B, 1, 1, L]`: attention mask (0 for valid, `-∞` for padding/future)
|
||||
|
||||
---
|
||||
## 2. Incremental Decoding & Update Logic
|
||||
|
||||
At decoding step `t`, we receive a single new token per sequence: `x_new ∈ ℝ^(B × 1 × hidden_dim)`.
|
||||
|
||||
### Step-by-Step Update (per layer `l`)
|
||||
```python
|
||||
# 1. Project new token to Q, K, V
|
||||
Q = x_new @ W_Q[l] # (B, 1, hidden_dim)
|
||||
K = x_new @ W_K[l] # (B, 1, hidden_dim)
|
||||
V = x_new @ W_V[l] # (B, 1, hidden_dim)
|
||||
|
||||
# 2. Reshape to multi-head format
|
||||
Q = reshape(Q, (B, 1, H, D)).transpose(0, 2, 1, 3) # (B, H, 1, D)
|
||||
K = reshape(K, (B, 1, H, D)).transpose(0, 2, 1, 3) # (B, H, 1, D)
|
||||
V = reshape(V, (B, 1, H, D)).transpose(0, 2, 1, 3) # (B, H, 1, D)
|
||||
|
||||
# 3. Write new K, V into cache at current position
|
||||
for b in range(B):
|
||||
p = pos_ptr[b]
|
||||
K_cache[l][b, :, p, :] = K[b, :, 0, :]
|
||||
V_cache[l][b, :, p, :] = V[b, :, 0, :]
|
||||
pos_ptr[b] += 1
|
||||
|
||||
# 4. Compute attention using cached K, V
|
||||
scores = matmul(Q, K_cache[l].transpose(0, 1, 3, 2)) / sqrt(D) # (B, H, 1, L)
|
||||
scores = scores + mask[b] # apply padding/causal mask
|
||||
attn_weights = softmax(scores, dim=-1)
|
||||
output = matmul(attn_weights, V_cache[l]) # (B, H, 1, D)
|
||||
|
||||
# 5. Reshape and project back
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, 1, hidden_dim)
|
||||
output = output @ W_O[l]
|
||||
```
|
||||
|
||||
**Key Properties:**
|
||||
- Only `Q, K, V` for the new token are computed. Past `K, V` are reused.
|
||||
- `pos_ptr` ensures each sequence writes to its correct offset, naturally handling variable lengths.
|
||||
- The mask prevents attending to padded positions and enforces causality.
|
||||
|
||||
---
|
||||
## 3. Attention Computation with Cache
|
||||
|
||||
The attention operation is:
|
||||
```
|
||||
Attention(Q, K_cache, V_cache, mask) = softmax( (Q @ K_cache^T) / √D + mask ) @ V_cache
|
||||
```
|
||||
|
||||
**Mask Construction for Variable Lengths:**
|
||||
```python
|
||||
mask[b, 0, 0, p] = 0.0 if p < seq_lens[b] else -INF
|
||||
```
|
||||
During decoding, `seq_lens[b]` increases by 1 each step. The mask is typically precomputed or updated incrementally using a simple comparison kernel.
|
||||
|
||||
**Numerical Stability:**
|
||||
- Apply `max(scores, dim=-1, keepdim=True)` subtraction before softmax to prevent overflow.
|
||||
- Use FP32 accumulation for `scores` and `attn_weights` even if cache is FP16.
|
||||
|
||||
---
|
||||
## 4. Memory Growth Analysis
|
||||
|
||||
KV-cache memory scales **linearly** with sequence length and batch size:
|
||||
```
|
||||
Memory = 2 × N × B × H × L × D × sizeof(dtype)
|
||||
```
|
||||
|
||||
**Example (Llama-2-7B, FP16):**
|
||||
- `N=32, H=32, D=128, B=1, L=2048, sizeof(FP16)=2`
|
||||
- `Memory = 2 × 32 × 1 × 32 × 2048 × 128 × 2 ≈ 8.5 GB`
|
||||
|
||||
**Implications:**
|
||||
- Doubling context length doubles cache memory.
|
||||
- Batch size `B` multiplies memory linearly, limiting throughput for long sequences.
|
||||
- Cache dominates GPU memory for `L > 1024`, often exceeding model weights.
|
||||
|
||||
---
|
||||
## 5. Optimizations
|
||||
|
||||
### A. PagedAttention (Memory Virtualization)
|
||||
**Problem:** Contiguous allocation wastes memory due to fragmentation and unused padding. Branching (e.g., speculative decoding) requires copying entire caches.
|
||||
|
||||
**Solution:**
|
||||
- Divide KV cache into fixed-size blocks (e.g., 32 tokens/block).
|
||||
- Maintain a **page table** per sequence mapping logical positions to physical block IDs.
|
||||
- Allocate blocks on-demand from a global pool.
|
||||
- Supports memory compaction, zero-copy branching, and eliminates padding waste.
|
||||
|
||||
**Memory Savings:** Up to 30-50% reduction in practice, especially with variable-length batches and speculative decoding.
|
||||
|
||||
### B. KV Cache Quantization
|
||||
**Problem:** FP16/BF16 cache is memory-heavy. Attention scores are robust to low-precision K/V.
|
||||
|
||||
**Solution:**
|
||||
- Quantize `K_cache` and `V_cache` to INT8 or INT4 per-channel or per-token.
|
||||
- Store scale/zero-point metadata per head or per block.
|
||||
- Dequantize on-the-fly during attention computation using fused kernels.
|
||||
|
||||
**Trade-offs:**
|
||||
- INT8: ~2× memory reduction, <0.5% perplexity drop.
|
||||
- INT4: ~4× reduction, requires careful calibration or learned scales.
|
||||
- Compatible with PagedAttention (quantized blocks).
|
||||
|
||||
---
|
||||
## 6. GPU Execution Mapping
|
||||
|
||||
### Memory Hierarchy & Access Patterns
|
||||
- **Global Memory:** `[B, H, L, D]` layout enables coalesced reads when threads in a warp process consecutive `D` elements.
|
||||
- **Shared Memory:** Attention kernels tile `Q` (1×D) and `K/V` (L×D) into shared memory. Typical tile: 128×128 or 64×128.
|
||||
- **Registers:** Per-thread accumulation of `Q @ K^T` and `attn @ V` fragments.
|
||||
|
||||
### Kernel Design
|
||||
1. **Cache Update Kernel:**
|
||||
- Strided copy from register/shared memory to global memory at `pos_ptr[b]`.
|
||||
- Highly parallel, latency-bound. Can be fused with projection kernel.
|
||||
|
||||
2. **Attention Kernel:**
|
||||
- Each thread block handles one `(b, h)` pair.
|
||||
- Uses warp-level matrix instructions (WMMA/TMA on Hopper) for `Q @ K^T`.
|
||||
- Mask applied via predicated execution or warp-level ballot.
|
||||
- Softmax computed in two passes (max then exp/sum) to avoid synchronization.
|
||||
- `attn @ V` computed in shared memory with double-buffering for latency hiding.
|
||||
|
||||
3. **Variable Length Handling:**
|
||||
- Blocks process sequences with similar lengths together (batch scheduling).
|
||||
- Warp-level primitives skip invalid positions using `seq_lens[b]`.
|
||||
- Modern GPUs use **Tensor Memory Accelerator (TMA)** for async, coalesced cache fetches.
|
||||
|
||||
### Synchronization & Throughput
|
||||
- No cross-sequence synchronization needed; each `(b, h)` is independent.
|
||||
- Pipeline parallelism: While layer `l` computes attention, layer `l+1` can prefetch next token's projections.
|
||||
- Memory bandwidth is the bottleneck; quantization and PagedAttention directly improve effective bandwidth utilization.
|
||||
|
||||
---
|
||||
## Summary
|
||||
|
||||
| Component | Design Choice | Rationale |
|
||||
|-----------|---------------|-----------|
|
||||
| Layout | `[B, H, L, D]` | Coalesced access, head-parallelism, register-friendly |
|
||||
| Update | Incremental write at `pos_ptr` | O(1) per token, no recomputation |
|
||||
| Attention | Masked `softmax(QK^T/√D)V` | Causal + variable-length safe |
|
||||
| Memory Growth | Linear in `L × B` | Inevitable without compression/eviction |
|
||||
| Optimizations | PagedAttention, Quantization | Fragmentation reduction, 2-4× memory savings |
|
||||
| GPU Mapping | Tiled shared memory, TMA, warp-level masking | Maximizes occupancy, hides latency, avoids divergence |
|
||||
|
||||
This design forms the foundation of modern inference engines (vLLM, TensorRT-LLM, SGLang). Implementation in C/CUDA would fuse projection, cache update, and attention into a single kernel per layer, leveraging async memory copies and register tiling to achieve near-theoretical bandwidth utilization.
|
||||
Reference in New Issue
Block a user