llama : enable chunked fused GDN path (#20340)
* llama : enable chunked fused GDN path * models : avoid Q and K repeats when using fused GDA * cont : fix comment Co-authored-by: Aman Gupta <amangupta052@gmail.com> * cont : fix the fix Co-authored-by: Aman Gupta <amangupta052@gmail.com> * cont : fix * metal : add GDN kernel (#20361) * metal : add Metal backend for GGML_OP_GATED_DELTA_NET Add a fused Metal kernel for the gated delta net recurrence op (#19504), enabling GPU-accelerated inference for DeltaNet-based models (Qwen3.5, etc.) on Apple Silicon. Supports both GDA (scalar gate) and KDA (per-row gate) modes with head_size 64 and 128. Unsupported configurations (head_size 32, non-contiguous tensors) gracefully fall back to CPU. Performance: Qwen3.5-0.8B Q4_K_M on M4 Max tg128: 170 -> 213 t/s (+25%) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * metal : validate contiguity of all input tensors in supports_op Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * metal : add algorithm equivalence comment for GDA decay path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * cont : unslop + optimize * cont : clean-up --------- Co-authored-by: Paul Flynn <paul@arkavo.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * CUDA: AR gated delta net improvements (#20391) * Add FastDiv to gated_delta_net_cuda * Shard columns across warps This reduces register pressure (avoids spill for S_v = 128) and gives the warp-scheduler more CTAs to schedule (thus hiding data-access latencies). * Remove unneded include in gated_delta_net.cu * Improve comments * Apply code-formating * Make sharding HIP-compatible 1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly 2. Add test with partial warp to test sum reduction on CUDA * Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t * Rename variables * Enable GDN also for prefill, move TODO for chunked_GDN * Actually remove the TODO from 206890897546bd16602c3b79394fd5ea09ef199f * Get warp size at runtime warp_size is not known at compile time in hip host code. * Don't expose ggml_cuda_get_physical_warp_size on host --------- Co-authored-by: uvos <devnull@uvos.xyz> * llama : refactor llm_build_delta_net_base API --------- Co-authored-by: Aman Gupta <amangupta052@gmail.com> Co-authored-by: Paul Flynn <paul@arkavo.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Oliver Simons <osimons@nvidia.com> Co-authored-by: uvos <devnull@uvos.xyz>
This commit is contained in:
@@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
// TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST]
|
||||
if (num_k_heads != num_v_heads) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
int64_t repeat_factor = num_v_heads / num_k_heads;
|
||||
@@ -431,13 +432,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
|
||||
Reference in New Issue
Block a user