Files
deep_pro_judge/kimi-k2.6/dflash_verify/PROMPT.md
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

148 lines
6.6 KiB
Markdown

Implement the TREE ATTENTION VERIFICATION and ACCEPTANCE/REJECTION
algorithm for DFlash-style speculative decoding, in pure NumPy.
BACKGROUND:
Speculative decoding uses a fast draft model to propose candidate tokens,
then the target model verifies them in parallel. Standard speculative
decoding uses a linear chain of candidates. DFlash uses a TREE of
candidates — each candidate token can have multiple children, forming
a tree of possible futures. The target model verifies all tree nodes
in one forward pass using a tree-structured attention mask.
SETUP:
You are given:
- A minimal target model (you write it: 1 transformer layer, ~64 dim,
1000 vocab, random weights). It's small but structurally correct.
- A draft model mock that produces a FIXED tree of tokens per step.
You don't need to implement a real draft model — just pass in the
tree tokens and structure as test input.
REQUIREMENTS:
1. TREE DATA STRUCTURE:
A tree step is defined by:
- tree_tokens: list[int] of length N — token IDs at each tree node
- tree_parents: list[int] of length N — parent index for each node
(-1 for root nodes, which are children of the last prompt token)
- tree_children: list[list[int]] — child indices for each node
Nodes are indexed 0..N-1 in topological order (a parent always
appears before its children). Root nodes are at depth 1 (their
logical "parent" is the last prompt token).
2. TREE ATTENTION MASK CONSTRUCTION:
Given P prompt tokens and N tree nodes, the full sequence for the
verification pass is [prompt_0, ..., prompt_{P-1}, tree_0, ..., tree_{N-1}].
Length = P + N.
Build a boolean attention mask M of shape (P+N, P+N) where M[i, j] = True
means position i CAN attend to position j:
RULES:
a) Prompt tokens attend causally to each other: for 0 <= i < P,
0 <= j < P: M[i, j] = (j <= i)
b) ALL tree nodes attend to ALL prompt tokens: for P <= i < P+N,
0 <= j < P: M[i, j] = True
c) Each tree node attends to ITSELF: M[i, i] = True for all i
d) A tree node attends to its ANCESTORS in the tree (transitively):
if node k is an ancestor of node i, then M[i, j] = True where
j = P + k (the global position of ancestor node k)
Find ancestors by following parent pointers to root.
e) A tree node does NOT attend to siblings, cousins, or the
descendants of other branches
Masked-out positions get score = -inf before softmax.
The mask is converted to additive form: mask_add[i, j] = 0 if allowed,
-inf if disallowed.
3. VERIFICATION FORWARD PASS:
- Concatenate prompt embeddings + tree node embeddings into a single
tensor of shape (P+N, d_model)
- Run ONE forward pass through the target model's transformer block
with the tree attention mask applied
- The model returns logits for each position in the concatenated sequence
- We only care about logits at tree node positions (indices P..P+N-1)
4. ACCEPTANCE/REJECTION SAMPLING:
For each tree node i in topological order (0..N-1):
a) If ANY ancestor of node i was REJECTED in a previous step:
→ SKIP this node and mark it as REJECTED (subtree invalidation)
→ Continue to next node
b) Get the target model's logits at position P+i
Convert to log-probabilities via log_softmax
The target's greedy prediction = argmax(log_probs)
c) The draft model proposed token = tree_tokens[i]
d) ACCEPTANCE CHECK (greedy mode, temperature=0):
If tree_tokens[i] == target_greedy_prediction:
→ ACCEPT. Keep tree_tokens[i]. Continue to children.
Else:
→ REJECT. Take target_greedy_prediction instead.
→ INVALIDATE entire subtree (all descendants of node i
will be skipped in subsequent steps due to rule 4a)
→ STOP processing further tree nodes for this cycle
(the rejected replacement token is the last accepted
token of this verification step)
CRITICAL: The subtree invalidation at step (a) is the most common bug.
Rejecting node i means ALL its descendants are invalid, even if they
would have matched the target's predictions. They were generated
conditioned on node i being correct, which turned out false.
5. FULL GENERATION LOOP:
```
generated_tokens = list(prompt)
while len(generated_tokens) < max_tokens:
# Draft model produces a tree (mocked: you pass it in)
tree_tokens, tree_parents = draft_model(generated_tokens)
# Build tree attention mask
mask = build_tree_mask(len(generated_tokens), tree_parents)
# Run target model on [generated_tokens | tree_tokens]
logits = target_model(generated_tokens + tree_tokens, mask)
# Extract logits at tree positions only
tree_logits = logits[len(generated_tokens):]
# Acceptance/rejection
accepted = accept_reject(tree_tokens, tree_parents,
tree_logits, temperature=0)
# Append accepted tokens
for token in accepted:
generated_tokens.append(token)
# If nothing accepted, fall back to target's greedy prediction
# at the last prompt position
if not accepted:
prompt_logits = target_model(generated_tokens, causal_mask)
new_token = argmax(prompt_logits[-1])
generated_tokens.append(new_token)
```
6. DELIVERABLES:
- Function build_tree_mask(prompt_len, tree_parents) → mask array (P+N, P+N)
- Function verify_and_accept(prompt_tokens, tree_tokens, tree_parents,
target_model, temperature) → (accepted_tokens, new_token)
- A MinimalLM class (or equivalent) for the target model
- Test 1 (BASIC): prompt=[10, 20, 30], tree with 3 root nodes (no depth-2),
temperature=0. Compare generated sequence against autoregressive
greedy decoding. Must match EXACTLY.
- Test 2 (SUBTREE INVALIDATION): Construct a tree where a depth-1
node is REJECTED but its depth-2 children WOULD have been accepted
(if processed independently). Verify the depth-2 children are
correctly SKIPPED and the output matches autoregressive.
- Test 3 (MULTI-STEP): Run 3 consecutive verification cycles where
accepted tokens from cycle N become the prompt for cycle N+1.
Verify the full generated sequence matches autoregressive.
THE GOLDEN TEST: for temperature=0, speculative decoding MUST produce
EXACTLY the same output sequence as autoregressive greedy decoding of
the same target model. Any deviation is a bug in the implementation.
Use only NumPy. No PyTorch, JAX, TensorFlow, or autograd.