CUDA: fuse muls (#21665)
This commit is contained in:
@@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
|
||||||
|
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
|
||||||
|
|
||||||
|
switch (n_fuse) {
|
||||||
|
case 2:
|
||||||
|
ggml_cuda_op_fused_binbcast_impl<op_mul, 2>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
ggml_cuda_op_fused_binbcast_impl<op_mul, 3>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
ggml_cuda_op_fused_binbcast_impl<op_mul, 4>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
ggml_cuda_op_fused_binbcast_impl<op_mul, 5>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
ggml_cuda_op_fused_binbcast_impl<op_mul, 6>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
ggml_cuda_op_fused_binbcast_impl<op_mul, 7>(ctx, dst);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
ggml_cuda_op_fused_binbcast_impl<op_mul, 8>(ctx, dst);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false && "Unsupported n_fuse value");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
|||||||
@@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||||||
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
|
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
|
||||||
|
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
|
||||||
|
|||||||
@@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->op == GGML_OP_ADD) {
|
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
|
||||||
int n_fuse = 0;
|
int n_fuse = 0;
|
||||||
ggml_op ops[8];
|
ggml_op ops[8];
|
||||||
std::fill(ops, ops + 8, GGML_OP_ADD);
|
std::fill(ops, ops + 8, node->op);
|
||||||
|
|
||||||
for (; n_fuse <= 6; ++n_fuse){
|
for (; n_fuse <= 6; ++n_fuse){
|
||||||
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
|
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
|
||||||
@@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||||||
n_fuse++;
|
n_fuse++;
|
||||||
|
|
||||||
if (n_fuse > 1) {
|
if (n_fuse > 1) {
|
||||||
ggml_tensor fused_add_node;
|
ggml_tensor fused_node;
|
||||||
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
|
memcpy(&fused_node, node, sizeof(ggml_tensor));
|
||||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||||
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||||
|
}
|
||||||
|
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
||||||
|
if (node->op == GGML_OP_ADD) {
|
||||||
|
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
|
||||||
|
} else {
|
||||||
|
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
|
||||||
}
|
}
|
||||||
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
|
||||||
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
|
|
||||||
i += n_fuse - 1;
|
i += n_fuse - 1;
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
Reference in New Issue
Block a user