llama : enable chunked fused GDN path (#20340)

* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 206890897546bd16602c3b79394fd5ea09ef199f

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
This commit is contained in:
Georgi Gerganov
2026-03-11 22:46:40 +02:00
committed by GitHub
parent f90bd1dd84
commit d28961d81e
20 changed files with 674 additions and 165 deletions
+70 -25
View File
@@ -151,7 +151,8 @@ llama_context::llama_context(
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
cparams.fused_gdn_ar = true;
cparams.fused_gdn_ch = false; // TODO: implement
cparams.fused_gdn_ch = true;
cparams.auto_fgdn = true;
// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -462,37 +463,81 @@ void llama_context::sched_reserve() {
cparams.auto_fa = false;
}
if (cparams.fused_gdn_ar) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check");
}
if (cparams.auto_fgdn) {
LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
if (cparams.fused_gdn_ar) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
}
}
if (gdn_device_mismatch) {
cparams.fused_gdn_ar = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
} else {
LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
}
}
if (gdn_device_mismatch) {
cparams.fused_gdn_ar = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__);
if (cparams.fused_gdn_ch) {
// more than one token in the batch per sequence in order to take the chunked path
auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
}
}
if (gdn_device_mismatch) {
cparams.fused_gdn_ch = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
} else {
LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
}
}
cparams.auto_fgdn = false;
}
// reserve worst-case graph