From 06f05e71c12f1cf86147e0f775fe73655c633829 Mon Sep 17 00:00:00 2001 From: Kaloyan Nikolov Date: Thu, 30 Apr 2026 22:38:37 +0200 Subject: [PATCH] [metal] wire contiguous Q4_0 kernel into dispatch (#29) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +- ggml/src/ggml-metal/ggml-metal.metal | 101 ++++++++++++++++++++++ 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f..4abec6739 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -732,6 +732,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta int nr1 = 1; // number of src1 rows per threadgroup size_t smem = 0; // shared memory + bool contig = false; const ggml_type tsrc0 = op->src[0]->type; const ggml_type tsrc1 = op->src[1]->type; @@ -766,6 +767,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta { nsg = N_SG_Q4_0; nr0 = N_R0_Q4_0; + contig = ne00 >= 256; } break; case GGML_TYPE_Q4_1: { @@ -877,7 +879,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; - snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), contig ? "_c" : "", suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaede..0f34a5382 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3533,6 +3533,107 @@ kernel void kernel_mul_mv_q4_0_f32( mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +// Q4_0 kernel with contiguous uint32_t weight reads (MLX-style) +// Each thread reads 4 contiguous uint32_t packs per block instead of +// 8 strided uint16_t reads, improving memory coalescing on Apple GPUs. +kernel void kernel_mul_mv_q4_0_f32_c( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NR0 = N_R0_Q4_0; + constexpr short NW = N_SIMDWIDTH; + + const int nb = args.ne00 / QK4_0; + + const int r0 = (tgpig.x * NSG + sgitg) * NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im % args.ne12; + const uint i13 = im / args.ne12; + + const uint64_t offset1 = r1 * args.nb11 + (i12) * args.nb12 + (i13) * args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q4_0 * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row) * args.nb01 + (i12 / args.r2) * args.nb02 + (i13 / args.r3) * args.nb03; + ax[row] = (device const block_q4_0 *) ((device char *) src0 + offset0); + } + + float sumf[NR0] = {0.f}; + + const short ix = (tiisg / (NW / 16)); + const short il = (tiisg % (NW / 16)) * 8; + const int ib0 = ix; + + device const float * yb = y + ib0 * QK4_0 + il; + + for (int ib = ib0; ib < nb; ib += 16) { + float sumy[2] = {0.f, 0.f}; + + FOR_UNROLL (short i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + float yl0 = yb[i + 0]; + float yl1 = yb[i + 1] / 256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + float yl8 = yb[i + 16] / 16.f; + float yl9 = yb[i + 17] / 4096.f; + + FOR_UNROLL (short row = 0; row < NR0; row++) { + const float d = ax[row][ib].d; + + device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs); + + float a0 = 0.f, a1 = 0.f, a2 = 0.f, a3 = 0.f; + + a0 += yl0 * (qs[0] & 0x0000000F); + a1 += yl1 * (qs[0] & 0x00000F00); + a2 += yl8 * (qs[0] & 0x0000F000); + a3 += yl9 * (qs[0] & 0x000F0000); + + a0 += yl0 * (qs[1] & 0x0000000F); + a1 += yl1 * (qs[1] & 0x00000F00); + a2 += yl8 * (qs[1] & 0x0000F000); + a3 += yl9 * (qs[1] & 0x000F0000); + + a0 += yl0 * (qs[2] & 0x0000000F); + a1 += yl1 * (qs[2] & 0x00000F00); + a2 += yl8 * (qs[2] & 0x0000F000); + a3 += yl9 * (qs[2] & 0x000F0000); + + a0 += yl0 * (qs[3] & 0x0000000F); + a1 += yl1 * (qs[3] & 0x00000F00); + a2 += yl8 * (qs[3] & 0x0000F000); + a3 += yl9 * (qs[3] & 0x000F0000); + + sumf[row] += d * (sumy[0] + sumy[1]) * -8.f + d * (a0 + a1 + a2 + a3); + } + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im * args.ne0 * args.ne1 + r1 * args.ne0; + + for (int row = 0; row < NR0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; + } + } +} + kernel void kernel_mul_mv_q4_1_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0,