ggml : 4-bit WASM SIMD support for Q4_0

This commit is contained in:
Georgi Gerganov 2023-02-26 22:15:09 +02:00
parent d76eb894e3
commit df37e2b5ff
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

142
ggml.c
View File

@ -407,6 +407,45 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
#else
#error "not implemented for QK"
#endif
#elif defined(__wasm_simd128__)
#if QK == 32
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
v128_t srcv [8];
v128_t asrcv[8];
v128_t amaxv[8];
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
amax = MAX(
MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0;
pd[i] = d;
for (int l = 0; l < 8; l++) {
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
}
memcpy(pb + i*16, pp, sizeof(pp));
}
#else
#error "not implemented for QK"
#endif
#else
// scalar
for (int i = 0; i < nb; i++) {
@ -1216,9 +1255,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
//printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7));
//printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7));
// scalar
#if defined(__ARM_FEATURE_QRDMX)
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
@ -1230,35 +1266,93 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
}
sumf = sum0 + sum1;
#else
#error "not implemented for QK"
#endif
#elif defined(__wasm_simd128__)
#if QK == 32
// wasm simd
float sum0 = 0.0f;
float sum1 = 0.0f;
//printf("sumf SIMD = %f\n", sumf);
for (int i = 0; i < nb; i += 2) {
const float d0_0 = pd0[i + 0];
const float d0_1 = pd0[i + 1];
const float d1_0 = pd1[i + 0];
const float d1_1 = pd1[i + 1];
//// scalar
//sumf = 0.0f;
//for (int i = 0; i < nb; i++) {
// const float d0 = pd0[i];
// const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;
// const uint8_t * restrict p0 = pb0 + i*QK/2;
// const uint8_t * restrict p1 = pb1 + i*QK/2;
const v128_t m4b = wasm_u8x16_splat(0xf);
const v128_t s8b = wasm_i8x16_splat(0x8);
// for (int j = 0; j < QK/2; j++) {
// const uint8_t v0 = p0[j];
// const uint8_t v1 = p1[j];
const v128_t v0_0 = wasm_v128_load(p0);
const v128_t v0_1 = wasm_v128_load(p0 + 16);
const v128_t v1_0 = wasm_v128_load(p1);
const v128_t v1_1 = wasm_v128_load(p1 + 16);
// const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
// const float f1 = d0*((int8_t) (v0 >> 4) - 8);
// 4-bit -> 8-bit
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
// const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
// const float f3 = d1*((int8_t) (v1 >> 4) - 8);
const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
// sumf += f0*f2 + f1*f3;
// }
//}
//printf("sumf scalar = %f\n", sumf);
//printf("--------\n");
const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
//exit(0);
const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
// sub 8
const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
// dot product into int16x8_t
const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
sum0 += d0_0*d1_0*(
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
sum1 += d0_1*d1_1*(
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
}
sumf = sum0 + sum1;
#else
#error "not implemented for QK"
#endif