diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 2eb9820bf..44e10b6b8 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -394,19 +394,29 @@ void ggml_graph_optimize(ggml_cgraph * gf) { // fuse only ops that start with these operations // can be expanded when needed if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_SUB || + node.op() == GGML_OP_MUL || + node.op() == GGML_OP_DIV || node.op() == GGML_OP_NORM || node.op() == GGML_OP_RMS_NORM) { ops[0] = node.op(); int f = i + 1; while (f < n && f < i + MAX_FUSE) { - // conservatively allow fusing only these ops - // can be expanded when needed - if (gf->nodes[f]->op != GGML_OP_ADD && - gf->nodes[f]->op != GGML_OP_MUL && - gf->nodes[f]->op != GGML_OP_NORM && - gf->nodes[f]->op != GGML_OP_RMS_NORM) { - break; + // bin ops (ADD/SUB/MUL/DIV) must be same type to fuse + // NORM/RMS_NORM can chain with MUL/ADD + if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_SUB || + node.op() == GGML_OP_MUL || + node.op() == GGML_OP_DIV) { + if (gf->nodes[f]->op != node.op()) break; + } else { + if (gf->nodes[f]->op != GGML_OP_ADD && + gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_NORM && + gf->nodes[f]->op != GGML_OP_RMS_NORM) { + break; + } } ops[f - i] = gf->nodes[f]->op; f++; diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f..4abec6739 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -732,6 +732,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta int nr1 = 1; // number of src1 rows per threadgroup size_t smem = 0; // shared memory + bool contig = false; const ggml_type tsrc0 = op->src[0]->type; const ggml_type tsrc1 = op->src[1]->type; @@ -766,6 +767,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta { nsg = N_SG_Q4_0; nr0 = N_R0_Q4_0; + contig = ne00 >= 256; } break; case GGML_TYPE_Q4_1: { @@ -877,7 +879,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; - snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), contig ? "_c" : "", suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c87..d4a798e14 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3118,19 +3118,15 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { int n_fuse = 1; - // c[0] = add(a, b[0]) - // c[1] = add(c[0], b[1]) - // c[2] = add(c[1], b[2]) + // c[0] = op(a, b[0]) + // c[1] = op(c[0], b[1]) + // c[2] = op(c[1], b[2]) // ... if (use_fusion) { - fops[0] = GGML_OP_ADD; - fops[1] = GGML_OP_ADD; - fops[2] = GGML_OP_ADD; - fops[3] = GGML_OP_ADD; - fops[4] = GGML_OP_ADD; - fops[5] = GGML_OP_ADD; - fops[6] = GGML_OP_ADD; - fops[7] = GGML_OP_ADD; + ggml_op cur_op = op->op; + for (int i = 0; i < 8; ++i) { + fops[i] = cur_op; + } // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops // across splits. idx_end indicates the last node in the current split @@ -3165,7 +3161,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { ++n_fuse; if (debug_fusion > 1 && n_fuse > 1) { - GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + GGML_LOG_DEBUG("%s: fuse: %s x %d\n", __func__, ggml_op_name(cur_op), n_fuse); } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaede..3c42730b4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3533,6 +3533,94 @@ kernel void kernel_mul_mv_q4_0_f32( mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +// Q4_0 kernel with contiguous uint32_t weight reads (MLX-style) +// Each thread reads 4 contiguous uint32_t packs per block instead of +// 8 strided uint16_t reads, improving memory coalescing on Apple GPUs. +kernel void kernel_mul_mv_q4_0_f32_c( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NR0 = N_R0_Q4_0; + constexpr short NW = N_SIMDWIDTH; + + const int nb = args.ne00 / QK4_0; + + const int r0 = (tgpig.x * NSG + sgitg) * NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im % args.ne12; + const uint i13 = im / args.ne12; + + const uint64_t offset1 = r1 * args.nb11 + (i12) * args.nb12 + (i13) * args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q4_0 * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row) * args.nb01 + (i12 / args.r2) * args.nb02 + (i13 / args.r3) * args.nb03; + ax[row] = (device const block_q4_0 *) ((device char *) src0 + offset0); + } + + float sumf[NR0] = {0.f}; + + const short ix = (tiisg / (NW / 16)); + const short il = (tiisg % (NW / 16)) * 8; + const int ib0 = ix; + + const uint q_off = il / 8; + + device const float * yb = y + ib0 * QK4_0 + il; + + for (int ib = ib0; ib < nb; ib += 16) { + float sumy = 0.f; + + FOR_UNROLL (short i = 0; i < 8; i += 2) { + sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17]; + } + + FOR_UNROLL (short row = 0; row < NR0; row++) { + const float d = ax[row][ib].d; + device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs); + + const uint32_t q0 = qs[q_off]; + const uint32_t q1 = qs[q_off + 2]; + + float acc = 0.f; + + FOR_UNROLL (short i = 0; i < 8; i += 2) { + const uint ni = i / 2; + + acc += ((q0 >> (4 * ni)) & 0xF) * yb[i + 0] + + ((q0 >> (4 * (ni + 1))) & 0xF) * yb[i + 1] + + ((q1 >> (4 * ni)) & 0xF) * yb[i + 16] + + ((q1 >> (4 * (ni + 1))) & 0xF) * yb[i + 17]; + } + + sumf[row] += d * (acc + sumy * -8.f); + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im * args.ne0 * args.ne1 + r1 * args.ne0; + + for (int row = 0; row < NR0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; + } + } +} + kernel void kernel_mul_mv_q4_1_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0,