feat: add model comparisons and sanitize session files

- Rename gamma to glm5 and model to minimax-m2.7
- Add model_comparison/ directory with head-to-head analyses
- Sanitize all session.jsonl files: remove absolute paths and usernames
- Remove __pycache__ artifacts
- Add .gitignore
This commit is contained in:
2026-04-23 11:16:01 +02:00
commit 8e72eef09c
62 changed files with 18469 additions and 0 deletions
+42
View File
@@ -0,0 +1,42 @@
All gradient checks pass. The implementation is complete. Here's a summary:
## Summary
**File:** `./minimax-m2.7/backwards/layer_norm_numpy.py`
### Key Components
1. **Forward Pass** (`layer_norm_forward`):
- Computes mean, variance, and standard deviation over the feature dimension (D)
- Normalizes: `x_norm = (x - mean) / sqrt(var + eps)`
- Scales and shifts: `y = gamma * x_norm + beta`
- Caches intermediates: `x_centered`, `x_norm`, `mean`, `std`, `gamma`
2. **Backward Pass** (`layer_norm_backward`):
- **d_gamma**: `sum(dy * x_norm)` over (B, T)
- **d_beta**: `sum(dy)` over (B, T)
- **dx**: `(dz - mean(dz) - x_norm * mean(dz * x_norm)) / std`
Where `dz = dy * gamma`. The consolidated formula avoids computing the full Jacobian.
3. **Gradient Check**:
- Uses central finite differences: `(f(x+h) - f(x-h)) / 2h`
- Fixed array flattening to use views instead of copies
- Spot-check for large tensors (>100k elements)
4. **Numerical Stability**:
- `eps` prevents division by zero: `std = sqrt(var + eps)`
- Catastrophic cancellation analysis documented
- Pairwise summation suggestion for fp16
5. **Complexity**:
- **Time**: O(BTD) for both forward and backward
- **Space**: O(BTD) for activations (training), O(1) for inference
6. **GPU Fusion Design**:
- Grid of `(B × T)` blocks, each handling one `(b,t)` position
- Three phases: mean reduction, variance reduction, normalize+output
- Warp-level shuffle reductions for efficiency
- Single kernel replaces 4-5 separate kernels
+17
View File
@@ -0,0 +1,17 @@
Implement a numerically stable backward pass for layer normalization from scratch in NumPy.
Constraints:
- Input: x of shape (B, T, D)
- Parameters: gamma, beta of shape (D,)
- Forward:
y = gamma * (x - mean) / sqrt(var + eps) + beta
Requirements:
1. Derive and implement gradients w.r.t. x, gamma, beta manually (no autodiff).
2. Avoid redundant recomputation — reuse intermediates where possible.
3. Ensure numerical stability (discuss where instability can occur).
4. Provide a gradient check using finite differences.
5. Analyze time and memory complexity.
6. Explain how you would fuse this into a single kernel for GPU execution.
Do not use PyTorch, TensorFlow, JAX, or autograd.
+961
View File
@@ -0,0 +1,961 @@
"""
Layer Normalization from scratch in NumPy
==========================================
Numerically stable backward pass implementation with gradient checking.
Author: pi coding assistant
Date: 2026-04-22
"""
import numpy as np
from typing import Tuple, Dict, Optional
import time
import copy
# =============================================================================
# Numerical constants
# =============================================================================
DEFAULT_EPS = 1e-8
# =============================================================================
# Helper functions
# =============================================================================
def logsumexp(x: np.ndarray, axis: int = -1, keepdims: bool = True) -> np.ndarray:
"""Numerically stable log-sum-exp."""
max_x = np.max(x, axis=axis, keepdims=True)
return max_x + np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True))
# =============================================================================
# Layer Normalization Forward Pass
# =============================================================================
def layer_norm_forward(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
eps: float = DEFAULT_EPS
) -> Tuple[np.ndarray, Dict]:
"""
Layer Norm forward pass.
y = gamma * (x - mean) / sqrt(var + eps) + beta
Args:
x: Input tensor of shape (B, T, D)
gamma: Scale parameter of shape (D,)
beta: Bias parameter of shape (D,)
eps: Small constant for numerical stability
Returns:
y: Normalized output of shape (B, T, D)
cache: Dictionary of intermediates for backward pass
"""
B, T, D = x.shape
# Compute mean over feature dimension
# mean[b, t] = (1/D) * sum_d x[b, t, d]
mean = np.mean(x, axis=-1, keepdims=True) # (B, T, 1)
# Compute variance over feature dimension
# var[b, t] = (1/D) * sum_d (x[b, t, d] - mean[b, t])^2
x_centered = x - mean # (B, T, D)
var = np.mean(x_centered ** 2, axis=-1, keepdims=True) # (B, T, 1)
# Compute standard deviation with eps for numerical stability
# std >= sqrt(eps) > 0, preventing division by zero
std = np.sqrt(var + eps) # (B, T, 1)
# Normalize
x_norm = x_centered / std # (B, T, D)
# Scale and shift
y = gamma * x_norm + beta # (B, T, D)
# Cache intermediates for backward pass
# We store only what we need to avoid recomputation
cache = {
'x': x, # Original input (needed for gradient check)
'x_centered': x_centered,
'x_norm': x_norm, # Normalized values (needed for d_gamma)
'mean': mean, # (B, T, 1)
'var': var, # (B, T, 1)
'std': std, # (B, T, 1)
'gamma': gamma, # Needed for gradient computation
'beta': beta, # Needed for gradient check
'eps': eps,
'B': B,
'T': T,
'D': D
}
return y, cache
# =============================================================================
# Layer Normalization Backward Pass
# =============================================================================
def layer_norm_backward(
dy: np.ndarray,
cache: Dict
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Layer Norm backward pass.
Derivation:
-----------
Let:
μ = mean[x] (over axis=-1)
σ² = var[x] (over axis=-1)
σ = sqrt(σ² + eps)
x̄ = (x - μ) / σ (normalized)
y = γ * x̄ + β
Then:
∂L/∂γ = sum(dy * x̄) over (B, T)
∂L/∂β = sum(dy) over (B, T)
For ∂L/∂x:
∂L/∂x_i = (∂L/∂x̄_i / σ)
- (∂L/∂x̄_i / D) * (1/σ)
- (x̄_i / D) * (∂L/∂σ²) * (2/σ)
where ∂L/∂σ² = -0.5 * sum_i(∂L/∂x̄_i * x̄_i) / σ³
Consolidating:
∂L/∂x_i = (γ_i / σ) * [∂L/∂y_i
- mean(∂L/∂y)
- x̄_i * mean(∂L/∂y * x̄)]
Args:
dy: Upstream gradient of shape (B, T, D)
cache: Forward pass intermediates
Returns:
dx: Gradient w.r.t. input x, shape (B, T, D)
d_gamma: Gradient w.r.t. gamma, shape (D,)
d_beta: Gradient w.r.t. beta, shape (D,)
"""
x = cache['x']
x_centered = cache['x_centered']
x_norm = cache['x_norm']
mean = cache['mean']
std = cache['std']
gamma = cache['gamma']
eps = cache['eps']
B, T, D = cache['B'], cache['T'], cache['D']
# -------------------------------------------------------------------------
# 1. Compute gradients w.r.t. gamma and beta
# -------------------------------------------------------------------------
# d_gamma[d] = sum_{b,t} dy[b,t,d] * x_norm[b,t,d]
d_gamma = np.sum(dy * x_norm, axis=(0, 1)) # (D,)
# d_beta[d] = sum_{b,t} dy[b,t,d]
d_beta = np.sum(dy, axis=(0, 1)) # (D,)
# -------------------------------------------------------------------------
# 2. Compute gradient w.r.t. normalized input
# -------------------------------------------------------------------------
# dz = dy * gamma (chain rule: y = gamma * x_norm + beta)
# Note: We can compute this and reuse in dx computation
dz = dy * gamma # (B, T, D)
# -------------------------------------------------------------------------
# 3. Compute gradient w.r.t. x
# -------------------------------------------------------------------------
#
# From the derivation:
# dx = (dz - mean(dz) - x_norm * mean(dz * x_norm)) / std
#
# This comes from applying the chain rule considering:
# - Direct dependence of x on x_norm through (x - mean) / std
# - Indirect dependence through mean and std
#
# Key insight: We compute the two reduction terms efficiently:
# - mean(dz) = (1/D) * sum(dz, axis=-1, keepdims=True)
# - mean(dz * x_norm) = (1/D) * sum(dz * x_norm, axis=-1, keepdims=True)
#
# Compute reduction terms (these are O(BTD) each)
sum_dz = np.sum(dz, axis=-1, keepdims=True) # (B, T, 1)
sum_dz_xnorm = np.sum(dz * x_norm, axis=-1, keepdims=True) # (B, T, 1)
# Compute dx using the consolidated formula
# dx = (dz - (sum_dz / D) - x_norm * (sum_dz_xnorm / D)) / std
dx = (dz - sum_dz / D - x_norm * sum_dz_xnorm / D) / std
# -------------------------------------------------------------------------
# Numerical stability analysis:
# ---------------------------
# 1. Division by std: We use std = sqrt(var + eps), so std >= sqrt(eps) > 0
# Example: eps = 1e-8 => std >= 1e-4, so no division by zero
#
# 2. Division by D: D is typically 512-4096, so this is stable
#
# 3. The formula (dz - mean(dz) - x_norm * mean(dz * x_norm)) / std
# When std is very small, the gradient can be large, but this is
# mathematically correct - small std means large normalization effect
#
# 4. For extreme stability, we could use the two-pass formula:
# dx = (dz - mean(dz) - x_norm * mean(dz * x_norm))
# dx = dx / std
# This avoids any intermediate overflow/underflow in the subtraction
#
# 5. Alternative numerically stable computation using centering trick:
# temp = dz / std
# dx = temp - x_norm * (sum(temp * x_norm) / D)
# dx = dx - sum(dx) / D (but this is less efficient)
#
# 6. Catastrophic cancellation can occur in: (dz - mean(dz))
# When dz is roughly constant across D, mean(dz) ≈ dz, causing
# cancellation. However, this is exactly when dx should be small,
# so the cancellation is benign (relative error is small).
#
# 7. The x_norm * mean(dz * x_norm) term can also suffer from cancellation
# when mean(dz * x_norm) ≈ 0, but again this is when the term is small.
#
# 8. For fp16 or extreme cases, consider pairwise summation for reductions
# and/or higher precision accumulators.
#
# -------------------------------------------------------------------------
return dx, d_gamma, d_beta
# =============================================================================
# Layer Norm Module (combines forward and backward)
# =============================================================================
class LayerNorm:
"""
Layer Normalization module with manual gradient computation.
Forward: y = gamma * (x - mean) / sqrt(var + eps) + beta
Backward: Computes gradients w.r.t. x, gamma, beta
"""
def __init__(self, normalized_shape: int, eps: float = DEFAULT_EPS):
"""
Args:
normalized_shape: Dimension D of the feature space
eps: Epsilon for numerical stability in sqrt(var + eps)
"""
self.normalized_shape = normalized_shape
self.eps = eps
# Initialize gamma (scale) and beta (shift) parameters
# Xavier initialization for gamma to keep variance stable
self.gamma = np.ones(normalized_shape) # Scale initialized to 1
self.beta = np.zeros(normalized_shape) # Shift initialized to 0
# Storage for gradients
self.d_gamma = None
self.d_beta = None
def forward(self, x: np.ndarray) -> Tuple[np.ndarray, Dict]:
"""Forward pass."""
assert x.shape[-1] == self.normalized_shape, \
f"Expected last dimension {self.normalized_shape}, got {x.shape[-1]}"
return layer_norm_forward(x, self.gamma, self.beta, self.eps)
def backward(self, dy: np.ndarray, cache: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Backward pass.
Args:
dy: Upstream gradient of shape (B, T, D)
cache: Forward pass cache
Returns:
dx: Gradient w.r.t. input x
d_gamma: Gradient w.r.t. gamma
d_beta: Gradient w.r.t. beta
"""
dx, d_gamma, d_beta = layer_norm_backward(dy, cache)
self.d_gamma = d_gamma
self.d_beta = d_beta
return dx, d_gamma, d_beta
def parameters(self) -> Tuple[np.ndarray, np.ndarray]:
"""Return (gamma, beta)."""
return self.gamma, self.beta
def gradients(self) -> Tuple[np.ndarray, np.ndarray]:
"""Return (d_gamma, d_beta)."""
return self.d_gamma, self.d_beta
# =============================================================================
# Gradient Checking via Finite Differences
# =============================================================================
def compute_numerical_gradient_gamma(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
h: float = 1e-5
) -> np.ndarray:
"""
Compute numerical gradient for gamma using finite differences.
Args:
x: Input tensor (B, T, D)
gamma: Scale parameter (D,)
beta: Bias parameter (D,)
dy: Upstream gradient (B, T, D)
h: Step size
Returns:
Numerical gradient for gamma (D,)
"""
D = len(gamma)
num_grad = np.zeros(D)
for i in range(D):
# Save original value
original = gamma[i]
# f(gamma + h)
gamma[i] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(gamma - h)
gamma[i] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
# Central difference
num_grad[i] = (loss_plus - loss_minus) / (2 * h)
# Restore original
gamma[i] = original
return num_grad
def compute_numerical_gradient_beta(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
h: float = 1e-5
) -> np.ndarray:
"""
Compute numerical gradient for beta using finite differences.
Args:
x: Input tensor (B, T, D)
gamma: Scale parameter (D,)
beta: Bias parameter (D,)
dy: Upstream gradient (B, T, D)
h: Step size
Returns:
Numerical gradient for beta (D,)
"""
D = len(beta)
num_grad = np.zeros(D)
for i in range(D):
# Save original value
original = beta[i]
# f(beta + h)
beta[i] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(beta - h)
beta[i] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
# Central difference
num_grad[i] = (loss_plus - loss_minus) / (2 * h)
# Restore original
beta[i] = original
return num_grad
def compute_numerical_gradient_x(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
h: float = 1e-5,
max_elements: int = 100000
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute numerical gradient for x using finite differences.
For large tensors, uses spot check.
Returns both the numerical gradient AND the numerical gradient for spot-checked elements.
Args:
x: Input tensor (B, T, D) - will be restored to original values
gamma: Scale parameter (D,)
beta: Bias parameter (D,)
dy: Upstream gradient (B, T, D)
h: Step size
max_elements: Maximum elements to check (spot check if larger)
Returns:
Tuple of (num_grad, spot_check_mask) where spot_check_mask marks checked elements
"""
B, T, D = x.shape
total_elements = B * T * D
# Save original x values
orig_x = x.copy()
if total_elements <= max_elements:
# Full gradient check
num_grad = np.zeros_like(x)
# Use reshape to get a view, not a copy
x_flat = x.reshape(-1)
num_grad_flat = num_grad.reshape(-1)
for i in range(total_elements):
if (i + 1) % 10000 == 0:
print(f" Progress: {i+1}/{total_elements}")
original = x_flat[i]
# f(x + h)
x_flat[i] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(x - h)
x_flat[i] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
# Central difference
num_grad_flat[i] = (loss_plus - loss_minus) / (2 * h)
# Restore
x_flat[i] = original
# Restore x to original values
x[:] = orig_x
return num_grad, np.ones((B, T, D), dtype=bool) # All elements checked
else:
# Spot check
print(f" Spot checking {max_elements} random elements...")
n_samples = max_elements
num_grad = np.zeros_like(x)
spot_checked = np.zeros((B, T, D), dtype=bool)
indices = [tuple(np.random.randint(b) for b in (B, T, D)) for _ in range(n_samples)]
for idx in indices:
bi, ti, di = idx
original = x[bi, ti, di]
spot_checked[bi, ti, di] = True
# f(x + h)
x[bi, ti, di] = original + h
y_plus, _ = layer_norm_forward(x, gamma, beta)
loss_plus = np.sum(y_plus * dy)
# f(x - h)
x[bi, ti, di] = original - h
y_minus, _ = layer_norm_forward(x, gamma, beta)
loss_minus = np.sum(y_minus * dy)
num_grad[bi, ti, di] = (loss_plus - loss_minus) / (2 * h)
# Restore
x[bi, ti, di] = original
# Restore x to original values
x[:] = orig_x
return num_grad, spot_checked
def gradient_check(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
dy: np.ndarray,
rtol: float = 1e-4,
atol: float = 1e-5,
verbose: bool = True
) -> Dict[str, bool]:
"""
Perform gradient check for all parameters.
Checks: |analytical - numerical| <= atol + rtol * |numerical|
Args:
x: Input tensor
gamma: Scale parameter
beta: Bias parameter
dy: Upstream gradient
rtol: Relative tolerance
atol: Absolute tolerance
verbose: Print detailed results
Returns:
Dictionary of pass/fail for each parameter
"""
results = {}
# Store originals
orig_gamma = gamma.copy()
orig_beta = beta.copy()
orig_x = x.copy()
# -------------------------------------------------------------------------
# Forward pass to get analytical gradients
# -------------------------------------------------------------------------
y, cache = layer_norm_forward(x, gamma, beta)
dx_analytical, d_gamma_analytical, d_beta_analytical = layer_norm_backward(dy, cache)
# -------------------------------------------------------------------------
# Check gradient w.r.t. gamma
# -------------------------------------------------------------------------
if verbose:
print("\n" + "="*60)
print("GRADIENT CHECK: gamma")
print("="*60)
# Reset gamma to original
gamma[:] = orig_gamma
d_gamma_numerical = compute_numerical_gradient_gamma(x, gamma, beta, dy)
# Compare
diff = np.abs(d_gamma_analytical - d_gamma_numerical)
tolerance = atol + rtol * np.abs(d_gamma_numerical)
passed = np.all(diff <= tolerance)
if verbose:
print(f"Analytical gradient shape: {d_gamma_analytical.shape}")
print(f"Numerical gradient shape: {d_gamma_numerical.shape}")
print(f"Max absolute difference: {np.max(diff):.2e}")
print(f"Max relative tolerance: {np.max(tolerance):.2e}")
print(f"Mean analytical: {np.mean(np.abs(d_gamma_analytical)):.6e}")
print(f"Mean numerical: {np.mean(np.abs(d_gamma_numerical)):.6e}")
print(f"\nGradient check: {'PASSED ✓' if passed else 'FAILED ✗'}")
results['gamma'] = passed
# -------------------------------------------------------------------------
# Check gradient w.r.t. beta
# -------------------------------------------------------------------------
if verbose:
print("\n" + "="*60)
print("GRADIENT CHECK: beta")
print("="*60)
# Reset beta to original
beta[:] = orig_beta
d_beta_numerical = compute_numerical_gradient_beta(x, gamma, beta, dy)
diff = np.abs(d_beta_analytical - d_beta_numerical)
tolerance = atol + rtol * np.abs(d_beta_numerical)
passed = np.all(diff <= tolerance)
if verbose:
print(f"Analytical gradient shape: {d_beta_analytical.shape}")
print(f"Numerical gradient shape: {d_beta_numerical.shape}")
print(f"Max absolute difference: {np.max(diff):.2e}")
print(f"Max relative tolerance: {np.max(tolerance):.2e}")
print(f"Mean analytical: {np.mean(np.abs(d_beta_analytical)):.6e}")
print(f"Mean numerical: {np.mean(np.abs(d_beta_numerical)):.6e}")
print(f"\nGradient check: {'PASSED ✓' if passed else 'FAILED ✗'}")
results['beta'] = passed
# -------------------------------------------------------------------------
# Check gradient w.r.t. x
# -------------------------------------------------------------------------
if verbose:
print("\n" + "="*60)
print("GRADIENT CHECK: x (input)")
print("="*60)
# Reset x to original
x[:] = orig_x
d_x_numerical, spot_checked = compute_numerical_gradient_x(x, gamma, beta, dy)
# Only check elements that were numerically computed
diff = np.abs(dx_analytical[spot_checked] - d_x_numerical[spot_checked])
tolerance = atol + rtol * np.abs(d_x_numerical[spot_checked])
passed = np.all(diff <= tolerance)
if verbose:
print(f"Analytical gradient shape: {dx_analytical.shape}")
print(f"Numerical gradient shape: {d_x_numerical.shape}")
print(f"Elements checked: {np.sum(spot_checked)} / {spot_checked.size}")
if np.any(spot_checked):
print(f"Max absolute difference: {np.max(diff):.2e}")
print(f"Max relative tolerance: {np.max(tolerance):.2e}")
print(f"Mean analytical (checked): {np.mean(np.abs(dx_analytical[spot_checked])):.6e}")
print(f"Mean numerical (checked): {np.mean(np.abs(d_x_numerical[spot_checked])):.6e}")
print(f"\nGradient check: {'PASSED ✓' if passed else 'FAILED ✗'}")
results['x'] = passed
# Restore originals
gamma[:] = orig_gamma
beta[:] = orig_beta
x[:] = orig_x
return results
# =============================================================================
# Complexity Analysis
# =============================================================================
def analyze_complexity():
"""Print complexity analysis for layer norm forward and backward."""
print("""
╔══════════════════════════════════════════════════════════════════════════════╗
║ COMPLEXITY ANALYSIS ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Input: x ∈ ^(B×T×D) ║
║ Parameters: γ, β ∈ ^D ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ FORWARD PASS ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Operation │ Work (FLOPs) │ Memory ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ mean (reduction) │ O(BTD) │ - ║
║ x - mean (broadcast) │ O(BTD) │ O(BTD) ║
║ var (reduction) │ O(BTD) │ - ║
║ sqrt(var + eps) │ O(BT) │ - ║
║ divide by std │ O(BTD) │ - ║
║ gamma * x_norm │ O(BTD) │ - ║
║ add beta │ O(BTD) │ O(BTD) output ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ TOTAL │ 5×O(BTD) │ O(BTD) ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ BACKWARD PASS ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Operation │ Work (FLOPs) │ Memory ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ d_gamma = sum(dy*x_norm)│ O(BTD) │ O(D) ║
║ d_beta = sum(dy) │ O(BTD) │ O(D) ║
║ dz = dy * gamma │ O(BTD) │ O(BTD) (can be avoided) ║
║ sum(dz) │ O(BTD) │ - ║
║ sum(dz * x_norm) │ O(BTD) │ - ║
║ dx computation │ O(BTD) │ O(BTD) output ║
║ ────────────────────────────────────────────────────────────────────────── ║
║ TOTAL │ 5×O(BTD) │ O(BTD) + O(D) ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ SUMMARY ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Time Complexity: O(BTD) for both forward and backward ║
║ Space Complexity: O(BTD) for storing activations (during training) ║
║ O(BTD) during inference (no need to store) ║
║ ║
║ Cache efficiency: We store x_centered, x_norm, mean, std ║
║ These are O(BTD) total, reused across all gradient comps ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
""")
# =============================================================================
# GPU Kernel Fusion Design
# =============================================================================
def explain_gpu_fusion():
"""Explain GPU kernel fusion for layer normalization."""
print("""
╔══════════════════════════════════════════════════════════════════════════════╗
║ GPU KERNEL FUSION DESIGN ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ CURRENT SEPARATE KERNEL APPROACH: ║
║ ──────────────────────────────────── ║
║ Kernel 1: Compute mean (reduction over D) ║
║ Kernel 2: Compute variance (reduction over D) ║
║ Kernel 3: Normalize (element-wise) ║
║ Kernel 4: Scale and shift (element-wise) ║
║ Kernel 5: Backward kernels (x, gamma, beta) ║
║ ║
║ Issues with separate kernels: ║
║ • Multiple kernel launches (overhead) ║
║ • Data movement between global memory passes ║
║ • Can't use persistent threads for reduction efficiency ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ FUSED KERNEL DESIGN (Forward): ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Grid: (B × T) blocks, each handling one (b,t) position ║
║ Block: 256-512 threads ║
║ ║
║ PHASE 1: Load and compute local sum ║
║ ───────────────────────────────── ║
║ • Each thread loads x[b,t,d] into shared memory ║
║ • Compute partial sum using warp-level reduction ║
║ • Single thread writes mean to __shared__ ║
║ ║
║ PHASE 2: Compute variance locally ║
║ ───────────────────────────────── ║
║ • Re-load x with loaded mean ║
║ • Compute (x-mean)² and partial variance ║
║ • Reduce to get variance ║
║ ║
║ PHASE 3: Normalize and output ║
║ ───────────────────────────────── ║
║ • All threads compute: y = gamma * (x-mean) / sqrt(var+eps) + beta ║
║ • Write to output (fully coalesced) ║
║ ║
║ MEMORY ACCESS PATTERN: ║
║ • Each block reads contiguous D elements (coalesced) ║
║ • Use shared memory for intermediate results ║
║ • Output writes are also coalesced ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ FUSED KERNEL DESIGN (Backward): ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ Key insight: dz = dy * gamma can be merged into the computation ║
║ ║
║ Grid: (B × T) blocks ║
║ Block: 256-512 threads ║
║ ║
║ SHARED MEMORY STRUCTURE: ║
║ [x_norm_0, x_norm_1, ..., x_norm_{D-1}] ║
║ [dz_0, dz_1, ..., dz_{D-1}] ║
║ ║
║ ALGORITHM: ║
║ 1. Load x_norm and compute local dz = dy * gamma ║
║ 2. Reduce to get sum(dz) and sum(dz * x_norm) ║
║ 3. Second pass to compute dx using the formula: ║
║ dx = (dz - mean(dz) - x_norm * mean(dz*x_norm)) / std ║
║ ║
║ REDUCTION OPTIMIZATIONS: ║
║ • Warp-level shuffle reductions (no shared memory needed) ║
║ • Block-level using shared memory with tree reduction ║
║ • Use block-level primitives for final reduction ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ BENEFITS OF FUSION: ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ ✓ Reduced kernel launch overhead (1 kernel vs 4-5) ║
║ ✓ Better memory bandwidth utilization (single read, single write) ║
║ ✓ Improved cache locality (data stays in registers/shared mem) ║
║ ✓ Only loads x once, computes mean and var from same data ║
║ ✓ Backward can reuse cached values from forward (if memory allows) ║
║ ✓ Lower register pressure allows for larger block sizes ║
║ ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ CUDA KERNEL SKETCH (Pseudo-code): ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ __global__ void layer_norm_fwd(float* y, const float* x, ║
║ const float* gamma, const float* beta, ║
║ int B, int T, int D, float eps) {
║ ║
║ __shared__ float mean_smem[256]; // block-level mean ║
║ __shared__ float var_smem[256]; // block-level variance ║
║ __shared__ float std_smem[256]; // block-level std ║
║ ║
║ int tid = threadIdx.x; ║
║ int bid = blockIdx.x; ║
║ int D_blk = (D + blockDim.x - 1) / blockDim.x; ║
║ ║
║ // Phase 1: Load and compute mean ║
║ float sum = 0.0; ║
║ for (int i = 0; i < D_blk; i++) {
║ int idx = bid * D + i * blockDim.x + tid; ║
║ sum += x[idx]; ║
║ } ║
║ sum = warpReduceSum(sum); ║
║ if (tid % 32 == 0) mean_smem[tid / 32] = sum; ║
║ __syncthreads(); ║
║ ║
║ float mean = mean_smem[0] / D; ║
║ ║
║ // Phase 2: Compute variance ║
║ sum = 0.0; ║
║ for (int i = 0; i < D_blk; i++) {
║ int idx = bid * D + i * blockDim.x + tid; ║
║ float diff = x[idx] - mean; ║
║ sum += diff * diff; ║
║ } ║
║ sum = warpReduceSum(sum); ║
║ if (tid % 32 == 0) var_smem[tid / 32] = sum; ║
║ __syncthreads(); ║
║ ║
║ float var = var_smem[0] / D; ║
║ float std = sqrt(var + eps); ║
║ ║
║ // Phase 3: Normalize and output ║
║ for (int i = 0; i < D_blk; i++) {
║ int idx = bid * D + i * blockDim.x + tid; ║
║ float x_norm = (x[idx] - mean) / std; ║
║ y[idx] = gamma[idx % D] * x_norm + beta[idx % D]; ║
║ } ║
║ } ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
""")
# =============================================================================
# Benchmark and Tests
# =============================================================================
def benchmark():
"""Benchmark forward and backward passes."""
print("\n" + "="*70)
print("BENCHMARKING LAYER NORMALIZATION")
print("="*70)
# Test different shapes
shapes = [
(32, 128, 256), # Small
(64, 128, 512), # Medium (BERT-base hidden)
(32, 512, 768), # BERT-base
(16, 512, 1024), # Larger
]
results = []
for B, T, D in shapes:
print(f"\nShape: (B={B}, T={T}, D={D})")
print("-" * 40)
# Create random inputs
np.random.seed(42)
x = np.random.randn(B, T, D).astype(np.float64)
gamma = np.random.randn(D).astype(np.float64)
beta = np.random.randn(D).astype(np.float64)
dy = np.random.randn(B, T, D).astype(np.float64)
# Forward benchmark
n_iters = 100
times = []
for _ in range(n_iters):
start = time.perf_counter()
y, cache = layer_norm_forward(x, gamma, beta)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
fwd_time = np.mean(times)
fwd_std = np.std(times)
# Backward benchmark
times = []
for _ in range(n_iters):
start = time.perf_counter()
dx, d_gamma, d_beta = layer_norm_backward(dy, cache)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
bwd_time = np.mean(times)
bwd_std = np.std(times)
# Throughput
elements = B * T * D
fwd_throughput = elements / (fwd_time / 1000) / 1e9 # GB/s
bwd_throughput = elements / (bwd_time / 1000) / 1e9
print(f"Forward: {fwd_time:.3f} ± {fwd_std:.3f} ms ({fwd_throughput:.1f} GB/s)")
print(f"Backward: {bwd_time:.3f} ± {bwd_std:.3f} ms ({bwd_throughput:.1f} GB/s)")
print(f"Total: {fwd_time + bwd_time:.3f} ms")
results.append({
'shape': (B, T, D),
'fwd_time': fwd_time,
'bwd_time': bwd_time,
'elements': elements
})
return results
def run_gradient_checks():
"""Run gradient checks on various shapes."""
print("\n" + "="*70)
print("RUNNING GRADIENT CHECKS")
print("="*70)
shapes = [
(2, 4, 8), # Tiny
(4, 8, 16), # Small
(8, 16, 32), # Medium-small
]
all_passed = True
for B, T, D in shapes:
print(f"\n{'='*60}")
print(f"Testing shape (B={B}, T={T}, D={D})")
print(f"{'='*60}")
# Create inputs
np.random.seed(123)
x = np.random.randn(B, T, D).astype(np.float64)
gamma = np.random.randn(D).astype(np.float64)
beta = np.random.randn(D).astype(np.float64)
dy = np.random.randn(B, T, D).astype(np.float64)
# Run gradient check
results = gradient_check(x, gamma, beta, dy, verbose=True)
for name, passed in results.items():
status = "✓ PASS" if passed else "✗ FAIL"
print(f" {name:>8}: {status}")
all_passed = all_passed and passed
print("\n" + "="*70)
if all_passed:
print("ALL GRADIENT CHECKS PASSED ✓")
else:
print("SOME GRADIENT CHECKS FAILED ✗")
print("="*70)
return all_passed
# =============================================================================
# Main
# =============================================================================
if __name__ == "__main__":
print("="*70)
print("LAYER NORMALIZATION: FORWARD AND BACKWARD FROM SCRATCH")
print("="*70)
# Print complexity analysis
analyze_complexity()
# Explain GPU fusion
explain_gpu_fusion()
# Run gradient checks
run_gradient_checks()
# Benchmark
benchmark()
File diff suppressed because one or more lines are too long