[scheduler] Hybrid MTP switching for concurrent streams (#9) #11

Merged
sleepy merged 1 commit from refs/pull/11/head into main 2026-05-03 22:50:53 +02:00
sleepy commented 2026-05-03 22:32:34 +02:00 (Migrated from localhost:18431)

Summary

  • First request uses MTP speculative decoding (~1.5x speedup)
  • When second request arrives, MTP request transitions into BatchGenerator for standard batched generation
  • When batch drains to 1, automatically transitions back to MTP mode
  • Removed old MTP fast path from batched.py (was blocking event loop)

Test results

  • 183/183 unit tests passing
  • Single request MTP: 16.8 tok/s
  • Concurrent batched: 8.9 tok/s (2 requests)
  • MTP→batched transition: verified
  • Batched→MTP transition: verified

Files changed

  • omlx/scheduler.py: +591 lines (_mtp_step, transitions, init/cleanup)
  • omlx/engine/batched.py: -48 lines (removed old MTP fast path)
  • omlx/request.py: +14 lines (MTP state fields)
## Summary - First request uses MTP speculative decoding (~1.5x speedup) - When second request arrives, MTP request transitions into BatchGenerator for standard batched generation - When batch drains to 1, automatically transitions back to MTP mode - Removed old MTP fast path from batched.py (was blocking event loop) ## Test results - 183/183 unit tests passing - Single request MTP: 16.8 tok/s - Concurrent batched: 8.9 tok/s (2 requests) - MTP→batched transition: verified - Batched→MTP transition: verified ## Files changed - `omlx/scheduler.py`: +591 lines (_mtp_step, transitions, init/cleanup) - `omlx/engine/batched.py`: -48 lines (removed old MTP fast path) - `omlx/request.py`: +14 lines (MTP state fields)
sleepy commented 2026-05-03 22:45:44 +02:00 (Migrated from localhost:18431)

PR Review: CHANGES_REQUESTED

Critical Bugs

1. _mtp_prev_tokens duplication in draft logits processing

Compared against mlx-lm's mtp_generate_step reference, _mtp_prev_tokens tracking has a systematic duplication bug affecting ALL three draft generation sites:

Idle phase draft step (scheduler.py:1828-1832):
After idle phase sets _mtp_prev_tokens = [confirmed_id] (line 1754), the draft step concatenates confirmed_tok again:

tokens_for_proc = mx.concatenate([request._mtp_prev_tokens, confirmed_tok.reshape(-1)])
# = [confirmed_id, confirmed_id]  ← DUPLICATE

Reference expects: [last_prompt_tok, confirmed_id]

Has_draft accept path (scheduler.py:1965-1969):
After the loop adds bonus_tok = toks[1] to _mtp_prev_tokens (line 1877-1881), the draft step concatenates bonus_tok again:

tokens_for_proc = mx.concatenate([request._mtp_prev_tokens, bonus_tok.reshape(-1)])
# = [..., toks[1], toks[1]]  ← DUPLICATE

Has_draft reject path (scheduler.py:2051-2055):
After trimming (removes toks[1]), _mtp_prev_tokens ends with toks[0] = verify_pred. Then verify_pred is concatenated again:

tokens_for_proc = mx.concatenate([request._mtp_prev_tokens, verify_pred.reshape(-1)])
# = [..., verify_pred, verify_pred]  ← DUPLICATE

Impact: Repetition penalty is applied incorrectly (double-counted on last token), causing degraded output quality when repetition_penalty != 1.0.

Root cause: _mtp_prev_tokens is already updated with the latest token before the draft step runs, but the draft step concatenates it again. The reference mtp_generate_step avoids this because _step_mtp receives prev_tokens as an argument that doesn't include the current main_tok — it adds main_tok itself.

Fix: The draft step should use _mtp_prev_tokens directly (already contains the history up to the current point), not concatenate the confirmed/bonus/verify token again.

2. _mtp_prev_tokens missing prompt tokens

The reference tracks prev_tokens starting from the last prompt token (y[0:1]). The PR starts _mtp_prev_tokens = None and only sets it to [confirmed_id] after the first backbone step. This means:

  • First backbone step's logits processor gets mx.array([], mx.uint32) instead of [last_prompt_token]
  • Draft step misses last_prompt_token from its history

This causes repetition penalty to not penalize prompt tokens in the first generated token.

Minor Issues

3. import math / import random inside hot loop

Lines 1890-1891: These are imported on every non-greedy _mtp_step call. Move to file-level imports.

4. _specprefill_enabled field in Request (request.py:201)

This is a SpecPrefill field lumped into the MTP state block. While it may need to exist (engine_core.py references it), it's not MTP-related and shouldn't be in the MTP comment block.

5. No test coverage

Zero tests for _mtp_step, _mtp_init_request, _mtp_clear_request, _transition_to_mtp, or the MTP↔batched transition logic. At minimum, unit tests for:

  • _mtp_step idle/has_draft phases with greedy + non-greedy sampling
  • Draft accept/reject paths
  • MTP→batched and batched→MTP transitions
  • _mtp_prev_tokens tracking correctness
  • Edge cases: stop token on confirmed/draft/bonus token, max_tokens hit

Verified OK

  • _transition_to_mtp: extract_cache exists on patched BatchGenerator (verified in venv) ✓
  • cache_to_use overwrites _mtp_model_cache correctly — external prefill processes tokens 0..N-2, idle phase processes N-1 ✓
  • _mtp_step early returns are correct — idle phase always returns, no fallthrough from idle to bottom code ✓
  • n_confirmed=1 and rollback logic match reference ✓
  • Accept probability calculation matches reference (with safe min(log_accept, 0)) ✓
  • VLM requests correctly excluded from MTP path ✓
  • Think prefix handling correct (no double-add) ✓
  • No security issues ✓
  • MTP→batched transition in _schedule_waiting: BatchGenerator stays valid after 1 insert ✓

Verdict: CHANGES_REQUESTED

The _mtp_prev_tokens duplication bug is a correctness issue that will produce wrong output when logits processors are active. Must be fixed before merge.

## PR Review: CHANGES_REQUESTED ### Critical Bugs #### 1. `_mtp_prev_tokens` duplication in draft logits processing Compared against mlx-lm's `mtp_generate_step` reference, `_mtp_prev_tokens` tracking has a systematic duplication bug affecting ALL three draft generation sites: **Idle phase draft step** (scheduler.py:1828-1832): After idle phase sets `_mtp_prev_tokens = [confirmed_id]` (line 1754), the draft step concatenates `confirmed_tok` again: ```python tokens_for_proc = mx.concatenate([request._mtp_prev_tokens, confirmed_tok.reshape(-1)]) # = [confirmed_id, confirmed_id] ← DUPLICATE ``` Reference expects: `[last_prompt_tok, confirmed_id]` **Has_draft accept path** (scheduler.py:1965-1969): After the loop adds `bonus_tok = toks[1]` to `_mtp_prev_tokens` (line 1877-1881), the draft step concatenates `bonus_tok` again: ```python tokens_for_proc = mx.concatenate([request._mtp_prev_tokens, bonus_tok.reshape(-1)]) # = [..., toks[1], toks[1]] ← DUPLICATE ``` **Has_draft reject path** (scheduler.py:2051-2055): After trimming (removes toks[1]), `_mtp_prev_tokens` ends with `toks[0] = verify_pred`. Then verify_pred is concatenated again: ```python tokens_for_proc = mx.concatenate([request._mtp_prev_tokens, verify_pred.reshape(-1)]) # = [..., verify_pred, verify_pred] ← DUPLICATE ``` **Impact**: Repetition penalty is applied incorrectly (double-counted on last token), causing degraded output quality when `repetition_penalty != 1.0`. **Root cause**: `_mtp_prev_tokens` is already updated with the latest token before the draft step runs, but the draft step concatenates it again. The reference `mtp_generate_step` avoids this because `_step_mtp` receives `prev_tokens` as an argument that doesn't include the current `main_tok` — it adds `main_tok` itself. **Fix**: The draft step should use `_mtp_prev_tokens` directly (already contains the history up to the current point), not concatenate the confirmed/bonus/verify token again. #### 2. `_mtp_prev_tokens` missing prompt tokens The reference tracks `prev_tokens` starting from the last prompt token (`y[0:1]`). The PR starts `_mtp_prev_tokens = None` and only sets it to `[confirmed_id]` after the first backbone step. This means: - First backbone step's logits processor gets `mx.array([], mx.uint32)` instead of `[last_prompt_token]` - Draft step misses `last_prompt_token` from its history This causes repetition penalty to not penalize prompt tokens in the first generated token. ### Minor Issues #### 3. `import math` / `import random` inside hot loop Lines 1890-1891: These are imported on every non-greedy `_mtp_step` call. Move to file-level imports. #### 4. `_specprefill_enabled` field in Request (request.py:201) This is a SpecPrefill field lumped into the MTP state block. While it may need to exist (engine_core.py references it), it's not MTP-related and shouldn't be in the MTP comment block. #### 5. No test coverage Zero tests for `_mtp_step`, `_mtp_init_request`, `_mtp_clear_request`, `_transition_to_mtp`, or the MTP↔batched transition logic. At minimum, unit tests for: - `_mtp_step` idle/has_draft phases with greedy + non-greedy sampling - Draft accept/reject paths - MTP→batched and batched→MTP transitions - `_mtp_prev_tokens` tracking correctness - Edge cases: stop token on confirmed/draft/bonus token, max_tokens hit ### Verified OK - `_transition_to_mtp`: `extract_cache` exists on patched BatchGenerator (verified in venv) ✓ - `cache_to_use` overwrites `_mtp_model_cache` correctly — external prefill processes tokens 0..N-2, idle phase processes N-1 ✓ - `_mtp_step` early returns are correct — idle phase always returns, no fallthrough from idle to bottom code ✓ - `n_confirmed=1` and rollback logic match reference ✓ - Accept probability calculation matches reference (with safe `min(log_accept, 0)`) ✓ - VLM requests correctly excluded from MTP path ✓ - Think prefix handling correct (no double-add) ✓ - No security issues ✓ - MTP→batched transition in `_schedule_waiting`: BatchGenerator stays valid after 1 insert ✓ ### Verdict: CHANGES_REQUESTED The `_mtp_prev_tokens` duplication bug is a correctness issue that will produce wrong output when logits processors are active. Must be fixed before merge.
Sign in to join this conversation.
No reviewers
No labels
bug
feature
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/omlx!11
No description provided.