perf: Batched q_norm/k_norm into single kernel dispatch #38

Closed
opened 2026-05-15 19:02:26 +02:00 by sleepy · 1 comment
Owner

Full attention layers (8 of 32) run per-head RMS norms: 16 q_norm + 4 k_norm = 20 separate dispatches per layer. Each dispatch only uses 256 threads for head_dim=256 elements. That's 20 tiny kernels with terrible GPU occupancy.

Solution: Write a batched_rms_norm_bf16 kernel that normalizes all heads in a single dispatch. Input: contiguous buffer with all heads laid out, num_heads parameter, per-head head_dim. One dispatch handles all 20 norms.

This saves 20×8 = 160 kernel dispatches per decode step.

Full attention layers (8 of 32) run per-head RMS norms: 16 q_norm + 4 k_norm = 20 separate dispatches per layer. Each dispatch only uses 256 threads for head_dim=256 elements. That's 20 tiny kernels with terrible GPU occupancy. Solution: Write a `batched_rms_norm_bf16` kernel that normalizes all heads in a single dispatch. Input: contiguous buffer with all heads laid out, num_heads parameter, per-head head_dim. One dispatch handles all 20 norms. This saves 20×8 = 160 kernel dispatches per decode step.
Author
Owner

Merged via PR #49 (squash). Coherence: token-identical to baseline. Benchmark: 43ms→40ms decode (-7%), 144 fewer kernel dispatches.

Merged via PR #49 (squash). Coherence: token-identical to baseline. Benchmark: 43ms→40ms decode (-7%), 144 fewer kernel dispatches.
Sign in to join this conversation.
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/sleepy-llm#38
No description provided.