ggml-cuda: refactor fusion code (#22468)

* ggml-cuda: refactor fusion code

* apply formatting + make env variable truthy
This commit is contained in:
Aman Gupta
2026-04-29 16:19:33 +08:00
committed by GitHub
parent b5c4227dc6
commit 3142f1dbb9
+355 -348
View File
@@ -3640,6 +3640,357 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
return false; return false;
} }
// try and fuse nodes and return the number of nodes to skip
static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) {
static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION"));
if (disable_fusion) {
return 0;
}
ggml_tensor * node = cgraph->nodes[i];
//topk-moe
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
ggml_cuda_topk_moe_args args;
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
std::vector<ggml_op> ops;
if (can_fuse) {
const ggml_tensor * logits = node->src[0];
ggml_tensor * weights = nullptr;
ggml_tensor * ids = nullptr;
const ggml_tensor * bias = nullptr;
const ggml_tensor * clamp = nullptr;
const ggml_tensor * scale = nullptr;
if (!args.delayed_softmax) {
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
int out_nodes[2]; // nodes which can't be elided
if (args.prob_bias) {
bias = cgraph->nodes[i + 2]->src[1];
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS });
out_nodes[0] = i + 4;
ids = cgraph->nodes[i + 4];
} else {
ops.insert(ops.end(),
{ gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS });
out_nodes[0] = i + 3;
ids = cgraph->nodes[i + 3];
}
if (args.norm) {
ops.insert(ops.end(),
{ GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE });
clamp = cgraph->nodes[i + ops.size() - 3];
}
if (args.scale) {
ops.insert(ops.end(), { GGML_OP_SCALE });
scale = cgraph->nodes[i + ops.size() - 1];
}
weights = cgraph->nodes[i + ops.size() - 1];
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
return ops.size() - 1;
}
} else if (!args.norm && !args.prob_bias) {
//special case gpt-oss, no norm, no bias.
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
weights = cgraph->nodes[i + 5];
ids = cgraph->nodes[i + 1];
const ggml_tensor * softmax = cgraph->nodes[i + 4];
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
return ops.size() - 1;
}
}
}
}
//RoPE + view + set-rows
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
ggml_tensor * rope = cgraph->nodes[i];
ggml_tensor * set_rows = cgraph->nodes[i + 2];
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
return 2;
}
// multi-(add or mul)
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
int n_fuse = 0;
ggml_op ops[8];
std::fill(ops, ops + 8, node->op);
for (; n_fuse <= 6; ++n_fuse) {
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
break;
}
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
break;
}
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
break;
}
}
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_node;
memcpy(&fused_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
if (node->op == GGML_OP_ADD) {
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
} else {
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
}
return n_fuse - 1;
}
}
bool fused_mul_mat_vec = false;
int fused_node_count = 0;
// gate + glu + up
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
ggml_tensor * glu = cgraph->nodes[i + 4];
ggml_tensor * gate_bias_n = glu->src[0];
ggml_tensor * up_bias_n = glu->src[1];
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
ggml_tensor * gate_n = nullptr;
ggml_tensor * up_n = nullptr;
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
gate_n = cgraph->nodes[i];
up_n = cgraph->nodes[i + 2];
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
gate_n = cgraph->nodes[i + 2];
up_n = cgraph->nodes[i];
} else {
continue;
}
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
if (op_bias == GGML_OP_ADD) {
if (bias_node->src[0] == mul_node) {
return bias_node->src[1];
}
if (bias_node->src[1] == mul_node) {
return bias_node->src[0];
}
return (ggml_tensor *) nullptr;
}
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
GGML_ASSERT(bias_node->src[0] == mul_node);
return bias_node->src[1];
};
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
if (!up_bias_tensor || !gate_bias_tensor) {
continue;
}
// we don't support repeating adds
if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
!ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
continue;
}
const ggml_tensor * src0 = up_n->src[0];
const ggml_tensor * src1 = up_n->src[1];
const ggml_tensor * ids = up_n->src[2];
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate_n->src[0];
fusion_data.x_bias = up_bias_tensor;
fusion_data.gate_bias = gate_bias_tensor;
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 5;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate_n->src[0];
fusion_data.x_bias = up_bias_tensor;
fusion_data.gate_bias = gate_bias_tensor;
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 5;
break;
}
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
ggml_tensor * glu = cgraph->nodes[i + 2];
ggml_tensor * gate = glu->src[0];
ggml_tensor * up = glu->src[1];
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) ||
(gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
if (!ok) {
continue;
}
const ggml_tensor * src0 = up->src[0];
const ggml_tensor * src1 = up->src[1];
const ggml_tensor * ids = up->src[2];
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate->src[0];
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 3;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate->src[0];
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 3;
break;
}
}
}
if (fused_mul_mat_vec) {
return fused_node_count - 1;
}
fused_mul_mat_vec = false;
fused_node_count = 0;
// gate + add + glu + up + add
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
continue;
}
ggml_tensor * mm_node = cgraph->nodes[i];
ggml_tensor * bias_node = cgraph->nodes[i + 1];
ggml_tensor * bias_tensor = nullptr;
if (bias_op == GGML_OP_ADD) {
if (bias_node->src[0] == mm_node) {
bias_tensor = bias_node->src[1];
} else if (bias_node->src[1] == mm_node) {
bias_tensor = bias_node->src[0];
} else {
continue;
}
} else {
if (bias_node->src[0] != mm_node) {
continue;
}
bias_tensor = bias_node->src[1];
}
const ggml_tensor * src0 = mm_node->src[0];
const ggml_tensor * src1 = mm_node->src[1];
const ggml_tensor * ids = mm_node->src[2];
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
continue;
}
if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
continue;
}
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.x_bias = bias_tensor;
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 2;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 2;
break;
}
}
if (fused_mul_mat_vec) {
return fused_node_count - 1;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) {
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
return 2;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]);
return 1;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]);
return 1;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]);
return 1;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]);
return 1;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node);
return 2;
}
return 0;
}
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
bool graph_evaluated_or_captured = false; bool graph_evaluated_or_captured = false;
@@ -3786,355 +4137,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
continue; continue;
} }
// start of fusion operations int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i);
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
ggml_cuda_topk_moe_args args;
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || if (nodes_to_skip != 0) {
cgraph->nodes[i]->op == GGML_OP_ARGSORT) { i += nodes_to_skip;
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); continue;
std::vector<ggml_op> ops;
if (can_fuse) {
const ggml_tensor * logits = node->src[0];
ggml_tensor * weights = nullptr;
ggml_tensor * ids = nullptr;
const ggml_tensor * bias = nullptr;
const ggml_tensor * clamp = nullptr;
const ggml_tensor * scale = nullptr;
if (!args.delayed_softmax) {
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
int out_nodes[2]; // nodes which can't be elided
if (args.prob_bias) {
bias = cgraph->nodes[i + 2]->src[1];
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS });
out_nodes[0] = i + 4;
ids = cgraph->nodes[i + 4];
} else {
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS });
out_nodes[0] = i + 3;
ids = cgraph->nodes[i + 3];
}
if (args.norm) {
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE });
clamp = cgraph->nodes[i + ops.size() - 3];
}
if (args.scale) {
ops.insert(ops.end(), { GGML_OP_SCALE });
scale = cgraph->nodes[i + ops.size() - 1];
}
weights = cgraph->nodes[i + ops.size() - 1];
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
} else if (!args.norm && !args.prob_bias) {
//special case gpt-oss, no norm, no bias.
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
weights = cgraph->nodes[i + 5];
ids = cgraph->nodes[i + 1];
const ggml_tensor * softmax = cgraph->nodes[i + 4];
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
}
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
ggml_tensor * rope = cgraph->nodes[i];
ggml_tensor * set_rows = cgraph->nodes[i + 2];
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
i += 2;
continue;
}
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
int n_fuse = 0;
ggml_op ops[8];
std::fill(ops, ops + 8, node->op);
for (; n_fuse <= 6; ++n_fuse){
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
break;
}
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
break;
}
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
break;
}
}
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_node;
memcpy(&fused_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
if (node->op == GGML_OP_ADD) {
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
} else {
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
}
i += n_fuse - 1;
continue;
}
}
bool fused_mul_mat_vec = false;
int fused_node_count = 0;
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
ggml_tensor * glu = cgraph->nodes[i + 4];
ggml_tensor * gate_bias_n = glu->src[0];
ggml_tensor * up_bias_n = glu->src[1];
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
ggml_tensor * gate_n = nullptr;
ggml_tensor * up_n = nullptr;
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
gate_n = cgraph->nodes[i];
up_n = cgraph->nodes[i + 2];
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
gate_n = cgraph->nodes[i + 2];
up_n = cgraph->nodes[i];
} else {
continue;
}
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
if (op_bias == GGML_OP_ADD) {
if (bias_node->src[0] == mul_node) {
return bias_node->src[1];
}
if (bias_node->src[1] == mul_node) {
return bias_node->src[0];
}
return (ggml_tensor *) nullptr;
}
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
GGML_ASSERT(bias_node->src[0] == mul_node);
return bias_node->src[1];
};
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
if (!up_bias_tensor || !gate_bias_tensor) {
continue;
}
// we don't support repeating adds
if (bias_op == GGML_OP_ADD &&
(!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
!ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
continue;
}
const ggml_tensor * src0 = up_n->src[0];
const ggml_tensor * src1 = up_n->src[1];
const ggml_tensor * ids = up_n->src[2];
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate_n->src[0];
fusion_data.x_bias = up_bias_tensor;
fusion_data.gate_bias = gate_bias_tensor;
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 5;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate_n->src[0];
fusion_data.x_bias = up_bias_tensor;
fusion_data.gate_bias = gate_bias_tensor;
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 5;
break;
}
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
ggml_tensor * glu = cgraph->nodes[i + 2];
ggml_tensor * gate = glu->src[0];
ggml_tensor * up = glu->src[1];
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
|| (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
if (!ok) continue;
const ggml_tensor * src0 = up->src[0];
const ggml_tensor * src1 = up->src[1];
const ggml_tensor * ids = up->src[2];
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate->src[0];
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 3;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate->src[0];
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 3;
break;
}
}
}
if (fused_mul_mat_vec) {
i += fused_node_count - 1;
continue;
}
fused_mul_mat_vec = false;
fused_node_count = 0;
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
continue;
}
ggml_tensor * mm_node = cgraph->nodes[i];
ggml_tensor * bias_node = cgraph->nodes[i + 1];
ggml_tensor * bias_tensor = nullptr;
if (bias_op == GGML_OP_ADD) {
if (bias_node->src[0] == mm_node) {
bias_tensor = bias_node->src[1];
} else if (bias_node->src[1] == mm_node) {
bias_tensor = bias_node->src[0];
} else {
continue;
}
} else {
if (bias_node->src[0] != mm_node) {
continue;
}
bias_tensor = bias_node->src[1];
}
const ggml_tensor * src0 = mm_node->src[0];
const ggml_tensor * src1 = mm_node->src[1];
const ggml_tensor * ids = mm_node->src[2];
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
continue;
}
if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
continue;
}
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.x_bias = bias_tensor;
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 2;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 2;
break;
}
}
if (fused_mul_mat_vec) {
i += fused_node_count - 1;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
continue;
}
} }
#ifndef NDEBUG #ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));