[feature] Mixed MTP + standard tokens in same batched step #10

Open
opened 2026-05-03 21:36:50 +02:00 by sleepy · 1 comment
sleepy commented 2026-05-03 21:36:50 +02:00 (Migrated from localhost:18431)

Goal

Integrate MTP speculative decoding into the scheduler batched step loop so MTP and non-MTP requests can coexist in the same batch. This is the ideal long-term architecture - currently blocked by hybrid switching (issue #9) which only runs MTP when batch_size=1.

Status

Blocked on mlx-lm fork issue #1: sleepy/mlx-lm#1

The mlx-lm BatchGenerator needs per-sequence MTP support (variable-length inputs, per-sequence cache isolation, shadow cache merge). Full spec is in the mlx-lm issue.

Why This Is Hard

  • MTP needs 2-token forward pass per sequence; batched generation uses 1 uniform pass
  • MTP requires variable-length inputs within a batch (1 or 2 tokens per sequence)
  • Requires per-sequence cache isolation (BatchKVCache has shared _idx cursor)
  • Requires per-sequence n_confirmed and return_hidden kwargs on model forward
  • GenerationBatch._step() assumes uniform [batch,1] input shape

Current State

  • Hybrid switching (issue #9) works: MTP solo, batch concurrent
  • First attempt at mixed MTP-in-batch failed review (cache corruption, double processing)
  • mlx-lm fork at https://git.kokoham.com/sleepy/mlx-lm has base MTP support (commit 97a62ef)

What Needs to Happen

  1. mlx-lm: Implement per-sequence MTP in GenerationBatch._step() (see mlx-lm issue #1 for full spec)
  2. mlx-lm: Implement shadow cache merge into BatchKVCache without breaking shared _idx
  3. omlx: Update BatchGenerator.insert() to pass mtp_enabled and mtp_caches
  4. omlx: Remove hybrid switching code (MTP→batch transition, _transition_to_mtp)
  5. omlx: Pin to new mlx-lm commit

Acceptance Criteria

  • MTP and non-MTP requests run in same batch step
  • No cache corruption (verified by batched MTP test)
  • No throughput regression for non-MTP requests
  • MTP speedup preserved (~1.5-2x for MTP sequences)
  • Clean fallback on draft rejection
  • Works with greedy and non-greedy sampling
  • Works with logits processors
## Goal Integrate MTP speculative decoding into the scheduler batched step loop so MTP and non-MTP requests can coexist in the same batch. This is the ideal long-term architecture - currently blocked by hybrid switching (issue #9) which only runs MTP when batch_size=1. ## Status **Blocked on mlx-lm fork issue #1:** https://git.kokoham.com/sleepy/mlx-lm/issues/1 The mlx-lm `BatchGenerator` needs per-sequence MTP support (variable-length inputs, per-sequence cache isolation, shadow cache merge). Full spec is in the mlx-lm issue. ## Why This Is Hard - MTP needs 2-token forward pass per sequence; batched generation uses 1 uniform pass - MTP requires variable-length inputs within a batch (1 or 2 tokens per sequence) - Requires per-sequence cache isolation (BatchKVCache has shared `_idx` cursor) - Requires per-sequence n_confirmed and return_hidden kwargs on model forward - `GenerationBatch._step()` assumes uniform [batch,1] input shape ## Current State - Hybrid switching (issue #9) works: MTP solo, batch concurrent - First attempt at mixed MTP-in-batch failed review (cache corruption, double processing) - mlx-lm fork at https://git.kokoham.com/sleepy/mlx-lm has base MTP support (commit 97a62ef) ## What Needs to Happen 1. mlx-lm: Implement per-sequence MTP in `GenerationBatch._step()` (see mlx-lm issue #1 for full spec) 2. mlx-lm: Implement shadow cache merge into `BatchKVCache` without breaking shared `_idx` 3. omlx: Update `BatchGenerator.insert()` to pass `mtp_enabled` and `mtp_caches` 4. omlx: Remove hybrid switching code (MTP→batch transition, `_transition_to_mtp`) 5. omlx: Pin to new mlx-lm commit ## Acceptance Criteria - [ ] MTP and non-MTP requests run in same batch step - [ ] No cache corruption (verified by batched MTP test) - [ ] No throughput regression for non-MTP requests - [ ] MTP speedup preserved (~1.5-2x for MTP sequences) - [ ] Clean fallback on draft rejection - [ ] Works with greedy and non-greedy sampling - [ ] Works with logits processors
Owner

Decision: Blocked — mixed mode is the architecture

After investigating the per-sequence in-batch approach (mlx-lm issue #1) and attempting implementation, the decision is to keep mixed mode as the long-term architecture.

What we have (mixed mode, already merged)

  • PR #63: Output parser integration into _mtp_step
  • PR #64: Mixed MTP + batched generation (phases 1-3)
  • MTP requests run in their own _mtp_step() with per-sequence KVCache
  • Standard requests run in BatchGenerator._step() with shared BatchKVCache

Why not in-batch (this issue)

BatchKVCache._idx is a shared scalar — rejected MTP drafts cannot be rolled back per-sequence. Fix requires per-row _idx (~200+ lines of cache.py, high risk). The throughput gain (~10-20% step overhead) does not justify the risk.

If this needs revisiting

  1. First: batch MTP requests together in a separate mini-batch forward
  2. Last resort: per-row _idx in BatchKVCache (mlx-lm issue #1)

See mlx-lm issue #1 for full analysis.

## Decision: Blocked — mixed mode is the architecture After investigating the per-sequence in-batch approach (mlx-lm issue #1) and attempting implementation, the decision is to **keep mixed mode** as the long-term architecture. ### What we have (mixed mode, already merged) - PR #63: Output parser integration into `_mtp_step` - PR #64: Mixed MTP + batched generation (phases 1-3) - MTP requests run in their own `_mtp_step()` with per-sequence `KVCache` - Standard requests run in `BatchGenerator._step()` with shared `BatchKVCache` ### Why not in-batch (this issue) `BatchKVCache._idx` is a shared scalar — rejected MTP drafts cannot be rolled back per-sequence. Fix requires per-row `_idx` (~200+ lines of cache.py, high risk). The throughput gain (~10-20% step overhead) does not justify the risk. ### If this needs revisiting 1. First: batch MTP requests together in a separate mini-batch forward 2. Last resort: per-row `_idx` in `BatchKVCache` (mlx-lm issue #1) See mlx-lm issue #1 for full analysis.
Sign in to join this conversation.
No labels
bug
feature
No milestone
No project
No assignees
2 participants
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/omlx#10
No description provided.