Files
deep_pro_judge/glm5/flash_attention_bwd/flash_attention.py
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

328 lines
12 KiB
Python

import numpy as np
def flash_attention_fwd(Q, K, V, tile_size, causal=True):
B, H, N, D = Q.shape
scale = 1.0 / np.sqrt(D)
dtype = Q.dtype
O = np.zeros((B, H, N, D), dtype=np.float64)
m = np.full((B, H, N), -np.inf, dtype=np.float64)
l = np.zeros((B, H, N), dtype=np.float64)
for b in range(B):
for h in range(H):
Q_bh = Q[b, h].astype(np.float64)
K_bh = K[b, h].astype(np.float64)
V_bh = V[b, h].astype(np.float64)
for i in range(0, N, tile_size):
i_end = min(i + tile_size, N)
T_q = i_end - i
q_tile = Q_bh[i:i_end]
m_row = np.full(T_q, -np.inf, dtype=np.float64)
l_row = np.zeros(T_q, dtype=np.float64)
o_acc = np.zeros((T_q, D), dtype=np.float64)
for j in range(0, N, tile_size):
j_end = min(j + tile_size, N)
if causal and j >= i_end:
continue
k_tile = K_bh[j:j_end]
v_tile = V_bh[j:j_end]
S = (q_tile @ k_tile.T) * scale
if causal:
causal_mask = np.arange(j, j_end)[None, :] > np.arange(i, i_end)[:, None]
S = np.where(causal_mask, -np.inf, S)
m_new = np.maximum(m_row, S.max(axis=-1))
rescale = np.exp(m_row - m_new)
P = np.exp(S - m_new[:, None])
if causal:
P = np.where(causal_mask, 0.0, P)
l_new = rescale * l_row + P.sum(axis=-1)
o_acc = rescale[:, None] * o_acc + P @ v_tile
m_row = m_new
l_row = l_new
o_acc = o_acc / l_row[:, None]
O[b, h, i:i_end] = o_acc
m[b, h, i:i_end] = m_row
l[b, h, i:i_end] = l_row
L = m + np.log(l)
O_out = O.astype(dtype)
cache = {'O': O_out, 'L': L, 'Q': Q, 'K': K, 'V': V}
return O_out, cache
def flash_attention_bwd(dO, cache, tile_size, causal=True):
Q = cache['Q']
K = cache['K']
V = cache['V']
O = cache['O']
L = cache['L']
B, H, N, D = Q.shape
scale = 1.0 / np.sqrt(D)
dtype = Q.dtype
dQ = np.zeros((B, H, N, D), dtype=np.float64)
dK = np.zeros((B, H, N, D), dtype=np.float64)
dV = np.zeros((B, H, N, D), dtype=np.float64)
for b in range(B):
for h in range(H):
for i in range(0, N, tile_size):
i_end = min(i + tile_size, N)
T_q = i_end - i
q_tile = Q[b, h, i:i_end].astype(np.float64)
do_tile = dO[b, h, i:i_end].astype(np.float64)
o_tile = O[b, h, i:i_end].astype(np.float64)
l_tile = L[b, h, i:i_end].astype(np.float64)
Di = (do_tile * o_tile).sum(axis=-1, keepdims=True)
dq_tile = np.zeros((T_q, D), dtype=np.float64)
for j in range(0, N, tile_size):
j_end = min(j + tile_size, N)
T_kv = j_end - j
if causal and j >= i_end:
continue
k_tile = K[b, h, j:j_end].astype(np.float64)
v_tile = V[b, h, j:j_end].astype(np.float64)
S = (q_tile @ k_tile.T) * scale
if causal:
causal_mask = np.arange(j, j_end)[None, :] > np.arange(i, i_end)[:, None]
S = np.where(causal_mask, -np.inf, S)
P = np.exp(S - l_tile[:, None])
if causal:
P = np.where(causal_mask, 0.0, P)
dV[b, h, j:j_end] += P.T @ do_tile
dP = do_tile @ v_tile.T
dS = P * (dP - Di)
if causal:
dS = np.where(causal_mask, 0.0, dS)
dq_tile += dS @ k_tile * scale
dK[b, h, j:j_end] += dS.T @ q_tile * scale
dQ[b, h, i:i_end] = dq_tile
return dQ.astype(dtype), dK.astype(dtype), dV.astype(dtype)
def naive_attention_fwd(Q, K, V, causal=True):
B, H, N, D = Q.shape
scale = 1.0 / np.sqrt(D)
O = np.zeros((B, H, N, D), dtype=np.float64)
for b in range(B):
for h in range(H):
S = (Q[b, h].astype(np.float64) @ K[b, h].astype(np.float64).T) * scale
if causal:
causal_mask = np.triu(np.ones((N, N), dtype=bool), k=1)
S = np.where(causal_mask, -np.inf, S)
S_max = S.max(axis=-1, keepdims=True)
P = np.exp(S - S_max)
P = P / P.sum(axis=-1, keepdims=True)
if causal:
P = np.where(causal_mask, 0.0, P)
O[b, h] = P @ V[b, h].astype(np.float64)
return O.astype(Q.dtype)
def naive_attention_bwd(Q, K, V, dO, causal=True):
B, H, N, D = Q.shape
scale = 1.0 / np.sqrt(D)
Q64 = Q.astype(np.float64)
K64 = K.astype(np.float64)
V64 = V.astype(np.float64)
dO64 = dO.astype(np.float64)
O64 = naive_attention_fwd(Q, K, V, causal=causal).astype(np.float64)
dQ = np.zeros((B, H, N, D), dtype=np.float64)
dK = np.zeros((B, H, N, D), dtype=np.float64)
dV = np.zeros((B, H, N, D), dtype=np.float64)
for b in range(B):
for h in range(H):
S = (Q64[b, h] @ K64[b, h].T) * scale
if causal:
causal_mask = np.triu(np.ones((N, N), dtype=bool), k=1)
S = np.where(causal_mask, -np.inf, S)
S_max = S.max(axis=-1, keepdims=True)
P = np.exp(S - S_max)
P = P / P.sum(axis=-1, keepdims=True)
if causal:
P = np.where(causal_mask, 0.0, P)
dV[b, h] = P.T @ dO64[b, h]
dP = dO64[b, h] @ V64[b, h].T
Di = (dO64[b, h] * O64[b, h]).sum(axis=-1, keepdims=True)
dS = P * (dP - Di)
if causal:
dS = np.where(causal_mask, 0.0, dS)
dQ[b, h] = dS @ K64[b, h] * scale
dK[b, h] = dS.T @ Q64[b, h] * scale
return dQ.astype(Q.dtype), dK.astype(Q.dtype), dV.astype(Q.dtype)
if __name__ == '__main__':
np.random.seed(42)
# Quick forward + backward sanity check (small)
print("=== Forward/backward sanity check (small) ===")
B_c, H_c, N_c, D_c, T_c = 1, 1, 8, 4, 4
Qc = np.random.randn(B_c, H_c, N_c, D_c)
Kc = np.random.randn(B_c, H_c, N_c, D_c)
Vc = np.random.randn(B_c, H_c, N_c, D_c)
Of, cf = flash_attention_fwd(Qc, Kc, Vc, T_c, causal=True)
On = naive_attention_fwd(Qc, Kc, Vc, causal=True)
fwd_err = np.max(np.abs(Of - On) / (np.abs(On) + 1e-8))
print(f"Forward max relative error: {fwd_err:.2e}")
dOc = np.random.randn(*Of.shape)
dQf, dKf, dVf = flash_attention_bwd(dOc, cf, T_c, causal=True)
dQn, dKn, dVn = naive_attention_bwd(Qc, Kc, Vc, dOc, causal=True)
for name, fg, ng in [('dQ', dQf, dQn), ('dK', dKf, dKn), ('dV', dVf, dVn)]:
rel = np.max(np.abs(fg - ng) / (np.abs(ng) + 1e-8))
print(f" {name} max rel error vs naive: {rel:.2e}")
print()
# ── Test 1: Gradient check with finite differences ──
print("Test 1: Gradient check (B=1, H=1, N=64, D=32, T=16, causal=True)")
B1, H1, N1, D1, T1 = 1, 1, 64, 32, 16
Q1 = np.random.randn(B1, H1, N1, D1)
K1 = np.random.randn(B1, H1, N1, D1)
V1 = np.random.randn(B1, H1, N1, D1)
O1, cache1 = flash_attention_fwd(Q1, K1, V1, T1, causal=True)
dO1 = np.random.randn(*O1.shape)
dQ1, dK1, dV1 = flash_attention_bwd(dO1, cache1, T1, causal=True)
dQ1n, dK1n, dV1n = naive_attention_bwd(Q1, K1, V1, dO1, causal=True)
for name, fg, ng in [('dQ', dQ1, dQ1n), ('dK', dK1, dK1n), ('dV', dV1, dV1n)]:
rel = np.max(np.abs(fg.astype(np.float64) - ng.astype(np.float64)) / (np.abs(ng.astype(np.float64)) + 1e-8))
print(f" {name} flash vs naive: {rel:.2e}")
eps = 1e-5
print(" Computing dV via finite differences...")
dV_fd = np.zeros_like(V1, dtype=np.float64)
for idx in np.ndindex(V1.shape):
V_plus = V1.copy(); V_plus[idx] += eps
V_minus = V1.copy(); V_minus[idx] -= eps
O_plus, _ = flash_attention_fwd(Q1, K1, V_plus, T1, causal=True)
O_minus, _ = flash_attention_fwd(Q1, K1, V_minus, T1, causal=True)
dV_fd[idx] = ((O_plus - O_minus) * dO1).sum() / (2 * eps)
rel_err_dV = np.max(np.abs(dV1.astype(np.float64) - dV_fd) / (np.abs(dV_fd) + 1e-8))
print(f" dV max relative error vs finite differences: {rel_err_dV:.2e}")
assert rel_err_dV < 1e-5, f"dV relative error {rel_err_dV} >= 1e-5"
print(" Computing dQ via finite differences (spot-check)...")
rand_idx_Q = np.random.choice(N1 * D1, size=10, replace=False)
dQ_fd = np.zeros_like(Q1, dtype=np.float64)
for flat_idx in rand_idx_Q:
idx = np.unravel_index(flat_idx, Q1.shape)
Q_plus = Q1.copy(); Q_plus[idx] += eps
Q_minus = Q1.copy(); Q_minus[idx] -= eps
O_p, _ = flash_attention_fwd(Q_plus, K1, V1, T1, causal=True)
O_m, _ = flash_attention_fwd(Q_minus, K1, V1, T1, causal=True)
dQ_fd[idx] = ((O_p - O_m) * dO1).sum() / (2 * eps)
rel_err_dQ = np.max(np.abs(dQ1.astype(np.float64).ravel()[rand_idx_Q] - dQ_fd.ravel()[rand_idx_Q]) /
(np.abs(dQ_fd.ravel()[rand_idx_Q]) + 1e-8))
print(f" dQ spot-check relative error: {rel_err_dQ:.2e}")
assert rel_err_dQ < 1e-5, f"dQ relative error {rel_err_dQ} >= 1e-5"
print(" Computing dK via finite differences (spot-check)...")
rand_idx_K = np.random.choice(N1 * D1, size=10, replace=False)
dK_fd = np.zeros_like(K1, dtype=np.float64)
for flat_idx in rand_idx_K:
idx = np.unravel_index(flat_idx, K1.shape)
K_plus = K1.copy(); K_plus[idx] += eps
K_minus = K1.copy(); K_minus[idx] -= eps
O_p, _ = flash_attention_fwd(Q1, K_plus, V1, T1, causal=True)
O_m, _ = flash_attention_fwd(Q1, K_minus, V1, T1, causal=True)
dK_fd[idx] = ((O_p - O_m) * dO1).sum() / (2 * eps)
rel_err_dK = np.max(np.abs(dK1.astype(np.float64).ravel()[rand_idx_K] - dK_fd.ravel()[rand_idx_K]) /
(np.abs(dK_fd.ravel()[rand_idx_K]) + 1e-8))
print(f" dK spot-check relative error: {rel_err_dK:.2e}")
assert rel_err_dK < 1e-5, f"dK relative error {rel_err_dK} >= 1e-5"
print(" PASSED\n")
# ── Test 2: Vs naive backward ──
print("Test 2: Vs naive backward (B=2, H=4, N=256, D=64, T=64, causal=True)")
B2, H2, N2, D2, T2 = 2, 4, 256, 64, 64
Q2 = np.random.randn(B2, H2, N2, D2).astype(np.float32)
K2 = np.random.randn(B2, H2, N2, D2).astype(np.float32)
V2 = np.random.randn(B2, H2, N2, D2).astype(np.float32)
O2, cache2 = flash_attention_fwd(Q2, K2, V2, T2, causal=True)
O2_naive = naive_attention_fwd(Q2, K2, V2, causal=True)
fwd_err = np.max(np.abs(O2 - O2_naive) / (np.abs(O2_naive) + 1e-6))
print(f" Forward max relative error: {fwd_err:.2e}")
dO2 = np.random.randn(*O2.shape).astype(np.float32)
dQ2, dK2, dV2 = flash_attention_bwd(dO2, cache2, T2, causal=True)
dQ2n, dK2n, dV2n = naive_attention_bwd(Q2, K2, V2, dO2, causal=True)
for name, flash_grad, naive_grad in [('dQ', dQ2, dQ2n), ('dK', dK2, dK2n), ('dV', dV2, dV2n)]:
rel = np.max(np.abs(flash_grad.astype(np.float64) - naive_grad.astype(np.float64)) /
(np.abs(naive_grad.astype(np.float64)) + 1e-6))
print(f" {name} max relative error vs naive: {rel:.2e}")
assert rel < 1e-4, f"{name} relative error {rel} >= 1e-4"
print(" PASSED\n")
# ── Test 3: Memory ──
print("Test 3: Memory (B=1, H=1, N=4096, D=64, T=128, causal=True)")
import tracemalloc
import gc
B3, H3, N3, D3, T3 = 1, 1, 4096, 64, 128
Q3 = np.random.randn(B3, H3, N3, D3).astype(np.float32)
K3 = np.random.randn(B3, H3, N3, D3).astype(np.float32)
V3 = np.random.randn(B3, H3, N3, D3).astype(np.float32)
gc.collect()
tracemalloc.start()
O3, cache3 = flash_attention_fwd(Q3, K3, V3, T3, causal=True)
dO3 = np.ones_like(O3)
dQ3, dK3, dV3 = flash_attention_bwd(dO3, cache3, T3, causal=True)
gc.collect()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
full_attn_mem = N3 * N3 * 4
ratio = peak / full_attn_mem
print(f" Current memory: {current / 1e6:.2f} MB")
print(f" Peak memory: {peak / 1e6:.2f} MB")
print(f" Full (N,N) matrix would be: {full_attn_mem / 1e6:.2f} MB")
print(f" Ratio: {ratio:.2%}")
assert ratio < 0.20, f"Memory ratio {ratio:.2%} >= 20%"
print(" PASSED\n")
print("All tests passed!")