[refactor] Replace trim-then-concat with geometric growth in KVCache/BatchKVCache #3
Loading…
Reference in a new issue
No description provided.
Delete branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Problem
When
max_kv_size=None(unbounded mode),KVCacheandBatchKVCachegrow via trim-then-concat (cache.py:344-348):This creates 4 temporary copies per growth step per layer. For 36-layer models, this causes multi-GB transient spikes every 256 tokens.
Additionally, the fixed 256-token step size means frequent growth steps. Each step requires allocation + copy, which is expensive.
Solution
Replace with geometric (2x) growth:
This gives amortized O(1) copies and dramatically reduces spike frequency.
Growth pattern comparison
At 32K tokens: current has 128 spikes, geometric has 7.
Affected classes
KVCache(cache.py:325)BatchKVCache(cache.py:914)QuantizedKVCache(cache.py:232) andChunkedKVCache(cache.py:733) — check if they share the same patternRequired changes
Replace the growth logic in
KVCache.update_and_fetch():max(current * 2, prev + keys.shape[2])Same for
BatchKVCache.update_and_fetch()(lines 944-961)Ensure
nbytesproperty still reports accurately (it already reports full allocation, which is fine)Verify
stateproperty works correctly with non-step-aligned sizesAcceptance criteria
Caution
offset,_idx, andstateall remain consistent after growthstatesetter (line 371-373 for KVCache, line 1000-1002 for BatchKVCache) setsoffset = keys.shape[2]— this must still work with non-256-aligned sizestrim()method must still work correctlyMerged via PR #7 (squash). KVCache, BatchKVCache, QuantizedKVCache, and ChunkedKVCache now use geometric (2x) growth instead of fixed +256 trim-then-concat. Growth spikes reduced from every 256 tokens to O(log n). 54/54 tests passed.