Files
deep_pro_judge/opus47_1m/beam_search/beam_search.py
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

330 lines
12 KiB
Python

"""Batched beam search decoder for autoregressive generation in pure NumPy."""
import numpy as np
def log_softmax(x, axis=-1):
m = np.max(x, axis=axis, keepdims=True)
shifted = x - m
return shifted - np.log(np.exp(shifted).sum(axis=axis, keepdims=True))
class TinyLM:
"""Random-weight 1-block transformer. Correctness of decoding is the
point — the model itself produces meaningless logits."""
def __init__(self, vocab_size=1000, d_model=64, seed=0):
rng = np.random.default_rng(seed)
self.vocab_size = vocab_size
self.d_model = d_model
s = 1.0 / np.sqrt(d_model)
self.embed = rng.standard_normal((vocab_size, d_model)) * s
self.Wq = rng.standard_normal((d_model, d_model)) * s
self.Wk = rng.standard_normal((d_model, d_model)) * s
self.Wv = rng.standard_normal((d_model, d_model)) * s
self.Wo = rng.standard_normal((d_model, d_model)) * s
self.W1 = rng.standard_normal((d_model, 4 * d_model)) * s
self.W2 = rng.standard_normal((4 * d_model, d_model)) * s
self.lm_head = rng.standard_normal((d_model, vocab_size)) * s
def forward(self, token_ids):
# token_ids: (N, T) -> last-position logits (N, V)
x = self.embed[token_ids]
Q = x @ self.Wq
K = x @ self.Wk
V = x @ self.Wv
scores = Q @ K.transpose(0, 2, 1) / np.sqrt(self.d_model)
T = scores.shape[-1]
mask = np.triu(np.ones((T, T), dtype=bool), k=1)
scores = np.where(mask, -1e9, scores)
attn = np.exp(scores - scores.max(-1, keepdims=True))
attn = attn / attn.sum(-1, keepdims=True)
h = (attn @ V) @ self.Wo
x = x + h
h2 = np.maximum(0, x @ self.W1) @ self.W2
x = x + h2
return x[:, -1, :] @ self.lm_head
def batched_beam_search(model, prompts, beam_width, max_new_tokens,
alpha=0.6, eos_token_id=0):
"""Beam search over multiple prompts, returning K best generations each.
Returns: list of length B; each element is a list of up to K dicts
{tokens, score, logprob, finished} sorted by length-penalized
score descending.
Why finished beams are NOT removed from the pool:
A beam that hits EOS early may have a high length-penalized score
(its short length means a small denominator). If we drop it from
the candidate pool the moment it finishes, an unfinished beam
with worse cumulative logprob can win simply because we never let
the finished beam compete. Keeping finished beams in the pool —
and ranking by length-penalized score — lets early-stoppers
legitimately defend their lead. (See test_eos_retention.)
"""
K = beam_width
B = len(prompts)
# Per batch item: list of beam dicts with
# tokens : full token list (prompt + generated)
# gen : generated-only token list (prompt does NOT count)
# logprob : raw accumulated logprob (never modified by length penalty)
# finished : True iff this beam has emitted EOS
state = [[{
"tokens": list(p),
"gen": [],
"logprob": 0.0,
"finished": False,
}] for p in prompts]
def lp_score(b):
L = len(b["gen"])
if L == 0:
# Only the initial beam has L=0; never compared against others.
return b["logprob"]
return b["logprob"] / (L ** alpha)
for _ in range(max_new_tokens):
# Stop early if every batch item already holds K finished beams.
if all(len(beams) >= K and all(b["finished"] for b in beams)
for beams in state):
break
# Gather every unfinished beam across all batch items.
active = [] # (batch_idx, beam_idx)
for bi, beams in enumerate(state):
for ki, b in enumerate(beams):
if not b["finished"]:
active.append((bi, ki))
if not active:
break
# One forward call per active beam. Lengths can differ across
# batches, so per-beam calls keep this simple and correct.
active_logps = []
for (bi, ki) in active:
tokens = state[bi][ki]["tokens"]
arr = np.array([tokens], dtype=np.int64)
logits = model.forward(arr)[0] # (V,)
active_logps.append(log_softmax(logits))
# For each batch item, build the candidate pool and pick top K.
for bi in range(B):
beams = state[bi]
pool = []
# Carry finished beams forward — they MUST stay eligible for
# selection so they compete against new candidates by
# length-penalized score. See module docstring on why.
for b in beams:
if b["finished"]:
pool.append(b)
# Expand each unfinished beam with its top-2K next-token
# candidates (2K, not K, preserves diversity).
for active_idx, (abi, aki) in enumerate(active):
if abi != bi:
continue
b = beams[aki]
lp = active_logps[active_idx] # (V,)
m = min(2 * K, lp.shape[0])
# argpartition gives unsorted top-m; that's fine because we
# re-sort the whole pool below.
top_idx = np.argpartition(-lp, m - 1)[:m]
for tok_idx in top_idx:
tok = int(tok_idx)
new_logprob = b["logprob"] + float(lp[tok])
pool.append({
"tokens": b["tokens"] + [tok],
"gen": b["gen"] + [tok],
"logprob": new_logprob,
"finished": (tok == eos_token_id),
})
# Rank all pool entries (finished + new candidates) by the
# length-penalized score and keep the top K.
pool.sort(key=lp_score, reverse=True)
state[bi] = pool[:K]
# Final result: sort once more, return generated tokens only.
results = []
for beams in state:
beams_sorted = sorted(beams, key=lp_score, reverse=True)
results.append([
{
"tokens": b["gen"],
"score": lp_score(b),
"logprob": b["logprob"],
"finished": b["finished"],
}
for b in beams_sorted
])
return results
# --------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------
def greedy_decode(model, prompt, max_new_tokens, eos_token_id):
tokens = list(prompt)
gen = []
for _ in range(max_new_tokens):
arr = np.array([tokens], dtype=np.int64)
logits = model.forward(arr)[0]
tok = int(np.argmax(logits))
tokens.append(tok)
gen.append(tok)
if tok == eos_token_id:
break
return gen
def test_greedy_equivalence():
"""Test 1: K=1, alpha=0 must equal greedy decoding."""
model = TinyLM(seed=42)
prompt = [3, 14, 159]
eos = 0
max_new = 12
greedy = greedy_decode(model, prompt, max_new, eos)
beam_results = batched_beam_search(
model, [prompt], beam_width=1, max_new_tokens=max_new,
alpha=0.0, eos_token_id=eos,
)
beam_tokens = beam_results[0][0]["tokens"]
assert beam_tokens == greedy, (
f"Beam (K=1, alpha=0) diverged from greedy:\n"
f" greedy = {greedy}\n beam = {beam_tokens}"
)
print(f"Test 1 OK — greedy == beam(K=1, alpha=0): {greedy}")
def test_per_batch_independence():
"""Test 2: beams from one prompt must not affect another prompt's
results. Run prompt-0 alone vs in a batch with prompt-1; the
prompt-0 result must be identical."""
model = TinyLM(seed=7)
p0 = [11, 22, 33]
p1 = [44, 55, 66, 77, 88]
eos = 0
K = 3
max_new = 8
solo = batched_beam_search(
model, [p0], beam_width=K, max_new_tokens=max_new,
alpha=0.6, eos_token_id=eos,
)[0]
together = batched_beam_search(
model, [p0, p1], beam_width=K, max_new_tokens=max_new,
alpha=0.6, eos_token_id=eos,
)
assert len(together) == 2
assert len(together[0]) <= K and len(together[1]) <= K
solo_seqs = [tuple(b["tokens"]) for b in solo]
batch_seqs = [tuple(b["tokens"]) for b in together[0]]
assert solo_seqs == batch_seqs, (
f"Per-batch independence violated:\n"
f" solo = {solo_seqs}\n in-batch = {batch_seqs}"
)
# Sanity: prompt-1's beams should be different from prompt-0's.
other_seqs = [tuple(b["tokens"]) for b in together[1]]
assert other_seqs != batch_seqs
print(f"Test 2 OK — prompt-0 results identical solo vs batched (K={K}).")
class _EOSMockModel:
"""Hand-crafted forward pass for the EOS retention test.
Step 1 (first call): produces logits whose softmax gives
logp(eos) = -3.0
logp(tok 1) = -4.0
logp(other) ≈ -6.977
Step 2 (second call): logits whose softmax gives
logp(tok 1) = -1.0 (the survivor extends with this)
logp(other) ≈ -7.365 (eos included, so beam stays unfinished)
"""
def __init__(self, eos_token=0, vocab_size=1000):
self.eos = eos_token
self.V = vocab_size
self.calls = 0
def forward(self, token_ids):
N = token_ids.shape[0]
logits = np.zeros((N, self.V))
if self.calls == 0:
# Distribute mass so e^logits sums to ~1, making logits == logp.
# p(eos)=e^-3, p(1)=e^-4, rest split: (1 - e^-3 - e^-4)/(V-2)
other_p = (1.0 - np.exp(-3.0) - np.exp(-4.0)) / (self.V - 2)
other_lp = float(np.log(other_p))
logits[:, :] = other_lp
logits[:, self.eos] = -3.0
logits[:, 1] = -4.0
else:
other_p = (1.0 - np.exp(-1.0)) / (self.V - 1)
other_lp = float(np.log(other_p))
logits[:, :] = other_lp
logits[:, 1] = -1.0 # winning continuation
# eos stays at other_lp ≈ -7.365 → not picked first
self.calls += 1
return logits
def test_eos_retention():
"""Test 3: the critical EOS-retention test.
Step 1: beam A emits EOS → logprob -3, len 1, finished.
beam B continues with token 1 → logprob -4, len 1.
Step 2: beam B extends with token 1 again → logprob -5, len 2.
Length-penalized scores (alpha=0.6):
A: -3 / 1^0.6 = -3.000
B: -5 / 2^0.6 ≈ -3.296
A must win. A buggy implementation that drops finished beams from
the pool would return B as the top result.
"""
eos = 0
K = 2
max_new = 2
model = _EOSMockModel(eos_token=eos, vocab_size=1000)
results = batched_beam_search(
model, [[42]], beam_width=K, max_new_tokens=max_new,
alpha=0.6, eos_token_id=eos,
)
top = results[0][0]
runner_up = results[0][1]
assert top["tokens"] == [eos], (
f"EOS beam was not the winner. Got tokens={top['tokens']}. "
f"This indicates finished beams were wrongly dropped from the pool."
)
assert top["finished"] is True
assert abs(top["logprob"] - (-3.0)) < 1e-6, top["logprob"]
assert abs(top["score"] - (-3.0)) < 1e-6, top["score"]
assert runner_up["tokens"] == [1, 1]
assert runner_up["finished"] is False
assert abs(runner_up["logprob"] - (-5.0)) < 1e-6, runner_up["logprob"]
expected_runner_score = -5.0 / (2.0 ** 0.6)
assert abs(runner_up["score"] - expected_runner_score) < 1e-6
print(
f"Test 3 OK — EOS beam (score={top['score']:.4f}) correctly beat "
f"unfinished beam (score={runner_up['score']:.4f})."
)
if __name__ == "__main__":
test_greedy_equivalence()
test_per_batch_independence()
test_eos_retention()
print("\nAll tests passed.")