feat: Multi-token prediction (MTP) speculative decoding #48

Open
opened 2026-05-15 19:09:57 +02:00 by sleepy · 0 comments
Owner

Overview

Qwen3.5-4B has native MTP support (mtp_num_hidden_layers: 1 in original config). Implement speculative decoding using the MTP head to predict draft tokens, then verify against the main model.

BLOCKED until coherent 37+ tok/s baseline is reached.

Current State

  • Config: original had mtp_num_hidden_layers: 1, mtp_use_dedicated_embeddings: false. Currently disabled (set to 0) in active config.
  • Weights missing: The MLX-bf16 conversion stripped the MTP tensors. Need to re-download the original model or convert with MTP weights preserved.
  • Codebase has scaffold: config.zig parses mtp_num_hidden_layers, mtp.zig has MTPHead struct (CPU stub), mtp_engine.zig returns error.Unimplemented, mtp.metal has placeholder kernel, model.zig conditionally initializes but never loads weights.

Implementation Plan

Phase 1: Weight loading

  1. Obtain model with MTP weights (re-download original Qwen3.5-4B or re-convert with MTP tensors preserved)
  2. Add weight map entries for MTP tensors to weight_map.zig
  3. Wire up weight loading in weight_loader.zig
  4. Restore mtp_num_hidden_layers: 1 in config

Phase 2: GPU MTP forward pass

  1. Port MTPHead.forward() to GPU BF16 (reuse existing matmul + RMS norm kernels)
  2. MTP head is lightweight: just hidden to hidden projection + shared or trained embed to output logits
  3. Expected cost: ~2ms per draft token (one matmul + norm vs full 32-layer decode at 42ms)

Phase 3: Speculative decoding engine

  1. Implement MTPEngine.draft_tokens(): run MTP head on last hidden state to produce N draft tokens
  2. Implement MTPEngine.verify_tokens(): run main model forward on draft tokens, compare logits
  3. Accept matching tokens, reject at first mismatch, re-encode from last accepted token
  4. Integrate into engine.zig decode loop

Phase 4: GPU pipelining

  1. Overlap MTP draft generation with main model verification (double-buffered command buffers)
  2. Batch verification: run main model on draft+1 tokens in parallel

Expected Speedup

With 1 MTP head producing 1 draft token:

  • Draft cost: ~2ms (single matmul)
  • Verify cost: ~42ms (full decode, same as baseline)
  • If draft is correct (~70pct acceptance for well-trained MTP): effective speed = 2 tokens / 44ms = ~45 tok/s
  • With 5 MTP heads: potentially 3-4 tokens accepted per step = 60-80 tok/s

Acceptance Criteria

  • MTP weights load correctly from safetensors
  • GPU MTP forward pass produces valid logits
  • Speculative decoding produces identical output to greedy decoding
  • Effective throughput is at least 1.5x baseline (37 tok/s to 55+ tok/s)

Dependencies

  • Issue 34 (37+ tok/s baseline) must be met first
  • Issue 41 (GPU argmax) recommended for efficient draft verification
  • Issue 47 (pipeline decode) recommended for overlapping draft and verify
## Overview Qwen3.5-4B has native MTP support (mtp_num_hidden_layers: 1 in original config). Implement speculative decoding using the MTP head to predict draft tokens, then verify against the main model. BLOCKED until coherent 37+ tok/s baseline is reached. ## Current State - Config: original had mtp_num_hidden_layers: 1, mtp_use_dedicated_embeddings: false. Currently disabled (set to 0) in active config. - Weights missing: The MLX-bf16 conversion stripped the MTP tensors. Need to re-download the original model or convert with MTP weights preserved. - Codebase has scaffold: config.zig parses mtp_num_hidden_layers, mtp.zig has MTPHead struct (CPU stub), mtp_engine.zig returns error.Unimplemented, mtp.metal has placeholder kernel, model.zig conditionally initializes but never loads weights. ## Implementation Plan ### Phase 1: Weight loading 1. Obtain model with MTP weights (re-download original Qwen3.5-4B or re-convert with MTP tensors preserved) 2. Add weight map entries for MTP tensors to weight_map.zig 3. Wire up weight loading in weight_loader.zig 4. Restore mtp_num_hidden_layers: 1 in config ### Phase 2: GPU MTP forward pass 1. Port MTPHead.forward() to GPU BF16 (reuse existing matmul + RMS norm kernels) 2. MTP head is lightweight: just hidden to hidden projection + shared or trained embed to output logits 3. Expected cost: ~2ms per draft token (one matmul + norm vs full 32-layer decode at 42ms) ### Phase 3: Speculative decoding engine 1. Implement MTPEngine.draft_tokens(): run MTP head on last hidden state to produce N draft tokens 2. Implement MTPEngine.verify_tokens(): run main model forward on draft tokens, compare logits 3. Accept matching tokens, reject at first mismatch, re-encode from last accepted token 4. Integrate into engine.zig decode loop ### Phase 4: GPU pipelining 1. Overlap MTP draft generation with main model verification (double-buffered command buffers) 2. Batch verification: run main model on draft+1 tokens in parallel ## Expected Speedup With 1 MTP head producing 1 draft token: - Draft cost: ~2ms (single matmul) - Verify cost: ~42ms (full decode, same as baseline) - If draft is correct (~70pct acceptance for well-trained MTP): effective speed = 2 tokens / 44ms = ~45 tok/s - With 5 MTP heads: potentially 3-4 tokens accepted per step = 60-80 tok/s ## Acceptance Criteria - MTP weights load correctly from safetensors - GPU MTP forward pass produces valid logits - Speculative decoding produces identical output to greedy decoding - Effective throughput is at least 1.5x baseline (37 tok/s to 55+ tok/s) ## Dependencies - Issue 34 (37+ tok/s baseline) must be met first - Issue 41 (GPU argmax) recommended for efficient draft verification - Issue 47 (pipeline decode) recommended for overlapping draft and verify
Sign in to join this conversation.
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/sleepy-llm#48
No description provided.