From eebf6bc0bd3e557bb085a5f9d373fefab372161f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20O?= Date: Fri, 7 Mar 2025 12:54:22 +0100 Subject: [PATCH] ggml-cpu: faster AVX2 variant for IQ1_M (llama/12216) --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 2ae66591..8c7dbd1c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -11718,9 +11718,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const #elif defined __AVX2__ - const __m256i mask = _mm256_set1_epi16(2 * 0x7); + const __m256i mask = _mm256_set1_epi16(0x7); const __m256i mone = _mm256_set1_epi16(1); const __m256i mone8 = _mm256_set1_epi8(1); + const __m256i mtwo8 = _mm256_set1_epi8(2); + // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half. + const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0); __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); @@ -11732,6 +11735,14 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const const uint16_t * sc = (const uint16_t *)x[i].scales; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + // Extract 3-bit scales (16 values) + __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc); + scales = _mm256_srlv_epi64(scales, scales_shift); + scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone); + + // Indices to repeat each scale 8 times. + __m256i scales_idx1 = _mm256_set1_epi16(0x0100); + __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8)); __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); @@ -11777,11 +11788,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1)); const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2)); - __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 2), _mm_set1_epi16(sc[ib/2] << 1)); - __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 8), _mm_set1_epi16(sc[ib/2] >> 5)); + __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1); + __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2); + + scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8); + scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8); - scale1 = _mm256_add_epi16(_mm256_and_si256(scale1, mask), mone); - scale2 = _mm256_add_epi16(_mm256_and_si256(scale2, mask), mone); const __m256i p1 = _mm256_madd_epi16(dot1, scale1); const __m256i p2 = _mm256_madd_epi16(dot2, scale2); const __m256i p3 = _mm256_madd_epi16(dot3, scale1);