diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index e99108dc5..bc580aeeb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -296,13 +296,22 @@ vec2 get_dm_scale(uint ib, uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; const uint is = iqs_k / 8; - u8vec2 scale_dm; - if (is < 4) { - scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); - } else { - scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), - (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); - } + + const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0], + data_a_packed32[ib_k].scales[1], + data_a_packed32[ib_k].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); + u8vec2 scale_dm = u8vec2(sc, mbyte); return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 6e4a29d2f..735951689 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -201,19 +201,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -237,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte;