Tensor-parallel: Fix delayed AllReduce on Gemma-4 MoE (#22129)
* Fix delayed AllReduce on Gemma-4 MoE Skip forward past nodes that don't consume the current one, and allow a chain of MULs. * Check for all sources before skipping nodes * Address review comments
This commit is contained in:
@@ -1683,6 +1683,36 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
|||||||
|
|
||||||
ggml_tensor * node = cgraph->nodes[id];
|
ggml_tensor * node = cgraph->nodes[id];
|
||||||
int32_t n_used = ggml_node_get_use_count(cgraph, id);
|
int32_t n_used = ggml_node_get_use_count(cgraph, id);
|
||||||
|
|
||||||
|
// Skip MIRRORED nodes that don't consume node
|
||||||
|
auto skip_unrelated = [&]() {
|
||||||
|
while (id + 1 < cgraph->n_nodes) {
|
||||||
|
ggml_tensor * next = cgraph->nodes[id+1];
|
||||||
|
if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bool safe = true;
|
||||||
|
for (int s = 0; s < GGML_MAX_SRC; s++) {
|
||||||
|
if (next->src[s] == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (next->src[s] == node) {
|
||||||
|
safe = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
||||||
|
safe = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!safe) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
id++;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
skip_unrelated();
|
||||||
if (id + 1 >= cgraph->n_nodes) {
|
if (id + 1 >= cgraph->n_nodes) {
|
||||||
return idr;
|
return idr;
|
||||||
}
|
}
|
||||||
@@ -1697,10 +1727,12 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
|||||||
n_used = ggml_node_get_use_count(cgraph, id);
|
n_used = ggml_node_get_use_count(cgraph, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (id + 1 >= cgraph->n_nodes) {
|
// Chain of MULs with MIRRORED src[1]
|
||||||
return idr;
|
while (true) {
|
||||||
}
|
skip_unrelated();
|
||||||
{
|
if (id + 1 >= cgraph->n_nodes) {
|
||||||
|
return idr;
|
||||||
|
}
|
||||||
ggml_tensor * next = cgraph->nodes[id+1];
|
ggml_tensor * next = cgraph->nodes[id+1];
|
||||||
if (next->op == GGML_OP_MUL && next->src[0] == node &&
|
if (next->op == GGML_OP_MUL && next->src[0] == node &&
|
||||||
ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
||||||
@@ -1708,6 +1740,8 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
|||||||
id++;
|
id++;
|
||||||
idr = id;
|
idr = id;
|
||||||
n_used = ggml_node_get_use_count(cgraph, id);
|
n_used = ggml_node_get_use_count(cgraph, id);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user