fix(metal): correct Q4_0 contiguous kernel nibble extraction (#29) #39

Closed
sleepy wants to merge 3 commits from fix/29-q40-contig-reads into master
Showing only changes of commit 31ce8b1ae5 - Show all commits
+19 -32
View File
@@ -3575,49 +3575,36 @@ kernel void kernel_mul_mv_q4_0_f32_c(
const short il = (tiisg % (NW / 16)) * 8;
const int ib0 = ix;
const uint q_off = il / 8;
device const float * yb = y + ib0 * QK4_0 + il;
for (int ib = ib0; ib < nb; ib += 16) {
float sumy[2] = {0.f, 0.f};
float sumy = 0.f;
FOR_UNROLL (short i = 0; i < 8; i += 2) {
sumy[0] += yb[i + 0] + yb[i + 1];
float yl0 = yb[i + 0];
float yl1 = yb[i + 1] / 256.f;
sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17];
}
sumy[1] += yb[i + 16] + yb[i + 17];
float yl8 = yb[i + 16] / 16.f;
float yl9 = yb[i + 17] / 4096.f;
FOR_UNROLL (short row = 0; row < NR0; row++) {
const float d = ax[row][ib].d;
device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs);
FOR_UNROLL (short row = 0; row < NR0; row++) {
const float d = ax[row][ib].d;
const uint32_t q0 = qs[q_off];
const uint32_t q1 = qs[q_off + 2];
device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs);
float acc = 0.f;
float a0 = 0.f, a1 = 0.f, a2 = 0.f, a3 = 0.f;
FOR_UNROLL (short i = 0; i < 8; i += 2) {
const uint ni = i / 2;
a0 += yl0 * (qs[0] & 0x0000000F);
a1 += yl1 * (qs[0] & 0x00000F00);
a2 += yl8 * (qs[0] & 0x0000F000);
a3 += yl9 * (qs[0] & 0x000F0000);
a0 += yl0 * (qs[1] & 0x0000000F);
a1 += yl1 * (qs[1] & 0x00000F00);
a2 += yl8 * (qs[1] & 0x0000F000);
a3 += yl9 * (qs[1] & 0x000F0000);
a0 += yl0 * (qs[2] & 0x0000000F);
a1 += yl1 * (qs[2] & 0x00000F00);
a2 += yl8 * (qs[2] & 0x0000F000);
a3 += yl9 * (qs[2] & 0x000F0000);
a0 += yl0 * (qs[3] & 0x0000000F);
a1 += yl1 * (qs[3] & 0x00000F00);
a2 += yl8 * (qs[3] & 0x0000F000);
a3 += yl9 * (qs[3] & 0x000F0000);
sumf[row] += d * (sumy[0] + sumy[1]) * -8.f + d * (a0 + a1 + a2 + a3);
acc += ((q0 >> (4 * ni)) & 0xF) * yb[i + 0]
+ ((q0 >> (4 * (ni + 1))) & 0xF) * yb[i + 1]
+ ((q1 >> (4 * ni)) & 0xF) * yb[i + 16]
+ ((q1 >> (4 * (ni + 1))) & 0xF) * yb[i + 17];
}
sumf[row] += d * (acc + sumy * -8.f);
}
yb += QK4_0 * 16;