diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 0a3e071e5..94b19b82b 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -2300,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #if defined __AVX2__ - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m15 = _mm256_set1_epi8(15); __m256 acc = _mm256_setzero_ps(); @@ -2314,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * GGML_RESTRICT qh = x[i].qh; const int8_t * GGML_RESTRICT q8 = y[i].qs; + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m256i scales_16 = _mm256_cvtepi8_epi16(scales); + const __m256i q8sclsub = _mm256_slli_epi32(_mm256_madd_epi16(q8sums, scales_16), 5); __m256i sumi = _mm256_setzero_si256(); int is = 0; for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m3), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(12)), 2); + const __m256i q4h_2 = _mm256_and_si256(q4bitsH, _mm256_set1_epi8(48)); + const __m256i q4h_3 = _mm256_srli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(-64)), 2); - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m15), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m15), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m15), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m15), q4h_3); const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); @@ -2372,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } + sumi = _mm256_sub_epi32(sumi, q8sclsub); acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); }