[perf] Pre-allocated paged q4 KV cache to eliminate prefill memory spikes #51

Open
opened 2026-05-14 23:29:17 +02:00 by sleepy · 2 comments
Owner

Branch: feature/1-paged-q4-kv-cache (2 commits ahead of main). Also includes fix/rotation-q4-merge.

Pre-allocated paged KV cache that quantizes keys/values to q4_0 immediately on write. Keys use Hadamard rotation before quantization. Eliminates fp16 KV peak memory.

Commits:

  • c92124f feat(cache): add PagedQ4KVCache with incremental q4_0 quantization
  • a200dd0 fix(mtp): auto-enable MTP speculative decoding for models with mtp_forward

Key files:

  • omlx/cache/paged_q4_cache.py (PagedQ4KVCache)
  • omlx/rotation_q4_cache.py (RotationQ4KVCache)
  • omlx/patches/rotation_q4_attention.py (Hadamard rotation attention patch)

Acceptance criteria:

  • PagedQ4KVCache works with continuous batching
  • RotationQ4KVCache has merge() for batch merging
  • KV memory usage reduced by ~44% vs fp16
  • No quality regression vs fp16 KV (perplexity within 1%)
  • Prefix cache save/restore works with q4 format

Related: #48 (Cache misses with q4 KV), #39 (Rewrite KV cache for memory efficiency)

Branch: feature/1-paged-q4-kv-cache (2 commits ahead of main). Also includes fix/rotation-q4-merge. Pre-allocated paged KV cache that quantizes keys/values to q4_0 immediately on write. Keys use Hadamard rotation before quantization. Eliminates fp16 KV peak memory. Commits: - c92124f feat(cache): add PagedQ4KVCache with incremental q4_0 quantization - a200dd0 fix(mtp): auto-enable MTP speculative decoding for models with mtp_forward Key files: - omlx/cache/paged_q4_cache.py (PagedQ4KVCache) - omlx/rotation_q4_cache.py (RotationQ4KVCache) - omlx/patches/rotation_q4_attention.py (Hadamard rotation attention patch) Acceptance criteria: - PagedQ4KVCache works with continuous batching - RotationQ4KVCache has merge() for batch merging - KV memory usage reduced by ~44% vs fp16 - No quality regression vs fp16 KV (perplexity within 1%) - Prefix cache save/restore works with q4 format Related: #48 (Cache misses with q4 KV), #39 (Rewrite KV cache for memory efficiency)
Author
Owner

Superseded by current main. Main uses mlx-lm QuantizedKVCache directly (via _apply_quantized_kv) with BatchQuantizedKVCache for batching, Hadamard rotation via omlx/patches/quantized_attention.py, and merge support in _patched_merge_caches. The PagedQ4KVCache approach in this branch was an alternative that is now unnecessary. Related issues #48 (cache misses) and #39 (memory efficiency) are both closed.

Superseded by current main. Main uses mlx-lm `QuantizedKVCache` directly (via `_apply_quantized_kv`) with `BatchQuantizedKVCache` for batching, Hadamard rotation via `omlx/patches/quantized_attention.py`, and merge support in `_patched_merge_caches`. The `PagedQ4KVCache` approach in this branch was an alternative that is now unnecessary. Related issues #48 (cache misses) and #39 (memory efficiency) are both closed.
Author
Owner

Rescoped: Pre-allocated paged q4 KV cache to eliminate prefill memory spikes

Problem

Current QuantizedKVCache (mlx-lm) grows dynamically during prefill: allocate fp16 buffer → write tokens → quantize in-place. This creates transient fp16 peaks that make Q4 KV save zero peak memory vs fp16 at long contexts:

Context Q4 KV peak fp16 KV peak Savings
32K 23.55 GB 20.92 GB -2.6 GB (worse!)
128K 31.86 GB 31.59 GB -0.3 GB

The q4 active memory after cleanup IS lower (~19.7 GB flat), but the fp16 intermediate during forward pass erases the savings.

Root Cause

KVCache.update_and_fetch() in mlx-lm stores fp16, then _apply_quantized_kv() converts after the fact. The fp16 KV exists simultaneously with the quantized copy during conversion.

Solution: Pre-allocated paged q4 KV

Pre-allocate the full KV cache in q4 format at model load (or at prefill start). Write tokens directly to q4 pages during prefill — never hold fp16 KV at all.

This is what llama.cpp does (attn-rot + paged q4_0):

  • Pre-allocate paged KV cache at max context length
  • Hadamard rotation on Q/K/V (we already have this)
  • Direct q4 quantization on write (no fp16 intermediate)
  • Result: zero memory spikes, flat memory profile from first token to last

Acceptance Criteria

  • KV cache pre-allocated in q4 format before prefill starts
  • Tokens quantized directly to q4 during prefill (no fp16 intermediate)
  • Peak memory at 128K context ≤ model_size + (128K × layers × kv_bytes_per_token_q4)
  • No memory growth spikes during prefill (flat profile like llama.cpp)
  • No quality regression vs current QuantizedKVCache (PPL within 0.1%)
  • Works with continuous batching
  • Works with prefix cache (SSD save/restore of q4 pages)
  • Works with Hadamard rotation (existing attention patch)
  • mlx-lm fork QuantizedKVCache (current approach — grows dynamically)
  • llama.cpp PR #21038 (attn-rot — pre-allocated paged q4)
  • QuaRot (arXiv 2404.00456) — Hadamard rotation for KV cache
  • Benchmark data showing the peak memory gap
## Rescoped: Pre-allocated paged q4 KV cache to eliminate prefill memory spikes ### Problem Current `QuantizedKVCache` (mlx-lm) grows dynamically during prefill: allocate fp16 buffer → write tokens → quantize in-place. This creates transient fp16 peaks that make Q4 KV save **zero peak memory** vs fp16 at long contexts: | Context | Q4 KV peak | fp16 KV peak | Savings | |---------|-----------|-------------|--------| | 32K | 23.55 GB | 20.92 GB | **-2.6 GB** (worse!) | | 128K | 31.86 GB | 31.59 GB | **-0.3 GB** | The q4 active memory after cleanup IS lower (~19.7 GB flat), but the fp16 intermediate during forward pass erases the savings. ### Root Cause `KVCache.update_and_fetch()` in mlx-lm stores fp16, then `_apply_quantized_kv()` converts after the fact. The fp16 KV exists simultaneously with the quantized copy during conversion. ### Solution: Pre-allocated paged q4 KV Pre-allocate the full KV cache in q4 format at model load (or at prefill start). Write tokens directly to q4 pages during prefill — never hold fp16 KV at all. This is what llama.cpp does (`attn-rot` + paged q4_0): - Pre-allocate paged KV cache at max context length - Hadamard rotation on Q/K/V (we already have this) - Direct q4 quantization on write (no fp16 intermediate) - Result: zero memory spikes, flat memory profile from first token to last ### Acceptance Criteria - [ ] KV cache pre-allocated in q4 format before prefill starts - [ ] Tokens quantized directly to q4 during prefill (no fp16 intermediate) - [ ] Peak memory at 128K context ≤ model_size + (128K × layers × kv_bytes_per_token_q4) - [ ] No memory growth spikes during prefill (flat profile like llama.cpp) - [ ] No quality regression vs current QuantizedKVCache (PPL within 0.1%) - [ ] Works with continuous batching - [ ] Works with prefix cache (SSD save/restore of q4 pages) - [ ] Works with Hadamard rotation (existing attention patch) ### Related - mlx-lm fork `QuantizedKVCache` (current approach — grows dynamically) - llama.cpp PR #21038 (attn-rot — pre-allocated paged q4) - QuaRot (arXiv 2404.00456) — Hadamard rotation for KV cache - Benchmark data showing the peak memory gap
sleepy changed title from [feature] Paged q4 KV cache with Hadamard rotation to [perf] Pre-allocated paged q4 KV cache to eliminate prefill memory spikes 2026-05-15 23:13:53 +02:00
sleepy reopened this issue 2026-05-15 23:13:53 +02:00
Sign in to join this conversation.
No labels
bug
feature
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/omlx#51
No description provided.