perf: Batched q_norm/k_norm into single kernel dispatch #38
Loading…
Reference in a new issue
No description provided.
Delete branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Full attention layers (8 of 32) run per-head RMS norms: 16 q_norm + 4 k_norm = 20 separate dispatches per layer. Each dispatch only uses 256 threads for head_dim=256 elements. That's 20 tiny kernels with terrible GPU occupancy.
Solution: Write a
batched_rms_norm_bf16kernel that normalizes all heads in a single dispatch. Input: contiguous buffer with all heads laid out, num_heads parameter, per-head head_dim. One dispatch handles all 20 norms.This saves 20×8 = 160 kernel dispatches per decode step.
Merged via PR #49 (squash). Coherence: token-identical to baseline. Benchmark: 43ms→40ms decode (-7%), 144 fewer kernel dispatches.