mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
24f0aa460b
* DRAFT: Introduction of CUDA Graphs to LLama.cpp * FIx issues raised in comments * Tidied to now only use CUDA runtime (not mixed with driver calls) * disable for multi-gpu and batch size > 1 * Disable CUDA graphs for old GPU arch and with env var * added missing CUDA_CHECKs * Addressed comments * further addressed comments * limit to GGML_ALLOW_CUDA_GRAPHS defined in llama.cpp cmake * Added more comprehensive graph node checking * With mechanism to fall back if graph capture fails * Revert "With mechanism to fall back if graph capture fails" This reverts commit eb9f15fb6fcb81384f732c4601a5b25c016a5143. * Fall back if graph capture fails and address other comments * - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS - rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS - updated Makefile build to enable CUDA graphs - removed graph capture failure checking in ggml_cuda_error using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context - fixed several resource leaks - fixed issue with zero node graphs - changed fixed size arrays to vectors - removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed - removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row - changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX - code style fixes - things to look into - VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional - possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes * fix build without cuda graphs * remove outdated comment * replace minimum cc value with a constant --------- Co-authored-by: slaren <slarengh@gmail.com>
825 lines
28 KiB
Plaintext
825 lines
28 KiB
Plaintext
#include "convert.cuh"
|
|
#include "dequantize.cuh"
|
|
|
|
#define CUDA_Q8_0_NE_ALIGN 2048
|
|
|
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
|
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
|
|
|
|
if (i >= k) {
|
|
return;
|
|
}
|
|
|
|
const int64_t ib = i/qk; // block index
|
|
const int64_t iqs = (i%qk)/qr; // quant index
|
|
const int64_t iybs = i - i%qk; // y block start index
|
|
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
|
|
|
// dequantize
|
|
dfloat2 v;
|
|
dequantize_kernel(vx, ib, iqs, v);
|
|
|
|
y[iybs + iqs + 0] = v.x;
|
|
y[iybs + iqs + y_offset] = v.y;
|
|
}
|
|
|
|
template <bool need_check>
|
|
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
|
|
#if __CUDA_ARCH__ >= CC_PASCAL
|
|
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
|
|
|
|
const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
|
|
const int * x0 = ((int *) vx) + blockIdx.x * nint;
|
|
half2 * y2 = (half2 *) (y + i0);
|
|
|
|
__shared__ int vals[nint];
|
|
|
|
#pragma unroll
|
|
for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
|
|
if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
|
|
break;
|
|
}
|
|
|
|
const int ix = ix0 + threadIdx.x;
|
|
vals[ix] = x0[ix];
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
#pragma unroll
|
|
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
|
|
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
|
|
return;
|
|
}
|
|
|
|
const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
|
|
const half d = *b0;
|
|
const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
|
|
|
|
y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
|
|
}
|
|
#else
|
|
GGML_UNUSED(vx);
|
|
GGML_UNUSED(y);
|
|
GGML_UNUSED(k);
|
|
NO_DEVICE_CODE;
|
|
#endif // __CUDA_ARCH__ >= CC_PASCAL
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
|
|
// assume 32 threads
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t il = tid/8;
|
|
const int64_t ir = tid%8;
|
|
const int64_t ib = 8*i + ir;
|
|
if (ib >= nb32) {
|
|
return;
|
|
}
|
|
|
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
|
|
|
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
|
|
const float d = __half2float(x->d);
|
|
const float dm = -8*d;
|
|
|
|
const uint8_t * q = x->qs + 4*il;
|
|
|
|
for (int l = 0; l < 4; ++l) {
|
|
y[l+ 0] = d * (q[l] & 0xF) + dm;
|
|
y[l+16] = d * (q[l] >> 4) + dm;
|
|
}
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
|
|
// assume 32 threads
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t il = tid/8;
|
|
const int64_t ir = tid%8;
|
|
const int64_t ib = 8*i + ir;
|
|
if (ib >= nb32) {
|
|
return;
|
|
}
|
|
|
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
|
|
|
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
|
|
const float2 d = __half22float2(x->dm);
|
|
|
|
const uint8_t * q = x->qs + 4*il;
|
|
|
|
for (int l = 0; l < 4; ++l) {
|
|
y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
|
|
y[l+16] = d.x * (q[l] >> 4) + d.y;
|
|
}
|
|
}
|
|
|
|
//================================== k-quants
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_q2_K * x = (const block_q2_K *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t n = tid/32;
|
|
const int64_t l = tid - 32*n;
|
|
const int64_t is = 8*n + l/16;
|
|
|
|
const uint8_t q = x[i].qs[32*n + l];
|
|
dst_t * y = yy + i*QK_K + 128*n;
|
|
|
|
float dall = __low2half(x[i].dm);
|
|
float dmin = __high2half(x[i].dm);
|
|
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
|
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
|
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
|
#else
|
|
const int64_t is = tid/16; // 0 or 1
|
|
const int64_t il = tid%16; // 0...15
|
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
|
float dall = __low2half(x[i].dm);
|
|
float dmin = __high2half(x[i].dm);
|
|
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
|
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
|
#endif
|
|
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_q3_K * x = (const block_q3_K *) vx;
|
|
|
|
#if QK_K == 256
|
|
const int64_t r = threadIdx.x/4;
|
|
const int64_t tid = r/2;
|
|
const int64_t is0 = r%2;
|
|
const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
|
|
const int64_t n = tid / 4;
|
|
const int64_t j = tid - 4*n;
|
|
|
|
uint8_t m = 1 << (4*n + j);
|
|
int64_t is = 8*n + 2*j + is0;
|
|
int shift = 2*j;
|
|
|
|
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
|
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
|
|
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
|
|
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
|
|
float d_all = x[i].d;
|
|
float dl = d_all * (us - 32);
|
|
|
|
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
|
const uint8_t * q = x[i].qs + 32*n;
|
|
const uint8_t * hm = x[i].hmask;
|
|
|
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
|
#else
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t is = tid/16; // 0 or 1
|
|
const int64_t il = tid%16; // 0...15
|
|
const int64_t im = il/8; // 0...1
|
|
const int64_t in = il%8; // 0...7
|
|
|
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
|
|
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
|
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
|
const float d = (float)x[i].d;
|
|
|
|
if (is == 0) {
|
|
y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
|
y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
|
} else {
|
|
y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
|
y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
|
}
|
|
#endif
|
|
|
|
}
|
|
|
|
#if QK_K == 256
|
|
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
|
if (j < 4) {
|
|
d = q[j] & 63; m = q[j + 4] & 63;
|
|
} else {
|
|
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
|
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
const block_q4_K * x = (const block_q4_K *) vx;
|
|
|
|
const int64_t i = blockIdx.x;
|
|
|
|
#if QK_K == 256
|
|
// assume 32 threads
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t il = tid/8;
|
|
const int64_t ir = tid%8;
|
|
const int64_t is = 2*il;
|
|
const int64_t n = 4;
|
|
|
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
|
|
|
const float dall = __low2half(x[i].dm);
|
|
const float dmin = __high2half(x[i].dm);
|
|
|
|
const uint8_t * q = x[i].qs + 32*il + n*ir;
|
|
|
|
uint8_t sc, m;
|
|
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
|
const float d1 = dall * sc; const float m1 = dmin * m;
|
|
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
|
const float d2 = dall * sc; const float m2 = dmin * m;
|
|
for (int l = 0; l < n; ++l) {
|
|
y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
|
y[l +32] = d2 * (q[l] >> 4) - m2;
|
|
}
|
|
#else
|
|
const int64_t tid = threadIdx.x;
|
|
const uint8_t * q = x[i].qs;
|
|
dst_t * y = yy + i*QK_K;
|
|
const float d = (float)x[i].dm[0];
|
|
const float m = (float)x[i].dm[1];
|
|
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
|
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
|
|
#endif
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
const block_q5_K * x = (const block_q5_K *) vx;
|
|
|
|
const int64_t i = blockIdx.x;
|
|
|
|
#if QK_K == 256
|
|
// assume 64 threads - this is very slightly better than the one below
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t il = tid/16; // il is in 0...3
|
|
const int64_t ir = tid%16; // ir is in 0...15
|
|
const int64_t is = 2*il; // is is in 0...6
|
|
|
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
|
|
|
const float dall = __low2half(x[i].dm);
|
|
const float dmin = __high2half(x[i].dm);
|
|
|
|
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
|
|
const uint8_t * qh = x[i].qh + 2*ir;
|
|
|
|
uint8_t sc, m;
|
|
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
|
const float d1 = dall * sc; const float m1 = dmin * m;
|
|
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
|
const float d2 = dall * sc; const float m2 = dmin * m;
|
|
|
|
uint8_t hm = 1 << (2*il);
|
|
y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
|
|
y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
|
|
hm <<= 1;
|
|
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
|
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
|
#else
|
|
const int64_t tid = threadIdx.x;
|
|
const uint8_t q = x[i].qs[tid];
|
|
const int64_t im = tid/8; // 0...3
|
|
const int64_t in = tid%8; // 0...7
|
|
const int64_t is = tid/16; // 0 or 1
|
|
const uint8_t h = x[i].qh[in] >> im;
|
|
const float d = x[i].d;
|
|
dst_t * y = yy + i*QK_K + tid;
|
|
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
|
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
|
#endif
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
const block_q6_K * x = (const block_q6_K *) vx;
|
|
|
|
const int64_t i = blockIdx.x;
|
|
#if QK_K == 256
|
|
|
|
// assume 64 threads - this is very slightly better than the one below
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t ip = tid/32; // ip is 0 or 1
|
|
const int64_t il = tid - 32*ip; // 0...32
|
|
const int64_t is = 8*ip + il/16;
|
|
|
|
dst_t * y = yy + i*QK_K + 128*ip + il;
|
|
|
|
const float d = x[i].d;
|
|
|
|
const uint8_t * ql = x[i].ql + 64*ip + il;
|
|
const uint8_t qh = x[i].qh[32*ip + il];
|
|
const int8_t * sc = x[i].scales + is;
|
|
|
|
y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
|
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
|
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
|
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
|
#else
|
|
|
|
// assume 32 threads
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t ip = tid/16; // 0 or 1
|
|
const int64_t il = tid - 16*ip; // 0...15
|
|
|
|
dst_t * y = yy + i*QK_K + 16*ip + il;
|
|
|
|
const float d = x[i].d;
|
|
|
|
const uint8_t ql = x[i].ql[16*ip + il];
|
|
const uint8_t qh = x[i].qh[il] >> (2*ip);
|
|
const int8_t * sc = x[i].scales;
|
|
|
|
y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
|
y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
|
#endif
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
|
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
|
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
|
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
|
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
|
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
const uint8_t * q3 = x[i].qs + 8*ib;
|
|
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
|
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
|
|
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
|
|
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
|
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
|
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
|
for (int j = 0; j < 4; ++j) {
|
|
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
|
}
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
const uint8_t * qs = x[i].qs + 8*ib;
|
|
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
|
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
|
|
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
|
|
const uint8_t signs = x[i].signs[4*ib + il];
|
|
for (int j = 0; j < 4; ++j) {
|
|
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
|
}
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
|
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
|
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
|
grid32[0] &= 0x0f0f0f0f;
|
|
for (int j = 0; j < 8; ++j) {
|
|
y[j] = d * (q[j] + delta);
|
|
}
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq1_m * x = (const block_iq1_m *) vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
#if QK_K == 256
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
|
iq1m_scale_t scale;
|
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
|
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
|
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
|
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
|
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
|
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
|
grid32[0] &= 0x0f0f0f0f;
|
|
for (int j = 0; j < 8; ++j) {
|
|
y[j] = d * (q[j] + delta);
|
|
}
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
|
const float d = (float)x[ib].d;
|
|
for (int j = 0; j < 4; ++j) {
|
|
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
|
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
|
}
|
|
|
|
}
|
|
|
|
#if QK_K != 64
|
|
template<typename dst_t>
|
|
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
const int64_t i = blockIdx.x;
|
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
|
|
|
const int64_t tid = threadIdx.x;
|
|
const int64_t il = tid/8; // 0...3
|
|
const int64_t ib = tid%8; // 0...7
|
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
|
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
|
for (int j = 0; j < 4; ++j) {
|
|
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
|
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
|
}
|
|
}
|
|
#endif
|
|
|
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
|
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
|
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
|
}
|
|
|
|
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
|
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
|
|
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
|
|
const bool need_check = false;
|
|
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
|
|
} else {
|
|
const bool need_check = true;
|
|
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
|
|
}
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
#if QK_K == 256
|
|
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
#else
|
|
dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
|
|
#endif
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
#if QK_K == 256
|
|
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
#else
|
|
dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
|
|
#endif
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb32 = k / 32;
|
|
const int nb = (k + 255) / 256;
|
|
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb32 = k / 32;
|
|
const int nb = (k + 255) / 256;
|
|
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
#if QK_K == 256
|
|
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
#else
|
|
dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
|
|
#endif
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
#if QK_K == 256
|
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
#else
|
|
dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
|
|
#endif
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = (k + QK_K - 1) / QK_K;
|
|
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = k / QK_K;
|
|
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
|
|
}
|
|
|
|
template<typename dst_t>
|
|
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
|
const int nb = (k + QK_K - 1) / QK_K;
|
|
#if QK_K == 64
|
|
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
|
#else
|
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
|
#endif
|
|
}
|
|
|
|
template <typename src_t, typename dst_t>
|
|
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
|
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
if (i >= k) {
|
|
return;
|
|
}
|
|
|
|
const src_t * x = (src_t *) vx;
|
|
|
|
y[i] = x[i];
|
|
}
|
|
|
|
template <typename src_t, typename dst_t>
|
|
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
|
}
|
|
|
|
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
switch (type) {
|
|
case GGML_TYPE_Q4_0:
|
|
return dequantize_row_q4_0_cuda;
|
|
case GGML_TYPE_Q4_1:
|
|
return dequantize_row_q4_1_cuda;
|
|
case GGML_TYPE_Q5_0:
|
|
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
|
case GGML_TYPE_Q5_1:
|
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
|
case GGML_TYPE_Q8_0:
|
|
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
|
|
return dequantize_block_q8_0_f16_cuda;
|
|
}
|
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
|
case GGML_TYPE_Q2_K:
|
|
return dequantize_row_q2_K_cuda;
|
|
case GGML_TYPE_Q3_K:
|
|
return dequantize_row_q3_K_cuda;
|
|
case GGML_TYPE_Q4_K:
|
|
return dequantize_row_q4_K_cuda;
|
|
case GGML_TYPE_Q5_K:
|
|
return dequantize_row_q5_K_cuda;
|
|
case GGML_TYPE_Q6_K:
|
|
return dequantize_row_q6_K_cuda;
|
|
case GGML_TYPE_IQ2_XXS:
|
|
return dequantize_row_iq2_xxs_cuda;
|
|
case GGML_TYPE_IQ2_XS:
|
|
return dequantize_row_iq2_xs_cuda;
|
|
case GGML_TYPE_IQ2_S:
|
|
return dequantize_row_iq2_s_cuda;
|
|
case GGML_TYPE_IQ3_XXS:
|
|
return dequantize_row_iq3_xxs_cuda;
|
|
case GGML_TYPE_IQ1_S:
|
|
return dequantize_row_iq1_s_cuda;
|
|
case GGML_TYPE_IQ1_M:
|
|
return dequantize_row_iq1_m_cuda;
|
|
case GGML_TYPE_IQ4_NL:
|
|
return dequantize_row_iq4_nl_cuda;
|
|
case GGML_TYPE_IQ4_XS:
|
|
return dequantize_row_iq4_xs_cuda;
|
|
case GGML_TYPE_IQ3_S:
|
|
return dequantize_row_iq3_s_cuda;
|
|
case GGML_TYPE_F32:
|
|
return convert_unary_cuda<float>;
|
|
default:
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
switch (type) {
|
|
case GGML_TYPE_Q4_0:
|
|
return dequantize_row_q4_0_cuda;
|
|
case GGML_TYPE_Q4_1:
|
|
return dequantize_row_q4_1_cuda;
|
|
case GGML_TYPE_Q5_0:
|
|
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
|
case GGML_TYPE_Q5_1:
|
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
|
case GGML_TYPE_Q8_0:
|
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
|
case GGML_TYPE_Q2_K:
|
|
return dequantize_row_q2_K_cuda;
|
|
case GGML_TYPE_Q3_K:
|
|
return dequantize_row_q3_K_cuda;
|
|
case GGML_TYPE_Q4_K:
|
|
return dequantize_row_q4_K_cuda;
|
|
case GGML_TYPE_Q5_K:
|
|
return dequantize_row_q5_K_cuda;
|
|
case GGML_TYPE_Q6_K:
|
|
return dequantize_row_q6_K_cuda;
|
|
case GGML_TYPE_IQ2_XXS:
|
|
return dequantize_row_iq2_xxs_cuda;
|
|
case GGML_TYPE_IQ2_XS:
|
|
return dequantize_row_iq2_xs_cuda;
|
|
case GGML_TYPE_IQ2_S:
|
|
return dequantize_row_iq2_s_cuda;
|
|
case GGML_TYPE_IQ3_XXS:
|
|
return dequantize_row_iq3_xxs_cuda;
|
|
case GGML_TYPE_IQ1_S:
|
|
return dequantize_row_iq1_s_cuda;
|
|
case GGML_TYPE_IQ1_M:
|
|
return dequantize_row_iq1_m_cuda;
|
|
case GGML_TYPE_IQ4_NL:
|
|
return dequantize_row_iq4_nl_cuda;
|
|
case GGML_TYPE_IQ4_XS:
|
|
return dequantize_row_iq4_xs_cuda;
|
|
case GGML_TYPE_IQ3_S:
|
|
return dequantize_row_iq3_s_cuda;
|
|
case GGML_TYPE_F16:
|
|
return convert_unary_cuda<half>;
|
|
default:
|
|
return nullptr;
|
|
}
|
|
}
|