Files
deep_pro_judge/glm5/dflash_verify/dflash.py
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

545 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
DFlash-style Tree Attention Verification for Speculative Decoding.
Pure NumPy implementation.
Convention: logits[i] predicts the next token after position i.
To verify tree_tokens[i], we check the target's prediction at the
parent's position (or P-1 for root nodes).
"""
import numpy as np
# ── Utility functions ──────────────────────────────────────────────
def softmax(x, axis=-1):
e = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e / e.sum(axis=axis, keepdims=True)
def log_softmax(x, axis=-1):
m = np.max(x, axis=axis, keepdims=True)
lse = np.log(np.sum(np.exp(x - m), axis=axis, keepdims=True))
return x - m - lse
def gelu(x):
return 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x ** 3)))
def sinusoidal_pe(max_len, d):
pe = np.zeros((max_len, d))
pos = np.arange(max_len)[:, None]
div = np.exp(np.arange(0, d, 2) * -(np.log(10000.0) / d))
pe[:, 0::2] = np.sin(pos * div)
pe[:, 1::2] = np.cos(pos * div)
return pe
# ── Model components ───────────────────────────────────────────────
class LayerNorm:
def __init__(self, d, eps=1e-5):
self.g = np.ones(d)
self.b = np.zeros(d)
self.eps = eps
def __call__(self, x):
mu = x.mean(-1, keepdims=True)
var = x.var(-1, keepdims=True)
return self.g * (x - mu) / np.sqrt(var + self.eps) + self.b
class Linear:
def __init__(self, d_in, d_out, rng):
self.w = rng.randn(d_in, d_out) * np.sqrt(2.0 / d_in)
self.b = np.zeros(d_out)
def __call__(self, x):
return x @ self.w + self.b
class TransformerBlock:
def __init__(self, d, nh, d_ff, rng):
self.nh = nh
self.dh = d // nh
self.wq = Linear(d, d, rng)
self.wk = Linear(d, d, rng)
self.wv = Linear(d, d, rng)
self.wo = Linear(d, d, rng)
self.ff1 = Linear(d, d_ff, rng)
self.ff2 = Linear(d_ff, d, rng)
self.ln1 = LayerNorm(d)
self.ln2 = LayerNorm(d)
def __call__(self, x, mask_add=None):
S = x.shape[0]
nh, dh = self.nh, self.dh
Q = self.wq(x).reshape(S, nh, dh).transpose(1, 0, 2)
K = self.wk(x).reshape(S, nh, dh).transpose(1, 0, 2)
V = self.wv(x).reshape(S, nh, dh).transpose(1, 0, 2)
scores = Q @ K.transpose(0, 2, 1) / np.sqrt(dh)
if mask_add is not None:
scores = scores + mask_add[None]
attn = softmax(scores, -1)
out = (attn @ V).transpose(1, 0, 2).reshape(S, -1)
out = self.wo(out)
x = self.ln1(x + out)
x = self.ln2(x + self.ff2(gelu(self.ff1(x))))
return x
class MinimalLM:
"""Single-layer transformer language model in pure NumPy."""
def __init__(self, vocab_size=1000, d=64, nh=4, d_ff=256, seed=42):
rng = np.random.RandomState(seed)
self.V = vocab_size
self.emb = rng.randn(vocab_size, d) * 0.02
self.pe = sinusoidal_pe(512, d)
self.block = TransformerBlock(d, nh, d_ff, rng)
self.ln_f = LayerNorm(d)
self.head = Linear(d, vocab_size, rng)
def forward(self, tokens, mask_add=None):
x = self.emb[tokens] + self.pe[:len(tokens)]
x = self.block(x, mask_add)
x = self.ln_f(x)
return self.head(x)
def greedy_generate(self, prompt, n):
toks = list(prompt)
for _ in range(n):
logits = self.forward(toks)
toks.append(int(np.argmax(logits[-1])))
return toks
# ── Mask builders ──────────────────────────────────────────────────
def build_causal_mask(L):
"""Standard causal (lower-triangular) additive attention mask."""
return np.where(np.tril(np.ones((L, L))), 0.0, -np.inf)
def build_tree_mask(P, tree_parents):
"""
Build tree attention mask for DFlash verification.
Args:
P: number of prompt tokens
tree_parents: list of parent index per tree node (-1 for roots)
Returns:
additive mask of shape (P+N, P+N) with N = len(tree_parents).
0.0 = attend, -inf = blocked.
Rules (from spec):
a) Prompt tokens attend causally to each other.
b) All tree nodes attend to ALL prompt tokens.
c) Every position attends to itself.
d) Each tree node attends to its ancestors in the tree.
e) No attendance to siblings, cousins, or other branches.
"""
N = len(tree_parents)
T = P + N
m = np.zeros((T, T), dtype=bool)
for i in range(P):
m[i, : i + 1] = True
m[P:, :P] = True
np.fill_diagonal(m, True)
for i in range(N):
a = tree_parents[i]
while a != -1:
m[P + i, P + a] = True
a = tree_parents[a]
return np.where(m, 0.0, -np.inf)
# ── Verification / acceptance ─────────────────────────────────────
def _ancestors(i, tree_parents):
out = []
c = tree_parents[i]
while c != -1:
out.append(c)
c = tree_parents[c]
return out
def verify_and_accept(prompt_tokens, tree_tokens, tree_parents, model,
temperature=0):
"""
Run one tree-verification cycle at the given temperature.
Accepted-path algorithm
───────────────────────
We follow ONE path through the tree (the one whose tokens match the
target model's greedy predictions). Processing order is topological.
* A node whose parent is the current path-end is "on the path".
* Accept on-path → extend path, continue.
* Reject on-path → emit target prediction, STOP cycle.
* Reject off-path → mark rejected (descendants skipped by rule 4a).
* Accept off-path → mark accepted (no effect on output).
* After all nodes: emit a bonus token from the last path position.
Returns list of tokens to append to the generated sequence.
"""
P = len(prompt_tokens)
N = len(tree_tokens)
full = list(prompt_tokens) + list(tree_tokens)
mask = build_tree_mask(P, tree_parents)
logits = model.forward(full, mask)
accepted = []
path_end = -1
rejected = set()
for i in range(N):
if any(a in rejected for a in _ancestors(i, tree_parents)):
rejected.add(i)
continue
parent = tree_parents[i]
logit_pos = (P - 1) if parent == -1 else (P + parent)
target_pred = int(np.argmax(logits[logit_pos]))
on_path = parent == path_end
if tree_tokens[i] == target_pred:
if on_path:
accepted.append(tree_tokens[i])
path_end = i
else:
rejected.add(i)
if on_path:
accepted.append(target_pred)
return accepted
bonus_pos = (P - 1) if path_end == -1 else (P + path_end)
accepted.append(int(np.argmax(logits[bonus_pos])))
return accepted
def _verify_detailed(prompt_tokens, tree_tokens, tree_parents, model):
"""Like verify_and_accept but returns internals for testing."""
P = len(prompt_tokens)
N = len(tree_tokens)
full = list(prompt_tokens) + list(tree_tokens)
mask = build_tree_mask(P, tree_parents)
logits = model.forward(full, mask)
accepted = []
path_end = -1
rejected = set()
skipped_by_ancestor = set()
decisions = []
for i in range(N):
anc = _ancestors(i, tree_parents)
if any(a in rejected for a in anc):
rejected.add(i)
skipped_by_ancestor.add(i)
decisions.append(("skipped_ancestor", i, anc))
continue
parent = tree_parents[i]
logit_pos = (P - 1) if parent == -1 else (P + parent)
target_pred = int(np.argmax(logits[logit_pos]))
on_path = parent == path_end
if tree_tokens[i] == target_pred:
if on_path:
accepted.append(tree_tokens[i])
path_end = i
decisions.append(("accepted_path", i, target_pred))
else:
decisions.append(("accepted_branch", i, target_pred))
else:
rejected.add(i)
if on_path:
accepted.append(target_pred)
decisions.append(("rejected_path", i, target_pred))
return accepted, rejected, skipped_by_ancestor, decisions
else:
decisions.append(("rejected_branch", i, target_pred))
bonus_pos = (P - 1) if path_end == -1 else (P + path_end)
accepted.append(int(np.argmax(logits[bonus_pos])))
return accepted, rejected, skipped_by_ancestor, decisions
def speculative_generate(model, prompt, max_new_tokens, draft_fn):
"""Full generation loop using tree speculative decoding."""
tokens = list(prompt)
gen = 0
while gen < max_new_tokens:
tt, tp = draft_fn(tokens)
if not tt:
logits = model.forward(tokens)
tokens.append(int(np.argmax(logits[-1])))
gen += 1
continue
acc = verify_and_accept(tokens, tt, tp, model)
for t in acc:
if gen >= max_new_tokens:
break
tokens.append(t)
gen += 1
return tokens
# ── Draft helpers ──────────────────────────────────────────────────
def _make_draft_fn(model, depth=2, n_wrong_branches=2):
"""Draft fn: correct main chain from target + wrong branches off node 0."""
def draft_fn(current):
chain = []
tmp = list(current)
for _ in range(depth):
logits = model.forward(tmp)
chain.append(int(np.argmax(logits[-1])))
tmp.append(chain[-1])
tt = [chain[0]]
tp = [-1]
for k in range(1, depth):
tt.append(chain[k])
tp.append(k - 1)
for w in range(n_wrong_branches):
tt.append((chain[0] + 5 + w * 7) % model.V)
tp.append(0)
return tt, tp
return draft_fn
# ── Tests ──────────────────────────────────────────────────────────
def test_tree_mask_correctness():
"""Verify tree mask structure matches spec rules ae."""
print("=" * 60)
print("TEST 0 TREE MASK CORRECTNESS")
print("=" * 60)
P = 3
tree_parents = [-1, 0, 0, 1]
mask = build_tree_mask(P, tree_parents)
T = P + len(tree_parents)
for i in range(P):
for j in range(P):
assert (mask[i, j] == 0.0) == (j <= i), \
f"Rule a) causal broken at ({i},{j})"
for i in range(P, T):
for j in range(P):
assert mask[i, j] == 0.0, \
f"Rule b) tree node {i} can't attend prompt {j}"
for i in range(T):
assert mask[i, i] == 0.0, f"Rule c) self-attention broken at {i}"
ancestors_of = {0: [], 1: [0], 2: [0], 3: [1, 0]}
for i in range(len(tree_parents)):
gi = P + i
for j in range(len(tree_parents)):
gj = P + j
expect = (j in ancestors_of[i]) or (j == i)
actual = mask[gi, gj] == 0.0
assert actual == expect, (
f"Rule d/e) node {i}->node {j}: expected={expect} got={actual}")
print(" Rules a-e verified on 4-node tree.")
print(" PASSED\n")
def test_basic():
"""Test 1 (BASIC): prompt=[10,20,30], 3 root nodes, no depth-2, temp=0.
Must match autoregressive greedy EXACTLY."""
print("=" * 60)
print("TEST 1 BASIC — 3 root nodes, temperature=0")
print("=" * 60)
model = MinimalLM(seed=42)
prompt = [10, 20, 30]
ref = model.greedy_generate(prompt, 6)
logits0 = model.forward(prompt)
t0 = int(np.argmax(logits0[-1]))
tree_tokens = [t0, (t0 + 5) % 1000, (t0 + 10) % 1000]
tree_parents = [-1, -1, -1]
acc = verify_and_accept(prompt, tree_tokens, tree_parents, model)
print(f" prompt = {prompt}")
print(f" tree_tokens = {tree_tokens}")
print(f" tree_parents = {tree_parents}")
print(f" accepted = {acc}")
print(f" autoregressive = {ref}")
assert acc == ref[len(prompt): len(prompt) + len(acc)], \
f"Single-cycle mismatch"
def draft_flat(cur):
lg = model.forward(cur)
tk = int(np.argmax(lg[-1]))
return [tk, (tk + 5) % 1000, (tk + 10) % 1000], [-1, -1, -1]
spec = speculative_generate(model, prompt, 6, draft_flat)
assert spec == ref, f"MISMATCH\n spec={spec}\n ref ={ref}"
print(f" speculative = {spec}")
print(" PASSED\n")
def test_subtree_invalidation():
"""Test 2 (SUBTREE INVALIDATION):
A depth-1 node is REJECTED, and its depth-2 child WOULD have matched
the target model's prediction, but is correctly SKIPPED by rule 4a.
Tree layout:
root0 (accepted) ── child0 (on main chain)
└─ root1 (rejected) ── child1 (would match, but skipped)
We verify:
1. child1's token matches what the target would predict via root1.
2. child1 is in the skipped_by_ancestor set.
3. Output matches autoregressive greedy.
"""
print("=" * 60)
print("TEST 2 SUBTREE INVALIDATION")
print("=" * 60)
tested_configs = []
for seed, prompt, wrong_offset in [
(42, [10, 20, 30], 5),
(99, [5, 15, 25], 7),
(7, [100, 200, 300], 13),
(314, [42], 9),
]:
model = MinimalLM(seed=seed)
P = len(prompt)
logits0 = model.forward(prompt)
t0 = int(np.argmax(logits0[-1]))
wrong_root = (t0 + wrong_offset) % model.V
logits_t0 = model.forward(prompt + [t0])
t1 = int(np.argmax(logits_t0[-1]))
dummy_tt = [t0, t1, wrong_root, 0]
dummy_tp = [-1, 0, 0, 2]
dummy_mask = build_tree_mask(P, dummy_tp)
dummy_logits = model.forward(prompt + dummy_tt, dummy_mask)
t1_given_wrong = int(np.argmax(dummy_logits[P + 2]))
tree_tokens = [t0, t1, wrong_root, t1_given_wrong]
tree_parents = [-1, 0, 0, 2]
acc, rejected, skipped, decisions = _verify_detailed(
prompt, tree_tokens, tree_parents, model)
ref = model.greedy_generate(prompt, len(acc))
assert acc == ref[P: P + len(acc)], (
f"seed={seed} output mismatch: acc={acc} ref={ref[P:]}")
assert 2 in rejected, f"seed={seed}: root1 (node 2) not rejected"
assert 3 in skipped, (
f"seed={seed}: child1 (node 3) not skipped by ancestor")
assert tree_tokens[3] == t1_given_wrong, "construction error"
parent_of_3 = tree_parents[3]
logit_pos_3 = (P - 1) if parent_of_3 == -1 else (P + parent_of_3)
would_match = tree_tokens[3] == int(np.argmax(dummy_logits[logit_pos_3]))
print(f" seed={seed:3d} prompt={prompt}")
print(f" t0={t0} wrong_root={wrong_root} t1={t1} "
f"child_of_wrong={t1_given_wrong}")
print(f" node3 would match target: {would_match}")
print(f" node3 skipped by ancestor: {3 in skipped}")
print(f" output matches autoregressive: True")
tested_configs.append(seed)
print(f"\n Tested {len(tested_configs)} configs: {tested_configs}")
print(" PASSED\n")
def test_multi_step():
"""Test 3 (MULTI-STEP): 3+ consecutive verification cycles.
Accepted tokens from cycle N become the prompt for cycle N+1."""
print("=" * 60)
print("TEST 3 MULTI-STEP (3+ verification cycles)")
print("=" * 60)
prompt = [10, 20, 30]
n_tokens = 10
for seed in [42, 7, 123, 999, 0]:
model = MinimalLM(seed=seed)
ref = model.greedy_generate(prompt, n_tokens)
spec = speculative_generate(model, prompt, n_tokens,
_make_draft_fn(model, depth=2))
assert spec == ref, (
f"seed={seed} MISMATCH\n spec={spec}\n ref ={ref}")
print(f" seed={seed:3d} match=True "
f"tokens={ref[len(prompt):len(prompt)+6]}...")
print(" PASSED\n")
def test_golden():
"""THE GOLDEN TEST: speculative == autoregressive for many configs.
At temperature=0, tree speculative decoding MUST produce EXACTLY
the same output sequence as autoregressive greedy decoding."""
print("=" * 60)
print("GOLDEN TEST")
print("=" * 60)
prompts = [[10, 20, 30], [1], [100, 200], list(range(5, 15))]
seeds = [42, 7, 123, 0, 999]
depths = [1, 2, 3]
n_configs = 0
fails = []
for seed in seeds:
model = MinimalLM(seed=seed)
for prompt in prompts:
for depth in depths:
ref = model.greedy_generate(prompt, 12)
draft_fn = _make_draft_fn(model, depth=depth,
n_wrong_branches=depth)
spec = speculative_generate(model, prompt, 12, draft_fn)
n_configs += 1
if spec != ref:
fails.append((seed, prompt[:3], depth))
if fails:
for s, p, d in fails:
print(f" FAIL seed={s} prompt={p}.. depth={d}")
assert False, f"{len(fails)}/{n_configs} configs FAILED"
else:
print(f" {n_configs} configurations: ALL PASSED")
print(" GOLDEN TEST PASSED\n")
if __name__ == "__main__":
test_tree_mask_correctness()
test_basic()
test_subtree_invalidation()
test_multi_step()
test_golden()
print("=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)