Files
llm_programming_tests/glm5/backwards/layer_norm.py
T
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

337 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Numerically stable forward & backward pass for Layer Normalization.
Forward:
y = gamma * (x - mean) / sqrt(var + eps) + beta
where mean and var are computed over the last dimension (D) independently
for each (b, t) position.
Reference derivation
--------------------
Let:
μ = (1/D) Σ_d x_d (mean over D)
σ² = (1/D) Σ_d (x_d - μ)² (variance over D)
σ̂ = sqrt(σ² + ε) (std with epsilon)
x̂ = (x - μ) / σ̂ (normalized x)
y = γ · x̂ + β (output)
Backward (upstream gradient dy arrives):
dβ = Σ_{b,t} dy (sum over batch & time)
dγ = Σ_{b,t} dy · x̂ (sum over batch & time)
For dx we chain through x̂, μ, σ̂:
dx̂ = dy · γ (element-wise)
dσ² = Σ_d [dx̂_d · (x_d - μ)] · (-½)(σ² + ε)^{-3/2}
= Σ_d [dx̂_d · (x_d - μ)] · (-1 / σ̂³)
dμ = Σ_d dx̂_d · (-1/σ̂) + dσ² · (-2/D) Σ_d (x_d - μ)
= Σ_d dx̂_d · (-1/σ̂) (second term = 0)
dx_d = dx̂_d / σ̂ + dσ² · (2/D)(x_d - μ) + dμ / D
After substitution and simplification (see analysis below):
dx = (1/σ̂) · [ dx̂ - (1/D)(x̂ · Σ_d dx̂_d · x̂_d + Σ_d dx̂_d) ]
Time complexity : O(B·T·D) — one pass over all elements
Memory complexity: O(B·T·D) for the output; intermediates σ̂ and x̂
are O(B·T) and O(B·T·D) respectively.
"""
import numpy as np
# ──────────────────────────────────────────────────────────────
# Forward pass — returns cache needed for backward
# ──────────────────────────────────────────────────────────────
def layer_norm_forward(x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
eps: float = 1e-5):
"""
x : (B, T, D)
gamma : (D,)
beta : (D,)
"""
# --- compute statistics over last dim ---
# keepdims=True avoids a broadcast copy later
mean = x.mean(axis=-1, keepdims=True) # (B, T, 1)
xc = x - mean # (B, T, D) centered x
# Numerical stability note: var is always >= 0 by construction
# because xc = x - mean, so (xc)**2 >= 0. The eps guards against
# the degenerate case where all elements in a row are identical.
var = (xc ** 2).mean(axis=-1, keepdims=True) # (B, T, 1)
# rsqrt is more stable than 1/sqrt for very small arguments because
# it avoids the intermediate sqrt → division round-off.
# NumPy has no native rsqrt, so we compute 1/sqrt carefully.
rstd = 1.0 / np.sqrt(var + eps) # (B, T, 1) reciprocal std
xhat = xc * rstd # (B, T, D) normalized
y = gamma * xhat + beta # (B, T, D)
# Cache everything needed for backward
cache = (xhat, rstd, gamma)
return y, cache
# ──────────────────────────────────────────────────────────────
# Backward pass — manually derived
# ──────────────────────────────────────────────────────────────
def layer_norm_backward(dy: np.ndarray, cache: tuple):
"""
dy : (B, T, D) upstream gradient
cache : (xhat, rstd, gamma) from forward
Returns: dx (B,T,D), dgamma (D,), dbeta (D,)
"""
xhat, rstd, gamma = cache
D = xhat.shape[-1]
# ── dgamma, dbeta ──────────────────────────────────────────
# Sum over all batch & time positions; pointwise over D.
dbeta = dy.sum(axis=(0, 1)) # (D,)
dgamma = (dy * xhat).sum(axis=(0, 1)) # (D,)
# ── dx ─────────────────────────────────────────────────────
# Chain through y = gamma * xhat + beta:
# dxhat = dy * gamma
dxhat = dy * gamma # (B, T, D)
# Direct implementation of the simplified formula:
#
# dx = (1/σ̂) [ dxhat - xhat · mean(dxhat · xhat) - mean(dxhat) ]
#
# where the means are over the D dimension.
#
# Derivation:
# dxhat_d / σ̂ + (dvar)(2/D)(x_d-μ) + dμ/D
# = dxhat_d/σ̂ - (x̂_d/D) Σ_j dxhat_j x̂_j - (1/D) Σ_j dxhat_j
# = (1/σ̂)[ dxhat_d - x̂_d · (1/D)Σ_j dxhat_j x̂_j - (1/D)Σ_j dxhat_j ]
#
# This avoids forming σ² + ε separately and reuses xhat directly.
# Inner products over D — these are O(B·T·D) but touch each element once
proj = (dxhat * xhat).sum(axis=-1, keepdims=True) # (B, T, 1)
dxhat_sum = dxhat.sum(axis=-1, keepdims=True) # (B, T, 1)
dx = rstd * (dxhat
- xhat * proj / D
- dxhat_sum / D) # (B, T, D)
return dx, dgamma, dbeta
# ──────────────────────────────────────────────────────────────
# Gradient check via finite differences
# ──────────────────────────────────────────────────────────────
def gradient_check(B=2, T=3, D=8, eps_fd=1e-5, tol=1e-4, seed=42):
"""Central finite-difference check on x, gamma, beta."""
rng = np.random.default_rng(seed)
x = rng.standard_normal((B, T, D))
gamma = rng.standard_normal(D) * 0.5 + 1.0
beta = rng.standard_normal(D) * 0.1
eps = 1e-5
# ── analytic backward ──
y, cache = layer_norm_forward(x, gamma, beta, eps=eps)
dy = rng.standard_normal(y.shape) # random upstream grad
dx, dgamma, dbeta = layer_norm_backward(dy, cache)
# ── helper: scalar loss = sum(dy * y) ──
def loss_fn(x_, g_, b_):
y_, _ = layer_norm_forward(x_, g_, b_, eps=eps)
return np.sum(dy * y_)
# ── finite-difference for x ──
dx_fd = np.zeros_like(x)
for idx in np.ndindex(x.shape):
x_plus = x.copy(); x_plus[idx] += eps_fd
x_minus = x.copy(); x_minus[idx] -= eps_fd
dx_fd[idx] = (loss_fn(x_plus, gamma, beta)
- loss_fn(x_minus, gamma, beta)) / (2 * eps_fd)
err_x = np.max(np.abs(dx - dx_fd))
rel_x = err_x / (np.max(np.abs(dx)) + 1e-12)
# ── finite-difference for gamma ──
dg_fd = np.zeros_like(gamma)
for i in range(D):
g_plus = gamma.copy(); g_plus[i] += eps_fd
g_minus = gamma.copy(); g_minus[i] -= eps_fd
dg_fd[i] = (loss_fn(x, g_plus, beta)
- loss_fn(x, g_minus, beta)) / (2 * eps_fd)
err_g = np.max(np.abs(dgamma - dg_fd))
rel_g = err_g / (np.max(np.abs(dgamma)) + 1e-12)
# ── finite-difference for beta ──
db_fd = np.zeros_like(beta)
for i in range(D):
b_plus = beta.copy(); b_plus[i] += eps_fd
b_minus = beta.copy(); b_minus[i] -= eps_fd
db_fd[i] = (loss_fn(x, gamma, b_plus)
- loss_fn(x, gamma, b_minus)) / (2 * eps_fd)
err_b = np.max(np.abs(dbeta - db_fd))
rel_b = err_b / (np.max(np.abs(dbeta)) + 1e-12)
# ── report ──
print("=" * 60)
print("Gradient check (central finite differences, h={})".format(eps_fd))
print("=" * 60)
for name, abs_err, rel_err in [("dx", err_x, rel_x),
("dgamma", err_g, rel_g),
("dbeta", err_b, rel_b)]:
ok = "PASS" if rel_err < tol else "FAIL"
print(f" {name:>6s}: max|err| = {abs_err:.2e} "
f"rel = {rel_err:.2e} [{ok}]")
print("=" * 60)
# ──────────────────────────────────────────────────────────────
# Complexity analysis & GPU fusion discussion
# ──────────────────────────────────────────────────────────────
def main():
B, T, D = 4, 16, 64
rng = np.random.default_rng(0)
x = rng.standard_normal((B, T, D)).astype(np.float64)
gamma = rng.standard_normal(D).astype(np.float64) * 0.5 + 1.0
beta = rng.standard_normal(D).astype(np.float64) * 0.1
y, cache = layer_norm_forward(x, gamma, beta)
dy = rng.standard_normal(y.shape)
dx, dg, db = layer_norm_backward(dy, cache)
print(f"Forward output shape : {y.shape}")
print(f"Backward dx shape : {dx.shape}")
print(f"Backward dgamma shape : {dg.shape}")
print(f"Backward dbeta shape : {db.shape}")
print()
# Gradient check
gradient_check(B=3, T=5, D=32)
print()
print_complexity_and_fusion(B, T, D)
def print_complexity_and_fusion(B, T, D):
N = B * T # number of independent normalizations
M = N * D # total elements
print("" * 60)
print("COMPLEXITY ANALYSIS")
print("" * 60)
print(f" Problem size: B={B}, T={T}, D={D}{N} vectors of dim {D}")
print()
print(" Forward pass:")
print(f" • mean : O(M) ({N} reductions of size {D})")
print(f" • var : O(M)")
print(f" • rstd : O(N) (one rsqrt per vector)")
print(f" • xhat : O(M)")
print(f" • y : O(M)")
print(f" Total time : O(M) = O(B·T·D)")
print(f" Extra memory: O(M) for xhat + O(N) for rstd")
print()
print(" Backward pass:")
print(f" • dbeta : O(M) (sum reduction)")
print(f" • dgamma : O(M) (sum reduction)")
print(f" • dx : O(M) (two D-wide reductions + elementwise)")
print(f" Total time : O(M) = O(B·T·D)")
print(f" Extra memory: O(M) for dxhat (can be fused in-place)")
print()
print("" * 60)
print("NUMERICAL STABILITY DISCUSSION")
print("" * 60)
print("""
1. Division by near-zero σ̂:
When all elements in a vector are identical, var = 0 and σ̂ = √ε.
The epsilon (typically 1e-5) prevents division by zero. Using
double precision (float64) for the gradient check gives ~1e-10
residual; in float32 the residual is ~1e-4, which is acceptable.
2. Catastrophic cancellation in xc = x - mean:
If x values are large but close together (e.g., x ≈ 1e6 with
σ ≈ 1e-3), then xc = x - mean loses relative precision in float32.
Remedy: the two-pass algorithm (compute mean first, then centered
sum of squares) is already used here, which is the standard
approach. For extreme cases, a compensated (Kahan) summation
or Welford's online algorithm can be used.
3. Overflow in xc² or var:
For very large values, squaring xc can overflow float16 or float32.
The standard fix is to compute in float32 for float16 inputs, or
use a scaled variant.
4. Gradient explosion when σ̂ is very small:
dx ∝ 1/σ̂, so tiny variance → large gradients. This is inherent
to the operation and is typically handled by gradient clipping
upstream. The epsilon also bounds 1/σ̂ ≤ 1/√ε.
5. rstd computation:
We compute 1/sqrt(var + eps) directly rather than forming
sqrt(var + eps) and then dividing. On GPU, the rsqrt instruction
is a single hardware instruction with correct rounding.
""")
print("" * 60)
print("GPU FUSION STRATEGY")
print("" * 60)
print("""
Goal: Fuse the entire backward pass into a single CUDA kernel that
loads each (B,T) row exactly once from global memory.
Kernel design (one thread-block per row of length D):
───────────────────────────────────────────────────────────────
Shared memory (per block, size ≈ 3·D·4 bytes for float32):
smem_xhat[D] — the normalized input
smem_dxhat[D] — dy * gamma
smem_proj[1] — scalar Σ dxhat_d · xhat_d
smem_sum[1] — scalar Σ dxhat_d
Steps inside the kernel (no globalmem round-trips between steps):
1. Each thread loads one (or more) element(s) of dy and xhat.
Compute dxhat_d = dy_d * gamma_d (gamma in constant mem or smem).
Store dxhat_d and xhat_d into shared memory.
2. Cooperative reduction across the D threads of the block:
proj += dxhat_d * xhat_d
sum += dxhat_d
Two warp-level reductions (or one Blelloch scan) give us the
two scalars in O(log D) steps.
3. Each thread computes:
dx_d = rstd * (dxhat_d - xhat_d * proj / D - sum / D)
and writes the result to global memory.
4. Atomic adds to global dgamma[d] += dy_d * xhat_d
dbeta[d] += dy_d
(one per element; can be deferred to a second pass or done
with block-level reduction + single atomic per block).
Memory traffic per row:
Reads : dy (D) + xhat (D) + rstd (1) = 2D + 1 elements
Writes: dx (D) + dgamma accumulator + dbeta accumulator
Total : ≈ 3D elements vs. ≈ 10D+ for an unfused implementation
(which would read/write intermediates to global memory
between each of the 45 separate kernel launches).
Additional optimizations:
• Use warp-level shuffles (__shfl_down_sync) instead of shared
memory for the reductions when D ≤ 32 (or D ≤ warpSize).
• Vectorized loads (float4 / float2) to improve memory throughput.
• For D values that don't divide evenly into warpSize, use a
grid-stride loop with cooperative groups.
• Fuse with the preceding or following elementwise op (GELU,
residual add) to eliminate another global memory round-trip.
""")
if __name__ == "__main__":
main()