Files
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

12 KiB
Raw Permalink Blame History

Final Challenge: Flash Attention Backward Pass (Tiled, Recompute)

Why this challenge

The forward pass of Flash Attention has been implemented correctly by all models tested so far. The backward pass is the real test — 5-10x harder, with subtle interactions between tiling, recomputation, and the softmax gradient. PyTorch's own autograd gets this wrong without careful torch.compile handling. Three of the five major open-source Flash Attention ports (xformers early, vLLM's first kernel, and llama.cpp's first attempt) shipped with gradient bugs that passed forward correctness but failed backward.

This challenge:

  • Runs on your M4 MacBook Pro (~200-400 MB, not GB)
  • Takes ~5-10 seconds for the gradient check
  • Catches incorrect implementations that "look right" in the forward
  • Is directly relevant to LLM training (every training framework uses Flash Attention)
  • Tests the exact capability gap between your local model and frontier models

The key trap: the dsoftmax formula is dS = P * (dP - rowsum(P * dP)). The rowsum is over the KEY dimension, and P must be the recomputed softmax from the stored logsumexp. Getting ANY of these details wrong produces gradients that look plausible but fail finite-difference verification.

The prompt

Implement the BACKWARD pass of tiled (Flash) attention using online softmax
recomputation, from scratch in NumPy.

You already have a forward pass (include it or write a minimal one). The forward
pass MUST store only these intermediates per (B, H) head:
  - O:    (N, D)  — attention output
  - L:    (N,)    — logsumexp per query row: L_i = m_i + log(l_i)
    where m_i is the final row max and l_i is the final row sum of exps
  - Q, K, V: the original inputs (required for recomputation)

The forward MUST NOT store the full (N, N) attention matrix or softmax matrix.
It MAY process Q and K/V in tiles of size T and use the online softmax recurrence.

BACKWARD PASS REQUIREMENTS:

1. RECOMPUTATION:
   Given dO (upstream gradient, same shape as O), Q, K, V, O, and L, compute:
     dQ: (N, D) — gradient w.r.t. queries
     dK: (N, D) — gradient w.r.t. keys
     dV: (N, D) — gradient w.r.t. values
   
   The backward pass must NOT materialize the full (N, N) attention or
   softmax matrix either. It recomputes softmax probabilities P on-the-fly
   from the stored L and locally recomputed S = Q @ K^T / sqrt(D).

2. GRADIENT FORMULAS (for a single N×D head, no batching yet):
   Let scale = 1/sqrt(D). For each tile interaction between Q_tile and K_tile:
   
   a) Recompute local attention scores: S = Q_tile @ K_tile^T * scale
   b) Recompute local softmax: P = exp(S - L_query[:, None])
      (L_query are the logsumexp values for the query rows in this tile,
       broadcast against the key dimension)
   c) Compute local dV contribution: dV += P^T @ dO_tile
   d) Compute local dP: dP = dO_tile @ V_tile^T
   e) Compute local dS via the softmax gradient:
        dS = P * (dP - rowsum(P * dP))   where rowsum is over the KEY axis
      IMPORTANT: P * dP is elementwise. rowsum sums over the last axis (keys).
      The subtraction broadcasts: rowsum(P*dP) has shape (T_q, 1), subtracted
      from dP which is (T_q, T_kv), then multiplied elementwise by P.
   f) Compute local dQ contribution: dQ += dS @ K_tile
   g) Compute local dK contribution: dK += dS^T @ Q_tile

3. TILING:
   The backward pass should also use tiling to avoid materializing full matrices.
   Process Q in tiles, and for each Q tile, iterate over KV tiles to recompute
   P, dP, dS and accumulate dQ, dK, dV. This mirrors the forward pass structure.

4. BATCHING:
   Extend the above to handle (B, H, N, D) tensors. The L tensor becomes
   (B, H, N). The tile loops can be per-(b,h) or batched — either is acceptable.

5. NUMERICAL STABILITY:
   - The stored L values already incorporate the row max, so P = exp(S - L)
     is numerically stable (arguments ≤ 0).
   - The dsoftmax formula involves computing (dP - rowsum(P * dP)). If dP has
     large values, the subtraction can cause cancellation, but this is inherent
     to softmax and handled by the upcast to float64 for the rowsum operation.
   - Ensure no division by zero or log of negative numbers.

6. CORRECTNESS VERIFICATION:
   Compare your backward pass output against numerical gradients (central
   finite differences) for a small test case (N=64, D=32, tile_size=16).
   Also compare against the naive full-materialized backward (which computes
   the full attention matrix).

Deliver:
- Function flash_attention_fwd(Q, K, V, tile_size, causal=True)
  → returns O (B,H,N,D) and cache dict with {'O': O, 'L': L, 'Q': Q, 'K': K, 'V': V}
- Function flash_attention_bwd(dO, cache, tile_size, causal=True)
  → returns dQ, dK, dV, each (B,H,N,D)
- Gradient check test: (B=1, H=1, N=64, D=32, T=16, causal=True)
  → compare bwd output vs central finite differences, assert relative error < 1e-5
- Correctness test: (B=2, H=4, N=256, D=64, T=64, causal=True)
  → compare bwd output vs naive full-materialized backward, assert rel error < 1e-4
- Memory test: (B=1, H=1, N=4096, D=64, T=128, causal=True)
  → verify peak memory is well below N² (use tracemalloc)

Use only NumPy. No PyTorch, JAX, TensorFlow, or autograd.

How the trap works

The dsoftmax formula in Step 2e is where 80% of implementations fail:

# CORRECT (what you should write):
dS = P * (dP - (P * dP).sum(axis=-1, keepdims=True))

# WRONG (very common — wrong axis):
dS = P * (dP - (P * dP).sum(axis=-2, keepdims=True))

# WRONG (forgets to multiply by P):
dS = dP - (P * dP).sum(axis=-1, keepdims=True)

# WRONG (divides instead of subtracts):
dS = P * dP / (P * dP).sum(axis=-1, keepdims=True)

# WRONG (uses dO instead of dP):
dS = P * (dP - (P * dO).sum(axis=-1, keepdims=True))

All of these produce dQ, dK, dV values that "look like gradients" — they have reasonable magnitudes and shapes — but fail finite-difference verification.

Additional trap: the stored L format

The forward pass stores L = m + log(l). To recompute P:

P = exp(S - L[:, None])  # S is (T_q, T_kv), L is (T_q,)

If the forward accidentally stores l (sum of exps) instead of L (logsumexp), the backward would need P = exp(S - log(l[:, None])) which is a different computation. The test catches this because the exp(S - wrong_value) produces incorrect P, which cascades to incorrect dV, dP, dS, etc.

Reference implementation skeleton

def flash_attention_fwd(Q, K, V, tile_size, causal=True):
    B, H, N, D = Q.shape
    scale = 1.0 / np.sqrt(D)
    T = tile_size
    
    O = np.zeros_like(Q)
    L = np.full((B, H, N), -np.inf)
    
    for b in range(B):
        for h in range(H):
            # ... standard tiled forward with online softmax ...
            # At the end of processing all KV tiles for a Q tile:
            #   O[b, h, q_s:q_e, :] = O_acc / l[:, None]
            #   L[b, h, q_s:q_e] = m + np.log(l)
    
    cache = {'O': O, 'L': L, 'Q': Q, 'K': K, 'V': V}
    return O, cache


def flash_attention_bwd(dO, cache, tile_size, causal=True):
    O = cache['O']
    L = cache['L']
    Q = cache['Q']
    K = cache['K']
    V = cache['V']
    
    B, H, N, D = Q.shape
    scale = 1.0 / np.sqrt(D)
    T = tile_size
    
    dQ = np.zeros_like(Q)
    dK = np.zeros_like(K)
    dV = np.zeros_like(V)
    
    for b in range(B):
        for h in range(H):
            # ... tiled backward pass ...
            # For each Q_tile (q_s:q_e) × KV_tile (k_s:k_e):
            #   S = Q_tile @ K_tile^T * scale
            #   P = exp(S - L_query[:, None])
            #   dV_tile += P^T @ dO_tile
            #   dP = dO_tile @ V_tile^T
            #   dS = P * (dP - (P * dP).sum(axis=-1, keepdims=True))
            #   dQ_tile += dS @ K_tile
            #   dK_tile += dS^T @ Q_tile
    
    return dQ, dK, dV

Test code that catches the bugs

def test_gradient_check():
    """Compare backward against central finite differences."""
    np.random.seed(42)
    B, H, N, D = 1, 1, 64, 32
    T = 16
    
    Q = np.random.randn(B, H, N, D).astype(np.float64)
    K = np.random.randn(B, H, N, D).astype(np.float64)
    V = np.random.randn(B, H, N, D).astype(np.float64)
    dO = np.random.randn(B, H, N, D).astype(np.float64)
    
    # Forward + backward
    O, cache = flash_attention_fwd(Q, K, V, T, causal=True)
    dQ, dK, dV = flash_attention_bwd(dO, cache, T, causal=True)
    
    # Finite difference check for dV (dQ and dK are more expensive)
    eps = 1e-5
    dV_fd = np.zeros_like(V)
    for b in range(B):
        for h in range(H):
            for i in range(N):
                for j in range(D):
                    V_plus = V.copy()
                    V_minus = V.copy()
                    V_plus[b, h, i, j] += eps
                    V_minus[b, h, i, j] -= eps
                    O_plus, _ = flash_attention_fwd(Q, K, V_plus, T, causal=True)
                    O_minus, _ = flash_attention_fwd(Q, K, V_minus, T, causal=True)
                    loss_plus = (dO * O_plus).sum()
                    loss_minus = (dO * O_minus).sum()
                    dV_fd[b, h, i, j] = (loss_plus - loss_minus) / (2 * eps)
    
    rel_err = np.abs(dV - dV_fd).max() / np.abs(dV_fd).max()
    print(f"dV relative error vs finite diff: {rel_err:.2e}")
    assert rel_err < 1e-5, f"dV gradient check FAILED: {rel_err:.2e}"
    
    # Spot-check dQ and dK at a few random positions
    for name, grad, tensor in [('dQ', dQ, Q), ('dK', dK, K)]:
        b, h, i, j = np.random.randint(0, B), np.random.randint(0, H), \
                      np.random.randint(0, N), np.random.randint(0, D)
        tensor_plus = tensor.copy()
        tensor_minus = tensor.copy()
        tensor_plus[b, h, i, j] += eps
        tensor_minus[b, h, i, j] -= eps
        O_plus, _ = flash_attention_fwd(
            Q if name != 'dQ' else tensor_plus,
            K if name != 'dK' else tensor_plus, V, T, causal=True
        )
        O_minus, _ = flash_attention_fwd(
            Q if name != 'dQ' else tensor_minus,
            K if name != 'dK' else tensor_minus, V, T, causal=True
        )
        loss_plus = (dO * O_plus).sum()
        loss_minus = (dO * O_minus).sum()
        fd_val = (loss_plus - loss_minus) / (2 * eps)
        rel = abs(grad[b, h, i, j] - fd_val) / (abs(fd_val) + 1e-10)
        print(f"{name}[{b},{h},{i},{j}] rel error: {rel:.2e}")
        assert rel < 1e-5, f"{name} gradient check FAILED at [{b},{h},{i},{j}]: {rel:.2e}"
    
    print("Gradient check PASSED\n")

Why this will separate models

Aspect What good models do What weak models do
dsoftmax axis sum over last axis (keys) sum over wrong axis, or forget keepdims
dsoftmax formula P * (dP - rowsum(P*dP)) Forget to multiply by P, or use dO instead of dP
Stored intermediate Store L = m + log(l) for stable recomputation Store wrong intermediate, causing P recomputation errors
Tile accumulation Accumulate dQ, dK, dV ACROSS tiles Overwrite instead of accumulating
Causal mask in bwd Skip entirely masked Q tile × KV tile pairs Include masked tiles → incorrect dK from -inf scores
Memory Never materialize (N,N) in backward either Allocate (N,N) dS array
Gradient check Passes at 1e-5 Fails — the gradients "look right" but are wrong

Grading rubric

Check Weight What it catches
dV matches finite differences at 1e-5 30% Basic backward correctness
dQ spot-check matches finite diff at 1e-5 25% Correct dS and dQ accumulation
dK spot-check matches finite diff at 1e-5 25% Correct dS transpose and dK accumulation
Large N=4096 test: peak memory < N² 10% No full matrix materialized in backward
Causal masking handled correctly in bwd 10% Fully masked tile pairs are skipped