[metal] extend bin op fusion to MUL/SUB/DIV chains (#28) #38
@@ -394,19 +394,29 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
|
|||||||
// fuse only ops that start with these operations
|
// fuse only ops that start with these operations
|
||||||
// can be expanded when needed
|
// can be expanded when needed
|
||||||
if (node.op() == GGML_OP_ADD ||
|
if (node.op() == GGML_OP_ADD ||
|
||||||
|
node.op() == GGML_OP_SUB ||
|
||||||
|
node.op() == GGML_OP_MUL ||
|
||||||
|
node.op() == GGML_OP_DIV ||
|
||||||
node.op() == GGML_OP_NORM ||
|
node.op() == GGML_OP_NORM ||
|
||||||
node.op() == GGML_OP_RMS_NORM) {
|
node.op() == GGML_OP_RMS_NORM) {
|
||||||
ops[0] = node.op();
|
ops[0] = node.op();
|
||||||
|
|
||||||
int f = i + 1;
|
int f = i + 1;
|
||||||
while (f < n && f < i + MAX_FUSE) {
|
while (f < n && f < i + MAX_FUSE) {
|
||||||
// conservatively allow fusing only these ops
|
// bin ops (ADD/SUB/MUL/DIV) must be same type to fuse
|
||||||
// can be expanded when needed
|
// NORM/RMS_NORM can chain with MUL/ADD
|
||||||
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
if (node.op() == GGML_OP_ADD ||
|
||||||
gf->nodes[f]->op != GGML_OP_MUL &&
|
node.op() == GGML_OP_SUB ||
|
||||||
gf->nodes[f]->op != GGML_OP_NORM &&
|
node.op() == GGML_OP_MUL ||
|
||||||
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
node.op() == GGML_OP_DIV) {
|
||||||
break;
|
if (gf->nodes[f]->op != node.op()) break;
|
||||||
|
} else {
|
||||||
|
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
||||||
|
gf->nodes[f]->op != GGML_OP_MUL &&
|
||||||
|
gf->nodes[f]->op != GGML_OP_NORM &&
|
||||||
|
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ops[f - i] = gf->nodes[f]->op;
|
ops[f - i] = gf->nodes[f]->op;
|
||||||
f++;
|
f++;
|
||||||
|
|||||||
@@ -3118,19 +3118,15 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||||||
|
|
||||||
int n_fuse = 1;
|
int n_fuse = 1;
|
||||||
|
|
||||||
// c[0] = add(a, b[0])
|
// c[0] = op(a, b[0])
|
||||||
// c[1] = add(c[0], b[1])
|
// c[1] = op(c[0], b[1])
|
||||||
// c[2] = add(c[1], b[2])
|
// c[2] = op(c[1], b[2])
|
||||||
// ...
|
// ...
|
||||||
if (use_fusion) {
|
if (use_fusion) {
|
||||||
fops[0] = GGML_OP_ADD;
|
ggml_op cur_op = op->op;
|
||||||
fops[1] = GGML_OP_ADD;
|
for (int i = 0; i < 8; ++i) {
|
||||||
fops[2] = GGML_OP_ADD;
|
fops[i] = cur_op;
|
||||||
fops[3] = GGML_OP_ADD;
|
}
|
||||||
fops[4] = GGML_OP_ADD;
|
|
||||||
fops[5] = GGML_OP_ADD;
|
|
||||||
fops[6] = GGML_OP_ADD;
|
|
||||||
fops[7] = GGML_OP_ADD;
|
|
||||||
|
|
||||||
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
|
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
|
||||||
// across splits. idx_end indicates the last node in the current split
|
// across splits. idx_end indicates the last node in the current split
|
||||||
@@ -3165,7 +3161,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||||||
++n_fuse;
|
++n_fuse;
|
||||||
|
|
||||||
if (debug_fusion > 1 && n_fuse > 1) {
|
if (debug_fusion > 1 && n_fuse > 1) {
|
||||||
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
|
GGML_LOG_DEBUG("%s: fuse: %s x %d\n", __func__, ggml_op_name(cur_op), n_fuse);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user