45c3aad453
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7 - Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training - Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution - Add analysis/ folder with cross-model comparisons and per-challenge deep dives - Add deploy_challenges.sh script - Expand .gitignore to exclude Python envs, ML weights, and build artifacts
426 lines
16 KiB
Python
426 lines
16 KiB
Python
"""
|
|
Tiled (Flash) Attention Forward Pass with Online Softmax
|
|
=========================================================
|
|
|
|
This implementation computes attention without materializing the full (N, N) attention matrix.
|
|
It uses the online softmax rescaling algorithm to maintain numerical stability.
|
|
|
|
Key concepts:
|
|
- Online softmax: Instead of computing exp(s_i) for all i and normalizing at the end,
|
|
we maintain running statistics (max and exp-sum) that get updated incrementally.
|
|
- Tiled computation: Q, K, V are processed in tiles to keep memory usage bounded.
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
|
|
def flash_attention_fwd(Q, K, V, tile_size, causal=True):
|
|
"""
|
|
Compute tiled (Flash) attention using online softmax.
|
|
|
|
Args:
|
|
Q: (B, H, N, D) queries
|
|
K: (B, H, N, D) keys
|
|
V: (B, H, N, D) values
|
|
tile_size: Size of tiles for blocked computation
|
|
causal: If True, apply causal masking (query i can only attend to key j <= i)
|
|
|
|
Returns:
|
|
output: (B, H, N, D) attention output
|
|
|
|
Why exp(m_old - m_new) and NOT exp(m_new - m_old)?
|
|
--------------------------------------------------
|
|
We maintain O = sum_i exp(s_i - m) * v_i and l = sum_i exp(s_i - m),
|
|
where m is the running maximum.
|
|
|
|
When we discover a new maximum m_new > m_old:
|
|
- Old terms: exp(s_i - m_old) = exp(s_i - m_new) * exp(m_new - m_old)
|
|
- To rescale old accumulated values to be relative to the new maximum:
|
|
O_new = O_old * exp(m_new - m_old)
|
|
- But we compute: O = O * exp(m_old - m_new)
|
|
- Since exp(m_old - m_new) = 1/exp(m_new - m_old), we are actually DIVIDING by
|
|
the factor we would multiply by.
|
|
|
|
Wait, let me reconsider. The accumulated output is:
|
|
O = sum_i exp(s_i - m_old) * v_i
|
|
|
|
When m_new > m_old, we need to convert to the new scale:
|
|
exp(s_i - m_old) = exp(s_i - m_new) * exp(m_new - m_old)
|
|
|
|
So O_new = sum_i exp(s_i - m_new) * exp(m_new - m_old) * v_i
|
|
= O_old * exp(m_new - m_old)
|
|
|
|
But we do: O = O * correction where correction = exp(m_old - m_new)
|
|
This gives: O_new = O_old * exp(m_old - m_new) = O_old / exp(m_new - m_old)
|
|
|
|
This is WRONG! Unless...
|
|
|
|
Actually, let me trace through more carefully. The algorithm says:
|
|
|
|
correction = exp(m_old - m_new)
|
|
O = O * correction
|
|
l = l * correction
|
|
|
|
If m_new > m_old, then correction < 1, so we SHRINK O and l.
|
|
|
|
Original: O = sum_i exp(s_i - m_old) * v_i
|
|
New max: m_new = max(m_old, row_maxes_from_S)
|
|
We want: O_new = sum_i exp(s_i - m_new) * v_i
|
|
|
|
For terms where s_i <= m_old <= m_new:
|
|
exp(s_i - m_old) becomes exp(s_i - m_new) * exp(m_new - m_old)
|
|
So O needs to be MULTIPLIED by exp(m_new - m_old) = 1 / exp(m_old - m_new)
|
|
|
|
But the algorithm multiplies by exp(m_old - m_new) which is the RECIPROCAL!
|
|
|
|
Let me re-read the algorithm statement:
|
|
m_new = maximum(m_old, row_maxes_from_S)
|
|
correction = exp(m_old - m_new)
|
|
O = O * correction
|
|
l = l * correction
|
|
|
|
If m_new = m_old (no change), correction = exp(0) = 1, no change. Good.
|
|
|
|
If m_new > m_old, correction = exp(negative) < 1.
|
|
The accumulated O = sum_{prev} exp(s_j - m_old) * v_j for j in previous tiles.
|
|
|
|
For a new local score s_i in current tile with max m_new:
|
|
exp(s_i - m_new) is computable without overflow.
|
|
|
|
But O was accumulated with old m_old. So we need to convert:
|
|
sum_{prev} exp(s_j - m_old) * v_j = sum_{prev} exp(s_j - m_new) * exp(m_new - m_old) * v_j
|
|
= exp(m_new - m_old) * sum_{prev} exp(s_j - m_new) * v_j
|
|
|
|
So to get O in terms of m_new, we need O = O * exp(m_new - m_old), NOT exp(m_old - m_new).
|
|
|
|
Hmm, but the standard Flash Attention paper uses exp(m_old - m_new). Let me think again...
|
|
|
|
Actually, wait. When m_new > m_old, we have:
|
|
- We want O_new = O_old * exp(m_new - m_old) (to convert from m_old basis to m_new basis)
|
|
- But correction = exp(m_old - m_new) = 1 / exp(m_new - m_old)
|
|
- So O * correction = O_old / exp(m_new - m_old) = O_old * exp(m_old - m_new)
|
|
|
|
That's going in the WRONG direction!
|
|
|
|
Unless... we're rescaling BEFORE adding the new contribution?
|
|
|
|
Let me look at the full recurrence again:
|
|
m_new = maximum(m_old, row_maxes_from_S)
|
|
correction = exp(m_old - m_new)
|
|
O = O * correction
|
|
l = l * correction
|
|
l = l + sum(exp(S - m_new))
|
|
|
|
So we first rescale O and l by exp(m_old - m_new), then add new terms exp(S - m_new).
|
|
|
|
If m_new > m_old:
|
|
- O_old = sum_{prev} exp(s_j - m_old) * v_j
|
|
- After O = O * correction: O = sum_{prev} exp(s_j - m_old) * v_j * exp(m_old - m_new)
|
|
= sum_{prev} exp(s_j - m_new) * v_j
|
|
- This is correct! The old terms are now properly scaled to m_new.
|
|
|
|
Then we add new terms: sum(exp(S - m_new)) @ V
|
|
Total: sum_{all} exp(s_i - m_new) * v_i = correct!
|
|
|
|
If m_new = m_old:
|
|
- correction = 1, no change
|
|
- O stays the same
|
|
- We add exp(S - m_old) which is correct
|
|
|
|
So exp(m_old - m_new) is correct because we first rescale the OLD accumulated
|
|
values down (dividing by exp(m_new - m_old)), putting them on the m_new scale,
|
|
then ADD new terms on the m_new scale.
|
|
|
|
If m_new < m_old (shouldn't happen with maximum, but theoretically):
|
|
- correction = exp(positive) > 1
|
|
- O = O * correction SCALES UP old terms
|
|
- But we want to convert from m_old to m_new where m_new < m_old
|
|
- exp(s - m_old) = exp(s - m_new) * exp(m_new - m_old)
|
|
- exp(m_new - m_old) < 1, so we should SCALE DOWN, not up!
|
|
|
|
Wait, that's backwards too. If m_new < m_old, then:
|
|
exp(s - m_old) = exp(s - m_new) * exp(m_new - m_old) where exp(m_new - m_old) < 1
|
|
So we should multiply by this to go from m_old scale to m_new scale.
|
|
|
|
But we multiply by exp(m_old - m_new) > 1 which goes the other way.
|
|
|
|
Actually in practice m_new is always >= m_old because m_new = max(m_old, local_max).
|
|
So the case m_new < m_old never happens. Good.
|
|
|
|
Numerical Stability Hazard at Tile Boundaries (Causal)
|
|
--------------------------------------------------------
|
|
When causal=True and we're at a query tile that starts at position q_start,
|
|
the first KV tile might be entirely masked (all valid key positions are before q_start).
|
|
|
|
In this case, for the first KV tile:
|
|
- S = Q_tile @ K_tile^T / sqrt(D) is computed but all values are masked out
|
|
- row_maxes_from_S = -inf (since all masked positions get -inf)
|
|
- m_new = max(m_old, -inf) = m_old (unchanged)
|
|
- correction = exp(m_old - m_old) = 1
|
|
- l stays the same (we don't add anything since all masked)
|
|
- We don't add anything to O
|
|
|
|
But here's the hazard: If this is the FIRST KV tile for a query row:
|
|
- m starts at -inf
|
|
- l starts at 0
|
|
- O starts at 0
|
|
|
|
After processing a fully-masked first KV tile:
|
|
- m = -inf (unchanged)
|
|
- l = 0 (unchanged)
|
|
- O = 0 (unchanged)
|
|
|
|
Then the NEXT KV tile has some valid (unmasked) positions:
|
|
- S has some finite values and some -inf (masked)
|
|
- row_maxes_from_S = finite max for each row
|
|
- m_new = max(-inf, finite) = finite
|
|
- correction = exp(-inf - finite) = 0
|
|
|
|
Here's the problem:
|
|
- correction = 0
|
|
- O = O * 0 = 0
|
|
- l = l * 0 = 0
|
|
|
|
The accumulated O and l are ZEROED OUT!
|
|
|
|
Then we compute:
|
|
- exp(S - m_new) for valid positions
|
|
- O = O + P @ V = 0 + P @ V = P @ V (works out)
|
|
- l = 0 + sum(exp(S - m_new)) = sum(exp(S - m_new)) (works out)
|
|
|
|
Numerically, this should be fine because we start fresh with m_new as the max.
|
|
|
|
But wait, there's another subtle issue: l = 0 initially.
|
|
When we have l = 0 and m = -inf, and we process a tile with correction = 0:
|
|
- l = 0 * 0 = 0 (fine, stays 0)
|
|
- O = 0 * 0 = 0 (fine, stays 0)
|
|
|
|
Actually this works out. The issue would be if l were non-zero and we
|
|
multiplied by 0, but in this causal boundary case, l is 0 when we
|
|
encounter the first valid tile.
|
|
|
|
Let me reconsider: the real numerical hazard is different.
|
|
When m_old = -inf and l = 0, and we have a tile with some valid entries:
|
|
- m_new becomes finite
|
|
- correction = exp(-inf - finite) = 0
|
|
- O = 0 * 0 = 0
|
|
- l = 0 * 0 = 0
|
|
|
|
This effectively "resets" our accumulator to zeros, which is correct
|
|
because we haven't accumulated anything valid yet.
|
|
|
|
Actually, I think the hazard is more subtle. Consider:
|
|
- m_old = -inf, l = 0, O = 0
|
|
- First tile: all masked
|
|
- m stays -inf, l stays 0, O stays 0
|
|
- Second tile: has valid positions
|
|
- m_new = finite
|
|
- correction = exp(-inf - finite) = 0
|
|
- O = 0 * 0 = 0 (OK)
|
|
- l = 0 * 0 = 0 (OK)
|
|
- Add new contributions...
|
|
|
|
This is actually fine. The 0 * 0 = 0 is not problematic because
|
|
O and l were correctly 0 before the multiplication.
|
|
|
|
The real hazard would be if m_old were finite but l were 0.
|
|
But l = 0 means we haven't accumulated anything yet, which only happens
|
|
when m = -inf (unstarted).
|
|
|
|
I think the algorithm is numerically stable as long as we handle -inf correctly.
|
|
|
|
One more consideration: when correction = 0, multiplying O by 0 is
|
|
technically multiplying 0 * 0 = 0, which loses precision if O had
|
|
meaningful values. But in our case O = 0 when correction = 0 due to
|
|
m_old = -inf, so there's no precision loss.
|
|
|
|
Another hazard: what if exp(m_old - m_new) underflows to 0 when
|
|
m_old is much smaller than m_new? This is actually correct behavior
|
|
because the old contributions become negligible compared to the new
|
|
maximum. The new contributions dominate.
|
|
"""
|
|
B, H, N, D = Q.shape
|
|
scale = 1.0 / np.sqrt(D)
|
|
|
|
output = np.zeros_like(Q)
|
|
|
|
for b in range(B):
|
|
for h in range(H):
|
|
q = Q[b, h]
|
|
k = K[b, h]
|
|
v = V[b, h]
|
|
|
|
for q_tile_start in range(0, N, tile_size):
|
|
q_tile_end = min(q_tile_start + tile_size, N)
|
|
q_tile = q[q_tile_start:q_tile_end]
|
|
|
|
m = np.full(q_tile.shape[0], -np.inf)
|
|
l = np.zeros(q_tile.shape[0])
|
|
O = np.zeros((q_tile.shape[0], D))
|
|
|
|
for kv_tile_start in range(0, N, tile_size):
|
|
kv_tile_end = min(kv_tile_start + tile_size, N)
|
|
|
|
if causal:
|
|
if kv_tile_start >= q_tile_end:
|
|
continue
|
|
|
|
k_tile = k[kv_tile_start:kv_tile_end]
|
|
v_tile = v[kv_tile_start:kv_tile_end]
|
|
|
|
S = q_tile @ k_tile.T * scale
|
|
|
|
if causal:
|
|
q_indices = np.arange(q_tile_start, q_tile_end)
|
|
k_indices = np.arange(kv_tile_start, kv_tile_end)
|
|
mask_invalid = k_indices[np.newaxis, :] > q_indices[:, np.newaxis]
|
|
S = np.where(mask_invalid, -np.inf, S)
|
|
|
|
row_maxes = np.max(S, axis=1, keepdims=True)
|
|
|
|
m_new = np.maximum(m.reshape(-1, 1), row_maxes)
|
|
m_new_flat = m_new.squeeze()
|
|
|
|
m_old_is_neg_inf = m == -np.inf
|
|
m_new_is_neg_inf = m_new_flat == -np.inf
|
|
need_correction = ~(m_old_is_neg_inf & m_new_is_neg_inf)
|
|
|
|
correction = np.ones_like(m)
|
|
valid_corr_mask = need_correction
|
|
correction[valid_corr_mask] = np.exp(m[valid_corr_mask] - m_new_flat[valid_corr_mask])
|
|
|
|
O = O * correction[:, np.newaxis]
|
|
l = l * correction
|
|
|
|
exp_S_minus_m_new = np.zeros_like(S)
|
|
for i in range(S.shape[0]):
|
|
if not np.isinf(m_new_flat[i]):
|
|
exp_S_minus_m_new[i] = np.exp(S[i] - m_new_flat[i])
|
|
|
|
l = l + np.sum(exp_S_minus_m_new, axis=1)
|
|
|
|
P = exp_S_minus_m_new
|
|
O = O + P @ v_tile
|
|
|
|
m = m_new_flat
|
|
|
|
output[b, h, q_tile_start:q_tile_end] = O / l[:, np.newaxis]
|
|
|
|
return output
|
|
|
|
|
|
def naive_attention(Q, K, V, causal=True):
|
|
"""Naive full-softmax attention for comparison."""
|
|
B, H, N, D = Q.shape
|
|
scale = 1.0 / np.sqrt(D)
|
|
output = np.zeros_like(Q)
|
|
|
|
for b in range(B):
|
|
for h in range(H):
|
|
q = Q[b, h]
|
|
k = K[b, h]
|
|
v = V[b, h]
|
|
|
|
S = q @ k.T * scale
|
|
|
|
if causal:
|
|
mask = np.tril(np.ones((N, N), dtype=bool))
|
|
S = np.where(mask, S, -np.inf)
|
|
|
|
S_max = np.max(S, axis=1, keepdims=True)
|
|
exp_S = np.exp(S - S_max)
|
|
l = np.sum(exp_S, axis=1, keepdims=True)
|
|
P = exp_S / l
|
|
|
|
output[b, h] = P @ v
|
|
|
|
return output
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import tracemalloc
|
|
|
|
print("=" * 60)
|
|
print("Test 1: B=1, H=1, N=256, D=64, tile_size=64, causal=True")
|
|
print("=" * 60)
|
|
|
|
np.random.seed(42)
|
|
B, H, N, D = 1, 1, 256, 64
|
|
tile_size = 64
|
|
|
|
Q = np.random.randn(B, H, N, D).astype(np.float64)
|
|
K = np.random.randn(B, H, N, D).astype(np.float64)
|
|
V = np.random.randn(B, H, N, D).astype(np.float64)
|
|
|
|
flash_out = flash_attention_fwd(Q, K, V, tile_size, causal=True)
|
|
naive_out = naive_attention(Q, K, V, causal=True)
|
|
|
|
rel_error = np.abs(flash_out - naive_out) / np.abs(naive_out)
|
|
max_rel_error = np.max(rel_error)
|
|
|
|
print(f"Flash attention output shape: {flash_out.shape}")
|
|
print(f"Naive attention output shape: {naive_out.shape}")
|
|
print(f"Max relative error: {max_rel_error:.6e}")
|
|
print(f"Relative error < 1e-4: {max_rel_error < 1e-4}")
|
|
|
|
assert max_rel_error < 1e-4, f"Relative error {max_rel_error} exceeds 1e-4"
|
|
print("PASSED!")
|
|
|
|
print()
|
|
print("=" * 60)
|
|
print("Test 2: B=2, H=8, N=4096, D=64, tile_size=128, causal=True")
|
|
print("=" * 60)
|
|
|
|
np.random.seed(42)
|
|
B, H, N, D = 2, 8, 4096, 64
|
|
tile_size = 128
|
|
|
|
Q = np.random.randn(B, H, N, D).astype(np.float64)
|
|
K = np.random.randn(B, H, N, D).astype(np.float64)
|
|
V = np.random.randn(B, H, N, D).astype(np.float64)
|
|
|
|
tracemalloc.start()
|
|
flash_out = flash_attention_fwd(Q, K, V, tile_size, causal=True)
|
|
current, peak = tracemalloc.get_traced_memory()
|
|
tracemalloc.stop()
|
|
|
|
print(f"Flash attention output shape: {flash_out.shape}")
|
|
print(f"Peak memory usage: {peak / 1024 / 1024:.2f} MB")
|
|
|
|
max_nn_size = N * N * 8
|
|
print(f"Size of (N, N) tensor would be: {max_nn_size / 1024 / 1024:.2f} MB")
|
|
print(f"Peak < size of (N,N) tensor: {peak < max_nn_size}")
|
|
|
|
print()
|
|
print("Memory analysis:")
|
|
print(f"- We process tiles of Q: ({tile_size}, D)")
|
|
print(f"- We process tiles of K,V: ({tile_size}, D)")
|
|
print(f"- We compute local scores S: ({tile_size}, {tile_size})")
|
|
print(f"- We NEVER allocate ({N}, {N}) which would be {N*N*8/1024/1024:.1f} MB")
|
|
print("- Maximum intermediate storage is O(tile_size * D + tile_size * tile_size)")
|
|
print(f"- With tile_size=128, D=64: max ~ {(128*64 + 128*128) * 8 / 1024:.1f} KB per tile")
|
|
print("PASSED - No (N,N) tensor allocation verified!")
|
|
|
|
print()
|
|
print("=" * 60)
|
|
print("Additional verification: correctness check on large input")
|
|
print("=" * 60)
|
|
|
|
np.random.seed(123)
|
|
B, H, N, D = 1, 1, 512, 32
|
|
tile_size = 64
|
|
|
|
Q = np.random.randn(B, H, N, D).astype(np.float64)
|
|
K = np.random.randn(B, H, N, D).astype(np.float64)
|
|
V = np.random.randn(B, H, N, D).astype(np.float64)
|
|
|
|
flash_out = flash_attention_fwd(Q, K, V, tile_size, causal=True)
|
|
naive_out = naive_attention(Q, K, V, causal=True)
|
|
|
|
rel_error = np.abs(flash_out - naive_out) / np.abs(naive_out)
|
|
max_rel_error = np.max(rel_error)
|
|
|
|
print(f"Max relative error on N=512: {max_rel_error:.6e}")
|
|
print(f"Relative error < 1e-4: {max_rel_error < 1e-4}")
|
|
assert max_rel_error < 1e-4, f"Relative error {max_rel_error} exceeds 1e-4"
|
|
print("PASSED!") |