[feature] Per-sequence MTP speculative decoding in BatchGenerator #1
Loading…
Reference in a new issue
No description provided.
Delete branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Goal
Implement per-sequence MTP speculative decoding inside
BatchGenerator._step()so MTP and non-MTP requests can coexist in the same batched forward pass.Background
MTP (Multi-Token Prediction) speculative decoding uses a draft model to predict the next token, then verifies it with the main model in a 2-token forward pass. On acceptance, 2 tokens advance per step (2x speedup). On rejection, 1 token advances (no penalty beyond the extra forward pass).
Currently MTP only works when batch_size=1. The hybrid switching approach (MTP solo, batch concurrent) works but loses MTP speedup under load.
Why the First Attempt Failed
A per-sequence cache extract + copy-back approach was attempted and rejected by review:
BatchKVCacheuses a shared_idxwrite cursor. Per-sequence cache extraction and copy-back breaks this invariant - accepted sequences get duplicate KV entries, rejected sequences get cross-contamination.Architecture: Approach C - True Mixed MTP-in-Batch
Core Insight
MTP verification is a 2-token forward pass on a per-sequence basis. To make this work in batch, the model forward must accept variable-length inputs (1 token for standard sequences, 2 tokens for MTP sequences) and produce per-sequence output handling.
Required Changes
1. Variable-length input support in
GenerationBatch._step()Current:
inputsshape is(batch_size,)reshaped to(batch_size, 1)- uniform 1-token input.Required: Build a ragged input tensor where MTP sequences get 2 tokens (confirmed + draft) and standard sequences get 1 token. This requires:
inputsas a list of arrays (one per sequence) with variable lengthsmax_lengthfor the batch forward passlengthsarray to track actual token countsprepare(lengths=..., right_padding=...)call (already exists for prompt processing, needs to be wired for generation)finalize()after forward to roll back paddingReference:
PromptProcessingBatch.prompt()(line 1361-1376) already does this pattern for variable-length prompts. The same pattern needs to work in generation.2. Per-sequence cache isolation
Problem:
BatchKVCachehas a shared_idxcursor. MTP sequences need independent cache state for verification.Solution: For MTP sequences, maintain a shadow cache (per-sequence
KVCache) that:BatchKVCacheat the correct positionThe merge must NOT use
mx.copytowith shared_idx. Instead:BatchKVCache.offset[i]to determine where each sequence's tokens should gooffset[i], then incrementoffset[i](not_idx)Alternative: Use
ArraysCache-style per-sequence cache management (already used for GatedDeltaNet layers). Each sequence maintains its own cache, and the batch forward pass assembles/disassembles.3. Model forward with per-sequence kwargs
Required: Model
__call__must accept:return_hidden=True- needed for MTP draft generation (already supported)n_confirmed=N- tells model how many tokens to confirm (already supported in Qwen3.5/3.6)mtp_enabledflag - model should skip MTP layers for non-MTP sequencesCurrent model support: Qwen3.5/3.6 already handles
n_confirmedandreturn_hidden. Themtp_forwardmethod exists. No model changes needed.4. MTP speculative logic in
_step()New
_step()flow:5. BatchGenerator integration
BatchGeneratorneeds to:mtp_enabledandmtp_cachesininsert()GenerationBatch.__init__()_step()Files to Modify
mlx_lm/generate.pyGenerationBatch.__init__,_step(),filter(),extend()mlx_lm/models/cache.pyBatchKVCacheper-sequence write support, shadow cache mergemlx_lm/generate.pyBatchGenerator.insert()- accept MTP parametersAcceptance Criteria
_step()callPrerequisites
97a62ef(base MTP support for single-sequence)mtp_forwardmethodBatchKVCache.extract()method for per-sequence cache extractionTesting
Write a test that:
_step()callsReverted after review. Critical issues found:
_idxcursorThe fundamental problem: MTP verification requires 2-token forward pass per sequence, while batch assumes uniform 1-token. Per-sequence cache operations break batch cache invariants.
Viable approaches:
Decision: Blocked in favor of mixed mode (Approach A)
After investigation and a failed implementation attempt (branch
feature/1-mixed-mtp-batch, commitecd900b), per-sequence MTP in batch is blocked in favor of the mixed-mode architecture in omlx (PRs #63, #64).Why blocked
The dummy-token-padding approach has a fatal flaw: rejected drafts cannot be rolled back because
BatchKVCache._idxis a shared scalar. The rejected draft KV entries stay in the cache permanently, corrupting future attention. This is the same failure mode from the first attempt.The fix would require per-row
_idxinBatchKVCache(~200+ lines of cache.py, high risk of subtle attention bugs). The throughput gain is marginal (~10-20% step overhead savings) and only matters at high MTP concurrency (8+ simultaneous MTP requests).Current approach: mixed mode (omlx scheduler)
BatchGenerator._step()with uniform[B, 1]inputs_mtp_step()with per-sequenceKVCache(trivial rollback on rejection)When to revisit
If MTP throughput under high concurrency becomes a real bottleneck:
_idxrefactor inBatchKVCache(this issue)Reference
feature/1-mixed-mtp-batchon this repo