From eeb79b026b006a4c57e7f49e5c63700526c052ee Mon Sep 17 00:00:00 2001 From: Kaloyan Nikolov Date: Thu, 30 Apr 2026 20:14:12 +0200 Subject: [PATCH 1/3] [metal] extend bin op fusion to MUL/SUB/DIV chains (#28) --- ggml/src/ggml-metal/ggml-metal-common.cpp | 24 ++++++++++++++++------- ggml/src/ggml-metal/ggml-metal-ops.cpp | 20 ++++++++----------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 2eb9820bf..44e10b6b8 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -394,19 +394,29 @@ void ggml_graph_optimize(ggml_cgraph * gf) { // fuse only ops that start with these operations // can be expanded when needed if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_SUB || + node.op() == GGML_OP_MUL || + node.op() == GGML_OP_DIV || node.op() == GGML_OP_NORM || node.op() == GGML_OP_RMS_NORM) { ops[0] = node.op(); int f = i + 1; while (f < n && f < i + MAX_FUSE) { - // conservatively allow fusing only these ops - // can be expanded when needed - if (gf->nodes[f]->op != GGML_OP_ADD && - gf->nodes[f]->op != GGML_OP_MUL && - gf->nodes[f]->op != GGML_OP_NORM && - gf->nodes[f]->op != GGML_OP_RMS_NORM) { - break; + // bin ops (ADD/SUB/MUL/DIV) must be same type to fuse + // NORM/RMS_NORM can chain with MUL/ADD + if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_SUB || + node.op() == GGML_OP_MUL || + node.op() == GGML_OP_DIV) { + if (gf->nodes[f]->op != node.op()) break; + } else { + if (gf->nodes[f]->op != GGML_OP_ADD && + gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_NORM && + gf->nodes[f]->op != GGML_OP_RMS_NORM) { + break; + } } ops[f - i] = gf->nodes[f]->op; f++; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c87..d4a798e14 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3118,19 +3118,15 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { int n_fuse = 1; - // c[0] = add(a, b[0]) - // c[1] = add(c[0], b[1]) - // c[2] = add(c[1], b[2]) + // c[0] = op(a, b[0]) + // c[1] = op(c[0], b[1]) + // c[2] = op(c[1], b[2]) // ... if (use_fusion) { - fops[0] = GGML_OP_ADD; - fops[1] = GGML_OP_ADD; - fops[2] = GGML_OP_ADD; - fops[3] = GGML_OP_ADD; - fops[4] = GGML_OP_ADD; - fops[5] = GGML_OP_ADD; - fops[6] = GGML_OP_ADD; - fops[7] = GGML_OP_ADD; + ggml_op cur_op = op->op; + for (int i = 0; i < 8; ++i) { + fops[i] = cur_op; + } // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops // across splits. idx_end indicates the last node in the current split @@ -3165,7 +3161,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { ++n_fuse; if (debug_fusion > 1 && n_fuse > 1) { - GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + GGML_LOG_DEBUG("%s: fuse: %s x %d\n", __func__, ggml_op_name(cur_op), n_fuse); } } -- 2.39.5 (Apple Git-154) From 06f05e71c12f1cf86147e0f775fe73655c633829 Mon Sep 17 00:00:00 2001 From: Kaloyan Nikolov Date: Thu, 30 Apr 2026 22:38:37 +0200 Subject: [PATCH 2/3] [metal] wire contiguous Q4_0 kernel into dispatch (#29) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +- ggml/src/ggml-metal/ggml-metal.metal | 101 ++++++++++++++++++++++ 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f..4abec6739 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -732,6 +732,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta int nr1 = 1; // number of src1 rows per threadgroup size_t smem = 0; // shared memory + bool contig = false; const ggml_type tsrc0 = op->src[0]->type; const ggml_type tsrc1 = op->src[1]->type; @@ -766,6 +767,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta { nsg = N_SG_Q4_0; nr0 = N_R0_Q4_0; + contig = ne00 >= 256; } break; case GGML_TYPE_Q4_1: { @@ -877,7 +879,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; - snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), contig ? "_c" : "", suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaede..0f34a5382 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3533,6 +3533,107 @@ kernel void kernel_mul_mv_q4_0_f32( mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +// Q4_0 kernel with contiguous uint32_t weight reads (MLX-style) +// Each thread reads 4 contiguous uint32_t packs per block instead of +// 8 strided uint16_t reads, improving memory coalescing on Apple GPUs. +kernel void kernel_mul_mv_q4_0_f32_c( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NR0 = N_R0_Q4_0; + constexpr short NW = N_SIMDWIDTH; + + const int nb = args.ne00 / QK4_0; + + const int r0 = (tgpig.x * NSG + sgitg) * NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im % args.ne12; + const uint i13 = im / args.ne12; + + const uint64_t offset1 = r1 * args.nb11 + (i12) * args.nb12 + (i13) * args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q4_0 * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row) * args.nb01 + (i12 / args.r2) * args.nb02 + (i13 / args.r3) * args.nb03; + ax[row] = (device const block_q4_0 *) ((device char *) src0 + offset0); + } + + float sumf[NR0] = {0.f}; + + const short ix = (tiisg / (NW / 16)); + const short il = (tiisg % (NW / 16)) * 8; + const int ib0 = ix; + + device const float * yb = y + ib0 * QK4_0 + il; + + for (int ib = ib0; ib < nb; ib += 16) { + float sumy[2] = {0.f, 0.f}; + + FOR_UNROLL (short i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + float yl0 = yb[i + 0]; + float yl1 = yb[i + 1] / 256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + float yl8 = yb[i + 16] / 16.f; + float yl9 = yb[i + 17] / 4096.f; + + FOR_UNROLL (short row = 0; row < NR0; row++) { + const float d = ax[row][ib].d; + + device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs); + + float a0 = 0.f, a1 = 0.f, a2 = 0.f, a3 = 0.f; + + a0 += yl0 * (qs[0] & 0x0000000F); + a1 += yl1 * (qs[0] & 0x00000F00); + a2 += yl8 * (qs[0] & 0x0000F000); + a3 += yl9 * (qs[0] & 0x000F0000); + + a0 += yl0 * (qs[1] & 0x0000000F); + a1 += yl1 * (qs[1] & 0x00000F00); + a2 += yl8 * (qs[1] & 0x0000F000); + a3 += yl9 * (qs[1] & 0x000F0000); + + a0 += yl0 * (qs[2] & 0x0000000F); + a1 += yl1 * (qs[2] & 0x00000F00); + a2 += yl8 * (qs[2] & 0x0000F000); + a3 += yl9 * (qs[2] & 0x000F0000); + + a0 += yl0 * (qs[3] & 0x0000000F); + a1 += yl1 * (qs[3] & 0x00000F00); + a2 += yl8 * (qs[3] & 0x0000F000); + a3 += yl9 * (qs[3] & 0x000F0000); + + sumf[row] += d * (sumy[0] + sumy[1]) * -8.f + d * (a0 + a1 + a2 + a3); + } + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im * args.ne0 * args.ne1 + r1 * args.ne0; + + for (int row = 0; row < NR0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; + } + } +} + kernel void kernel_mul_mv_q4_1_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, -- 2.39.5 (Apple Git-154) From 31ce8b1ae5da5f40101e6f1cd49ceebd51c1e8b5 Mon Sep 17 00:00:00 2001 From: Kaloyan Nikolov Date: Fri, 1 May 2026 00:13:56 +0200 Subject: [PATCH 3/3] fix(metal): correct Q4_0 contiguous kernel nibble extraction - Extract all 8 nibbles per uint32_t with proper bit shifts - Use il-based offset for uint32_t selection (qs[il/8] and qs[il/8+2]) - Apply bias correction once per block instead of 4x accumulated --- ggml/src/ggml-metal/ggml-metal.metal | 51 +++++++++++----------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0f34a5382..3c42730b4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3575,49 +3575,36 @@ kernel void kernel_mul_mv_q4_0_f32_c( const short il = (tiisg % (NW / 16)) * 8; const int ib0 = ix; + const uint q_off = il / 8; + device const float * yb = y + ib0 * QK4_0 + il; for (int ib = ib0; ib < nb; ib += 16) { - float sumy[2] = {0.f, 0.f}; + float sumy = 0.f; FOR_UNROLL (short i = 0; i < 8; i += 2) { - sumy[0] += yb[i + 0] + yb[i + 1]; - float yl0 = yb[i + 0]; - float yl1 = yb[i + 1] / 256.f; + sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17]; + } - sumy[1] += yb[i + 16] + yb[i + 17]; - float yl8 = yb[i + 16] / 16.f; - float yl9 = yb[i + 17] / 4096.f; + FOR_UNROLL (short row = 0; row < NR0; row++) { + const float d = ax[row][ib].d; + device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs); - FOR_UNROLL (short row = 0; row < NR0; row++) { - const float d = ax[row][ib].d; + const uint32_t q0 = qs[q_off]; + const uint32_t q1 = qs[q_off + 2]; - device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs); + float acc = 0.f; - float a0 = 0.f, a1 = 0.f, a2 = 0.f, a3 = 0.f; + FOR_UNROLL (short i = 0; i < 8; i += 2) { + const uint ni = i / 2; - a0 += yl0 * (qs[0] & 0x0000000F); - a1 += yl1 * (qs[0] & 0x00000F00); - a2 += yl8 * (qs[0] & 0x0000F000); - a3 += yl9 * (qs[0] & 0x000F0000); - - a0 += yl0 * (qs[1] & 0x0000000F); - a1 += yl1 * (qs[1] & 0x00000F00); - a2 += yl8 * (qs[1] & 0x0000F000); - a3 += yl9 * (qs[1] & 0x000F0000); - - a0 += yl0 * (qs[2] & 0x0000000F); - a1 += yl1 * (qs[2] & 0x00000F00); - a2 += yl8 * (qs[2] & 0x0000F000); - a3 += yl9 * (qs[2] & 0x000F0000); - - a0 += yl0 * (qs[3] & 0x0000000F); - a1 += yl1 * (qs[3] & 0x00000F00); - a2 += yl8 * (qs[3] & 0x0000F000); - a3 += yl9 * (qs[3] & 0x000F0000); - - sumf[row] += d * (sumy[0] + sumy[1]) * -8.f + d * (a0 + a1 + a2 + a3); + acc += ((q0 >> (4 * ni)) & 0xF) * yb[i + 0] + + ((q0 >> (4 * (ni + 1))) & 0xF) * yb[i + 1] + + ((q1 >> (4 * ni)) & 0xF) * yb[i + 16] + + ((q1 >> (4 * (ni + 1))) & 0xF) * yb[i + 17]; } + + sumf[row] += d * (acc + sumy * -8.f); } yb += QK4_0 * 16; -- 2.39.5 (Apple Git-154)