fix(metal): correct Q4_0 contiguous kernel nibble extraction (#29) #39
@@ -3575,49 +3575,36 @@ kernel void kernel_mul_mv_q4_0_f32_c(
|
||||
const short il = (tiisg % (NW / 16)) * 8;
|
||||
const int ib0 = ix;
|
||||
|
||||
const uint q_off = il / 8;
|
||||
|
||||
device const float * yb = y + ib0 * QK4_0 + il;
|
||||
|
||||
for (int ib = ib0; ib < nb; ib += 16) {
|
||||
float sumy[2] = {0.f, 0.f};
|
||||
float sumy = 0.f;
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 8; i += 2) {
|
||||
sumy[0] += yb[i + 0] + yb[i + 1];
|
||||
float yl0 = yb[i + 0];
|
||||
float yl1 = yb[i + 1] / 256.f;
|
||||
sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17];
|
||||
}
|
||||
|
||||
sumy[1] += yb[i + 16] + yb[i + 17];
|
||||
float yl8 = yb[i + 16] / 16.f;
|
||||
float yl9 = yb[i + 17] / 4096.f;
|
||||
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
||||
const float d = ax[row][ib].d;
|
||||
device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs);
|
||||
|
||||
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
||||
const float d = ax[row][ib].d;
|
||||
const uint32_t q0 = qs[q_off];
|
||||
const uint32_t q1 = qs[q_off + 2];
|
||||
|
||||
device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs);
|
||||
float acc = 0.f;
|
||||
|
||||
float a0 = 0.f, a1 = 0.f, a2 = 0.f, a3 = 0.f;
|
||||
FOR_UNROLL (short i = 0; i < 8; i += 2) {
|
||||
const uint ni = i / 2;
|
||||
|
||||
a0 += yl0 * (qs[0] & 0x0000000F);
|
||||
a1 += yl1 * (qs[0] & 0x00000F00);
|
||||
a2 += yl8 * (qs[0] & 0x0000F000);
|
||||
a3 += yl9 * (qs[0] & 0x000F0000);
|
||||
|
||||
a0 += yl0 * (qs[1] & 0x0000000F);
|
||||
a1 += yl1 * (qs[1] & 0x00000F00);
|
||||
a2 += yl8 * (qs[1] & 0x0000F000);
|
||||
a3 += yl9 * (qs[1] & 0x000F0000);
|
||||
|
||||
a0 += yl0 * (qs[2] & 0x0000000F);
|
||||
a1 += yl1 * (qs[2] & 0x00000F00);
|
||||
a2 += yl8 * (qs[2] & 0x0000F000);
|
||||
a3 += yl9 * (qs[2] & 0x000F0000);
|
||||
|
||||
a0 += yl0 * (qs[3] & 0x0000000F);
|
||||
a1 += yl1 * (qs[3] & 0x00000F00);
|
||||
a2 += yl8 * (qs[3] & 0x0000F000);
|
||||
a3 += yl9 * (qs[3] & 0x000F0000);
|
||||
|
||||
sumf[row] += d * (sumy[0] + sumy[1]) * -8.f + d * (a0 + a1 + a2 + a3);
|
||||
acc += ((q0 >> (4 * ni)) & 0xF) * yb[i + 0]
|
||||
+ ((q0 >> (4 * (ni + 1))) & 0xF) * yb[i + 1]
|
||||
+ ((q1 >> (4 * ni)) & 0xF) * yb[i + 16]
|
||||
+ ((q1 >> (4 * (ni + 1))) & 0xF) * yb[i + 17];
|
||||
}
|
||||
|
||||
sumf[row] += d * (acc + sumy * -8.f);
|
||||
}
|
||||
|
||||
yb += QK4_0 * 16;
|
||||
|
||||
Reference in New Issue
Block a user