diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 2eb9820bf..44e10b6b8 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -394,19 +394,29 @@ void ggml_graph_optimize(ggml_cgraph * gf) { // fuse only ops that start with these operations // can be expanded when needed if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_SUB || + node.op() == GGML_OP_MUL || + node.op() == GGML_OP_DIV || node.op() == GGML_OP_NORM || node.op() == GGML_OP_RMS_NORM) { ops[0] = node.op(); int f = i + 1; while (f < n && f < i + MAX_FUSE) { - // conservatively allow fusing only these ops - // can be expanded when needed - if (gf->nodes[f]->op != GGML_OP_ADD && - gf->nodes[f]->op != GGML_OP_MUL && - gf->nodes[f]->op != GGML_OP_NORM && - gf->nodes[f]->op != GGML_OP_RMS_NORM) { - break; + // bin ops (ADD/SUB/MUL/DIV) must be same type to fuse + // NORM/RMS_NORM can chain with MUL/ADD + if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_SUB || + node.op() == GGML_OP_MUL || + node.op() == GGML_OP_DIV) { + if (gf->nodes[f]->op != node.op()) break; + } else { + if (gf->nodes[f]->op != GGML_OP_ADD && + gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_NORM && + gf->nodes[f]->op != GGML_OP_RMS_NORM) { + break; + } } ops[f - i] = gf->nodes[f]->op; f++; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c87..d4a798e14 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3118,19 +3118,15 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { int n_fuse = 1; - // c[0] = add(a, b[0]) - // c[1] = add(c[0], b[1]) - // c[2] = add(c[1], b[2]) + // c[0] = op(a, b[0]) + // c[1] = op(c[0], b[1]) + // c[2] = op(c[1], b[2]) // ... if (use_fusion) { - fops[0] = GGML_OP_ADD; - fops[1] = GGML_OP_ADD; - fops[2] = GGML_OP_ADD; - fops[3] = GGML_OP_ADD; - fops[4] = GGML_OP_ADD; - fops[5] = GGML_OP_ADD; - fops[6] = GGML_OP_ADD; - fops[7] = GGML_OP_ADD; + ggml_op cur_op = op->op; + for (int i = 0; i < 8; ++i) { + fops[i] = cur_op; + } // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops // across splits. idx_end indicates the last node in the current split @@ -3165,7 +3161,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { ++n_fuse; if (debug_fusion > 1 && n_fuse > 1) { - GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + GGML_LOG_DEBUG("%s: fuse: %s x %d\n", __func__, ggml_op_name(cur_op), n_fuse); } }