[feature] MTP speculative decoding with concurrent streams #9

Closed
opened 2026-05-03 21:24:42 +02:00 by sleepy · 2 comments
sleepy commented 2026-05-03 21:24:42 +02:00 (Migrated from localhost:18431)

Problem:
MTP speculative decoding only works for single-request fast path. When concurrent streams are active, requests fall back to the scheduler (non-MTP batched path), losing the ~1.5x speedup.

Current behavior:

  • Single request with MTP ON: ~36 tok/s (fast path in engine/batched.py:491-537)
  • Concurrent requests: fall back to scheduler, ~23 tok/s, MTP not used

Goal:
Investigate and implement MTP support for concurrent streams so multiple users can benefit from MTP speedup simultaneously.

Acceptance criteria:

  • MTP works with 2+ concurrent streams
  • No regression in single-stream MTP throughput
  • No regression in non-MTP batched throughput
**Problem:** MTP speculative decoding only works for single-request fast path. When concurrent streams are active, requests fall back to the scheduler (non-MTP batched path), losing the ~1.5x speedup. **Current behavior:** - Single request with MTP ON: ~36 tok/s (fast path in engine/batched.py:491-537) - Concurrent requests: fall back to scheduler, ~23 tok/s, MTP not used **Goal:** Investigate and implement MTP support for concurrent streams so multiple users can benefit from MTP speedup simultaneously. **Acceptance criteria:** - [ ] MTP works with 2+ concurrent streams - [ ] No regression in single-stream MTP throughput - [ ] No regression in non-MTP batched throughput
sleepy commented 2026-05-03 21:30:52 +02:00 (Migrated from localhost:18431)

Investigation Complete

Root cause: MTP and continuous batching have incompatible execution patterns. MTP needs 2-3 sequential forward passes with variable-length inputs, per-request rollback state, and a separate MTP KV cache. Batched generation uses 1 uniform forward pass per step with BatchKVCache. These cannot be unified without rewriting the entire pipeline.

Additional bug found: The current MTP fast path blocks the asyncio event loop (synchronous for loop in async def), preventing all other request processing during generation.

Recommended approach: Hybrid switching

  • 1 active request → MTP draft-verify loop (~1.5x speedup)
  • 1 active request → standard batched generation (all requests)

  • Batch drains to 1 → switch back to MTP

This matches real-world usage: MTP matters most for single-request interactive chat; concurrent scenarios prioritize throughput over per-request latency.

Key implementation steps:

  1. Remove MTP fast path from batched.py, route all requests through scheduler
  2. Add Scheduler._mtp_step() — single-request MTP loop on the executor thread
  3. Add MTP↔batched transition logic when concurrency changes
  4. Track MTP state (caches, draft token, phase) on Request

Approaches considered and rejected:

  • A (parallel MTP + batched): Impossible — single Metal stream, shared model weights, global mutable state
  • B (integrated mixed MTP+standard): Hard — requires variable-length batched inputs, BatchMTPCache, BatchArraysCache with rollback, per-sequence n_confirmed
  • C (time-slice exclusive): Possible but no throughput benefit — stalls all other requests during MTP
## Investigation Complete **Root cause**: MTP and continuous batching have incompatible execution patterns. MTP needs 2-3 sequential forward passes with variable-length inputs, per-request rollback state, and a separate MTP KV cache. Batched generation uses 1 uniform forward pass per step with BatchKVCache. These cannot be unified without rewriting the entire pipeline. **Additional bug found**: The current MTP fast path blocks the asyncio event loop (synchronous `for` loop in `async def`), preventing all other request processing during generation. **Recommended approach: Hybrid switching** - 1 active request → MTP draft-verify loop (~1.5x speedup) - >1 active request → standard batched generation (all requests) - Batch drains to 1 → switch back to MTP This matches real-world usage: MTP matters most for single-request interactive chat; concurrent scenarios prioritize throughput over per-request latency. **Key implementation steps**: 1. Remove MTP fast path from `batched.py`, route all requests through scheduler 2. Add `Scheduler._mtp_step()` — single-request MTP loop on the executor thread 3. Add MTP↔batched transition logic when concurrency changes 4. Track MTP state (caches, draft token, phase) on `Request` **Approaches considered and rejected**: - A (parallel MTP + batched): Impossible — single Metal stream, shared model weights, global mutable state - B (integrated mixed MTP+standard): Hard — requires variable-length batched inputs, BatchMTPCache, BatchArraysCache with rollback, per-sequence n_confirmed - C (time-slice exclusive): Possible but no throughput benefit — stalls all other requests during MTP
sleepy commented 2026-05-03 22:51:03 +02:00 (Migrated from localhost:18431)

Merged via squash (PR #11). Hybrid MTP switching implemented: single-request → MTP (~16.8 tok/s), concurrent → batched (~8.9 tok/s), auto-transitions between modes. Fix for Approach B tracked in #10.

Merged via squash (PR #11). Hybrid MTP switching implemented: single-request → MTP (~16.8 tok/s), concurrent → batched (~8.9 tok/s), auto-transitions between modes. Fix for Approach B tracked in #10.
Sign in to join this conversation.
No labels
bug
feature
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/omlx#9
No description provided.