Files
sleepy 8e72eef09c feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
2026-04-23 11:16:01 +02:00

961 lines
44 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Layer Normalization from scratch in NumPy
==========================================
Numerically stable backward pass implementation with gradient checking.
Author: pi coding assistant
Date: 2026-04-22
"""
import numpy as np
from typing import Tuple, Dict, Optional
import time
import copy
# =============================================================================
# Numerical constants
# =============================================================================
DEFAULT_EPS = 1e-8
# =============================================================================
# Helper functions
# =============================================================================
def logsumexp(x: np.ndarray, axis: int = -1, keepdims: bool = True) -> np.ndarray:
"""Numerically stable log-sum-exp."""
max_x = np.max(x, axis=axis, keepdims=True)
return max_x + np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True))
# =============================================================================
# Layer Normalization Forward Pass
# =============================================================================
def layer_norm_forward(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
eps: float = DEFAULT_EPS
) -> Tuple[np.ndarray, Dict]:
"""
Layer Norm forward pass.
y = gamma * (x - mean) / sqrt(var + eps) + beta
Args:
x: Input tensor of shape (B, T, D)
gamma: Scale parameter of shape (D,)
beta: Bias parameter of shape (D,)
eps: Small constant for numerical stability
Returns:
y: Normalized output of shape (B, T, D)
cache: Dictionary of intermediates for backward pass
"""
B, T, D = x.shape
# Compute mean over feature dimension
# mean[b, t] = (1/D) * sum_d x[b, t, d]
mean = np.mean(x, axis=-1, keepdims=True) # (B, T, 1)
# Compute variance over feature dimension
# var[b, t] = (1/D) * sum_d (x[b, t, d] - mean[b, t])^2
x_centered = x - mean # (B, T, D)
var = np.mean(x_centered ** 2, axis=-1, keepdims=True) # (B, T, 1)
# Compute standard deviation with eps for numerical stability
# std >= sqrt(eps) > 0, preventing division by zero
std = np.sqrt(var + eps) # (B, T, 1)
# Normalize
x_norm = x_centered / std # (B, T, D)
# Scale and shift
y = gamma * x_norm + beta # (B, T, D)
# Cache intermediates for backward pass
# We store only what we need to avoid recomputation
cache = {
'x': x, # Original input (needed for gradient check)
'x_centered': x_centered,
'x_norm': x_norm, # Normalized values (needed for d_gamma)
'mean': mean, # (B, T, 1)
'var': var, # (B, T, 1)
'std': std, # (B, T, 1)
'gamma': gamma, # Needed for gradient computation
'beta': beta, # Needed for gradient check
'eps': eps,
'B': B,
'T': T,
'D': D
}
return y, cache
# =============================================================================
# Layer Normalization Backward Pass
# =============================================================================
def layer_norm_backward(
dy: np.ndarray,
cache: Dict
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Layer Norm backward pass.
Derivation:
-----------
Let:
μ = mean[x] (over axis=-1)
σ² = var[x] (over axis=-1)
σ = sqrt(σ² + eps)
x̄ = (x - μ) / σ (normalized)
y = γ * x̄ + β
Then:
∂L/∂γ = sum(dy * x̄) over (B, T)
∂L/∂β = sum(dy) over (B, T)
For ∂L/∂x:
∂L/∂x_i = (∂L/∂x̄_i / σ)
- (∂L/∂x̄_i / D) * (1/σ)
- (x̄_i / D) * (∂L/∂σ²) * (2/σ)
where ∂L/∂σ² = -0.5 * sum_i(∂L/∂x̄_i * x̄_i) / σ³
Consolidating:
∂L/∂x_i = (γ_i / σ) * [∂L/∂y_i
- mean(∂L/∂y)
- x̄_i * mean(∂L/∂y * x̄)]
Args:
dy: Upstream gradient of shape (B, T, D)
cache: Forward pass intermediates
Returns:
dx: Gradient w.r.t. input x, shape (B, T, D)
d_gamma: Gradient w.r.t. gamma, shape (D,)
d_beta: Gradient w.r.t. beta, shape (D,)
"""
x = cache['x']
x_centered = cache['x_centered']
x_norm = cache['x_norm']
mean = cache['mean']
std = cache['std']
gamma = cache['gamma']
eps = cache['eps']
B, T, D = cache['B'], cache['T'], cache['D']
# -------------------------------------------------------------------------
# 1. Compute gradients w.r.t. gamma and beta
# -------------------------------------------------------------------------
# d_gamma[d] = sum_{b,t} dy[b,t,d] * x_norm[b,t,d]
d_gamma = np.sum(dy * x_norm, axis=(0, 1)) # (D,)
# d_beta[d] = sum_{b,t} dy[b,t,d]
d_beta = np.sum(dy, axis=(0, 1)) # (D,)
# -------------------------------------------------------------------------
# 2. Compute gradient w.r.t. normalized input
# -------------------------------------------------------------------------
# dz = dy * gamma (chain rule: y = gamma * x_norm + beta)
# Note: We can compute this and reuse in dx computation
dz = dy * gamma # (B, T, D)
# -------------------------------------------------------------------------
# 3. Compute gradient w.r.t. x
# -------------------------------------------------------------------------
#
# From the derivation:
# dx = (dz - mean(dz) - x_norm * mean(dz * x_norm)) / std
#
# This comes from applying the chain rule considering:
# - Direct dependence of x on x_norm through (x - mean) / std
# - Indirect dependence through mean and std
#
# Key insight: We compute the two reduction terms efficiently:
# - mean(dz) = (1/D) * sum(dz, axis=-1, keepdims=True)
# - mean(dz * x_norm) = (1/D) * sum(dz * x_norm, axis=-1, keepdims=True)
#
# Compute reduction terms (these are O(BTD) each)
sum_dz = np.sum(dz, axis=-1, keepdims=True) # (B, T, 1)
sum_dz_xnorm = np.sum(dz * x_norm, axis=-1, keepdims=True) # (B, T, 1)
# Compute dx using the consolidated formula
# dx = (dz - (sum_dz / D) - x_norm * (sum_dz_xnorm / D)) / std
dx = (dz - sum_dz / D - x_norm * sum_dz_xnorm / D) / std
# -------------------------------------------------------------------------
# Numerical stability analysis:
# ---------------------------
# 1. Division by std: We use std = sqrt(var + eps), so std >= sqrt(eps) > 0
# Example: eps = 1e-8 => std >= 1e-4, so no division by zero
#
# 2. Division by D: D is typically 512-4096, so this is stable
#
# 3. The formula (dz - mean(dz) - x_norm * mean(dz * x_norm)) / std
# When std is very small, the gradient can be large, but this is
# mathematically correct - small std means large normalization effect
#
# 4. For extreme stability, we could use the two-pass formula:
# dx = (dz - mean(dz) - x_norm * mean(dz * x_norm))
# dx = dx / std
# This avoids any intermediate overflow/underflow in the subtraction
#
# 5. Alternative numerically stable computation using centering trick:
# temp = dz / std
# dx = temp - x_norm * (sum(temp * x_norm) / D)
# dx = dx - sum(dx) / D (but this is less efficient)
#
# 6. Catastrophic cancellation can occur in: (dz - mean(dz))
# When dz is roughly constant across D, mean(dz) ≈ dz, causing
# cancellation. However, this is exactly when dx should be small,
# so the cancellation is benign (relative error is small).
#
# 7. The x_norm * mean(dz * x_norm) term can also suffer from cancellation
# when mean(dz * x_norm) ≈ 0, but again this is when the term is small.
#
# 8. For fp16 or extreme cases, consider pairwise summation for reductions
# and/or higher precision accumulators.
#
# -------------------------------------------------------------------------
return dx, d_gamma, d_beta
# =============================================================================
# Layer Norm Module (combines forward and backward)
# =============================================================================
class LayerNorm:
"""
Layer Normalization module with manual gradient computation.
Forward: y = gamma * (x - mean) / sqrt(var + eps) + beta
Backward: Computes gradients w.r.t. x, gamma, beta
"""
def __init__(self, normalized_shape: int, eps: float = DEFAULT_EPS):
"""
Args:
normalized_shape: Dimension D of the feature space
eps: Epsilon for numerical stability in sqrt(var + eps)
"""
self.normalized_shape = normalized_shape
self.eps = eps
# Initialize gamma (scale) and beta (shift) parameters
# Xavier initialization for gamma to keep variance stable
self.gamma = np.ones(normalized_shape) # Scale initialized to 1
self.beta = np.zeros(normalized_shape) # Shift initialized to 0
# Storage for gradients
self.d_gamma = None
self.d_beta = None
def forward(self, x: np.ndarray) -> Tuple[np.ndarray, Dict]:
"""Forward pass."""
assert x.shape[-1] == self.normalized_shape, \
f"Expected last dimension {self.normalized_shape}, got {x.shape[-1]}"
return layer_norm_forward(x, self.gamma, self.beta, self.eps)
def backward(self, dy: np.ndarray, cache: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Backward pass.
Args:
dy: Upstream gradient of shape (B, T, D)
cache: Forward pass cache
Returns:
dx: Gradient w.r.t. input x
d_gamma: Gradient w.r.t. gamma
d_beta: Gradient w.r.t. beta
"""
dx, d_gamma, d_beta = layer_norm_backward(dy, cache)
self.d_gamma = d_gamma
self.d_beta = d_beta
return dx, d_gamma, d_beta
def parameters(self) -> Tuple[np.ndarray, np.ndarray]:
"""Return (gamma, beta)."""
return self.gamma, self.beta
def gradients(self) -> Tuple[np.ndarray, np.ndarray]:
"""Return (d_gamma, d_beta)."""
return self.d_gamma, self.d_beta
# =============================================================================
# Gradient Checking via Finite Differences
# =============================================================================
def compute_numerical_gradient_gamma(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
h: float = 1e-5
) -> np.ndarray:
"""
Compute numerical gradient for gamma using finite differences.
Args:
x: Input tensor (B, T, D)
gamma: Scale parameter (D,)
beta: Bias parameter (D,)
dy: Upstream gradient (B, T, D)
h: Step size
Returns:
Numerical gradient for gamma (D,)
"""
D = len(gamma)
num_grad = np.zeros(D)
for i in range(D):
# Save original value
original = gamma[i]
# f(gamma + h)
gamma[i] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(gamma - h)
gamma[i] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
# Central difference
num_grad[i] = (loss_plus - loss_minus) / (2 * h)
# Restore original
gamma[i] = original
return num_grad
def compute_numerical_gradient_beta(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
h: float = 1e-5
) -> np.ndarray:
"""
Compute numerical gradient for beta using finite differences.
Args:
x: Input tensor (B, T, D)
gamma: Scale parameter (D,)
beta: Bias parameter (D,)
dy: Upstream gradient (B, T, D)
h: Step size
Returns:
Numerical gradient for beta (D,)
"""
D = len(beta)
num_grad = np.zeros(D)
for i in range(D):
# Save original value
original = beta[i]
# f(beta + h)
beta[i] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(beta - h)
beta[i] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
# Central difference
num_grad[i] = (loss_plus - loss_minus) / (2 * h)
# Restore original
beta[i] = original
return num_grad
def compute_numerical_gradient_x(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
h: float = 1e-5,
max_elements: int = 100000
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute numerical gradient for x using finite differences.
For large tensors, uses spot check.
Returns both the numerical gradient AND the numerical gradient for spot-checked elements.
Args:
x: Input tensor (B, T, D) - will be restored to original values
gamma: Scale parameter (D,)
beta: Bias parameter (D,)
dy: Upstream gradient (B, T, D)
h: Step size
max_elements: Maximum elements to check (spot check if larger)
Returns:
Tuple of (num_grad, spot_check_mask) where spot_check_mask marks checked elements
"""
B, T, D = x.shape
total_elements = B * T * D
# Save original x values
orig_x = x.copy()
if total_elements <= max_elements:
# Full gradient check
num_grad = np.zeros_like(x)
# Use reshape to get a view, not a copy
x_flat = x.reshape(-1)
num_grad_flat = num_grad.reshape(-1)
for i in range(total_elements):
if (i + 1) % 10000 == 0:
print(f" Progress: {i+1}/{total_elements}")
original = x_flat[i]
# f(x + h)
x_flat[i] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(x - h)
x_flat[i] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
# Central difference
num_grad_flat[i] = (loss_plus - loss_minus) / (2 * h)
# Restore
x_flat[i] = original
# Restore x to original values
x[:] = orig_x
return num_grad, np.ones((B, T, D), dtype=bool) # All elements checked
else:
# Spot check
print(f" Spot checking {max_elements} random elements...")
n_samples = max_elements
num_grad = np.zeros_like(x)
spot_checked = np.zeros((B, T, D), dtype=bool)
indices = [tuple(np.random.randint(b) for b in (B, T, D)) for _ in range(n_samples)]
for idx in indices:
bi, ti, di = idx
original = x[bi, ti, di]
spot_checked[bi, ti, di] = True
# f(x + h)
x[bi, ti, di] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(x - h)
x[bi, ti, di] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
num_grad[bi, ti, di] = (loss_plus - loss_minus) / (2 * h)
# Restore
x[bi, ti, di] = original
# Restore x to original values
x[:] = orig_x
return num_grad, spot_checked
def gradient_check(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
rtol: float = 1e-4,
atol: float = 1e-5,
verbose: bool = True
) -> Dict[str, bool]:
"""
Perform gradient check for all parameters.
Checks: |analytical - numerical| <= atol + rtol * |numerical|
Args:
x: Input tensor
gamma: Scale parameter
beta: Bias parameter
dy: Upstream gradient
rtol: Relative tolerance
atol: Absolute tolerance
verbose: Print detailed results
Returns:
Dictionary of pass/fail for each parameter
"""
results = {}
# Store originals
orig_gamma = gamma.copy()
orig_beta = beta.copy()
orig_x = x.copy()
# -------------------------------------------------------------------------
# Forward pass to get analytical gradients
# -------------------------------------------------------------------------
y, cache = layer_norm_forward(x, gamma, beta)
dx_analytical, d_gamma_analytical, d_beta_analytical = layer_norm_backward(dy, cache)
# -------------------------------------------------------------------------
# Check gradient w.r.t. gamma
# -------------------------------------------------------------------------
if verbose:
print("\n" + "="*60)
print("GRADIENT CHECK: gamma")
print("="*60)
# Reset gamma to original
gamma[:] = orig_gamma
d_gamma_numerical = compute_numerical_gradient_gamma(x, gamma, beta, dy)
# Compare
diff = np.abs(d_gamma_analytical - d_gamma_numerical)
tolerance = atol + rtol * np.abs(d_gamma_numerical)
passed = np.all(diff <= tolerance)
if verbose:
print(f"Analytical gradient shape: {d_gamma_analytical.shape}")
print(f"Numerical gradient shape: {d_gamma_numerical.shape}")
print(f"Max absolute difference: {np.max(diff):.2e}")
print(f"Max relative tolerance: {np.max(tolerance):.2e}")
print(f"Mean analytical: {np.mean(np.abs(d_gamma_analytical)):.6e}")
print(f"Mean numerical: {np.mean(np.abs(d_gamma_numerical)):.6e}")
print(f"\nGradient check: {'PASSED ✓' if passed else 'FAILED ✗'}")
results['gamma'] = passed
# -------------------------------------------------------------------------
# Check gradient w.r.t. beta
# -------------------------------------------------------------------------
if verbose:
print("\n" + "="*60)
print("GRADIENT CHECK: beta")
print("="*60)
# Reset beta to original
beta[:] = orig_beta
d_beta_numerical = compute_numerical_gradient_beta(x, gamma, beta, dy)
diff = np.abs(d_beta_analytical - d_beta_numerical)
tolerance = atol + rtol * np.abs(d_beta_numerical)
passed = np.all(diff <= tolerance)
if verbose:
print(f"Analytical gradient shape: {d_beta_analytical.shape}")
print(f"Numerical gradient shape: {d_beta_numerical.shape}")
print(f"Max absolute difference: {np.max(diff):.2e}")
print(f"Max relative tolerance: {np.max(tolerance):.2e}")
print(f"Mean analytical: {np.mean(np.abs(d_beta_analytical)):.6e}")
print(f"Mean numerical: {np.mean(np.abs(d_beta_numerical)):.6e}")
print(f"\nGradient check: {'PASSED ✓' if passed else 'FAILED ✗'}")
results['beta'] = passed
# -------------------------------------------------------------------------
# Check gradient w.r.t. x
# -------------------------------------------------------------------------
if verbose:
print("\n" + "="*60)
print("GRADIENT CHECK: x (input)")
print("="*60)
# Reset x to original
x[:] = orig_x
d_x_numerical, spot_checked = compute_numerical_gradient_x(x, gamma, beta, dy)
# Only check elements that were numerically computed
diff = np.abs(dx_analytical[spot_checked] - d_x_numerical[spot_checked])
tolerance = atol + rtol * np.abs(d_x_numerical[spot_checked])
passed = np.all(diff <= tolerance)
if verbose:
print(f"Analytical gradient shape: {dx_analytical.shape}")
print(f"Numerical gradient shape: {d_x_numerical.shape}")
print(f"Elements checked: {np.sum(spot_checked)} / {spot_checked.size}")
if np.any(spot_checked):
print(f"Max absolute difference: {np.max(diff):.2e}")
print(f"Max relative tolerance: {np.max(tolerance):.2e}")
print(f"Mean analytical (checked): {np.mean(np.abs(dx_analytical[spot_checked])):.6e}")
print(f"Mean numerical (checked): {np.mean(np.abs(d_x_numerical[spot_checked])):.6e}")
print(f"\nGradient check: {'PASSED ✓' if passed else 'FAILED ✗'}")
results['x'] = passed
# Restore originals
gamma[:] = orig_gamma
beta[:] = orig_beta
x[:] = orig_x
return results
# =============================================================================
# Complexity Analysis
# =============================================================================
def analyze_complexity():
"""Print complexity analysis for layer norm forward and backward."""
print("""
╔══════════════════════════════════════════════════════════════════════════════╗
║ COMPLEXITY ANALYSIS ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Input: x ∈ ^(B×T×D) ║
║ Parameters: γ, β ∈ ^D ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ FORWARD PASS ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Operation │ Work (FLOPs) │ Memory ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ mean (reduction) │ O(BTD) │ - ║
║ x - mean (broadcast) │ O(BTD) │ O(BTD) ║
║ var (reduction) │ O(BTD) │ - ║
║ sqrt(var + eps) │ O(BT) │ - ║
║ divide by std │ O(BTD) │ - ║
║ gamma * x_norm │ O(BTD) │ - ║
║ add beta │ O(BTD) │ O(BTD) output ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ TOTAL │ 5×O(BTD) │ O(BTD) ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ BACKWARD PASS ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Operation │ Work (FLOPs) │ Memory ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ d_gamma = sum(dy*x_norm)│ O(BTD) │ O(D) ║
║ d_beta = sum(dy) │ O(BTD) │ O(D) ║
║ dz = dy * gamma │ O(BTD) │ O(BTD) (can be avoided) ║
║ sum(dz) │ O(BTD) │ - ║
║ sum(dz * x_norm) │ O(BTD) │ - ║
║ dx computation │ O(BTD) │ O(BTD) output ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ TOTAL │ 5×O(BTD) │ O(BTD) + O(D) ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ SUMMARY ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Time Complexity: O(BTD) for both forward and backward ║
║ Space Complexity: O(BTD) for storing activations (during training) ║
║ O(BTD) during inference (no need to store) ║
║ ║
║ Cache efficiency: We store x_centered, x_norm, mean, std ║
║ These are O(BTD) total, reused across all gradient comps ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
""")
# =============================================================================
# GPU Kernel Fusion Design
# =============================================================================
def explain_gpu_fusion():
"""Explain GPU kernel fusion for layer normalization."""
print("""
╔══════════════════════════════════════════════════════════════════════════════╗
║ GPU KERNEL FUSION DESIGN ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ CURRENT SEPARATE KERNEL APPROACH: ║
║ ──────────────────────────────────── ║
║ Kernel 1: Compute mean (reduction over D) ║
║ Kernel 2: Compute variance (reduction over D) ║
║ Kernel 3: Normalize (element-wise) ║
║ Kernel 4: Scale and shift (element-wise) ║
║ Kernel 5: Backward kernels (x, gamma, beta) ║
║ ║
║ Issues with separate kernels: ║
║ • Multiple kernel launches (overhead) ║
║ • Data movement between global memory passes ║
║ • Can't use persistent threads for reduction efficiency ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ FUSED KERNEL DESIGN (Forward): ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Grid: (B × T) blocks, each handling one (b,t) position ║
║ Block: 256-512 threads ║
║ ║
║ PHASE 1: Load and compute local sum ║
║ ───────────────────────────────── ║
║ • Each thread loads x[b,t,d] into shared memory ║
║ • Compute partial sum using warp-level reduction ║
║ • Single thread writes mean to __shared__ ║
║ ║
║ PHASE 2: Compute variance locally ║
║ ───────────────────────────────── ║
║ • Re-load x with loaded mean ║
║ • Compute (x-mean)² and partial variance ║
║ • Reduce to get variance ║
║ ║
║ PHASE 3: Normalize and output ║
║ ───────────────────────────────── ║
║ • All threads compute: y = gamma * (x-mean) / sqrt(var+eps) + beta ║
║ • Write to output (fully coalesced) ║
║ ║
║ MEMORY ACCESS PATTERN: ║
║ • Each block reads contiguous D elements (coalesced) ║
║ • Use shared memory for intermediate results ║
║ • Output writes are also coalesced ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ FUSED KERNEL DESIGN (Backward): ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Key insight: dz = dy * gamma can be merged into the computation ║
║ ║
║ Grid: (B × T) blocks ║
║ Block: 256-512 threads ║
║ ║
║ SHARED MEMORY STRUCTURE: ║
║ [x_norm_0, x_norm_1, ..., x_norm_{D-1}] ║
║ [dz_0, dz_1, ..., dz_{D-1}] ║
║ ║
║ ALGORITHM: ║
║ 1. Load x_norm and compute local dz = dy * gamma ║
║ 2. Reduce to get sum(dz) and sum(dz * x_norm) ║
║ 3. Second pass to compute dx using the formula: ║
║ dx = (dz - mean(dz) - x_norm * mean(dz*x_norm)) / std ║
║ ║
║ REDUCTION OPTIMIZATIONS: ║
║ • Warp-level shuffle reductions (no shared memory needed) ║
║ • Block-level using shared memory with tree reduction ║
║ • Use block-level primitives for final reduction ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ BENEFITS OF FUSION: ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ ✓ Reduced kernel launch overhead (1 kernel vs 4-5) ║
║ ✓ Better memory bandwidth utilization (single read, single write) ║
║ ✓ Improved cache locality (data stays in registers/shared mem) ║
║ ✓ Only loads x once, computes mean and var from same data ║
║ ✓ Backward can reuse cached values from forward (if memory allows) ║
║ ✓ Lower register pressure allows for larger block sizes ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ CUDA KERNEL SKETCH (Pseudo-code): ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ __global__ void layer_norm_fwd(float* y, const float* x, ║
║ const float* gamma, const float* beta, ║
║ int B, int T, int D, float eps) {
║ ║
║ __shared__ float mean_smem[256]; // block-level mean ║
║ __shared__ float var_smem[256]; // block-level variance ║
║ __shared__ float std_smem[256]; // block-level std ║
║ ║
║ int tid = threadIdx.x; ║
║ int bid = blockIdx.x; ║
║ int D_blk = (D + blockDim.x - 1) / blockDim.x; ║
║ ║
║ // Phase 1: Load and compute mean ║
║ float sum = 0.0; ║
║ for (int i = 0; i < D_blk; i++) {
║ int idx = bid * D + i * blockDim.x + tid; ║
║ sum += x[idx]; ║
║ } ║
║ sum = warpReduceSum(sum); ║
║ if (tid % 32 == 0) mean_smem[tid / 32] = sum; ║
║ __syncthreads(); ║
║ ║
║ float mean = mean_smem[0] / D; ║
║ ║
║ // Phase 2: Compute variance ║
║ sum = 0.0; ║
║ for (int i = 0; i < D_blk; i++) {
║ int idx = bid * D + i * blockDim.x + tid; ║
║ float diff = x[idx] - mean; ║
║ sum += diff * diff; ║
║ } ║
║ sum = warpReduceSum(sum); ║
║ if (tid % 32 == 0) var_smem[tid / 32] = sum; ║
║ __syncthreads(); ║
║ ║
║ float var = var_smem[0] / D; ║
║ float std = sqrt(var + eps); ║
║ ║
║ // Phase 3: Normalize and output ║
║ for (int i = 0; i < D_blk; i++) {
║ int idx = bid * D + i * blockDim.x + tid; ║
║ float x_norm = (x[idx] - mean) / std; ║
║ y[idx] = gamma[idx % D] * x_norm + beta[idx % D]; ║
║ } ║
║ } ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
""")
# =============================================================================
# Benchmark and Tests
# =============================================================================
def benchmark():
"""Benchmark forward and backward passes."""
print("\n" + "="*70)
print("BENCHMARKING LAYER NORMALIZATION")
print("="*70)
# Test different shapes
shapes = [
(32, 128, 256), # Small
(64, 128, 512), # Medium (BERT-base hidden)
(32, 512, 768), # BERT-base
(16, 512, 1024), # Larger
]
results = []
for B, T, D in shapes:
print(f"\nShape: (B={B}, T={T}, D={D})")
print("-" * 40)
# Create random inputs
np.random.seed(42)
x = np.random.randn(B, T, D).astype(np.float64)
gamma = np.random.randn(D).astype(np.float64)
beta = np.random.randn(D).astype(np.float64)
dy = np.random.randn(B, T, D).astype(np.float64)
# Forward benchmark
n_iters = 100
times = []
for _ in range(n_iters):
start = time.perf_counter()
y, cache = layer_norm_forward(x, gamma, beta)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
fwd_time = np.mean(times)
fwd_std = np.std(times)
# Backward benchmark
times = []
for _ in range(n_iters):
start = time.perf_counter()
dx, d_gamma, d_beta = layer_norm_backward(dy, cache)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
bwd_time = np.mean(times)
bwd_std = np.std(times)
# Throughput
elements = B * T * D
fwd_throughput = elements / (fwd_time / 1000) / 1e9 # GB/s
bwd_throughput = elements / (bwd_time / 1000) / 1e9
print(f"Forward: {fwd_time:.3f} ± {fwd_std:.3f} ms ({fwd_throughput:.1f} GB/s)")
print(f"Backward: {bwd_time:.3f} ± {bwd_std:.3f} ms ({bwd_throughput:.1f} GB/s)")
print(f"Total: {fwd_time + bwd_time:.3f} ms")
results.append({
'shape': (B, T, D),
'fwd_time': fwd_time,
'bwd_time': bwd_time,
'elements': elements
})
return results
def run_gradient_checks():
"""Run gradient checks on various shapes."""
print("\n" + "="*70)
print("RUNNING GRADIENT CHECKS")
print("="*70)
shapes = [
(2, 4, 8), # Tiny
(4, 8, 16), # Small
(8, 16, 32), # Medium-small
]
all_passed = True
for B, T, D in shapes:
print(f"\n{'='*60}")
print(f"Testing shape (B={B}, T={T}, D={D})")
print(f"{'='*60}")
# Create inputs
np.random.seed(123)
x = np.random.randn(B, T, D).astype(np.float64)
gamma = np.random.randn(D).astype(np.float64)
beta = np.random.randn(D).astype(np.float64)
dy = np.random.randn(B, T, D).astype(np.float64)
# Run gradient check
results = gradient_check(x, gamma, beta, dy, verbose=True)
for name, passed in results.items():
status = "✓ PASS" if passed else "✗ FAIL"
print(f" {name:>8}: {status}")
all_passed = all_passed and passed
print("\n" + "="*70)
if all_passed:
print("ALL GRADIENT CHECKS PASSED ✓")
else:
print("SOME GRADIENT CHECKS FAILED ✗")
print("="*70)
return all_passed
# =============================================================================
# Main
# =============================================================================
if __name__ == "__main__":
print("="*70)
print("LAYER NORMALIZATION: FORWARD AND BACKWARD FROM SCRATCH")
print("="*70)
# Print complexity analysis
analyze_complexity()
# Explain GPU fusion
explain_gpu_fusion()
# Run gradient checks
run_gradient_checks()
# Benchmark
benchmark()