[perf] Vectorized bf16x4 GEMV kernel for M=1 decode (#40) #51
Loading…
Reference in a new issue
No description provided.
Delete branch "perf/40-gemv-bandwidth-optimization"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Summary
New
matmul_bf16_m1_vec4kernel using vectorized bf16x4 loads (packed_bfloat4) with 2-SG reduction. Replaces scalar bf16 loads with 4-wide vector loads for better memory coalescing.Benchmark
Output: token-identical to baseline (verified A/B).
Changes
matmul.metal: Newmatmul_bf16_m1_vec4kernel with bf16x4 loads, 2-SG (64 threads/TG), threadgroup shared memory for cross-SG reductiondispatch.zig: Newset_matmul_bf16_m1_vec4dispatch functionmodel.zig:forward_decode()switched to vec4 kernelRemaining gap
Target is 37 tok/s (27ms). Currently at 24.6 tok/s (35ms). Need further optimization:
Closes #40