Files
deep_pro_judge/qwen36/KV_cache system.md
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

17 KiB
Raw Blame History

PROMPT: Implement an efficient KV-cache system for autoregressive transformer inference from scratch.

Requirements:

  1. Support incremental decoding (one token at a time).
  2. Avoid recomputing attention for past tokens.
  3. Handle:
    • multi-head attention
    • batching with variable sequence lengths
  4. Provide:
    • data structure layout (memory format)
    • update logic per step
    • attention computation using cached keys/values

Additionally:

  • Analyze memory growth over long sequences.
  • Propose at least two optimizations (e.g., paged attention, chunking, compression).
  • Explain how this would map to GPU execution.

Do not use any frameworks.

MODEL:

Here's a thinking process:

  1. Understand User Request:
  • Goal: Implement an efficient KV-cache system for autoregressive transformer inference from scratch.
  • Requirements:
    • Support incremental decoding (one token at a time)
    • Avoid recomputing attention for past tokens
    • Handle multi-head attention
    • Handle batching with variable sequence lengths
    • Provide: data structure layout, update logic per step, attention computation using cached K/V
  • Additional:
    • Analyze memory growth over long sequences
    • Propose at least two optimizations (e.g., paged attention, chunking, compression)
    • Explain GPU execution mapping
  • Constraint: Do not use any frameworks (pure Python/CUDA-like pseudocode or conceptual implementation, but I'll stick to clear, framework-agnostic code/pseudocode with numpy-like operations for clarity, while emphasizing the algorithmic structure).
  1. Deconstruct the Problem:
  • KV-Cache Basics: In autoregressive decoding, each new token attends to all previous tokens. Recomputing Q, K, V for past tokens is wasteful. We cache K and V tensors for each layer, each head, and each sequence position.
  • Data Structure: Needs to store K and V for each layer, each head, each batch item, up to max sequence length. Format: [batch_size, num_heads, seq_len, head_dim] or [batch_size, seq_len, num_heads, head_dim]. I'll choose [batch_size, num_heads, seq_len, head_dim] as it's common and efficient for attention computation.
  • Batching with Variable Lengths: Need a way to handle different sequence lengths per batch item. Common approaches: padding + attention mask, or packed sequences with offsets. I'll use padding + mask for simplicity, but mention packed sequences as an optimization.
  • Incremental Decoding: At each step, compute Q for the new token, retrieve cached K/V, compute attention, cache new K/V.
  • Attention Computation: Standard scaled dot-product attention with mask to handle variable lengths and prevent attending to future tokens (though in decoding, it's naturally causal).
  • Memory Growth: KV cache grows linearly with sequence length. Formula: 2 * num_layers * batch_size * num_heads * seq_len * head_dim * dtype_size.
  • Optimizations: PagedAttention (vLLM), quantization/compression, chunked/continuous batching, sliding window, eviction policies.
  • GPU Mapping: Memory coalescing, shared memory for attention, kernel fusion, async memory transfers, handling variable lengths via warp-level primitives or block-level tiling.
  1. Design Data Structure & Layout:
  • Let's define:
    • B: batch size
    • L: max sequence length
    • H: number of attention heads
    • D: head dimension
    • N: number of transformer layers
  • KV Cache per layer: K_cache[B, H, L, D], V_cache[B, H, L, D]
  • To handle variable lengths efficiently, I'll add a seq_lens[B] array and an attention_mask[B, 1, 1, L] (or [B, L] broadcasted).
  • Memory layout: Contiguous in memory for GPU efficiency. [B, H, L, D] allows efficient tiling in attention kernels.
  1. Update Logic per Step:
  • Input: x_new of shape [B, 1, hidden_dim]
  • Compute Q, K, V for the new token: Q = x_new @ W_Q, K = x_new @ W_K, V = x_new @ W_V
  • Reshape to [B, H, 1, D]
  • Update cache: K_cache[:, :, seq_pos:seq_pos+1, :] = K, same for V
  • Compute attention: scores = Q @ K_cache.transpose(..., 2, 3) / sqrt(D)
  • Apply mask: scores += mask (where mask is -inf for invalid positions)
  • attn_weights = softmax(scores)
  • output = attn_weights @ V_cache
  • Reshape and combine heads, apply output projection.
  1. Attention Computation with Cache:
  • Need to handle the mask carefully for variable lengths.
  • Mask shape: [B, 1, 1, L] where mask[b, 0, 0, j] = 0 if j < seq_lens[b] else -inf
  • Actually, during decoding, we only attend to positions 0 to seq_lens[b]-1. The new token is at seq_lens[b], so after update, seq_lens[b] += 1.
  • I'll write clear pseudocode/numpy-like code that's framework-agnostic.
  1. Memory Growth Analysis:
  • Formula: Memory = 2 * N * B * H * L * D * sizeof(dtype)
  • Example: Llama-2-7B: N=32, H=32, D=128, B=1, L=2048, dtype=float16 (2 bytes)
  • 2 * 32 * 1 * 32 * 2048 * 128 * 2 ≈ 8.5 GB
  • Grows linearly with L and B. Becomes bottleneck for long contexts or large batches.
  1. Optimizations:
  • PagedAttention (vLLM): Divide KV cache into fixed-size blocks (e.g., 32 tokens). Use a page table to map logical sequence positions to physical blocks. Enables memory compaction, supports branching, reduces fragmentation.
  • Quantization/Compression: Store KV cache in INT8/INT4 instead of FP16/BF16. Reduces memory by 2x-4x. Can be done per-channel or per-token.
  • Sliding Window Attention: Only cache last W tokens. Older tokens are evicted. Good for local dependencies.
  • Chunked/Continuous Batching: Interleave prefill and decode steps, dynamically schedule tokens to maximize GPU utilization.
  • I'll detail PagedAttention and Quantization as requested.
  1. GPU Execution Mapping:
  • Memory access patterns: [B, H, L, D] layout allows coalesced reads when processing heads in parallel.
  • Attention kernel: Typically uses tiling (e.g., 128x128 blocks) with shared memory for Q, K, V tiles.
  • Mask application: Warp-level primitives or predicated execution.
  • Cache update: Simple strided copy, highly parallel.
  • Variable lengths: Use block-level tiling where each block handles a subset of sequences, or use a unified attention kernel that handles padding via masks. Modern GPUs use warp-level matrix instructions (WMMA/TMA) for efficiency.
  • I'll explain how the layout maps to GPU memory hierarchy, kernel design, and synchronization.
  1. Structure the Response:
  • Introduction
  • Data Structure Layout
  • Update Logic & Attention Computation (with code/pseudocode)
  • Memory Growth Analysis
  • Optimizations (PagedAttention, Quantization)
  • GPU Execution Mapping
  • Conclusion
  • Ensure no frameworks are used, pure algorithmic/mathematical description with clear pseudocode.

I'll write it carefully, checking each requirement.

  • Incremental decoding: Yes, step-by-step update.
  • Avoid recomputing: Cache K/V, only compute Q for new token.
  • Multi-head: Handled via H dimension.
  • Batching with variable lengths: Handled via seq_lens and attention mask.
  • Data structure layout: Explicitly defined.
  • Update logic: Step-by-step.
  • Attention computation: With mask and cache.
  • Memory growth: Formula + example.
  • Optimizations: Two detailed.
  • GPU mapping: Memory hierarchy, kernel design, coalescing, tiling.

All requirements met. I'll draft it now. Self-Correction/Refinement during thought:

  • The attention mask for variable lengths during decoding: Actually, during decoding, each sequence has its own current length. The mask should be [B, 1, 1, L] where valid positions are 0 to seq_lens[b]-1. After computing attention, we update seq_lens[b] += 1.
  • I should clarify that the cache is pre-allocated to max length, but only valid up to seq_lens[b].
  • For GPU mapping, I'll mention TMA (Tensor Memory Accelerator) for Hopper, but keep it general enough.
  • I'll use clear mathematical notation and pseudocode that's framework-agnostic.
  • I'll make sure the code doesn't import numpy/torch, just uses array-like operations conceptually.

Structure:

  1. Data Structure & Memory Layout
  2. Incremental Decoding & Update Logic
  3. Attention Computation with Cache
  4. Memory Growth Analysis
  5. Optimizations (PagedAttention, Quantization)
  6. GPU Execution Mapping
  7. Summary

Let's write it out carefully. I'll ensure the pseudocode is clear, the math is precise, and the explanations are thorough. No frameworks, pure algorithmic description. All constraints satisfied. Proceed. Output matches the response. Self-Correction/Verification during drafting:

  • Check batch variable length handling: I'll use a seq_lens array and a causal+padding mask. The mask will be applied during attention. This is standard.
  • Check multi-head: Explicitly handled via H dimension and head-wise operations.
  • Check incremental: Only new token's Q,K,V computed, cache updated at seq_pos.
  • Memory growth: Linear with L and B. Formula correct.
  • Optimizations: PagedAttention and Quantization explained with mechanics.
  • GPU mapping: Coalesced access, tiling, shared memory, warp-level primitives, async copies. All good. I'll generate. Output generation. [Done] Self-Correction/Note during thought: I should make sure the pseudocode doesn't accidentally imply framework usage. I'll use explicit array indexing and mathematical operations. I'll also clarify that in practice, this would be implemented in C/CUDA, but the algorithm is framework-agnostic. I'll structure it clearly. All requirements are explicitly addressed. Ready. Output matches the final response. [Proceeds]

Here is a complete, framework-agnostic design for an efficient KV-cache system tailored for autoregressive transformer inference.


1. Data Structure & Memory Layout

Core Dimensions

  • B: batch size
  • L: maximum context length
  • N: number of transformer layers
  • H: number of attention heads
  • D: head dimension (hidden_dim / H)
  • dtype: numerical precision (e.g., FP16/BF16)

Cache Layout

For each layer l ∈ [0, N-1], we maintain two contiguous tensors:

K_cache[l][b, h, p, d]  ∈ ^(B × H × L × D)
V_cache[l][b, h, p, d]  ∈ ^(B × H × L × D)
  • b: batch index
  • h: head index
  • p: sequence position (0 to L-1)
  • d: head dimension index

Memory Format Choice: [B, H, L, D] is preferred over [B, L, H, D] because:

  1. Attention computation iterates over heads independently, making H the second dimension cache-friendly for head-parallel kernels.
  2. The innermost dimension D aligns with SIMD/vector registers, enabling coalesced memory access during Q @ K^T and attn @ V.

Auxiliary State

  • seq_lens[B]: current valid length per sequence
  • pos_ptr[B]: next insertion index (equals seq_lens[b] before update)
  • mask[B, 1, 1, L]: attention mask (0 for valid, -∞ for padding/future)

2. Incremental Decoding & Update Logic

At decoding step t, we receive a single new token per sequence: x_new ∈ ^(B × 1 × hidden_dim).

Step-by-Step Update (per layer l)

# 1. Project new token to Q, K, V
Q = x_new @ W_Q[l]  # (B, 1, hidden_dim)
K = x_new @ W_K[l]  # (B, 1, hidden_dim)
V = x_new @ W_V[l]  # (B, 1, hidden_dim)

# 2. Reshape to multi-head format
Q = reshape(Q, (B, 1, H, D)).transpose(0, 2, 1, 3)  # (B, H, 1, D)
K = reshape(K, (B, 1, H, D)).transpose(0, 2, 1, 3)  # (B, H, 1, D)
V = reshape(V, (B, 1, H, D)).transpose(0, 2, 1, 3)  # (B, H, 1, D)

# 3. Write new K, V into cache at current position
for b in range(B):
    p = pos_ptr[b]
    K_cache[l][b, :, p, :] = K[b, :, 0, :]
    V_cache[l][b, :, p, :] = V[b, :, 0, :]
    pos_ptr[b] += 1

# 4. Compute attention using cached K, V
scores = matmul(Q, K_cache[l].transpose(0, 1, 3, 2)) / sqrt(D)  # (B, H, 1, L)
scores = scores + mask[b]  # apply padding/causal mask
attn_weights = softmax(scores, dim=-1)
output = matmul(attn_weights, V_cache[l])  # (B, H, 1, D)

# 5. Reshape and project back
output = output.transpose(0, 2, 1, 3).reshape(B, 1, hidden_dim)
output = output @ W_O[l]

Key Properties:

  • Only Q, K, V for the new token are computed. Past K, V are reused.
  • pos_ptr ensures each sequence writes to its correct offset, naturally handling variable lengths.
  • The mask prevents attending to padded positions and enforces causality.

3. Attention Computation with Cache

The attention operation is:

Attention(Q, K_cache, V_cache, mask) = softmax( (Q @ K_cache^T) / √D + mask ) @ V_cache

Mask Construction for Variable Lengths:

mask[b, 0, 0, p] = 0.0 if p < seq_lens[b] else -INF

During decoding, seq_lens[b] increases by 1 each step. The mask is typically precomputed or updated incrementally using a simple comparison kernel.

Numerical Stability:

  • Apply max(scores, dim=-1, keepdim=True) subtraction before softmax to prevent overflow.
  • Use FP32 accumulation for scores and attn_weights even if cache is FP16.

4. Memory Growth Analysis

KV-cache memory scales linearly with sequence length and batch size:

Memory = 2 × N × B × H × L × D × sizeof(dtype)

Example (Llama-2-7B, FP16):

  • N=32, H=32, D=128, B=1, L=2048, sizeof(FP16)=2
  • Memory = 2 × 32 × 1 × 32 × 2048 × 128 × 2 ≈ 8.5 GB

Implications:

  • Doubling context length doubles cache memory.
  • Batch size B multiplies memory linearly, limiting throughput for long sequences.
  • Cache dominates GPU memory for L > 1024, often exceeding model weights.

5. Optimizations

A. PagedAttention (Memory Virtualization)

Problem: Contiguous allocation wastes memory due to fragmentation and unused padding. Branching (e.g., speculative decoding) requires copying entire caches.

Solution:

  • Divide KV cache into fixed-size blocks (e.g., 32 tokens/block).
  • Maintain a page table per sequence mapping logical positions to physical block IDs.
  • Allocate blocks on-demand from a global pool.
  • Supports memory compaction, zero-copy branching, and eliminates padding waste.

Memory Savings: Up to 30-50% reduction in practice, especially with variable-length batches and speculative decoding.

B. KV Cache Quantization

Problem: FP16/BF16 cache is memory-heavy. Attention scores are robust to low-precision K/V.

Solution:

  • Quantize K_cache and V_cache to INT8 or INT4 per-channel or per-token.
  • Store scale/zero-point metadata per head or per block.
  • Dequantize on-the-fly during attention computation using fused kernels.

Trade-offs:

  • INT8: ~2× memory reduction, <0.5% perplexity drop.
  • INT4: ~4× reduction, requires careful calibration or learned scales.
  • Compatible with PagedAttention (quantized blocks).

6. GPU Execution Mapping

Memory Hierarchy & Access Patterns

  • Global Memory: [B, H, L, D] layout enables coalesced reads when threads in a warp process consecutive D elements.
  • Shared Memory: Attention kernels tile Q (1×D) and K/V (L×D) into shared memory. Typical tile: 128×128 or 64×128.
  • Registers: Per-thread accumulation of Q @ K^T and attn @ V fragments.

Kernel Design

  1. Cache Update Kernel:

    • Strided copy from register/shared memory to global memory at pos_ptr[b].
    • Highly parallel, latency-bound. Can be fused with projection kernel.
  2. Attention Kernel:

    • Each thread block handles one (b, h) pair.
    • Uses warp-level matrix instructions (WMMA/TMA on Hopper) for Q @ K^T.
    • Mask applied via predicated execution or warp-level ballot.
    • Softmax computed in two passes (max then exp/sum) to avoid synchronization.
    • attn @ V computed in shared memory with double-buffering for latency hiding.
  3. Variable Length Handling:

    • Blocks process sequences with similar lengths together (batch scheduling).
    • Warp-level primitives skip invalid positions using seq_lens[b].
    • Modern GPUs use Tensor Memory Accelerator (TMA) for async, coalesced cache fetches.

Synchronization & Throughput

  • No cross-sequence synchronization needed; each (b, h) is independent.
  • Pipeline parallelism: While layer l computes attention, layer l+1 can prefetch next token's projections.
  • Memory bandwidth is the bottleneck; quantization and PagedAttention directly improve effective bandwidth utilization.

Summary

Component Design Choice Rationale
Layout [B, H, L, D] Coalesced access, head-parallelism, register-friendly
Update Incremental write at pos_ptr O(1) per token, no recomputation
Attention Masked softmax(QK^T/√D)V Causal + variable-length safe
Memory Growth Linear in L × B Inevitable without compression/eviction
Optimizations PagedAttention, Quantization Fragmentation reduction, 2-4× memory savings
GPU Mapping Tiled shared memory, TMA, warp-level masking Maximizes occupancy, hides latency, avoids divergence

This design forms the foundation of modern inference engines (vLLM, TensorRT-LLM, SGLang). Implementation in C/CUDA would fuse projection, cache update, and attention into a single kernel per layer, leveraging async memory copies and register tiling to achieve near-theoretical bandwidth utilization.