[perf] Batched q_norm/k_norm into single kernel dispatches (#38) #49
Loading…
Reference in a new issue
No description provided.
Delete branch "perf/38-batched-qnorm-knorm"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Summary
Replaces 20 per-head RMS norm dispatches per full attention layer with 2 batched dispatches (one for all Q heads, one for all K heads).
Savings: 18 fewer dispatches × 8 full attention layers = 144 fewer kernel dispatches per decode step.
Changes
batched_rms_norm_bf16inrms_norm.metal— one threadgroup per head, handles strided layout viabyte_strideparameterset_batched_rms_norm_bf16indispatch.zigmodel.zig: replaced per-head norm loops with 2 batched dispatchesBenchmark
Closes #38