Add ability to use importance matrix for all k-quants (llama/4930)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2024-01-14 16:21:12 +02:00 committed by Georgi Gerganov
parent f6614155e4
commit f904b31a7d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 461 additions and 15 deletions

View File

@ -1244,7 +1244,8 @@ static inline int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
const float * restrict qw) {
float max = 0;
float amax = 0;
for (int i = 0; i < n; ++i) {
@ -1270,14 +1271,13 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
rmse_type = -rmse_type;
return_early = true;
}
int weight_type = rmse_type%2;
float sumlx = 0;
float suml2 = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MAX(-nmax, MIN(nmax-1, l));
L[i] = l + nmax;
float w = weight_type == 1 ? x[i] * x[i] : 1;
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
sumlx += w*x[i]*l;
suml2 += w*l*l;
}
@ -1293,7 +1293,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MAX(-nmax, MIN(nmax-1, l));
float w = weight_type == 1 ? x[i] * x[i] : 1;
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
sumlx += w*x[i]*l;
suml2 += w*l*l;
}
@ -2089,6 +2089,112 @@ size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n
return (n/QK_K*sizeof(block_q3_K));
}
static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q3_K_reference(x, y, n_per_row);
#else
assert(n_per_row % QK_K == 0);
const int nb = n_per_row / QK_K;
int8_t L[QK_K];
float scales[QK_K / 16];
float weight[16];
float sw[QK_K / 16];
int8_t Ls[QK_K / 16];
for (int i = 0; i < nb; i++) {
float sumx2 = 0;
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
float sigma2 = 2*sumx2/QK_K;
for (int j = 0; j < QK_K/16; ++j) {
if (quant_weights) {
const float * qw = quant_weights ? quant_weights + QK_K * i + 16*j : NULL;
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);
} else {
for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];
}
float sumw = 0;
for (int l = 0; l < 16; ++l) sumw += weight[l];
sw[j] = sumw;
scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
}
memset(y[i].scales, 0, 12);
float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
for (int j = 0; j < QK_K/16; ++j) {
int l = Ls[j];
if (j < 8) {
y[i].scales[j] = l & 0xF;
} else {
y[i].scales[j-8] |= ((l & 0xF) << 4);
}
l >>= 4;
y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
}
y[i].d = GGML_FP32_TO_FP16(d_block);
int8_t sc;
for (int j = 0; j < QK_K/16; ++j) {
sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
float d = GGML_FP16_TO_FP32(y[i].d) * sc;
if (!d) {
continue;
}
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-4, MIN(3, l));
L[16*j + ii] = l + 4;
}
}
memset(y[i].hmask, 0, QK_K/8);
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
int m = 0;
uint8_t hm = 1;
for (int j = 0; j < QK_K; ++j) {
if (L[j] > 3) {
y[i].hmask[m] |= hm;
L[j] -= 4;
}
if (++m == QK_K/8) {
m = 0; hm <<= 1;
}
}
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
x += QK_K;
}
#endif
}
size_t quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
int row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
if (!quant_weights) {
quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
}
return nrow * row_size;
}
// ====================== 4-bit (de)-quantization
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
@ -2254,6 +2360,108 @@ size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n
return (n/QK_K*sizeof(block_q4_K));
}
static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q4_K_reference(x, y, n_per_row);
#else
assert(n_per_row % QK_K == 0);
const int nb = n_per_row / QK_K;
uint8_t L[QK_K];
uint8_t Laux[32];
float weights[32];
float mins[QK_K/32];
float scales[QK_K/32];
for (int i = 0; i < nb; i++) {
float sum_x2 = 0;
for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
float sigma2 = sum_x2/QK_K;
float av_x = sqrtf(sigma2);
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
if (quant_weights) {
const float * qw = quant_weights + QK_K*i + 32*j;
for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
} else {
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
}
scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
//scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
}
float min = mins[j];
if (min > max_min) {
max_min = min;
}
}
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
for (int j = 0; j < QK_K/32; ++j) {
uint8_t ls = nearest_int(inv_scale*scales[j]);
uint8_t lm = nearest_int(inv_min*mins[j]);
ls = MIN(63, ls);
lm = MIN(63, lm);
if (j < 4) {
y[i].scales[j] = ls;
y[i].scales[j+4] = lm;
} else {
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
y[i].scales[j-4] |= ((ls >> 4) << 6);
y[i].scales[j-0] |= ((lm >> 4) << 6);
}
}
y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
uint8_t sc, m;
for (int j = 0; j < QK_K/32; ++j) {
get_scale_min_k4(j, y[i].scales, &sc, &m);
const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
if (!d) continue;
const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(15, l));
L[32*j + ii] = l;
}
}
uint8_t * q = y[i].qs;
for (int j = 0; j < QK_K; j += 64) {
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
q += 32;
}
x += QK_K;
}
#endif
}
size_t quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
int row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
if (!quant_weights) {
quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
}
return nrow * row_size;
}
// ====================== 5-bit (de)-quantization
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
@ -2349,7 +2557,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
#else
float max_scale = 0, amax = 0;
for (int j = 0; j < QK_K/16; ++j) {
scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL);
float abs_scale = fabsf(scales[j]);
if (abs_scale > amax) {
amax = abs_scale;
@ -2460,6 +2668,123 @@ size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n
return (n/QK_K*sizeof(block_q5_K));
}
static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q5_K_reference(x, y, n_per_row);
#else
assert(n_per_row % QK_K == 0);
const int nb = n_per_row / QK_K;
uint8_t L[QK_K];
float mins[QK_K/32];
float scales[QK_K/32];
float weights[32];
uint8_t Laux[32];
for (int i = 0; i < nb; i++) {
float sum_x2 = 0;
for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
float sigma2 = sum_x2/QK_K;
float av_x = sqrtf(sigma2);
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
if (quant_weights) {
const float * qw = quant_weights + QK_K*i + 32*j;
for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
} else {
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
}
scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
}
float min = mins[j];
if (min > max_min) {
max_min = min;
}
}
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
for (int j = 0; j < QK_K/32; ++j) {
uint8_t ls = nearest_int(inv_scale*scales[j]);
uint8_t lm = nearest_int(inv_min*mins[j]);
ls = MIN(63, ls);
lm = MIN(63, lm);
if (j < 4) {
y[i].scales[j] = ls;
y[i].scales[j+4] = lm;
} else {
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
y[i].scales[j-4] |= ((ls >> 4) << 6);
y[i].scales[j-0] |= ((lm >> 4) << 6);
}
}
y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
uint8_t sc, m;
for (int j = 0; j < QK_K/32; ++j) {
get_scale_min_k4(j, y[i].scales, &sc, &m);
const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
if (!d) continue;
const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(31, l));
L[32*j + ii] = l;
}
}
uint8_t * restrict qh = y[i].qh;
uint8_t * restrict ql = y[i].qs;
memset(qh, 0, QK_K/8);
uint8_t m1 = 1, m2 = 2;
for (int n = 0; n < QK_K; n += 64) {
for (int j = 0; j < 32; ++j) {
int l1 = L[n + j];
if (l1 > 15) {
l1 -= 16; qh[j] |= m1;
}
int l2 = L[n + j + 32];
if (l2 > 15) {
l2 -= 16; qh[j] |= m2;
}
ql[j] = l1 | (l2 << 4);
}
m1 <<= 2; m2 <<= 2;
ql += 32;
}
x += QK_K;
}
#endif
}
size_t quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
int row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
if (!quant_weights) {
quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
}
return nrow * row_size;
}
// ====================== 6-bit (de)-quantization
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
@ -2476,7 +2801,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
for (int ib = 0; ib < QK_K/16; ++ib) {
const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
scales[ib] = scale;
const float abs_scale = fabsf(scale);
@ -2608,6 +2933,112 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t *
return (n/QK_K*sizeof(block_q6_K));
}
static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) {
#if QK_K != 256
(void)quant_weights;
quantize_row_q6_K_reference(x, y, n_per_row);
#else
assert(n_per_row % QK_K == 0);
const int nb = n_per_row / QK_K;
int8_t L[QK_K];
float scales[QK_K/16];
//float weights[16];
for (int i = 0; i < nb; i++) {
//float sum_x2 = 0;
//for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];
//float sigma2 = sum_x2/QK_K;
float max_scale = 0;
float max_abs_scale = 0;
for (int ib = 0; ib < QK_K/16; ++ib) {
float scale;
if (quant_weights) {
const float * qw = quant_weights + QK_K*i + 16*ib;
//for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);
//scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);
scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);
} else {
scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
}
scales[ib] = scale;
const float abs_scale = fabsf(scale);
if (abs_scale > max_abs_scale) {
max_abs_scale = abs_scale;
max_scale = scale;
}
}
if (!max_abs_scale) {
memset(&y[i], 0, sizeof(block_q6_K));
y[i].d = GGML_FP32_TO_FP16(0.f);
x += QK_K;
continue;
}
float iscale = -128.f/max_scale;
y[i].d = GGML_FP32_TO_FP16(1/iscale);
for (int ib = 0; ib < QK_K/16; ++ib) {
y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
}
for (int j = 0; j < QK_K/16; ++j) {
float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
if (!d) {
continue;
}
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-32, MIN(31, l));
L[16*j + ii] = l + 32;
}
}
uint8_t * restrict ql = y[i].ql;
uint8_t * restrict qh = y[i].qh;
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[j + l + 0] & 0xF;
const uint8_t q2 = L[j + l + 32] & 0xF;
const uint8_t q3 = L[j + l + 64] & 0xF;
const uint8_t q4 = L[j + l + 96] & 0xF;
ql[l+ 0] = q1 | (q3 << 4);
ql[l+32] = q2 | (q4 << 4);
qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
}
ql += 64;
qh += 32;
}
x += QK_K;
}
#endif
}
size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
int row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
if (!quant_weights) {
quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
}
return nrow * row_size;
}
// ====================== "True" 2-bit (de)-quantization
static const uint64_t iq2xxs_grid[256] = {

View File

@ -249,4 +249,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict
size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q5_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q6_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);

28
ggml.c
View File

@ -18713,26 +18713,38 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(start % QK_K == 0);
block_q3_K * block = (block_q3_K*)dst + start / QK_K;
result = ggml_quantize_q3_K(src + start, block, n, n, hist);
GGML_ASSERT(start % n_per_row == 0);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_q3_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(start % QK_K == 0);
block_q4_K * block = (block_q4_K*)dst + start / QK_K;
result = ggml_quantize_q4_K(src + start, block, n, n, hist);
GGML_ASSERT(start % n_per_row == 0);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_q4_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_Q5_K:
{
GGML_ASSERT(start % QK_K == 0);
block_q5_K * block = (block_q5_K*)dst + start / QK_K;
result = ggml_quantize_q5_K(src + start, block, n, n, hist);
GGML_ASSERT(start % n_per_row == 0);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_q5_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(start % QK_K == 0);
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
GGML_ASSERT(start % n_per_row == 0);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_q6_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_IQ2_XXS:
{