perf: Fuse RMS norm into matmul for decode #43
Loading…
Reference in a new issue
No description provided.
Delete branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Every layer starts with: RMS norm hidden → matmul(hidden, weight). These are always paired.
For decode (M=1), the RMS norm reads 2560 bf16 (5KB) and writes 2560 bf16. The matmul then reads the same 2560 bf16. Fusing them eliminates one 5KB read+write per norm-matmul pair.
Saves ~4 reads/writes per layer (input_norm→qkv, post_norm→gate, etc.) × 32 layers = 128 fewer kernel dispatches + ~320MB less memory traffic per decode step.
Approach rejected. Fusing norm into row-per-TG GEMV is counterproductive for M=1 decode. Each threadgroup computes 1 output row, so the RMS norm gets redundantly computed N times (152960x for lm_head). Result: 162ms vs 34ms baseline (4.7x SLOWER). Fusion only works when multiple output rows share a threadgroup and thus share the norm computation. Our M=1 GEMV with 1-row-per-TG makes this impossible. Closing as wontfix for current architecture.