From f79d0d4f74200711724da0751d222e97e07dcbab Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Mon, 11 Mar 2024 07:51:49 +0100 Subject: [PATCH] Better 1.5 bit quantization (llama/5971) * Trying blocvks of 16 for IQ1_S - seems slightly better * iq1s_blocks16: Adjust scale fudge factor to 1.125 * iq1s_blocks16: going to blocks of 32 with 2048 lattice points, so same bpw. This is even better than blocks of 16. Should I try blocks of 64? But to keep the same bpw, when I go to 4096 lattice points, I need to remove blocks alltogether and just have superblocks of 256 weights. * iq1s_blocks16: Use 2* as sigma2 in weight adjustment * iq1s_blocks16: scalar and AVX2 dot products * iq1s_blocks16: CUDA dot product * iq1s_blocks16: Metal works, Neon does not Metal works but TG is dog slow (35 t/s). PP is OKish (493 t/s). Not seeing the bug in the Neon implementation for now. * iq1s_blocks16: fixed Neon * iq1s_blocks16: very slightly faster TG on Metal Still pathetic at 37 t/s * iq1s_blocks16: speedup Metal by packing codebook into uint32_t's * Formatting * iq1s_blocks16: uint32_t codebook is also better in CUDA TG-128 is now 204 t/s up from 194 t/s. PP-512 is 5890 t/s, so significantly better than other quants * iq1s_blocks16: slightly faster Neon dot product * iq1s_blocks16: faster AVX2 dot product * iq1s_blocks16: adjust to ggml-common.h --------- Co-authored-by: Iwan Kawrakow --- ggml-cuda.cu | 62 +++--- ggml-metal.metal | 66 +++--- ggml-quants.c | 510 ++++++++++++++++++++++++++++------------------- ggml-quants.h | 4 +- 4 files changed, 378 insertions(+), 264 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c207ff87..d2945d3c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -565,8 +565,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N #define QI1_S (QK_K / (4*QR1_S)) typedef struct { half d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; } block_iq1_s; static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); @@ -1722,11 +1722,22 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ const int il = tid/8; // 0...3 const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; - const int i8 = 4*ib+il; - uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); - const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); - const float d = (float)x[i].d * (2*(h & 7) + 1); - for (int j = 0; j < 8; ++j) y[j] = d * grid[j]; + const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1); +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = *((const int *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)))); + grid32[1] = __vsub4((grid32[0] >> 4) & 0x0f0f0f0f, 0x01010101); + grid32[0] = __vsub4(grid32[0] & 0x0f0f0f0f, 0x01010101); + for (int j = 0; j < 8; ++j) { + y[j] = d * q[j]; + } +#else + const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8))); + for (int j = 0; j < 4; ++j) { + y[j+0] = d * ((grid[j] & 0xf) - 1); + y[j+4] = d * ((grid[j] >> 4) - 1); + } +#endif #else assert(false); #endif @@ -4538,44 +4549,33 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( #endif } - static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #if QK_K == 256 const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const int ib32 = iqs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - const uint8_t h1 = bq1->scales[2*ib32+0]; - const uint8_t h2 = bq1->scales[2*ib32+1]; + int sumi = 0; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const int * q8 = (const int *)bq8_1[ib32].qs; - const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); - const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); - const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); - const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); - for (int j = 0; j < 2; ++j) { - sumi1 = __dp4a(q8[j+0], grid1[j], sumi1); - sumi2 = __dp4a(q8[j+2], grid2[j], sumi2); - sumi3 = __dp4a(q8[j+4], grid3[j], sumi3); - sumi4 = __dp4a(q8[j+6], grid4[j], sumi4); + for (int l = 0; l < 4; ++l) { + const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8))); + int grid0 = __vsub4(grid[0] & 0x0f0f0f0f, 0x01010101); + int grid1 = __vsub4((grid[0] >> 4) & 0x0f0f0f0f, 0x01010101); + sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi)); } #else const int8_t * q8 = bq8_1[ib32].qs; - const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); - const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); - const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); - const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); - for (int j = 0; j < 8; ++j) { - sumi1 += q8[j+ 0] * grid1[j]; - sumi2 += q8[j+ 8] * grid2[j]; - sumi3 += q8[j+16] * grid3[j]; - sumi4 += q8[j+24] * grid4[j]; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8))); + for (int j = 0; j < 4; ++j) { + sumi += q8[j] * ((grid[j] & 0xf) - 1) + q8[j+4] * ((grid[j] >> 4) - 1); + } + q8 += 8; } #endif const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds); - return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + - sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); + return d * sumi * (2*(bq1->qh[ib32] >> 12) + 1); #else assert(false); return 0.f; diff --git a/ggml-metal.metal b/ggml-metal.metal index 50185ae4..912822ee 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2595,8 +2595,8 @@ typedef struct { typedef struct { half d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; } block_iq1_s; // Non-linear quants @@ -4338,48 +4338,53 @@ void kernel_mul_mv_iq1_s_f32_impl( device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; + float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg/2; - const int il = tiisg%2; + const int ix = tiisg; - device const float * y4 = y + 32 * ix + 16 * il; + device const float * y4 = y + 32 * ix; - for (int ib32 = ix; ib32 < nb32; ib32 += 16) { + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 16; ++i) { + float sumy = 0; + for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; + sumy += yl[i]; } const int ibl = ib32 / (QK_K / 32); const int ib = ib32 % (QK_K / 32); device const block_iq1_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib + 2 * il; - device const uint8_t * sc = xr->scales + 2 * ib + il; - device const half * dh = &xr->d; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); - float2 sum = {0}; - for (int j = 0; j < 8; ++j) { - sum[0] += yl[j+ 0] * grid1[j]; - sum[1] += yl[j+ 8] * grid2[j]; + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); } - sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1)); + sumf[row] += (float)dh[0] * (sum - sumy) * (2*(qh[0] >> 12) + 1); dh += nb*sizeof(block_iq1_s)/2; qs += nb*sizeof(block_iq1_s); - sc += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; } - y4 += 16 * 32; + y4 += 32 * 32; } for (int row = 0; row < N_DST; ++row) { @@ -5066,16 +5071,19 @@ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & template void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; const float d = xb->d; - device const uint8_t * qs = xb->qs + 2*il; - device const uint8_t * sc = xb->scales + il; - const float dl1 = d * (2*(sc[0] & 7) + 1); - const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1); - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = dl1 * grid1[i]; - reg[i/4+2][i%4] = dl2 * grid2[i]; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*(qh[ib32] >> 12) + 1); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) - dl; + reg[1][i] = dl * (grid1[i] >> 4) - dl; + reg[2][i] = dl * (grid2[i] & 0xf) - dl; + reg[3][i] = dl * (grid2[i] >> 4) - dl; } } diff --git a/ggml-quants.c b/ggml-quants.c index 42d8a5d8..f9a3d9fd 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3449,39 +3449,22 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in assert(k % QK_K == 0); const int nb = k / QK_K; - float db[4]; - uint16_t idx[4]; - //const int8_t * grid[4]; - for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); - const uint8_t * sc = x[i].scales; - const uint8_t * qs = x[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; - for (int i8 = 0; i8 < QK_K/8; i8 += 4) { - idx[0] = qs[0] | ((sc[0] & 0x08) << 5); - idx[1] = qs[1] | ((sc[0] & 0x80) << 1); - idx[2] = qs[2] | ((sc[1] & 0x08) << 5); - idx[3] = qs[3] | ((sc[1] & 0x80) << 1); - //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); - //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); - //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5))); - //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1))); - db[0] = d * (2*(sc[0] & 7) + 1); - db[1] = d * (2*((sc[0] >> 4) & 7) + 1); - db[2] = d * (2*(sc[1] & 7) + 1); - db[3] = d * (2*((sc[1] >> 4) & 7) + 1); + for (int ib = 0; ib < QK_K/32; ++ib) { + const float dl = d * (2*(qh[ib] >> 12) + 1); for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); for (int j = 0; j < 8; ++j) { - //y[j] = db[l] * grid[l][j]; - y[j] = db[l] * grid[j]; + y[j] = dl * grid[j]; } y += 8; } qs += 4; - sc += 2; } } } @@ -9587,113 +9570,72 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void const int nb = n / QK_K; - // TODO: implement for QK_K = 64 -#if defined __ARM_NEON && QK_K == 256 +#if defined __ARM_NEON - const uint8x16_t m8 = vdupq_n_u8(0x08); - const uint8x16_t m7 = vdupq_n_u8(0x07); - const uint8x16_t m1 = vdupq_n_u8(0x01); - const int32x4_t vzero = vdupq_n_s32(0); - - uint16_t gindex[8]; - uint16x8x2_t vindex; - int8x16x4_t q1b; + ggml_int8x16x4_t q1b; ggml_int8x16x4_t q8b; - uint16x8x4_t scales; - int32x4x2_t sumi; - int32x4x2_t dotq; float sumf = 0; for (int i = 0; i < nb; ++i) { - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * sc = x[i].scales; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; - sumi.val[0] = sumi.val[1] = vzero; + int sumi1 = 0, sumi2 = 0; - for (int i128 = 0; i128 < QK_K/128; ++i128) { - const uint8x16_t ql = vld1q_u8(qs); qs += 16; - const uint8x8_t tm1 = vld1_u8 (sc); sc += 8; - const uint8x8_t tm2 = vshr_n_u8(tm1, 4); - const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2)); - const uint8x16_t hbit = vandq_u8(qh, m8); - vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5)); - vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5)); - const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1); - scales.val[0] = vmovl_u8(vget_low_u8 (scales8)); - scales.val[1] = vmovl_u8(vget_high_u8 (scales8)); + for (int ib = 0; ib < QK_K/32; ib += 2) { - for (int l = 0; l < 2; ++l) { - vst1q_u16(gindex+0, vindex.val[l]); - q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1]))); - q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3]))); - q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5]))); - q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7]))); - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); + qs += 8; - dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1])); - dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3])); + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * (2*(qh[ib+0] >> 12) + 1); + sumi2 += vaddvq_s32(p2) * (2*(qh[ib+1] >> 12) + 1); - sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l])))); - sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l])))); - } } - sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1])); + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2); } *s = sumf; - // TODO: implement for QK_K = 64 -#elif defined __AVX2__ && QK_K == 256 - - const __m128i m8 = _mm_set1_epi8(0x08); - const __m128i m7 = _mm_set1_epi8(0x07); - const __m128i m1 = _mm_set1_epi8(0x01); - const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); - const __m128i shuffle_s[4] = { - _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000), - _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404), - _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808), - _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c) - }; - - uint64_t aux64; - - typedef union m256i_uint16 { - __m256i reg; - uint16_t s[16]; - } m256i_uint16_t; - - m256i_uint16_t v_gindex; +#elif defined __AVX2__ __m256 accum = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * sc = x[i].scales; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; __m256i sumi = _mm256_setzero_si256(); - for (int i128 = 0; i128 < QK_K/128; ++i128) { - const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16; - memcpy(&aux64, sc, 8); sc += 8; - const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h); - const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8)); - v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5)); - const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - for (int i32 = 0; i32 < 4; ++i32) { - const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]], - iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]); - const __m256i dot = mul_add_epi8(q1b, q8b); - const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32])); - const __m256i p = _mm256_madd_epi16(s16, dot); - sumi = _mm256_add_epi32(sumi, p); - } + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*(qh[ib+0] >> 12) + 1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*(qh[ib+1] >> 12) + 1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); } accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum); @@ -9704,35 +9646,26 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void #else - int db[4]; - uint16_t idx[4]; - float sumf = 0; - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; i++) { - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * sc = x[i].scales; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; int sumi = 0; - for (int i32 = 0; i32 < QK_K/32; ++i32) { - idx[0] = qs[0] | ((sc[0] & 0x08) << 5); - idx[1] = qs[1] | ((sc[0] & 0x80) << 1); - idx[2] = qs[2] | ((sc[1] & 0x08) << 5); - idx[3] = qs[3] | ((sc[1] & 0x80) << 1); - db[0] = (2*(sc[0] & 7) + 1); - db[1] = (2*((sc[0] >> 4) & 7) + 1); - db[2] = (2*(sc[1] & 7) + 1); - db[3] = (2*((sc[1] >> 4) & 7) + 1); + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*(qh[ib] >> 12) + 1; + int lsum = 0; for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); - int suml = 0; - for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j]; - sumi += db[l] * suml; + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } q8 += 8; } + sumi += ls * lsum; qs += 4; - sc += 2; } sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi; @@ -9996,7 +9929,7 @@ static inline int iq2_grid_size(enum ggml_type type) { GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S); return type == GGML_TYPE_IQ2_XXS ? 256 : type == GGML_TYPE_IQ2_XS ? 512 : - type == GGML_TYPE_IQ1_S ? 512 : 1024; + type == GGML_TYPE_IQ1_S ? NGRID_IQ1S : 1024; } static int iq2_compare_func(const void * left, const void * right) { @@ -10063,39 +9996,135 @@ void iq2xs_init_impl(enum ggml_type type) { 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048, 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690, }; - static const uint16_t kgrid_1bit_512[512] = { - 10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545, - 553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444, - 1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440, - 2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422, - 4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397, - 5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769, - 5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788, - 6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794, - 9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272, - 10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665, - 16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685, - 17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529, - 18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517, - 20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872, - 20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653, - 21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842, - 21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913, - 21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608, - 22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072, - 23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110, - 25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937, - 25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885, - 26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808, - 32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320, - 33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918, - 34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125, - 37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973, - 38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485, - 38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497, - 39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514, - 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512, - 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680, + static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = { + 0, 2, 5, 8, 10, 17, 21, 32, 34, 40, 42, 69, 81, 84, 86, 101, + 128, 130, 136, 138, 149, 160, 162, 168, 170, 260, 261, 273, 276, 278, 281, 282, + 293, 321, 326, 329, 338, 341, 346, 353, 356, 358, 360, 389, 401, 404, 406, 421, + 512, 514, 520, 522, 533, 544, 546, 552, 554, 581, 593, 601, 612, 617, 640, 642, + 648, 650, 657, 661, 665, 672, 674, 680, 682, 1041, 1044, 1046, 1061, 1089, 1097, 1109, + 1114, 1124, 1125, 1169, 1177, 1189, 1281, 1284, 1285, 1286, 1301, 1304, 1306, 1321, 1344, 1349, + 1354, 1360, 1361, 1364, 1365, 1366, 1369, 1376, 1378, 1381, 1384, 1386, 1409, 1425, 1429, 1432, + 1434, 1441, 1444, 1445, 1446, 1449, 1556, 1561, 1601, 1604, 1616, 1618, 1621, 1624, 1632, 1633, + 1638, 1641, 1669, 1681, 1684, 1689, 2048, 2050, 2056, 2058, 2069, 2080, 2082, 2088, 2090, 2117, + 2129, 2134, 2149, 2176, 2178, 2184, 2186, 2197, 2208, 2210, 2216, 2218, 2309, 2321, 2324, 2329, + 2340, 2341, 2369, 2384, 2385, 2389, 2401, 2404, 2409, 2449, 2452, 2454, 2457, 2469, 2560, 2562, + 2568, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2641, 2649, 2657, 2661, 2688, 2690, 2693, 2696, + 2698, 2709, 2720, 2722, 2728, 2730, 4112, 4113, 4116, 4121, 4132, 4133, 4161, 4164, 4176, 4181, + 4184, 4193, 4196, 4197, 4201, 4241, 4244, 4246, 4257, 4261, 4353, 4356, 4358, 4361, 4368, 4370, + 4373, 4376, 4385, 4388, 4393, 4421, 4426, 4432, 4433, 4434, 4436, 4437, 4438, 4441, 4448, 4453, + 4484, 4498, 4501, 4513, 4516, 4625, 4628, 4630, 4645, 4672, 4678, 4681, 4690, 4693, 4696, 4698, + 4708, 4710, 4741, 4753, 4756, 4758, 4773, 5121, 5126, 5129, 5140, 5141, 5144, 5145, 5153, 5158, + 5185, 5189, 5190, 5192, 5194, 5201, 5204, 5205, 5206, 5209, 5218, 5221, 5224, 5252, 5257, 5264, + 5268, 5269, 5272, 5273, 5274, 5281, 5284, 5285, 5289, 5378, 5381, 5386, 5393, 5396, 5397, 5398, + 5401, 5408, 5410, 5413, 5416, 5418, 5441, 5444, 5445, 5446, 5457, 5458, 5460, 5461, 5462, 5465, + 5466, 5473, 5476, 5477, 5478, 5481, 5504, 5506, 5508, 5509, 5512, 5514, 5520, 5521, 5524, 5525, + 5526, 5529, 5530, 5536, 5538, 5541, 5633, 5636, 5637, 5638, 5653, 5654, 5656, 5658, 5665, 5670, + 5696, 5698, 5700, 5701, 5704, 5706, 5713, 5717, 5718, 5720, 5721, 5729, 5732, 5733, 5736, 5737, + 5738, 5766, 5770, 5778, 5781, 5796, 5801, 6161, 6166, 6181, 6209, 6212, 6214, 6217, 6224, 6229, + 6232, 6234, 6240, 6241, 6244, 6246, 6249, 6277, 6289, 6292, 6309, 6416, 6418, 6421, 6426, 6433, + 6437, 6466, 6468, 6469, 6472, 6481, 6484, 6485, 6486, 6489, 6490, 6496, 6501, 6506, 6537, 6545, + 6546, 6549, 6552, 6561, 6566, 6569, 6665, 6678, 6692, 6694, 6724, 6726, 6729, 6736, 6738, 6741, + 6744, 6753, 6758, 6761, 6789, 6801, 6806, 6810, 8192, 8194, 8200, 8202, 8213, 8224, 8226, 8229, + 8232, 8234, 8261, 8273, 8281, 8289, 8293, 8320, 8322, 8328, 8330, 8341, 8352, 8354, 8357, 8360, + 8362, 8453, 8465, 8468, 8473, 8485, 8514, 8516, 8521, 8533, 8536, 8538, 8545, 8548, 8549, 8550, + 8581, 8592, 8598, 8601, 8613, 8705, 8712, 8714, 8721, 8725, 8736, 8738, 8744, 8746, 8773, 8785, + 8790, 8793, 8805, 8833, 8840, 8842, 8849, 8853, 8864, 8866, 8872, 8874, 9221, 9236, 9238, 9241, + 9253, 9284, 9285, 9286, 9289, 9298, 9301, 9304, 9306, 9318, 9349, 9361, 9364, 9369, 9377, 9381, + 9481, 9493, 9505, 9513, 9536, 9541, 9544, 9553, 9556, 9557, 9561, 9570, 9573, 9576, 9609, 9616, + 9620, 9621, 9624, 9626, 9633, 9636, 9638, 9641, 9733, 9744, 9746, 9753, 9765, 9793, 9801, 9813, + 9824, 9825, 9833, 9860, 9862, 9872, 9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282, + 10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521, + 10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752, + 10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890, + 10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484, + 16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673, + 16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772, + 16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986, + 16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494, + 17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666, + 17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744, + 17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809, + 17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953, + 17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049, + 18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517, + 18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704, + 18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784, + 18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012, + 19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501, + 20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617, + 20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761, + 20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822, + 20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896, + 20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078, + 21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526, + 21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589, + 21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653, + 21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780, + 21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832, + 21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864, + 21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924, + 21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048, + 22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098, + 22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154, + 22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561, + 22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665, + 22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821, + 22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884, + 22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061, + 23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144, + 23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656, + 24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850, + 24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970, + 24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221, + 25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674, + 25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749, + 25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926, + 25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001, + 26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176, + 26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250, + 26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721, + 26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949, + 26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044, + 27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270, + 27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852, + 32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046, + 33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161, + 33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369, + 33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877, + 33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117, + 34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192, + 34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394, + 34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858, + 34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986, + 35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172, + 35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412, + 35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901, + 36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124, + 37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205, + 37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396, + 37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889, + 37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985, + 37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161, + 38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226, + 38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290, + 38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432, + 38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538, + 38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998, + 39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194, + 39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269, + 39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497, + 39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994, + 41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130, + 41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349, + 41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561, + 41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068, + 42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278, + 42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386, + 42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592, + 42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048, + 43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284, + 43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530, + 43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690, }; static const uint16_t kgrid_2bit_1024[1024] = { 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, @@ -10169,7 +10198,7 @@ void iq2xs_init_impl(enum ggml_type type) { const int nwant = type == GGML_TYPE_IQ1_S ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2; const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 : type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : - type == GGML_TYPE_IQ1_S ? kgrid_1bit_512 : kgrid_2bit_1024; + type == GGML_TYPE_IQ1_S ? kgrid_1bit_2048 : kgrid_2bit_1024; uint64_t * kgrid_q2xs; int * kmap_q2xs; uint16_t * kneighbors_q2xs; @@ -11408,12 +11437,70 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } +static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float d2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = (grid_i[j] - 3)/2; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + static int iq1_sort_helper(const void * left, const void * right) { const float * l = left; const float * r = right; return *l < *r ? -1 : *l > *r ? 1 : 0; } +#define IQ1S_BLOCK_SIZE 32 static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ1_S); @@ -11432,37 +11519,37 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy block_iq1_s * y = vy; - float scales[QK_K/8]; - float weight[8]; - int8_t L[8]; - float sumx[9]; - float sumw[9]; - float pairs[16]; + float scales[QK_K/IQ1S_BLOCK_SIZE]; + float weight[IQ1S_BLOCK_SIZE]; + int8_t L[IQ1S_BLOCK_SIZE]; + float sumx[IQ1S_BLOCK_SIZE+1]; + float sumw[IQ1S_BLOCK_SIZE+1]; + float pairs[2*IQ1S_BLOCK_SIZE]; int * idx = (int *)(pairs + 1); - uint8_t hbit[QK_K/8]; + uint16_t index[IQ1S_BLOCK_SIZE/8]; for (int ibl = 0; ibl < nbl; ++ibl) { y[ibl].d = GGML_FP32_TO_FP16(0.f); memset(y[ibl].qs, 0, QK_K/8); - memset(y[ibl].scales, 0, QK_K/16); + memset(y[ibl].qh, 0, QK_K/16); float max_scale = 0; const float * xbl = x + QK_K*ibl; float sumx2 = 0; for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; - float sigma2 = sumx2/QK_K; + float sigma2 = 2*sumx2/QK_K; - for (int ib = 0; ib < QK_K/8; ++ib) { - const float * xb = xbl + 8*ib; - const float * qw = quant_weights + QK_K*ibl + 8*ib; - for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) { + const float * xb = xbl + IQ1S_BLOCK_SIZE*ib; + const float * qw = quant_weights + QK_K*ibl + IQ1S_BLOCK_SIZE*ib; + for (int i = 0; i < IQ1S_BLOCK_SIZE; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); float max = fabsf(xb[0]); - for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i])); + for (int i = 1; i < IQ1S_BLOCK_SIZE; ++i) max = MAX(max, fabsf(xb[i])); if (!max) { scales[ib] = 0; - memset(L, 1, 8); + memset(L, 1, IQ1S_BLOCK_SIZE); continue; } // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. @@ -11471,14 +11558,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale // for each possible and score for each split. - for (int j = 0; j < 8; ++j) { + for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) { pairs[2*j] = xb[j]; idx[2*j] = j; } - qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper); + qsort(pairs, IQ1S_BLOCK_SIZE, 2*sizeof(float), iq1_sort_helper); { sumx[0] = sumw[0] = 0; - for (int j = 0; j < 8; ++j) { + for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) { int i = idx[2*j]; sumx[j+1] = sumx[j] + weight[i]*xb[i]; sumw[j+1] = sumw[j] + weight[i]; @@ -11486,10 +11573,10 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy } float best_score = 0, scale = max; int besti1 = 0, besti2 = 0; - for (int i1 = 0; i1 <= 8; ++i1) { - for (int i2 = i1; i2 <= 8; ++i2) { - float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]); - float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]); + for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) { + for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) { + float sumqx = -(sumx[i1] - sumx[0]) + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2]); + float sumq2 = (sumw[i1] - sumw[0]) + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2]); if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { scale = sumqx/sumq2; best_score = scale*sumqx; besti1 = i1; besti2 = i2; @@ -11498,23 +11585,43 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; - for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2; + for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2; if (scale < 0) { - for (int j = 0; j < 8; ++j) L[j] = 2 - L[j]; + for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j]; scale = -scale; } - // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring - // grid point that minimizes SSD. - uint16_t u = 0; - for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j); - int grid_index = kmap_q2xs[u]; - if (grid_index < 0) { - const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; - grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS); - GGML_ASSERT(grid_index >= 0); + bool all_on_grid = true; + for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) { + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; } - y[ibl].qs[ib] = grid_index & 255; - hbit[ib] = grid_index >> 8; + if (!all_on_grid) { + float sumqx = 0, sumq2 = 0; + for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = (pg[j] - 3)/2; + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } + if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; + } + uint16_t h = 0; + for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) { + y[ibl].qs[(IQ1S_BLOCK_SIZE/8)*ib + k] = index[k] & 255; + h |= (index[k] >> 8) << 3*k; + } + y[ibl].qh[ib] = h; GGML_ASSERT(scale >= 0); scales[ib] = scale; max_scale = MAX(max_scale, scale); @@ -11525,14 +11632,13 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy continue; } - float d = max_scale/15; - y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed. + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.085f is another fudge factor. Don't ask me why it is needed. float id = 1/d; - for (int ib = 0; ib < QK_K/8; ++ib) { + for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) { int l = nearest_int(0.5f*(id*scales[ib]-1)); - l = MAX(0, MIN(7, l)); - if (hbit[ib]) l |= 8; - y[ibl].scales[ib/2] |= (l << 4*(ib%2)); + l = MAX(0, MIN(15, l)); + y[ibl].qh[ib] |= (l << 12); } } } diff --git a/ggml-quants.h b/ggml-quants.h index 47dd5285..74aabf41 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -217,8 +217,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N typedef struct { ggml_fp16_t d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; } block_iq1_s; static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");