Files
deep_pro_judge/opus47_1m/flash_attention_bwd/flash_attention.py
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

158 lines
5.3 KiB
Python

import numpy as np
def flash_attention_fwd(Q, K, V, tile_size, causal=True):
"""Tiled forward pass with online softmax. Stores only O, L, Q, K, V."""
B, H, N, D = Q.shape
scale = 1.0 / np.sqrt(D)
T = tile_size
O = np.zeros_like(Q)
L = np.zeros((B, H, N), dtype=Q.dtype)
for b in range(B):
for h in range(H):
Qbh, Kbh, Vbh = Q[b, h], K[b, h], V[b, h]
for q_start in range(0, N, T):
q_end = min(q_start + T, N)
Tq = q_end - q_start
Q_tile = Qbh[q_start:q_end]
m_i = np.full((Tq,), -np.inf, dtype=Q.dtype)
l_i = np.zeros((Tq,), dtype=Q.dtype)
O_i = np.zeros((Tq, D), dtype=Q.dtype)
for k_start in range(0, N, T):
k_end = min(k_start + T, N)
if causal and k_start >= q_end:
continue
K_tile = Kbh[k_start:k_end]
V_tile = Vbh[k_start:k_end]
S = (Q_tile @ K_tile.T) * scale
if causal and k_end > q_start + 1:
q_idx = np.arange(q_start, q_end)[:, None]
k_idx = np.arange(k_start, k_end)[None, :]
mask = k_idx > q_idx
if mask.any():
S = np.where(mask, -np.inf, S)
m_block = S.max(axis=-1)
m_new = np.maximum(m_i, m_block)
with np.errstate(invalid='ignore'):
m_new_safe = np.where(np.isneginf(m_new), 0.0, m_new)
P = np.exp(S - m_new_safe[:, None])
m_i_safe = np.where(np.isneginf(m_i), 0.0, m_i)
alpha = np.where(
np.isneginf(m_i), 0.0,
np.exp(m_i_safe - m_new_safe),
)
l_i = alpha * l_i + P.sum(axis=-1)
O_i = alpha[:, None] * O_i + P @ V_tile
m_i = m_new
O[b, h, q_start:q_end] = O_i / l_i[:, None]
L[b, h, q_start:q_end] = m_i + np.log(l_i)
cache = {'O': O, 'L': L, 'Q': Q, 'K': K, 'V': V}
return O, cache
def flash_attention_bwd(dO, cache, tile_size, causal=True):
"""Tiled backward pass with on-the-fly softmax recomputation from L."""
Q, K, V, O, L = cache['Q'], cache['K'], cache['V'], cache['O'], cache['L']
B, H, N, D = Q.shape
scale = 1.0 / np.sqrt(D)
T = tile_size
dQ = np.zeros_like(Q)
dK = np.zeros_like(K)
dV = np.zeros_like(V)
# Precompute D_i = sum_d O[i,d] * dO[i,d] — equals sum_k P[i,k] * dP[i,k]
# over ALL keys, which is what the softmax-gradient row sum requires.
D_row = (O * dO).sum(axis=-1) # (B, H, N)
for b in range(B):
for h in range(H):
Qbh, Kbh, Vbh = Q[b, h], K[b, h], V[b, h]
Lbh = L[b, h]
dObh = dO[b, h]
Dbh = D_row[b, h]
for q_start in range(0, N, T):
q_end = min(q_start + T, N)
Q_tile = Qbh[q_start:q_end]
dO_tile = dObh[q_start:q_end]
L_q = Lbh[q_start:q_end]
D_q = Dbh[q_start:q_end]
dQ_tile = np.zeros_like(Q_tile)
for k_start in range(0, N, T):
k_end = min(k_start + T, N)
if causal and k_start >= q_end:
continue
K_tile = Kbh[k_start:k_end]
V_tile = Vbh[k_start:k_end]
S = (Q_tile @ K_tile.T) * scale
if causal and k_end > q_start + 1:
q_idx = np.arange(q_start, q_end)[:, None]
k_idx = np.arange(k_start, k_end)[None, :]
mask = k_idx > q_idx
if mask.any():
S = np.where(mask, -np.inf, S)
with np.errstate(invalid='ignore'):
P = np.exp(S - L_q[:, None])
dV[b, h, k_start:k_end] += P.T @ dO_tile
dP = dO_tile @ V_tile.T
dS = P * (dP - D_q[:, None])
dQ_tile += (dS @ K_tile) * scale
dK[b, h, k_start:k_end] += (dS.T @ Q_tile) * scale
dQ[b, h, q_start:q_end] = dQ_tile
return dQ, dK, dV
# ----- Naive reference implementations for testing -----
def naive_forward(Q, K, V, causal=True):
B, H, N, D = Q.shape
scale = 1.0 / np.sqrt(D)
S = np.einsum('bhid,bhjd->bhij', Q, K) * scale
if causal:
mask = np.triu(np.ones((N, N), dtype=bool), k=1)
S = np.where(mask[None, None], -np.inf, S)
S_max = S.max(axis=-1, keepdims=True)
P = np.exp(S - S_max)
P = P / P.sum(axis=-1, keepdims=True)
O = np.einsum('bhij,bhjd->bhid', P, V)
return O, P
def naive_backward(Q, K, V, dO, P):
_, _, _, D = Q.shape
scale = 1.0 / np.sqrt(D)
dV = np.einsum('bhij,bhid->bhjd', P, dO)
dP = np.einsum('bhid,bhjd->bhij', dO, V)
rowsum = (P * dP).sum(axis=-1, keepdims=True)
dS = P * (dP - rowsum)
dQ = np.einsum('bhij,bhjd->bhid', dS, K) * scale
dK = np.einsum('bhij,bhid->bhjd', dS, Q) * scale
return dQ, dK, dV