[scheduler] Hybrid MTP switching for concurrent streams (#9) #11
Loading…
Reference in a new issue
No description provided.
Delete branch "refs/pull/11/head"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Summary
Test results
Files changed
omlx/scheduler.py: +591 lines (_mtp_step, transitions, init/cleanup)omlx/engine/batched.py: -48 lines (removed old MTP fast path)omlx/request.py: +14 lines (MTP state fields)PR Review: CHANGES_REQUESTED
Critical Bugs
1.
_mtp_prev_tokensduplication in draft logits processingCompared against mlx-lm's
mtp_generate_stepreference,_mtp_prev_tokenstracking has a systematic duplication bug affecting ALL three draft generation sites:Idle phase draft step (scheduler.py:1828-1832):
After idle phase sets
_mtp_prev_tokens = [confirmed_id](line 1754), the draft step concatenatesconfirmed_tokagain:Reference expects:
[last_prompt_tok, confirmed_id]Has_draft accept path (scheduler.py:1965-1969):
After the loop adds
bonus_tok = toks[1]to_mtp_prev_tokens(line 1877-1881), the draft step concatenatesbonus_tokagain:Has_draft reject path (scheduler.py:2051-2055):
After trimming (removes toks[1]),
_mtp_prev_tokensends withtoks[0] = verify_pred. Then verify_pred is concatenated again:Impact: Repetition penalty is applied incorrectly (double-counted on last token), causing degraded output quality when
repetition_penalty != 1.0.Root cause:
_mtp_prev_tokensis already updated with the latest token before the draft step runs, but the draft step concatenates it again. The referencemtp_generate_stepavoids this because_step_mtpreceivesprev_tokensas an argument that doesn't include the currentmain_tok— it addsmain_tokitself.Fix: The draft step should use
_mtp_prev_tokensdirectly (already contains the history up to the current point), not concatenate the confirmed/bonus/verify token again.2.
_mtp_prev_tokensmissing prompt tokensThe reference tracks
prev_tokensstarting from the last prompt token (y[0:1]). The PR starts_mtp_prev_tokens = Noneand only sets it to[confirmed_id]after the first backbone step. This means:mx.array([], mx.uint32)instead of[last_prompt_token]last_prompt_tokenfrom its historyThis causes repetition penalty to not penalize prompt tokens in the first generated token.
Minor Issues
3.
import math/import randominside hot loopLines 1890-1891: These are imported on every non-greedy
_mtp_stepcall. Move to file-level imports.4.
_specprefill_enabledfield in Request (request.py:201)This is a SpecPrefill field lumped into the MTP state block. While it may need to exist (engine_core.py references it), it's not MTP-related and shouldn't be in the MTP comment block.
5. No test coverage
Zero tests for
_mtp_step,_mtp_init_request,_mtp_clear_request,_transition_to_mtp, or the MTP↔batched transition logic. At minimum, unit tests for:_mtp_stepidle/has_draft phases with greedy + non-greedy sampling_mtp_prev_tokenstracking correctnessVerified OK
_transition_to_mtp:extract_cacheexists on patched BatchGenerator (verified in venv) ✓cache_to_useoverwrites_mtp_model_cachecorrectly — external prefill processes tokens 0..N-2, idle phase processes N-1 ✓_mtp_stepearly returns are correct — idle phase always returns, no fallthrough from idle to bottom code ✓n_confirmed=1and rollback logic match reference ✓min(log_accept, 0)) ✓_schedule_waiting: BatchGenerator stays valid after 1 insert ✓Verdict: CHANGES_REQUESTED
The
_mtp_prev_tokensduplication bug is a correctness issue that will produce wrong output when logits processors are active. Must be fixed before merge.