Move linear attention layers to GPU (24/32 layers on CPU) #33
Loading…
Reference in a new issue
No description provided.
Delete branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Problem
24 of 32 Qwen3.5 layers are linear attention, currently running on CPU during decode. Each decode step does GPU→CPU→GPU roundtrip for these layers. This is the main speed bottleneck.
Current: ~0.6 tok/s decode (linear attention on CPU)
Target: 37+ tok/s (need GPU linear attention)
LinearAttentionLayer.forward_gpu() already exists and handles seq_len=1 with persistent GPU state buffers. Needs wiring into the decode loop.
Acceptance
Max 2 attempts.
Merged hybrid GPU+CPU linear attention (
85b7a77). GPU matmuls for the 3 big projections (37M ops/layer), CPU for state management. ~2x speedup overall (0.67→1.4 tok/s). Full GPU linear attention deferred — the remaining CPU state management is now the bottleneck.