[feature] Per-sequence MTP speculative decoding in BatchGenerator #1

Open
opened 2026-05-03 23:46:39 +02:00 by sleepy · 2 comments
sleepy commented 2026-05-03 23:46:39 +02:00 (Migrated from localhost:18431)

Goal

Implement per-sequence MTP speculative decoding inside BatchGenerator._step() so MTP and non-MTP requests can coexist in the same batched forward pass.

Background

MTP (Multi-Token Prediction) speculative decoding uses a draft model to predict the next token, then verifies it with the main model in a 2-token forward pass. On acceptance, 2 tokens advance per step (2x speedup). On rejection, 1 token advances (no penalty beyond the extra forward pass).

Currently MTP only works when batch_size=1. The hybrid switching approach (MTP solo, batch concurrent) works but loses MTP speedup under load.

Why the First Attempt Failed

A per-sequence cache extract + copy-back approach was attempted and rejected by review:

  1. Cache corruption: BatchKVCache uses a shared _idx write cursor. Per-sequence cache extraction and copy-back breaks this invariant - accepted sequences get duplicate KV entries, rejected sequences get cross-contamination.
  2. Double processing: Accepted draft tokens get KV stored twice (once in verification, once in main forward).
  3. Wrong rejection output: Rejected drafts return the confirmed token instead of the model's prediction.

Architecture: Approach C - True Mixed MTP-in-Batch

Core Insight

MTP verification is a 2-token forward pass on a per-sequence basis. To make this work in batch, the model forward must accept variable-length inputs (1 token for standard sequences, 2 tokens for MTP sequences) and produce per-sequence output handling.

Required Changes

1. Variable-length input support in GenerationBatch._step()

Current: inputs shape is (batch_size,) reshaped to (batch_size, 1) - uniform 1-token input.

Required: Build a ragged input tensor where MTP sequences get 2 tokens (confirmed + draft) and standard sequences get 1 token. This requires:

  • inputs as a list of arrays (one per sequence) with variable lengths
  • Right-padding to max_length for the batch forward pass
  • Per-sequence lengths array to track actual token counts
  • Cache prepare(lengths=..., right_padding=...) call (already exists for prompt processing, needs to be wired for generation)
  • finalize() after forward to roll back padding

Reference: PromptProcessingBatch.prompt() (line 1361-1376) already does this pattern for variable-length prompts. The same pattern needs to work in generation.

2. Per-sequence cache isolation

Problem: BatchKVCache has a shared _idx cursor. MTP sequences need independent cache state for verification.

Solution: For MTP sequences, maintain a shadow cache (per-sequence KVCache) that:

  • Is updated during the 2-token verification forward pass
  • On acceptance: merge into the main BatchKVCache at the correct position
  • On rejection: discard entirely (no main cache contamination)

The merge must NOT use mx.copyto with shared _idx. Instead:

  • Track per-sequence write positions independently
  • Use BatchKVCache.offset[i] to determine where each sequence's tokens should go
  • Write at offset[i], then increment offset[i] (not _idx)

Alternative: Use ArraysCache-style per-sequence cache management (already used for GatedDeltaNet layers). Each sequence maintains its own cache, and the batch forward pass assembles/disassembles.

3. Model forward with per-sequence kwargs

Required: Model __call__ must accept:

  • return_hidden=True - needed for MTP draft generation (already supported)
  • n_confirmed=N - tells model how many tokens to confirm (already supported in Qwen3.5/3.6)
  • Per-sequence mtp_enabled flag - model should skip MTP layers for non-MTP sequences

Current model support: Qwen3.5/3.6 already handles n_confirmed and return_hidden. The mtp_forward method exists. No model changes needed.

4. MTP speculative logic in _step()

New _step() flow:

1. For each MTP sequence in "has_draft" phase:
   a. Build 2-token input (confirmed + draft)
   b. Run verification forward with shadow cache
   c. Accept/reject based on logprob comparison
   d. On accept: merge shadow cache, advance sequence by 2 tokens
   e. On reject: discard shadow cache, advance sequence by 1 token (use model prediction)

2. Build batch input:
   - MTP sequences: 1 token (next token after verification)
   - Standard sequences: 1 token
   - Pad to uniform length

3. Run main forward pass for all sequences

4. For each MTP sequence in "idle" phase:
   a. Use hidden state from main forward
   b. Call model.mtp_forward() to generate draft
   c. Store draft token, set phase to "has_draft"

5. Return outputs (MTP sequences may have 2 tokens, standard have 1)

5. BatchGenerator integration

BatchGenerator needs to:

  • Accept mtp_enabled and mtp_caches in insert()
  • Pass them to GenerationBatch.__init__()
  • Handle variable-length outputs from _step()
  • Track per-sequence token counts (MTP sequences may advance faster)

Files to Modify

File Changes
mlx_lm/generate.py GenerationBatch.__init__, _step(), filter(), extend()
mlx_lm/models/cache.py BatchKVCache per-sequence write support, shadow cache merge
mlx_lm/generate.py BatchGenerator.insert() - accept MTP parameters

Acceptance Criteria

  • MTP and non-MTP sequences run in same _step() call
  • No cache corruption (verified by running batch with all MTP sequences)
  • No throughput regression for non-MTP sequences
  • MTP speedup preserved (~1.5-2x for MTP sequences)
  • Clean fallback on draft rejection
  • Works with greedy and non-greedy sampling
  • Works with logits processors

Prerequisites

  • mlx-lm commit 97a62ef (base MTP support for single-sequence)
  • Qwen3.5/3.6 model with mtp_forward method
  • BatchKVCache.extract() method for per-sequence cache extraction

Testing

Write a test that:

  1. Creates a batch with 2 MTP sequences and 2 standard sequences
  2. Runs 100 _step() calls
  3. Verifies: no cache corruption, correct token outputs, MTP sequences advance faster
  4. Compares output tokens against non-batched MTP generation (should be identical)
## Goal Implement per-sequence MTP speculative decoding inside `BatchGenerator._step()` so MTP and non-MTP requests can coexist in the same batched forward pass. ## Background MTP (Multi-Token Prediction) speculative decoding uses a draft model to predict the next token, then verifies it with the main model in a 2-token forward pass. On acceptance, 2 tokens advance per step (2x speedup). On rejection, 1 token advances (no penalty beyond the extra forward pass). Currently MTP only works when batch_size=1. The hybrid switching approach (MTP solo, batch concurrent) works but loses MTP speedup under load. ## Why the First Attempt Failed A per-sequence cache extract + copy-back approach was attempted and rejected by review: 1. **Cache corruption:** `BatchKVCache` uses a shared `_idx` write cursor. Per-sequence cache extraction and copy-back breaks this invariant - accepted sequences get duplicate KV entries, rejected sequences get cross-contamination. 2. **Double processing:** Accepted draft tokens get KV stored twice (once in verification, once in main forward). 3. **Wrong rejection output:** Rejected drafts return the confirmed token instead of the model's prediction. ## Architecture: Approach C - True Mixed MTP-in-Batch ### Core Insight MTP verification is a **2-token forward pass** on a per-sequence basis. To make this work in batch, the model forward must accept **variable-length inputs** (1 token for standard sequences, 2 tokens for MTP sequences) and produce **per-sequence output handling**. ### Required Changes #### 1. Variable-length input support in `GenerationBatch._step()` **Current:** `inputs` shape is `(batch_size,)` reshaped to `(batch_size, 1)` - uniform 1-token input. **Required:** Build a ragged input tensor where MTP sequences get 2 tokens (confirmed + draft) and standard sequences get 1 token. This requires: - `inputs` as a list of arrays (one per sequence) with variable lengths - Right-padding to `max_length` for the batch forward pass - Per-sequence `lengths` array to track actual token counts - Cache `prepare(lengths=..., right_padding=...)` call (already exists for prompt processing, needs to be wired for generation) - `finalize()` after forward to roll back padding **Reference:** `PromptProcessingBatch.prompt()` (line 1361-1376) already does this pattern for variable-length prompts. The same pattern needs to work in generation. #### 2. Per-sequence cache isolation **Problem:** `BatchKVCache` has a shared `_idx` cursor. MTP sequences need independent cache state for verification. **Solution:** For MTP sequences, maintain a **shadow cache** (per-sequence `KVCache`) that: - Is updated during the 2-token verification forward pass - On acceptance: merge into the main `BatchKVCache` at the correct position - On rejection: discard entirely (no main cache contamination) The merge must NOT use `mx.copyto` with shared `_idx`. Instead: - Track per-sequence write positions independently - Use `BatchKVCache.offset[i]` to determine where each sequence's tokens should go - Write at `offset[i]`, then increment `offset[i]` (not `_idx`) **Alternative:** Use `ArraysCache`-style per-sequence cache management (already used for GatedDeltaNet layers). Each sequence maintains its own cache, and the batch forward pass assembles/disassembles. #### 3. Model forward with per-sequence kwargs **Required:** Model `__call__` must accept: - `return_hidden=True` - needed for MTP draft generation (already supported) - `n_confirmed=N` - tells model how many tokens to confirm (already supported in Qwen3.5/3.6) - Per-sequence `mtp_enabled` flag - model should skip MTP layers for non-MTP sequences **Current model support:** Qwen3.5/3.6 already handles `n_confirmed` and `return_hidden`. The `mtp_forward` method exists. No model changes needed. #### 4. MTP speculative logic in `_step()` New `_step()` flow: ``` 1. For each MTP sequence in "has_draft" phase: a. Build 2-token input (confirmed + draft) b. Run verification forward with shadow cache c. Accept/reject based on logprob comparison d. On accept: merge shadow cache, advance sequence by 2 tokens e. On reject: discard shadow cache, advance sequence by 1 token (use model prediction) 2. Build batch input: - MTP sequences: 1 token (next token after verification) - Standard sequences: 1 token - Pad to uniform length 3. Run main forward pass for all sequences 4. For each MTP sequence in "idle" phase: a. Use hidden state from main forward b. Call model.mtp_forward() to generate draft c. Store draft token, set phase to "has_draft" 5. Return outputs (MTP sequences may have 2 tokens, standard have 1) ``` #### 5. BatchGenerator integration `BatchGenerator` needs to: - Accept `mtp_enabled` and `mtp_caches` in `insert()` - Pass them to `GenerationBatch.__init__()` - Handle variable-length outputs from `_step()` - Track per-sequence token counts (MTP sequences may advance faster) ### Files to Modify | File | Changes | |------|---------| | `mlx_lm/generate.py` | `GenerationBatch.__init__`, `_step()`, `filter()`, `extend()` | | `mlx_lm/models/cache.py` | `BatchKVCache` per-sequence write support, shadow cache merge | | `mlx_lm/generate.py` | `BatchGenerator.insert()` - accept MTP parameters | ### Acceptance Criteria - [ ] MTP and non-MTP sequences run in same `_step()` call - [ ] No cache corruption (verified by running batch with all MTP sequences) - [ ] No throughput regression for non-MTP sequences - [ ] MTP speedup preserved (~1.5-2x for MTP sequences) - [ ] Clean fallback on draft rejection - [ ] Works with greedy and non-greedy sampling - [ ] Works with logits processors ### Prerequisites - mlx-lm commit `97a62ef` (base MTP support for single-sequence) - Qwen3.5/3.6 model with `mtp_forward` method - `BatchKVCache.extract()` method for per-sequence cache extraction ### Testing Write a test that: 1. Creates a batch with 2 MTP sequences and 2 standard sequences 2. Runs 100 `_step()` calls 3. Verifies: no cache corruption, correct token outputs, MTP sequences advance faster 4. Compares output tokens against non-batched MTP generation (should be identical)
sleepy commented 2026-05-04 00:39:38 +02:00 (Migrated from localhost:18431)

Reverted after review. Critical issues found:

  1. Cache corruption: per-sequence cache extract/copy-back incompatible with BatchKVCache shared _idx cursor
  2. Accepted sequences double-processed: draft KV stored twice, offset inflated
  3. Rejection output wrong: returns confirmed token instead of model prediction

The fundamental problem: MTP verification requires 2-token forward pass per sequence, while batch assumes uniform 1-token. Per-sequence cache operations break batch cache invariants.

Viable approaches:

  • A) Keep hybrid switching (MTP-only when batch=1) — current working approach
  • B) Run MTP and batch sequentially in omlx — works but adds overhead (no net speedup)
  • C) True mixed: require model to handle variable-length inputs (1 or 2 tokens per seq) with per-sequence cache isolation — major mlx-lm architecture change
Reverted after review. Critical issues found: 1. Cache corruption: per-sequence cache extract/copy-back incompatible with BatchKVCache shared `_idx` cursor 2. Accepted sequences double-processed: draft KV stored twice, offset inflated 3. Rejection output wrong: returns confirmed token instead of model prediction The fundamental problem: MTP verification requires 2-token forward pass per sequence, while batch assumes uniform 1-token. Per-sequence cache operations break batch cache invariants. Viable approaches: - A) Keep hybrid switching (MTP-only when batch=1) — current working approach - B) Run MTP and batch sequentially in omlx — works but adds overhead (no net speedup) - C) True mixed: require model to handle variable-length inputs (1 or 2 tokens per seq) with per-sequence cache isolation — major mlx-lm architecture change
Owner

Decision: Blocked in favor of mixed mode (Approach A)

After investigation and a failed implementation attempt (branch feature/1-mixed-mtp-batch, commit ecd900b), per-sequence MTP in batch is blocked in favor of the mixed-mode architecture in omlx (PRs #63, #64).

Why blocked

The dummy-token-padding approach has a fatal flaw: rejected drafts cannot be rolled back because BatchKVCache._idx is a shared scalar. The rejected draft KV entries stay in the cache permanently, corrupting future attention. This is the same failure mode from the first attempt.

The fix would require per-row _idx in BatchKVCache (~200+ lines of cache.py, high risk of subtle attention bugs). The throughput gain is marginal (~10-20% step overhead savings) and only matters at high MTP concurrency (8+ simultaneous MTP requests).

Current approach: mixed mode (omlx scheduler)

  • Standard requests: BatchGenerator._step() with uniform [B, 1] inputs
  • MTP requests: separate _mtp_step() with per-sequence KVCache (trivial rollback on rejection)
  • ~10-20% step overhead from the extra small forward pass
  • Zero cache.py changes

When to revisit

If MTP throughput under high concurrency becomes a real bottleneck:

  1. First try: batch MTP requests together in a separate mini-batch forward (no cache.py changes)
  2. Last resort: per-row _idx refactor in BatchKVCache (this issue)

Reference

  • Architecture comparison: sleepy/omlx#10
  • Failed attempt: branch feature/1-mixed-mtp-batch on this repo
## Decision: Blocked in favor of mixed mode (Approach A) After investigation and a failed implementation attempt (branch `feature/1-mixed-mtp-batch`, commit `ecd900b`), per-sequence MTP in batch is **blocked** in favor of the mixed-mode architecture in omlx (PRs #63, #64). ### Why blocked The dummy-token-padding approach has a fatal flaw: rejected drafts cannot be rolled back because `BatchKVCache._idx` is a shared scalar. The rejected draft KV entries stay in the cache permanently, corrupting future attention. This is the same failure mode from the first attempt. The fix would require per-row `_idx` in `BatchKVCache` (~200+ lines of cache.py, high risk of subtle attention bugs). The throughput gain is marginal (~10-20% step overhead savings) and only matters at high MTP concurrency (8+ simultaneous MTP requests). ### Current approach: mixed mode (omlx scheduler) - Standard requests: `BatchGenerator._step()` with uniform `[B, 1]` inputs - MTP requests: separate `_mtp_step()` with per-sequence `KVCache` (trivial rollback on rejection) - ~10-20% step overhead from the extra small forward pass - Zero cache.py changes ### When to revisit If MTP throughput under high concurrency becomes a real bottleneck: 1. First try: batch MTP requests together in a separate mini-batch forward (no cache.py changes) 2. Last resort: per-row `_idx` refactor in `BatchKVCache` (this issue) ### Reference - Architecture comparison: https://git.kokoham.com/sleepy/omlx/issues/10 - Failed attempt: branch `feature/1-mixed-mtp-batch` on this repo
Sign in to join this conversation.
No labels
feature
perf
refactor
No milestone
No project
No assignees
2 participants
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/mlx-lm#1
No description provided.