[feature] Mixed MTP + standard tokens in same batched step #10
Loading…
Reference in a new issue
No description provided.
Delete branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Goal
Integrate MTP speculative decoding into the scheduler batched step loop so MTP and non-MTP requests can coexist in the same batch. This is the ideal long-term architecture - currently blocked by hybrid switching (issue #9) which only runs MTP when batch_size=1.
Status
Blocked on mlx-lm fork issue #1: sleepy/mlx-lm#1
The mlx-lm
BatchGeneratorneeds per-sequence MTP support (variable-length inputs, per-sequence cache isolation, shadow cache merge). Full spec is in the mlx-lm issue.Why This Is Hard
_idxcursor)GenerationBatch._step()assumes uniform [batch,1] input shapeCurrent State
What Needs to Happen
GenerationBatch._step()(see mlx-lm issue #1 for full spec)BatchKVCachewithout breaking shared_idxBatchGenerator.insert()to passmtp_enabledandmtp_caches_transition_to_mtp)Acceptance Criteria
Decision: Blocked — mixed mode is the architecture
After investigating the per-sequence in-batch approach (mlx-lm issue #1) and attempting implementation, the decision is to keep mixed mode as the long-term architecture.
What we have (mixed mode, already merged)
_mtp_step_mtp_step()with per-sequenceKVCacheBatchGenerator._step()with sharedBatchKVCacheWhy not in-batch (this issue)
BatchKVCache._idxis a shared scalar — rejected MTP drafts cannot be rolled back per-sequence. Fix requires per-row_idx(~200+ lines of cache.py, high risk). The throughput gain (~10-20% step overhead) does not justify the risk.If this needs revisiting
_idxinBatchKVCache(mlx-lm issue #1)See mlx-lm issue #1 for full analysis.