[refactor] Replace trim-then-concat with geometric growth in KVCache/BatchKVCache #3

Closed
opened 2026-05-15 19:51:58 +02:00 by sleepy · 1 comment
Owner

Problem

When max_kv_size=None (unbounded mode), KVCache and BatchKVCache grow via trim-then-concat (cache.py:344-348):

if prev % self.step != 0:
    self.keys = self.keys[..., :prev, :]       # copy 1
    self.values = self.values[..., :prev, :]   # copy 2
self.keys = mx.concatenate([self.keys, new_k], axis=2)    # copy 3
self.values = mx.concatenate([self.values, new_v], axis=2) # copy 4

This creates 4 temporary copies per growth step per layer. For 36-layer models, this causes multi-GB transient spikes every 256 tokens.

Additionally, the fixed 256-token step size means frequent growth steps. Each step requires allocation + copy, which is expensive.

Solution

Replace with geometric (2x) growth:

  • Start with 256 columns
  • When full, allocate 2x the current capacity (not current + 256)
  • Copy old data into new buffer in a single operation
  • No trim step needed (always copy full buffer)

This gives amortized O(1) copies and dramatically reduces spike frequency.

Growth pattern comparison

Tokens Current (256-step) Geometric (2x)
0-256 alloc 256 alloc 256
257 alloc 512 (spike) alloc 512 (spike)
513 alloc 768 (spike) alloc 1024 (spike)
769 alloc 1024 (spike) — no spike —
1025 alloc 1280 (spike) alloc 2048 (spike)
... spike every 256 spike at 2048, 4096, 8192, ...

At 32K tokens: current has 128 spikes, geometric has 7.

Affected classes

  • KVCache (cache.py:325)
  • BatchKVCache (cache.py:914)
  • Possibly QuantizedKVCache (cache.py:232) and ChunkedKVCache (cache.py:733) — check if they share the same pattern

Required changes

  1. Replace the growth logic in KVCache.update_and_fetch():

    • Remove trim step (lines 344-346)
    • Calculate new capacity as max(current * 2, prev + keys.shape[2])
    • Allocate new buffer, copy old data, swap
  2. Same for BatchKVCache.update_and_fetch() (lines 944-961)

  3. Ensure nbytes property still reports accurately (it already reports full allocation, which is fine)

  4. Verify state property works correctly with non-step-aligned sizes

Acceptance criteria

  • KVCache grows geometrically (2x), not linearly (+256)
  • BatchKVCache grows geometrically
  • No trim step in growth path
  • Spike frequency reduced from every 256 tokens to log2(context_length)
  • Output tokens are identical to old behavior (no correctness regression)
  • All existing tests pass
  • New test: verify growth pattern (e.g., at 257 tokens, capacity is 512, not 512)
  • New test: verify no correctness regression at growth boundaries (tokens 255-258)

Caution

  • This is the most delicate change — it affects the core cache invariant
  • Verify offset, _idx, and state all remain consistent after growth
  • The state setter (line 371-373 for KVCache, line 1000-1002 for BatchKVCache) sets offset = keys.shape[2] — this must still work with non-256-aligned sizes
  • The trim() method must still work correctly
  • Test with both small (1-token) and large (2048-token prefill) inputs
  • Test edge case: growth triggered during prefill (multi-token input that exceeds capacity by more than the growth quantum)
## Problem When `max_kv_size=None` (unbounded mode), `KVCache` and `BatchKVCache` grow via trim-then-concat (cache.py:344-348): ```python if prev % self.step != 0: self.keys = self.keys[..., :prev, :] # copy 1 self.values = self.values[..., :prev, :] # copy 2 self.keys = mx.concatenate([self.keys, new_k], axis=2) # copy 3 self.values = mx.concatenate([self.values, new_v], axis=2) # copy 4 ``` This creates 4 temporary copies per growth step per layer. For 36-layer models, this causes multi-GB transient spikes every 256 tokens. Additionally, the fixed 256-token step size means frequent growth steps. Each step requires allocation + copy, which is expensive. ## Solution Replace with geometric (2x) growth: - Start with 256 columns - When full, allocate 2x the current capacity (not current + 256) - Copy old data into new buffer in a single operation - No trim step needed (always copy full buffer) This gives amortized O(1) copies and dramatically reduces spike frequency. ### Growth pattern comparison | Tokens | Current (256-step) | Geometric (2x) | |--------|-------------------|----------------| | 0-256 | alloc 256 | alloc 256 | | 257 | alloc 512 (spike) | alloc 512 (spike) | | 513 | alloc 768 (spike) | alloc 1024 (spike) | | 769 | alloc 1024 (spike)| — no spike — | | 1025 | alloc 1280 (spike)| alloc 2048 (spike) | | ... | spike every 256 | spike at 2048, 4096, 8192, ... | At 32K tokens: current has 128 spikes, geometric has 7. ### Affected classes - `KVCache` (cache.py:325) - `BatchKVCache` (cache.py:914) - Possibly `QuantizedKVCache` (cache.py:232) and `ChunkedKVCache` (cache.py:733) — check if they share the same pattern ### Required changes 1. Replace the growth logic in `KVCache.update_and_fetch()`: - Remove trim step (lines 344-346) - Calculate new capacity as `max(current * 2, prev + keys.shape[2])` - Allocate new buffer, copy old data, swap 2. Same for `BatchKVCache.update_and_fetch()` (lines 944-961) 3. Ensure `nbytes` property still reports accurately (it already reports full allocation, which is fine) 4. Verify `state` property works correctly with non-step-aligned sizes ### Acceptance criteria - [ ] KVCache grows geometrically (2x), not linearly (+256) - [ ] BatchKVCache grows geometrically - [ ] No trim step in growth path - [ ] Spike frequency reduced from every 256 tokens to log2(context_length) - [ ] Output tokens are identical to old behavior (no correctness regression) - [ ] All existing tests pass - [ ] New test: verify growth pattern (e.g., at 257 tokens, capacity is 512, not 512) - [ ] New test: verify no correctness regression at growth boundaries (tokens 255-258) ### Caution - This is the most delicate change — it affects the core cache invariant - Verify `offset`, `_idx`, and `state` all remain consistent after growth - The `state` setter (line 371-373 for KVCache, line 1000-1002 for BatchKVCache) sets `offset = keys.shape[2]` — this must still work with non-256-aligned sizes - The `trim()` method must still work correctly - Test with both small (1-token) and large (2048-token prefill) inputs - Test edge case: growth triggered during prefill (multi-token input that exceeds capacity by more than the growth quantum)
Author
Owner

Merged via PR #7 (squash). KVCache, BatchKVCache, QuantizedKVCache, and ChunkedKVCache now use geometric (2x) growth instead of fixed +256 trim-then-concat. Growth spikes reduced from every 256 tokens to O(log n). 54/54 tests passed.

Merged via PR #7 (squash). KVCache, BatchKVCache, QuantizedKVCache, and ChunkedKVCache now use geometric (2x) growth instead of fixed +256 trim-then-concat. Growth spikes reduced from every 256 tokens to O(log n). 54/54 tests passed.
Sign in to join this conversation.
No labels
feature
perf
refactor
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/mlx-lm#3
No description provided.