diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0f34a5382..3c42730b4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3575,49 +3575,36 @@ kernel void kernel_mul_mv_q4_0_f32_c( const short il = (tiisg % (NW / 16)) * 8; const int ib0 = ix; + const uint q_off = il / 8; + device const float * yb = y + ib0 * QK4_0 + il; for (int ib = ib0; ib < nb; ib += 16) { - float sumy[2] = {0.f, 0.f}; + float sumy = 0.f; FOR_UNROLL (short i = 0; i < 8; i += 2) { - sumy[0] += yb[i + 0] + yb[i + 1]; - float yl0 = yb[i + 0]; - float yl1 = yb[i + 1] / 256.f; + sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17]; + } - sumy[1] += yb[i + 16] + yb[i + 17]; - float yl8 = yb[i + 16] / 16.f; - float yl9 = yb[i + 17] / 4096.f; + FOR_UNROLL (short row = 0; row < NR0; row++) { + const float d = ax[row][ib].d; + device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs); - FOR_UNROLL (short row = 0; row < NR0; row++) { - const float d = ax[row][ib].d; + const uint32_t q0 = qs[q_off]; + const uint32_t q1 = qs[q_off + 2]; - device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs); + float acc = 0.f; - float a0 = 0.f, a1 = 0.f, a2 = 0.f, a3 = 0.f; + FOR_UNROLL (short i = 0; i < 8; i += 2) { + const uint ni = i / 2; - a0 += yl0 * (qs[0] & 0x0000000F); - a1 += yl1 * (qs[0] & 0x00000F00); - a2 += yl8 * (qs[0] & 0x0000F000); - a3 += yl9 * (qs[0] & 0x000F0000); - - a0 += yl0 * (qs[1] & 0x0000000F); - a1 += yl1 * (qs[1] & 0x00000F00); - a2 += yl8 * (qs[1] & 0x0000F000); - a3 += yl9 * (qs[1] & 0x000F0000); - - a0 += yl0 * (qs[2] & 0x0000000F); - a1 += yl1 * (qs[2] & 0x00000F00); - a2 += yl8 * (qs[2] & 0x0000F000); - a3 += yl9 * (qs[2] & 0x000F0000); - - a0 += yl0 * (qs[3] & 0x0000000F); - a1 += yl1 * (qs[3] & 0x00000F00); - a2 += yl8 * (qs[3] & 0x0000F000); - a3 += yl9 * (qs[3] & 0x000F0000); - - sumf[row] += d * (sumy[0] + sumy[1]) * -8.f + d * (a0 + a1 + a2 + a3); + acc += ((q0 >> (4 * ni)) & 0xF) * yb[i + 0] + + ((q0 >> (4 * (ni + 1))) & 0xF) * yb[i + 1] + + ((q1 >> (4 * ni)) & 0xF) * yb[i + 16] + + ((q1 >> (4 * (ni + 1))) & 0xF) * yb[i + 17]; } + + sumf[row] += d * (acc + sumy * -8.f); } yb += QK4_0 * 16;