Metal kernels for GPU linear attention (conv1d, delta rule, norm/gate) #35

Open
opened 2026-05-15 12:20:34 +02:00 by sleepy · 0 comments
Owner

Objective

Write 2 Metal compute kernels for linear attention decode (seq_len=1), eliminating all CPU state management.

Kernel 1: conv1d_state_bf16

  • Input: mixed_qkv[conv_dim=6656] (bf16), conv_state[(conv_k-1)conv_dim] (bf16), conv_weight[conv_dimconv_k] (bf16)
  • Output: conv_out[conv_dim] (bf16), updated conv_state
  • Ops: shift conv_state left by conv_dim, store mixed_qkv at end, conv1d dot product, apply SiLU
  • conv_k=4, conv_dim=6656
  • See attention.zig:433-462 for exact CPU algorithm

Kernel 2: linear_attn_delta_rule_bf16

  • Input: conv_out[conv_dim] (bf16), z[value_dim=4096] (bf16), a_proj[32] (bf16), b_proj[32] (bf16), A_log[32] (f32), dt_bias[32] (f32), norm_weight[head_v_dim=128] (bf16)
  • Inout: recurrent_state[num_v_headshead_k_dimhead_v_dim = 32128128 = 524288] (f32)
  • Output: normed_out[value_dim=4096] (bf16)
  • Params: num_k_heads=16, num_v_heads=32, head_k_dim=128, head_v_dim=128, eps=1e-6
  • One threadgroup per value head (32 threadgroups, 128 threads each)
  • Per threadgroup (head h):
    1. Extract q,k from conv_out with head group mapping (src_h = h / (num_v_heads/num_k_heads))
    2. Extract v from conv_out
    3. beta = sigmoid(b_proj[h]), g = -exp(A_log[h]) * softplus(a_proj[h] + dt_bias[h])
    4. RMSNorm q (scale=1/head_k_dim), k (scale=1/sqrt(head_k_dim))
    5. State decay: S[d*128+e] *= exp(g)
    6. kv_mem[e] = sum_d(S[d*128+e] * k[d])
    7. delta[e] = (v[e] - kv_mem[e]) * beta
    8. S[d*128+e] += k[d] * delta[e]
    9. out[e] = sum_d(S[d*128+e] * q[d])
    10. RMSNorm out, multiply by norm_weight[d] * silu(z[h*128+d])
  • See attention.zig:488-625 for exact CPU algorithm

Files to create/modify

  • CREATE: src/metal/kernels/linear_attention.metal
  • MODIFY: src/metal/dispatch.zig — add dispatch functions
  • MODIFY: build.zig — add linear_attention.metal to kernel compilation

Acceptance

  • Both kernels compile in metallib
  • dispatch.zig has set_conv1d_state_bf16() and set_linear_attn_delta_rule_bf16()
  • zig build passes
  • zig build test passes (91/91)
  • Push branch feat/35-gpu-linear-attention-kernels to Forgejo

Constraints

  • BF16 throughout (use bfloat type from bfloat.metal)
  • Threadgroup shared memory for q, k, v, kv_mem, delta, out (~1.5KB per threadgroup)
  • recurrent_state can stay f32 (precision for accumulation)
  • Do NOT modify model.zig or attention.zig (separate issue for wiring)
## Objective Write 2 Metal compute kernels for linear attention decode (seq_len=1), eliminating all CPU state management. ## Kernel 1: `conv1d_state_bf16` - Input: mixed_qkv[conv_dim=6656] (bf16), conv_state[(conv_k-1)*conv_dim] (bf16), conv_weight[conv_dim*conv_k] (bf16) - Output: conv_out[conv_dim] (bf16), updated conv_state - Ops: shift conv_state left by conv_dim, store mixed_qkv at end, conv1d dot product, apply SiLU - conv_k=4, conv_dim=6656 - See attention.zig:433-462 for exact CPU algorithm ## Kernel 2: `linear_attn_delta_rule_bf16` - Input: conv_out[conv_dim] (bf16), z[value_dim=4096] (bf16), a_proj[32] (bf16), b_proj[32] (bf16), A_log[32] (f32), dt_bias[32] (f32), norm_weight[head_v_dim=128] (bf16) - Inout: recurrent_state[num_v_heads*head_k_dim*head_v_dim = 32*128*128 = 524288] (f32) - Output: normed_out[value_dim=4096] (bf16) - Params: num_k_heads=16, num_v_heads=32, head_k_dim=128, head_v_dim=128, eps=1e-6 - One threadgroup per value head (32 threadgroups, 128 threads each) - Per threadgroup (head h): 1. Extract q,k from conv_out with head group mapping (src_h = h / (num_v_heads/num_k_heads)) 2. Extract v from conv_out 3. beta = sigmoid(b_proj[h]), g = -exp(A_log[h]) * softplus(a_proj[h] + dt_bias[h]) 4. RMSNorm q (scale=1/head_k_dim), k (scale=1/sqrt(head_k_dim)) 5. State decay: S[d*128+e] *= exp(g) 6. kv_mem[e] = sum_d(S[d*128+e] * k[d]) 7. delta[e] = (v[e] - kv_mem[e]) * beta 8. S[d*128+e] += k[d] * delta[e] 9. out[e] = sum_d(S[d*128+e] * q[d]) 10. RMSNorm out, multiply by norm_weight[d] * silu(z[h*128+d]) - See attention.zig:488-625 for exact CPU algorithm ## Files to create/modify - CREATE: src/metal/kernels/linear_attention.metal - MODIFY: src/metal/dispatch.zig — add dispatch functions - MODIFY: build.zig — add linear_attention.metal to kernel compilation ## Acceptance - Both kernels compile in metallib - dispatch.zig has set_conv1d_state_bf16() and set_linear_attn_delta_rule_bf16() - zig build passes - zig build test passes (91/91) - Push branch feat/35-gpu-linear-attention-kernels to Forgejo ## Constraints - BF16 throughout (use bfloat type from bfloat.metal) - Threadgroup shared memory for q, k, v, kv_mem, delta, out (~1.5KB per threadgroup) - recurrent_state can stay f32 (precision for accumulation) - Do NOT modify model.zig or attention.zig (separate issue for wiring)
Sign in to join this conversation.
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
sleepy/sleepy-llm#35
No description provided.