mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-01 22:54:53 +02:00
cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (llama/12000)
This commit is contained in:
parent
b1385e9aa9
commit
98dab49b9a
@ -1,4 +1,5 @@
|
|||||||
#include "cpy.cuh"
|
#include "cpy.cuh"
|
||||||
|
#include "dequantize.cuh"
|
||||||
|
|
||||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||||
|
|
||||||
@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||||
const block_q8_0 * xi = (const block_q8_0 *) cxi;
|
float * cdstf = (float *)(cdsti);
|
||||||
float * dsti = (float *) cdsti;
|
|
||||||
|
|
||||||
const float d = (float)xi->d;
|
#pragma unroll
|
||||||
|
for (int j = 0; j < QK8_0; j += 2) {
|
||||||
for (int j = 0; j < QK8_0; j++) {
|
dfloat2 dq;
|
||||||
dsti[j] = xi->qs[j] * d;
|
dequantize_q8_0(cxi, 0, j, dq);
|
||||||
|
*(cdstf + j) = dq.x;
|
||||||
|
*(cdstf + j + 1) = dq.y;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
|||||||
memcpy(dsti->qh, &qh, sizeof(qh));
|
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<dequantize_kernel_t dequant, int qk>
|
||||||
|
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||||
|
float * cdstf = (float *)(cdsti);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < qk/2; j++) {
|
||||||
|
dfloat2 dq;
|
||||||
|
dequant(cxi, 0, j, dq);
|
||||||
|
*(cdstf + j) = dq.x;
|
||||||
|
*(cdstf + j + qk/2) = dq.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||||
if (x <= val[0]) return 0;
|
if (x <= val[0]) return 0;
|
||||||
@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q4_0_f32_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02,
|
||||||
|
const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||||
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_1_cuda(
|
static void ggml_cpy_f32_q4_1_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q4_1_f32_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02,
|
||||||
|
const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||||
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q5_0_cuda(
|
static void ggml_cpy_f32_q5_0_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q5_0_f32_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02,
|
||||||
|
const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||||
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q5_1_cuda(
|
static void ggml_cpy_f32_q5_1_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q5_1_f32_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02,
|
||||||
|
const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||||
|
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int num_blocks = ne;
|
||||||
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||||
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||||
|
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||||
|
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||||
|
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|||||||
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||||
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||||
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||||
|
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||||
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||||
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||||
|
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
|
@ -3075,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user