Files
llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl
T
Reese Levine dd2914dc81 ggml-webgpu: support for SSM_SCAN and disable set_rows error checking (#22327)
* Implement ssm_scan

* Remove blocking in graph_compute and check for set rows

* Fix bindings

* Update op support
2026-04-25 09:18:15 +03:00

169 lines
5.5 KiB
WebGPU Shading Language

#ifdef USE_SUBGROUP_REDUCTION
enable subgroups;
#endif
struct Params {
offset_s: u32,
offset_x: u32,
offset_dt: u32,
offset_A: u32,
offset_B: u32,
offset_C: u32,
offset_ids: u32,
offset_dst: u32,
stride_s1: u32,
stride_s2: u32,
stride_s3: u32,
stride_x1: u32,
stride_x2: u32,
stride_x3: u32,
stride_dt1: u32,
stride_dt2: u32,
a_ne0: u32,
stride_A1: u32,
stride_B1: u32,
stride_B2: u32,
stride_B3: u32,
stride_C1: u32,
stride_C2: u32,
stride_C3: u32,
d_state: u32,
d_inner: u32,
n_head: u32,
n_group: u32,
n_seq_tokens: u32,
n_seqs: u32,
y_elems: u32,
};
@group(0) @binding(0) var<storage, read_write> s_in: array<f32>;
@group(0) @binding(1) var<storage, read_write> x: array<f32>;
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
@group(0) @binding(4) var<storage, read_write> B: array<f32>;
@group(0) @binding(5) var<storage, read_write> C: array<f32>;
@group(0) @binding(6) var<storage, read_write> ids: array<i32>;
@group(0) @binding(7) var<storage, read_write> dst: array<f32>;
@group(0) @binding(8) var<uniform> params: Params;
var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>;
var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>;
var<workgroup> shared_reduce: array<f32, TOKENS_PER_TILE * WG_SIZE>;
fn reduce_base(token_in_tile: u32) -> u32 {
return token_in_tile * WG_SIZE;
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
#ifdef USE_SUBGROUP_REDUCTION
, @builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
@builtin(num_subgroups) num_subgroups: u32
#endif
) {
let tid = local_id.x;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
let i1 = wg_linear % params.d_inner;
let head_seq = wg_linear / params.d_inner;
let ir = head_seq % params.n_head;
let i3 = head_seq / params.n_head;
let state_slot = u32(ids[params.offset_ids + i3]);
let g = ir / (params.n_head / params.n_group);
let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3;
var s_prev = s_in[s_idx];
let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1];
for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) {
if (tid < TOKENS_PER_TILE) {
let token = token_base + tid;
if (token < params.n_seq_tokens) {
let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3;
let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2;
let dt0 = dt[dt_idx];
let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0);
shared_dtsp[tid] = dtsp;
shared_x_dt[tid] = x[x_idx] * dtsp;
}
}
workgroupBarrier();
for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) {
let token = token_base + token_in_tile;
if (token >= params.n_seq_tokens) {
break;
}
let x_dt = shared_x_dt[token_in_tile];
let dA = exp(shared_dtsp[token_in_tile] * A0);
let reduce_idx = reduce_base(token_in_tile) + tid;
let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3;
let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3;
let s = s_prev * dA + B[b_idx] * x_dt;
s_prev = s;
#ifdef USE_SUBGROUP_REDUCTION
let subgroup_partial = subgroupAdd(s * C[c_idx]);
if (subgroup_invocation_id == 0u) {
shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial;
}
#else
shared_reduce[reduce_idx] = s * C[c_idx];
#endif
workgroupBarrier();
#ifdef USE_SUBGROUP_REDUCTION
if (tid == 0u) {
var sum = 0.0;
for (var sg = 0u; sg < num_subgroups; sg++) {
sum += shared_reduce[reduce_base(token_in_tile) + sg];
}
let y_idx =
params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) +
i3 * (params.n_seq_tokens * params.n_head * params.d_inner);
dst[y_idx] = sum;
}
#else
for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) {
if (tid < stride) {
shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride];
}
workgroupBarrier();
}
if (tid == 0u) {
let y_idx =
params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) +
i3 * (params.n_seq_tokens * params.n_head * params.d_inner);
dst[y_idx] = shared_reduce[reduce_base(token_in_tile)];
}
#endif
workgroupBarrier();
}
}
let state_idx =
params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) +
i3 * (params.d_state * params.d_inner * params.n_head);
dst[state_idx] = s_prev;
}