ggml : resolve sync conflicst (ggml/0)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-27 17:17:23 +03:00
parent c7ea4fd235
commit ef6dcf0d0c
2 changed files with 54 additions and 68 deletions

View File

@ -4187,30 +4187,25 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
acc_3 = __lsx_vfadd_s(p3_d, acc_3); acc_3 = __lsx_vfadd_s(p3_d, acc_3);
} }
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
#endif
#else for (; ib < nb; ++ib) {
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi0 = 0; int sumi0 = 0;
int sumi1 = 0; int sumi1 = 0;
for (int j = 0; j < qk/2; ++j) { for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0x0F) - 8; const int v0 = (x[ib].qs[j] & 0x0F) - 8;
const int v1 = (x[i].qs[j] >> 4) - 8; const int v1 = (x[ib].qs[j] >> 4) - 8;
sumi0 += (v0 * y[i].qs[j]); sumi0 += (v0 * y[ib].qs[j]);
sumi1 += (v1 * y[i].qs[j + qk/2]); sumi1 += (v1 * y[ib].qs[j + qk/2]);
} }
int sumi = sumi0 + sumi1; int sumi = sumi0 + sumi1;
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
} }
*s = sumf; *s = sumf;
#endif
} }
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
@ -4479,36 +4474,34 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
acc = __lasx_xvfmadd_s( d0d1, xy, acc ); acc = __lasx_xvfmadd_s( d0d1, xy, acc );
} }
*s = hsum_float_8(acc) + summs; sumf = hsum_float_8(acc) + summs;
#endif
#else for (; ib < nb; ++ib) {
// scalar int sumi0 = 0;
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi0 = 0
int sumi1 = 0; int sumi1 = 0;
for (int j = 0; j < qk/2; ++j) { for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0x0F); const int v0 = (x[ib].qs[j] & 0x0F);
const int v1 = (x[i].qs[j] >> 4); const int v1 = (x[ib].qs[j] >> 4);
sumi0 += (v0 * y[i].qs[j]); sumi0 += (v0 * y[ib].qs[j]);
sumi1 += (v1 * y[i].qs[j + qk/2]); sumi1 += (v1 * y[ib].qs[j + qk/2]);
} }
int sumi = sumi0 + sumi1; int sumi = sumi0 + sumi1;
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
} }
*s = sumf; *s = sumf;
#endif
} }
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
int ib = 0;
float sumf = 0;
assert(n % qk == 0); assert(n % qk == 0);
assert(qk == QK5_0); assert(qk == QK5_0);
assert(nrc == 1); assert(nrc == 1);
@ -4830,15 +4823,11 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
acc = __lasx_xvfmadd_s(d, q, acc); acc = __lasx_xvfmadd_s(d, q, acc);
} }
*s = hsum_float_8(acc); sumf = hsum_float_8(acc);
#endif
#else for (; ib < nb; ++ib) {
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
uint32_t qh; uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
int sumi0 = 0; int sumi0 = 0;
int sumi1 = 0; int sumi1 = 0;
@ -4847,25 +4836,27 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
const int32_t x0 = (int8_t)(((x[i].qs[j] & 0x0F) | xh_0) - 16); const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
const int32_t x1 = (int8_t)(((x[i].qs[j] >> 4) | xh_1) - 16); const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
sumi0 += (x0 * y[i].qs[j]); sumi0 += (x0 * y[ib].qs[j]);
sumi1 += (x1 * y[i].qs[j + qk/2]); sumi1 += (x1 * y[ib].qs[j + qk/2]);
} }
int sumi = sumi0 + sumi1; int sumi = sumi0 + sumi1;
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
} }
*s = sumf; *s = sumf;
#endif
} }
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1; const int qk = QK8_1;
const int nb = n / qk; const int nb = n / qk;
int ib = 0;
float sumf = 0;
assert(n % qk == 0); assert(n % qk == 0);
assert(qk == QK5_1); assert(qk == QK5_1);
assert(nrc == 1); assert(nrc == 1);
@ -5188,33 +5179,29 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
float summs = 0.0f; float summs = 0.0f;
// Main loop // Main loop
for (int i = 0; i < nb; i++) { for (; ib < nb; ++ib) {
const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d)); const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d));
summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
__m256i qx = bytes_from_nibbles_32(x[i].qs); __m256i qx = bytes_from_nibbles_32(x[ib].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh); __m256i bxhi = bytes_from_bits_32(x[ib].qh);
bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
qx = __lasx_xvor_v(qx, bxhi); qx = __lasx_xvor_v(qx, bxhi);
const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[i].d)); const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d));
const __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
const __m256 q = mul_sum_us8_pairs_float(qx, qy); const __m256 q = mul_sum_us8_pairs_float(qx, qy);
acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
} }
*s = hsum_float_8(acc) + summs; sumf = hsum_float_8(acc) + summs;
#endif
#else for (; ib < nb; ++ib) {
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
uint32_t qh; uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
int sumi0 = 0; int sumi0 = 0;
int sumi1 = 0; int sumi1 = 0;
@ -5223,19 +5210,18 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
sumi0 += (x0 * y[i].qs[j]); sumi0 += (x0 * y[ib].qs[j]);
sumi1 += (x1 * y[i].qs[j + qk/2]); sumi1 += (x1 * y[ib].qs[j + qk/2]);
} }
int sumi = sumi0 + sumi1; int sumi = sumi0 + sumi1;
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
} }
*s = sumf; *s = sumf;
#endif
} }
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {

View File

@ -14765,21 +14765,21 @@ static void ggml_compute_forward_pool_1d_sk_p0(
switch (op) { switch (op) {
case GGML_OP_POOL_AVG: drow[i] = 0; break; case GGML_OP_POOL_AVG: drow[i] = 0; break;
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
} }
for (int ki = 0; ki < k; ++ki) { for (int ki = 0; ki < k; ++ki) {
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
switch (op) { switch (op) {
case GGML_OP_POOL_AVG: drow[i] += srow_j; break; case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
} }
++j; ++j;
} }
switch (op) { switch (op) {
case GGML_OP_POOL_AVG: drow[i] /= k; break; case GGML_OP_POOL_AVG: drow[i] /= k; break;
case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
} }
} }
@ -14848,7 +14848,7 @@ static void ggml_compute_forward_pool_2d(
switch (op) { switch (op) {
case GGML_OP_POOL_AVG: *out = 0; break; case GGML_OP_POOL_AVG: *out = 0; break;
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
} }
const int ix = offset0 + ox * s0; const int ix = offset0 + ox * s0;
@ -14864,14 +14864,14 @@ static void ggml_compute_forward_pool_2d(
switch (op) { switch (op) {
case GGML_OP_POOL_AVG: *out += srow_j; break; case GGML_OP_POOL_AVG: *out += srow_j; break;
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
} }
} }
} }
switch (op) { switch (op) {
case GGML_OP_POOL_AVG: *out /= ka; break; case GGML_OP_POOL_AVG: *out /= ka; break;
case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
} }
} }
} }