Tensor-parallel: Fix delayed AllReduce on Gemma-4 MoE (#22129)

* Fix delayed AllReduce on Gemma-4 MoE

Skip forward past nodes that don't consume the current one, and allow a chain of MULs.

* Check for all sources before skipping nodes

* Address review comments
This commit is contained in:
Gaurav Garg
2026-04-20 21:55:39 +05:30
committed by GitHub
parent fb19f94c71
commit fd6ae4ca1c
+38 -4
View File
@@ -1683,6 +1683,36 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
ggml_tensor * node = cgraph->nodes[id]; ggml_tensor * node = cgraph->nodes[id];
int32_t n_used = ggml_node_get_use_count(cgraph, id); int32_t n_used = ggml_node_get_use_count(cgraph, id);
// Skip MIRRORED nodes that don't consume node
auto skip_unrelated = [&]() {
while (id + 1 < cgraph->n_nodes) {
ggml_tensor * next = cgraph->nodes[id+1];
if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
break;
}
bool safe = true;
for (int s = 0; s < GGML_MAX_SRC; s++) {
if (next->src[s] == nullptr) {
continue;
}
if (next->src[s] == node) {
safe = false;
break;
}
if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
safe = false;
break;
}
}
if (!safe) {
break;
}
id++;
}
};
skip_unrelated();
if (id + 1 >= cgraph->n_nodes) { if (id + 1 >= cgraph->n_nodes) {
return idr; return idr;
} }
@@ -1697,10 +1727,12 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
n_used = ggml_node_get_use_count(cgraph, id); n_used = ggml_node_get_use_count(cgraph, id);
} }
} }
if (id + 1 >= cgraph->n_nodes) { // Chain of MULs with MIRRORED src[1]
return idr; while (true) {
} skip_unrelated();
{ if (id + 1 >= cgraph->n_nodes) {
return idr;
}
ggml_tensor * next = cgraph->nodes[id+1]; ggml_tensor * next = cgraph->nodes[id+1];
if (next->op == GGML_OP_MUL && next->src[0] == node && if (next->op == GGML_OP_MUL && next->src[0] == node &&
ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
@@ -1708,6 +1740,8 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
id++; id++;
idr = id; idr = id;
n_used = ggml_node_get_use_count(cgraph, id); n_used = ggml_node_get_use_count(cgraph, id);
} else {
break;
} }
} }