mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-23 08:21:12 +02:00
ggml : Q2k interleaving implementation - x86/x64 SIMD (llama/14373)
* Initial Q2_K Block Interleaving Implementation * Addressed review comments and clean up of the code * Post rebase fixes * Initial CI/CD fixes * Update declarations in arch-fallback.h * Changes for GEMV Q2_K in arch-fallback.h * Enable repacking only on AVX-512 machines * Update comments in repack.cpp * Address q2k comments --------- Co-authored-by: Manogna-Sree <elisetti.manognasree@multicorewareinc.com>
This commit is contained in:
committed by
Georgi Gerganov
parent
78668cb8d1
commit
1c6cb7df47
@@ -37,17 +37,21 @@
|
|||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||||
@@ -72,11 +76,13 @@
|
|||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#elif defined(__loongarch64)
|
#elif defined(__loongarch64)
|
||||||
// quants.c
|
// quants.c
|
||||||
@@ -92,11 +98,13 @@
|
|||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#elif defined(__riscv)
|
#elif defined(__riscv)
|
||||||
// quants.c
|
// quants.c
|
||||||
@@ -119,10 +127,12 @@
|
|||||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#elif defined(__s390x__)
|
#elif defined(__s390x__)
|
||||||
// quants.c
|
// quants.c
|
||||||
@@ -147,11 +157,13 @@
|
|||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#elif defined(__wasm__)
|
#elif defined(__wasm__)
|
||||||
// quants.c
|
// quants.c
|
||||||
@@ -175,10 +187,12 @@
|
|||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#endif
|
#endif
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -412,6 +412,82 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
const int qk = QK_K;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 8;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert (n % qk == 0);
|
||||||
|
assert (nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(s);
|
||||||
|
UNUSED(bs);
|
||||||
|
UNUSED(vx);
|
||||||
|
UNUSED(vy);
|
||||||
|
UNUSED(nr);
|
||||||
|
UNUSED(nc);
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
float sumf[8];
|
||||||
|
float sum_minf[8];
|
||||||
|
int sumi1,sumi2,sumi3,sumi4;
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
const block_q8_K * a_ptr = (const block_q8_K *)vy;
|
||||||
|
for(int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[j] = 0.0;
|
||||||
|
sum_minf[j] = 0.0;
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
|
||||||
|
const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
|
||||||
|
const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
|
||||||
|
const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
|
||||||
|
const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi1 = 0;
|
||||||
|
sumi2 = 0;
|
||||||
|
sumi3 = 0;
|
||||||
|
sumi4 = 0;
|
||||||
|
sumi = 0;
|
||||||
|
int offset = ((k / 2) % 2) + j * 2;
|
||||||
|
for (int i = 0; i < blocklen; ++i){
|
||||||
|
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
|
||||||
|
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
|
||||||
|
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
|
||||||
|
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
|
||||||
|
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
|
||||||
|
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
|
||||||
|
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
|
||||||
|
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
|
||||||
|
|
||||||
|
sumi1 = sumi1 * (scales_0[offset] & 0xF);
|
||||||
|
sumi2 = sumi2 * (scales_1[offset] & 0xF);
|
||||||
|
sumi3 = sumi3 * (scales_2[offset] & 0xF);
|
||||||
|
sumi4 = sumi4 * (scales_3[offset] & 0xF);
|
||||||
|
sumi += sumi1 + sumi2 + sumi3 + sumi4;
|
||||||
|
}
|
||||||
|
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for(int sb = 0; sb < 8; sb++) {
|
||||||
|
const uint8_t *mins = b_ptr[l].scales + sb * 16;
|
||||||
|
for(int j = 0; j < ncols_interleaved; j++){
|
||||||
|
sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
@@ -711,6 +787,97 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
const int qk = QK_K;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 8;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert (n % qk == 0);
|
||||||
|
assert (nr % 4 == 0);
|
||||||
|
assert (nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(s);
|
||||||
|
UNUSED(bs);
|
||||||
|
UNUSED(vx);
|
||||||
|
UNUSED(vy);
|
||||||
|
UNUSED(nr);
|
||||||
|
UNUSED(nc);
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
float sumf[4][8];
|
||||||
|
float sum_minf[4][8];
|
||||||
|
int sumi1, sumi2, sumi3, sumi4;
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
|
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[m][j] = 0.0;
|
||||||
|
sum_minf[m][j] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
|
||||||
|
|
||||||
|
const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
|
||||||
|
const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
|
||||||
|
const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
|
||||||
|
const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi1 = 0;
|
||||||
|
sumi2 = 0;
|
||||||
|
sumi3 = 0;
|
||||||
|
sumi4 = 0;
|
||||||
|
sumi = 0;
|
||||||
|
int offset = ((k / 2) % 2) + j * 2;
|
||||||
|
for (int i = 0; i < blocklen; ++i){
|
||||||
|
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
|
||||||
|
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
|
||||||
|
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
|
||||||
|
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
|
||||||
|
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
|
||||||
|
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
|
||||||
|
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
|
||||||
|
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
|
||||||
|
sumi1 = sumi1 * (scales_0[offset] & 0xF);
|
||||||
|
sumi2 = sumi2 * (scales_1[offset] & 0xF);
|
||||||
|
sumi3 = sumi3 * (scales_2[offset] & 0xF);
|
||||||
|
sumi4 = sumi4 * (scales_3[offset] & 0xF);
|
||||||
|
sumi += sumi1 + sumi2 + sumi3 + sumi4;
|
||||||
|
}
|
||||||
|
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for(int sb = 0; sb < 8; sb++) {
|
||||||
|
const uint8_t *mins = b_ptr[l].scales + sb * 16;
|
||||||
|
for(int m = 0; m < 4; m++) {
|
||||||
|
const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||||
|
for(int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
|
||||||
|
sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
@@ -914,6 +1081,50 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {
|
||||||
|
block_q2_Kx8 out;
|
||||||
|
|
||||||
|
// Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int end = QK_K * 2 / blck_size_interleave;
|
||||||
|
|
||||||
|
// Interleave Q2_K quants by taking 8 bytes at a time
|
||||||
|
for (int i = 0; i < end; ++i) {
|
||||||
|
int src_id = i % 8;
|
||||||
|
int src_offset = (i / 8) * blck_size_interleave;
|
||||||
|
int dst_offset = i * blck_size_interleave;
|
||||||
|
|
||||||
|
uint64_t elems;
|
||||||
|
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||||
|
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K
|
||||||
|
// Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value)
|
||||||
|
// The output Q2_Kx8 structure has 128 bytes for storing scales and mins
|
||||||
|
// Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
|
||||||
|
// For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
|
||||||
|
|
||||||
|
for(int i = 0; i < 128; i++){
|
||||||
|
|
||||||
|
// Index for selecting which q2k super block
|
||||||
|
int src1 = (i % 16) / 2;
|
||||||
|
// Index for selecting scale
|
||||||
|
int src2 = ((i / 16) * 2) + (i % 2);
|
||||||
|
|
||||||
|
out.scales[i] = in[src1].scales[src2];
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
||||||
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
||||||
@@ -975,6 +1186,37 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block
|
|||||||
GGML_UNUSED(data_size);
|
GGML_UNUSED(data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||||
|
GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
|
||||||
|
GGML_ASSERT(interleave_block == 8);
|
||||||
|
constexpr int nrows_interleaved = 8;
|
||||||
|
|
||||||
|
block_q2_Kx8 * dst = (block_q2_Kx8*)t->data;
|
||||||
|
const block_q2_K * src = (const block_q2_K*) data;
|
||||||
|
block_q2_K dst_tmp[8];
|
||||||
|
int nrow = ggml_nrows(t);
|
||||||
|
int nblocks = t->ne[0] / QK_K;
|
||||||
|
|
||||||
|
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));
|
||||||
|
|
||||||
|
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||||
|
for (int64_t x = 0; x < nblocks; x++) {
|
||||||
|
for (int i = 0; i < nrows_interleaved; i++ ) {
|
||||||
|
dst_tmp[i] = src[x + i * nblocks];
|
||||||
|
}
|
||||||
|
*dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
|
||||||
|
}
|
||||||
|
src += nrows_interleaved * nblocks;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
GGML_UNUSED(data_size);
|
||||||
|
}
|
||||||
|
|
||||||
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
||||||
GGML_ASSERT(interleave_block == 8);
|
GGML_ASSERT(interleave_block == 8);
|
||||||
@@ -1095,6 +1337,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
|
|||||||
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||||
|
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||||
|
}
|
||||||
|
|
||||||
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||||
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
||||||
}
|
}
|
||||||
@@ -1124,6 +1370,10 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|||||||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
@@ -1148,6 +1398,10 @@ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|||||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
@@ -1421,6 +1675,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
||||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||||
|
|
||||||
|
// instance for Q2
|
||||||
|
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
|
||||||
|
|
||||||
// instance for IQ4
|
// instance for IQ4
|
||||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
||||||
|
|
||||||
@@ -1446,6 +1703,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||||||
return &q4_K_8x8_q8_K;
|
return &q4_K_8x8_q8_K;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (cur->type == GGML_TYPE_Q2_K) {
|
||||||
|
if (ggml_cpu_has_avx512()) {
|
||||||
|
if (cur->ne[1] % 8 == 0) {
|
||||||
|
return &q2_K_8x8_q8_K;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||||
if (cur->ne[1] % 4 == 0) {
|
if (cur->ne[1] % 4 == 0) {
|
||||||
|
@@ -44,7 +44,14 @@ struct block_q4_Kx8 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||||
|
struct block_q2_Kx8 {
|
||||||
|
ggml_half d[8]; // super-block scale for quantized scales
|
||||||
|
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||||
|
uint8_t scales[128]; // scales and mins, quantized with 4 bits
|
||||||
|
uint8_t qs[512]; // 2--bit quants
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||||
struct block_q8_Kx4 {
|
struct block_q8_Kx4 {
|
||||||
float d[4]; // delta
|
float d[4]; // delta
|
||||||
int8_t qs[QK_K * 4]; // quants
|
int8_t qs[QK_K * 4]; // quants
|
||||||
@@ -71,11 +78,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
// Native implementations
|
// Native implementations
|
||||||
@@ -86,11 +95,13 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
#if defined(__cplusplus)
|
#if defined(__cplusplus)
|
||||||
|
Reference in New Issue
Block a user