Files
llm_programming_tests/qwen36/backwards/layer_norm_backward.py
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

528 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Numerically Stable Layer Normalization Backward Pass — From Scratch in NumPy
Forward:
μ = mean(x, axis=-1) # (B, T)
σ² = var(x, axis=-1) # (B, T)
x_hat = (x - μ) / sqrt(σ² + ε) # (B, T, D)
y = γ · x_hat + β # (B, T, D)
Backward (given ∂L/∂y ≡ dy of shape (B, T, D)):
dγ = sum(dy · x_hat, axis=(0,1)) # (D,)
dβ = sum(dy, axis=(0,1)) # (D,)
dx = (1/N) · (σ²+ε)^(-1/2) · [
N·dy
- sum(dy, axis=-1)
- x_hat · sum(dy·x_hat, axis=-1)
] # (B, T, D)
where N = D (feature dimension).
Derivation sketch (see comments in code for full detail):
The normalization map x ↦ x_hat is a projection onto the unit sphere
(per position). Its Jacobian has the form:
∂x_hat_i / ∂x_j = (1/σ) · (δ_ij - 1/N - x_hat_i · x_hat_j / N)
Contracting with dy gives the compact formula above.
Numerical stability notes:
1. Variance computation: use the two-pass (Welford-style) algorithm
instead of E[x²] - E[x]² to avoid catastrophic cancellation.
2. The backward formula reuses x_hat (already computed in forward),
avoiding recomputing (x - μ) / σ.
3. All divisions go through σ = sqrt(σ² + ε) with ε > 0, so no
division-by-zero.
4. The term (σ²+ε)^(-1/2) is precomputed once and broadcast.
"""
import numpy as np
# ---------------------------------------------------------------------------
# Forward pass
# ---------------------------------------------------------------------------
def layer_norm_forward(x, gamma, beta, eps=1e-5):
"""
Layer normalization forward pass.
Parameters
----------
x : (B, T, D) — input
gamma : (D,) — scale
beta : (D,) — shift
eps : float — numerical stability constant
Returns
-------
y : (B, T, D) — output
cache : dict — intermediates for backward
"""
B, T, D = x.shape
# --- mean (B, T) ---
mu = x.mean(axis=-1) # (B, T)
# --- variance via two-pass (numerically stable) ---
# Pass 1: centered values
x_centered = x - mu[..., np.newaxis] # (B, T, D)
# Pass 2: variance of centered values
var = np.mean(x_centered ** 2, axis=-1) # (B, T)
# --- normalization ---
std_inv = 1.0 / np.sqrt(var + eps) # (B, T)
x_hat = x_centered * std_inv[..., np.newaxis] # (B, T, D)
# --- affine ---
y = gamma[np.newaxis, np.newaxis, :] * x_hat + beta[np.newaxis, np.newaxis, :]
# Cache only what the backward pass needs — minimal memory footprint.
# The backward formula uses x_hat, std_inv, and gamma. Nothing else.
cache = {
"x_hat": x_hat, # (B, T, D) — normalized input
"std_inv": std_inv, # (B, T) — 1/sqrt(var + eps)
"gamma": gamma, # (D,) — scale parameter
"D": D, # scalar — feature dimension
}
return y, cache
# ---------------------------------------------------------------------------
# Backward pass — numerically stable, no redundant recomputation
# ---------------------------------------------------------------------------
def layer_norm_backward(dy, cache):
"""
Layer normalization backward pass.
Given dy = ∂L/∂y of shape (B, T, D), compute gradients w.r.t.
x, gamma, and beta.
The key insight for numerical stability is to express dx entirely in
terms of quantities already cached from the forward pass (x_hat,
std_inv), avoiding any recomputation of (x - μ) or sqrt(var + ε).
Parameters
----------
dy : (B, T, D) — upstream gradient
cache : dict — from forward pass
Returns
-------
dx : (B, T, D)
dgamma: (D,)
dbeta : (D,)
"""
x_hat = cache["x_hat"] # (B, T, D)
std_inv = cache["std_inv"] # (B, T)
gamma = cache["gamma"] # (D,)
D = cache["D"] # scalar
B, T, _ = dy.shape
# --- gradient w.r.t. gamma and beta (trivial) ---
dgamma = np.sum(dy * x_hat, axis=(0, 1)) # (D,)
dbeta = np.sum(dy, axis=(0, 1)) # (D,)
# --- gradient w.r.t. x (the non-trivial part) ---
#
# Full derivation:
# y = γ · x_hat + β
# ∂L/∂x_hat = γ · dy
#
# x_hat_i = (x_i - μ) / σ, where σ = sqrt(var + ε)
#
# ∂x_hat_i / ∂x_j = (1/σ) · (δ_ij - 1/D - x_hat_i · x_hat_j / D)
#
# Therefore:
# ∂L/∂x_j = Σ_i (∂L/∂x_hat_i) · ∂x_hat_i / ∂x_j
# = (1/σ) · [ Σ_i (γ·dy)_i · (δ_ij - 1/D - x_hat_i·x_hat_j/D) ]
# = (1/σ) · [ (γ·dy)_j - (1/D)·Σ_i(γ·dy)_i - x_hat_j·(1/D)·Σ_i(γ·dy)_i·x_hat_i ]
#
# Let g = γ · dy (elementwise)
# dx = (1/σ) · [ g - mean(g) - x_hat · mean(g · x_hat) ]
#
# This is the compact, numerically stable form. All terms are O(1) per
# element after the two reductions (mean over D).
g = gamma[np.newaxis, np.newaxis, :] * dy # (B, T, D)
# Two reductions over the feature dimension D
g_mean = g.mean(axis=-1, keepdims=True) # (B, T, 1)
gx_mean = (g * x_hat).mean(axis=-1, keepdims=True) # (B, T, 1)
# Combine — std_inv broadcasts from (B, T) to (B, T, D)
dx = std_inv[..., np.newaxis] * (g - g_mean - x_hat * gx_mean)
return dx, dgamma, dbeta
# ---------------------------------------------------------------------------
# Gradient check — finite differences
# ---------------------------------------------------------------------------
def numerical_gradient(f, param, delta=1e-5, **fixed_kwargs):
"""
Compute numerical gradient of scalar function f w.r.t. param using
central finite differences.
f should take param as its first positional argument and return a scalar.
"""
grad = np.zeros_like(param)
flat_param = param.ravel()
flat_grad = grad.ravel()
for i in range(len(flat_param)):
old_val = flat_param[i]
flat_param[i] = old_val + delta
f_plus = f(param.reshape(param.shape), **fixed_kwargs)
flat_param[i] = old_val - delta
f_minus = f(param.reshape(param.shape), **fixed_kwargs)
flat_grad[i] = (f_plus - f_minus) / (2 * delta)
flat_param[i] = old_val
return grad
def gradient_check(gamma, beta, x, eps=1e-5, delta=1e-5):
"""
Verify analytical gradients against finite-difference numerical gradients.
Returns a dict with relative errors for each parameter.
"""
# Random upstream gradient
dy = np.random.randn(*x.shape)
# --- Analytical gradients ---
y, cache = layer_norm_forward(x, gamma, beta, eps=eps)
dx_analytical, dgamma_analytical, dbeta_analytical = layer_norm_backward(dy, cache)
# --- Numerical gradients ---
def loss_wrt_x(x_arg):
y_arg, _ = layer_norm_forward(x_arg, gamma, beta, eps=eps)
return np.sum(y_arg * dy)
def loss_wrt_gamma(gamma_arg):
y_arg, _ = layer_norm_forward(x, gamma_arg, beta, eps=eps)
return np.sum(y_arg * dy)
def loss_wrt_beta(beta_arg):
y_arg, _ = layer_norm_forward(x, gamma_arg=gamma, beta_arg=beta_arg, eps=eps)
return np.sum(y_arg * dy)
# Fix the kwargs properly
def loss_x(x_arg):
y_arg, _ = layer_norm_forward(x_arg, gamma, beta, eps=eps)
return np.sum(y_arg * dy)
def loss_gamma(gamma_arg):
y_arg, _ = layer_norm_forward(x, gamma_arg, beta, eps=eps)
return np.sum(y_arg * dy)
def loss_beta(beta_arg):
y_arg, _ = layer_norm_forward(x, gamma, beta_arg, eps=eps)
return np.sum(y_arg * dy)
dx_numerical = numerical_gradient(loss_x, x, delta=delta)
dgamma_numerical = numerical_gradient(loss_gamma, gamma, delta=delta)
dbeta_numerical = numerical_gradient(loss_beta, beta, delta=delta)
# --- Relative errors ---
def rel_error(a, b):
denom = np.max(np.abs(a) + np.abs(b))
if denom < 1e-12:
return 0.0
return np.max(np.abs(a - b)) / denom
errors = {
"dx": rel_error(dx_analytical, dx_numerical),
"dgamma": rel_error(dgamma_analytical, dgamma_numerical),
"dbeta": rel_error(dbeta_analytical, dbeta_numerical),
}
return errors
# ---------------------------------------------------------------------------
# Complexity analysis
# ---------------------------------------------------------------------------
def print_complexity_analysis(B, T, D):
"""
Time and memory complexity of layer norm forward + backward.
Notation: N = B·T·D (total elements), D = feature dim.
FORWARD:
┌──────────────────────────────────────────────────────────────────┐
│ Operation │ FLOPs │ Memory (extra) │
├──────────────────────────────────────────────────────────────────┤
│ mean(x, axis=-1) │ N │ B·T │
│ x_centered = x - μ │ N │ B·T·D │
│ var = mean(x_centered²) │ 2N │ B·T │
│ std_inv = 1/sqrt(var+ε) │ B·T │ B·T │
│ x_hat = x_centered * σ⁻¹ │ N │ B·T·D │
│ y = γ·x_hat + β │ 2N │ B·T·D (output) │
├──────────────────────────────────────────────────────────────────┤
│ Total │ ~6N │ ~3·B·T·D │
└──────────────────────────────────────────────────────────────────┘
BACKWARD:
┌──────────────────────────────────────────────────────────────────┐
│ Operation │ FLOPs │ Memory (extra) │
├──────────────────────────────────────────────────────────────────┤
│ g = γ · dy │ N │ B·T·D │
│ g_mean = mean(g, axis=-1) │ N │ B·T │
│ gx_mean = mean(g·x_hat) │ 2N │ B·T │
│ dx = σ⁻¹·(g - g_mean - …) │ 3N │ B·T·D │
│ dgamma = sum(dy·x_hat) │ 2N │ D │
│ dbeta = sum(dy) │ N │ D │
├──────────────────────────────────────────────────────────────────┤
│ Total │ ~9N │ ~B·T·D │
└──────────────────────────────────────────────────────────────────┘
OVERALL:
Time: O(N) = O(B·T·D) — linear in total elements
Memory: O(B·T·D) — dominated by cached x_hat
KEY OBSERVATIONS:
• The backward pass is ~1.5× the forward pass in FLOPs.
• Memory is dominated by caching x_hat (B·T·D floats).
• The two-pass variance is O(N) extra FLOPs but essential for
numerical stability — the naive E[x²]-E[x]² formula can lose
15+ digits of precision when var ≪ mean².
"""
N = B * T * D
print(f"Complexity Analysis for B={B}, T={T}, D={D} (N={N:,} total elements)")
print(f" Forward FLOPs: ~{6*N:,}")
print(f" Backward FLOPs: ~{9*N:,}")
print(f" Total FLOPs: ~{15*N:,}")
print(f" Extra memory: ~{3*N * 4 / 1024 / 1024:.1f} MB (forward cache)")
print(f" Time complexity: O(B·T·D)")
print(f" Space complexity: O(B·T·D)")
# ---------------------------------------------------------------------------
# GPU kernel fusion discussion
# ---------------------------------------------------------------------------
GPU_FUSION_DISCUSSION = """
GPU KERNEL FUSION FOR LAYER NORM
=================================
1. FORWARD KERNEL (single kernel, no intermediate global memory writes):
Thread block: one block per (b, t) position, D threads per block.
Each thread handles one feature dimension d.
Pseudocode (CUDA-style):
```
__global__ void layer_norm_fwd(const float* __restrict__ x,
const float* __restrict__ gamma,
const float* __restrict__ beta,
float* __restrict__ y,
int B, int T, int D, float eps) {
int bt = blockIdx.x; // flattened (b, t)
int d = threadIdx.x; // feature dimension
int stride = gridDim.x;
// --- Parallel reduce: mean ---
float sum = 0.0f;
for (int i = d; i < D; i += blockDim.x)
sum += x[bt * D + i];
float mu = blockReduceSum(sum) / D;
// --- Parallel reduce: variance (two-pass) ---
float sum2 = 0.0f;
for (int i = d; i < D; i += blockDim.x) {
float diff = x[bt * D + i] - mu;
sum2 += diff * diff;
}
float var = blockReduceSum(sum2) / D;
float std_inv = rsqrtf(var + eps); // hardware reciprocal sqrt
// --- Write output ---
float x_hat = (x[bt * D + d] - mu) * std_inv;
y[bt * D + d] = gamma[d] * x_hat + beta[d];
// --- Cache x_hat for backward (write to pre-allocated buffer) ---
// This is the ONLY intermediate that must survive to backward.
// All other intermediates (mu, var, std_inv) are register-local.
}
```
Key fusion benefits:
• x is read ONCE from global memory (not twice as in separate mean/var).
• mu, var, std_inv live in registers/shared memory — zero global writes.
• x_hat is written once to the cache buffer.
• rsqrtf is a single hardware instruction on NVIDIA GPUs.
2. BACKWARD KERNEL (single kernel):
Thread block: one block per (b, t), D threads per block.
```
__global__ void layer_norm_bwd(const float* __restrict__ dy,
const float* __restrict__ x_hat,
const float* __restrict__ gamma,
float std_inv, // passed as param or loaded
float* __restrict__ dx,
float* __restrict__ dgamma,
float* __restrict__ dbeta,
int D) {
int bt = blockIdx.x;
int d = threadIdx.x;
float g = gamma[d] * dy[bt * D + d];
// --- Parallel reduce: mean(g) and mean(g * x_hat) ---
float g_sum = 0.0f, gx_sum = 0.0f;
for (int i = d; i < D; i += blockDim.x) {
g_sum += gamma[i] * dy[bt * D + i];
gx_sum += gamma[i] * dy[bt * D + i] * x_hat[bt * D + i];
}
float g_mean = blockReduceSum(g_sum) / D;
float gx_mean = blockReduceSum(gx_sum) / D;
// --- Compute dx ---
float x_hat_d = x_hat[bt * D + d];
dx[bt * D + d] = std_inv * (g - g_mean - x_hat_d * gx_mean);
// --- Atomic adds for dgamma, dbeta ---
float dy_d = dy[bt * D + d];
atomicAdd(&dgamma[bt * D_stride + d], dy_d * x_hat_d);
atomicAdd(&dbeta[bt * D_stride + d], dy_d);
}
```
Key fusion benefits:
• dy and x_hat are read ONCE each.
• The two reductions (g_mean, gx_mean) share the same loop — one pass.
• dx is computed and written in the same thread that computed g.
• dgamma/dbeta use atomicAdd (D is typically small enough that contention
is manageable; alternatively, use a two-phase reduce).
3. MEMORY TRAFFIC COMPARISON:
Naive (separate kernels):
Forward: read x (1×), write mu (1×), read x+mu (2×), write var (1×),
read x+mu+var (3×), write x_hat (1×), read x_hat+γ+β (3×),
write y (1×) → ~12 global memory accesses per element
Backward: similar explosion
Fused:
Forward: read x (1×), read γ+β (1×), write x_hat (1×), write y (1×)
→ 4 global memory accesses per element
Backward: read dy (1×), read x_hat (1×), read γ (1×), write dx (1×),
atomic dgamma+dbeta (1×) → 5 global memory accesses per element
The fused approach is ~2-3× faster in practice because memory bandwidth
is the bottleneck for layer norm (it's an O(N) algorithm with O(N) memory).
4. SHARED MEMORY OPTIMIZATION:
For small D (≤ 1024), load the entire (b,t) slice into shared memory:
```
__shared__ float s_x[1024], s_dy[1024], s_xhat[1024];
// Cooperative load
s_x[d] = x[bt * D + d];
__syncthreads();
// All subsequent ops use shared memory (L1-equivalent speed)
```
This cuts global memory reads from 3 to 1 per kernel launch.
5. TENSOR CORE / WARP LEVEL:
Layer norm doesn't benefit from tensor cores (no GEMM), but warp-level
primitives (__shfl_down_sync) can replace shared memory for the parallel
reductions when D ≤ 32, eliminating synchronization overhead entirely.
"""
# ---------------------------------------------------------------------------
# Main — run gradient check and analysis
# ---------------------------------------------------------------------------
def main():
np.random.seed(42)
B, T, D = 4, 8, 16
x = np.random.randn(B, T, D).astype(np.float64)
gamma = np.random.randn(D).astype(np.float64)
beta = np.random.randn(D).astype(np.float64)
print("=" * 70)
print("LAYER NORMALIZATION — BACKWARD PASS (NUMPY, FROM SCRATCH)")
print("=" * 70)
# --- Forward ---
y, cache = layer_norm_forward(x, gamma, beta)
print(f"\nForward: x({x.shape}) → y({y.shape})")
print(f" y[0,0,:4] = {y[0, 0, :4]}")
# --- Backward ---
dy = np.random.randn(B, T, D).astype(np.float64)
dx, dgamma, dbeta = layer_norm_backward(dy, cache)
print(f"\nBackward: dy({dy.shape}) → dx({dx.shape}), dγ({dgamma.shape}), dβ({dbeta.shape})")
# --- Gradient check ---
print("\n" + "-" * 70)
print("GRADIENT CHECK (central finite differences, δ=1e-5)")
print("-" * 70)
errors = gradient_check(gamma, beta, x)
for name, err in errors.items():
status = "✓ PASS" if err < 1e-6 else "✗ FAIL"
print(f" {name:8s} relative error: {err:.2e} {status}")
# --- Complexity ---
print("\n" + "-" * 70)
print("COMPLEXITY ANALYSIS")
print("-" * 70)
print_complexity_analysis(B, T, D)
# --- GPU fusion discussion ---
print("\n" + "-" * 70)
print("GPU KERNEL FUSION STRATEGY")
print("-" * 70)
print(GPU_FUSION_DISCUSSION)
# --- Numerical stability demo ---
print("\n" + "-" * 70)
print("NUMERICAL STABILITY DEMONSTRATION")
print("-" * 70)
print("""
Where instability can occur and how we handle it:
1. VARIANCE COMPUTATION
Problem: E[x²] - E[x]² loses precision when var ≪ mean²
(catastrophic cancellation). Example: x = [1e8, 1e8+1, 1e8+2]
Solution: TWO-PASS algorithm — center first, then compute variance.
This is what we do: x_centered = x - μ, then var = mean(x_centered²).
2. DIVISION BY ZERO
Problem: var could be exactly 0 (all features identical).
Solution: add ε (default 1e-5) inside sqrt: sqrt(var + ε).
This is standard and numerically safe.
3. BACKWARD OVERFLOW
Problem: if std_inv is very large (var ≈ 0), dx could overflow.
Solution: ε prevents std_inv from exceeding 1/sqrt(ε) ≈ 447.
With float32, this is well within range.
4. ACCUMULATION ERROR IN REDUCTIONS
Problem: summing D values can accumulate floating-point error.
Solution: NumPy uses pairwise summation (O(log D) error growth
instead of O(D)). For extreme cases, use Kahan summation.
5. RECOMPUTATION AVOIDANCE
Problem: naive backward recomputes (x - μ) and sqrt(var + ε).
Solution: cache x_hat and std_inv from forward. The backward
formula uses ONLY these cached values + dy — zero recomputation.
""")
if __name__ == "__main__":
main()