mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-18 22:45:45 +02:00
ggml : add CUDA support for ggml_conv
This commit is contained in:
89
ggml-cuda.cu
89
ggml-cuda.cu
@@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
|||||||
*dsti = __float2half(*xi);
|
*dsti = __float2half(*xi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
||||||
|
const half * xi = (const half *) cxi;
|
||||||
|
half * dsti = (half *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_1>
|
template <cpy_kernel_t cpy_1>
|
||||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
@@ -4729,6 +4736,17 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|||||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
|
||||||
|
int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||||
|
int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||||
|
__syncthreads();
|
||||||
|
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
||||||
|
int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
|
||||||
|
int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||||
|
dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<int qk, int qr, dequantize_kernel_t dq>
|
template<int qk, int qr, dequantize_kernel_t dq>
|
||||||
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
||||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||||
@@ -5618,6 +5636,16 @@ static void ggml_cpy_f32_f16_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f16_f16_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
|
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||||
|
}
|
||||||
|
|
||||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||||
@@ -5701,6 +5729,16 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
|
|||||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
||||||
|
int OH, int IW, int IH,
|
||||||
|
int OW, int IC,
|
||||||
|
int KH, int KW, int N, int ofs0, int ofs1,
|
||||||
|
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
||||||
|
dim3 block_nums(IC, OH, OW);
|
||||||
|
dim3 block_dims(N, KH, KW);
|
||||||
|
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||||
|
}
|
||||||
|
|
||||||
// buffer pool for cuda
|
// buffer pool for cuda
|
||||||
#define MAX_CUDA_BUFFERS 256
|
#define MAX_CUDA_BUFFERS 256
|
||||||
|
|
||||||
@@ -6483,7 +6521,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|||||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
||||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||||
}
|
}
|
||||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
||||||
size_t dst_f16_as = 0;
|
size_t dst_f16_as = 0;
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
||||||
|
|
||||||
@@ -6659,6 +6697,45 @@ inline void ggml_cuda_op_alibi(
|
|||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_cuda_op_im2col(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||||
|
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||||
|
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
||||||
|
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
||||||
|
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
||||||
|
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
||||||
|
|
||||||
|
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||||
|
|
||||||
|
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
||||||
|
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
||||||
|
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
||||||
|
const int64_t IW = src1->ne[0];
|
||||||
|
|
||||||
|
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
||||||
|
const int64_t KW = src0->ne[0];
|
||||||
|
|
||||||
|
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||||
|
const int64_t OW = dst->ne[1];
|
||||||
|
|
||||||
|
|
||||||
|
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
||||||
|
OH, IW, IH, OW, IC, KH, KW, N,
|
||||||
|
src1->nb[is_2D ? 3 : 2] / 4, // nb is byte offset, src is type float32
|
||||||
|
src1->nb[is_2D ? 2 : 1] / 4, // nb is byte offset, src is type float32
|
||||||
|
s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
|
|
||||||
|
(void) src0;
|
||||||
|
(void) src0_dd;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_diag_mask_inf(
|
inline void ggml_cuda_op_diag_mask_inf(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
@@ -7549,6 +7626,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||||
ne10, ne11, nb10, nb11, nb12, main_stream);
|
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
|
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||||
|
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||||
@@ -7580,6 +7660,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src1;
|
(void) src1;
|
||||||
@@ -7943,6 +8027,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
func = ggml_cuda_alibi;
|
func = ggml_cuda_alibi;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
func = ggml_cuda_im2col;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
19
ggml.h
19
ggml.h
@@ -403,13 +403,8 @@ extern "C" {
|
|||||||
GGML_OP_ROPE_BACK,
|
GGML_OP_ROPE_BACK,
|
||||||
GGML_OP_ALIBI,
|
GGML_OP_ALIBI,
|
||||||
GGML_OP_CLAMP,
|
GGML_OP_CLAMP,
|
||||||
GGML_OP_CONV_1D,
|
|
||||||
GGML_OP_CONV_1D_STAGE_0, // internal
|
|
||||||
GGML_OP_CONV_1D_STAGE_1, // internal
|
|
||||||
GGML_OP_CONV_TRANSPOSE_1D,
|
GGML_OP_CONV_TRANSPOSE_1D,
|
||||||
GGML_OP_CONV_2D,
|
GGML_OP_IM2COL,
|
||||||
GGML_OP_CONV_2D_STAGE_0, // internal
|
|
||||||
GGML_OP_CONV_2D_STAGE_1, // internal
|
|
||||||
GGML_OP_CONV_TRANSPOSE_2D,
|
GGML_OP_CONV_TRANSPOSE_2D,
|
||||||
GGML_OP_POOL_1D,
|
GGML_OP_POOL_1D,
|
||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
@@ -1398,6 +1393,18 @@ extern "C" {
|
|||||||
float min,
|
float min,
|
||||||
float max);
|
float max);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_im2col(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
int s0,
|
||||||
|
int s1,
|
||||||
|
int p0,
|
||||||
|
int p1,
|
||||||
|
int d0,
|
||||||
|
int d1,
|
||||||
|
bool is_2D);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
21
whisper.cpp
21
whisper.cpp
@@ -588,7 +588,6 @@ struct whisper_kv_cache {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct whisper_model_data {
|
struct whisper_model_data {
|
||||||
ggml_backend_buffer_t buffer_conv; // TODO: tmp until GPU support for conv
|
|
||||||
ggml_backend_buffer_t buffer_main;
|
ggml_backend_buffer_t buffer_main;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -827,9 +826,8 @@ struct whisper_context {
|
|||||||
return backend_gpu ? backend_gpu : backend_cpu;
|
return backend_gpu ? backend_gpu : backend_cpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: always on CPU until we have a GPU support for conv
|
|
||||||
ggml_backend_t backend_conv() const {
|
ggml_backend_t backend_conv() const {
|
||||||
return backend_cpu;
|
return backend_gpu ? backend_gpu : backend_cpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_t backend_main() const {
|
ggml_backend_t backend_main() const {
|
||||||
@@ -1408,31 +1406,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
size_t size_main = 0;
|
size_t size_main = 0;
|
||||||
|
|
||||||
for (const auto & t : model.tensors) {
|
for (const auto & t : model.tensors) {
|
||||||
if (t.first.find("conv") != std::string::npos) {
|
|
||||||
size_conv += ggml_nbytes(t.second) + ggml_tensor_overhead();
|
|
||||||
} else {
|
|
||||||
size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
|
size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model.data->buffer_conv = ggml_backend_alloc_buffer(wctx.backend_conv(), size_conv);
|
|
||||||
model.data->buffer_main = ggml_backend_alloc_buffer(wctx.backend_main(), size_main);
|
model.data->buffer_main = ggml_backend_alloc_buffer(wctx.backend_main(), size_main);
|
||||||
|
|
||||||
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_conv()), size_conv / 1024.0 / 1024.0);
|
|
||||||
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_main()), size_main / 1024.0 / 1024.0);
|
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_main()), size_main / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_allocr * alloc_conv = ggml_allocr_new_from_buffer(model.data->buffer_conv);
|
|
||||||
ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main);
|
ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main);
|
||||||
|
|
||||||
// allocate tensors in the backend buffers
|
// allocate tensors in the backend buffers
|
||||||
{
|
{
|
||||||
for (const auto & t : model.tensors) {
|
for (const auto & t : model.tensors) {
|
||||||
if (t.first.find("conv") != std::string::npos) {
|
|
||||||
ggml_allocr_alloc(alloc_conv, t.second);
|
|
||||||
} else {
|
|
||||||
ggml_allocr_alloc(alloc_main, t.second);
|
ggml_allocr_alloc(alloc_main, t.second);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1496,9 +1483,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool is_conv = name.find("conv") != std::string::npos;
|
ggml_backend * backend = wctx.backend_main();
|
||||||
|
|
||||||
ggml_backend * backend = is_conv ? wctx.backend_conv() : wctx.backend_main();
|
|
||||||
|
|
||||||
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
|
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
|
||||||
|
|
||||||
@@ -1532,7 +1517,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_allocr_free(alloc_conv);
|
|
||||||
ggml_allocr_free(alloc_main);
|
ggml_allocr_free(alloc_main);
|
||||||
|
|
||||||
wctx.t_load_us = ggml_time_us() - t_start_us;
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
||||||
@@ -3273,7 +3257,6 @@ void whisper_free(struct whisper_context * ctx) {
|
|||||||
ggml_free(ctx->model.ctx);
|
ggml_free(ctx->model.ctx);
|
||||||
}
|
}
|
||||||
if (ctx->model.data) {
|
if (ctx->model.data) {
|
||||||
ggml_backend_buffer_free(ctx->model.data->buffer_conv);
|
|
||||||
ggml_backend_buffer_free(ctx->model.data->buffer_main);
|
ggml_backend_buffer_free(ctx->model.data->buffer_main);
|
||||||
|
|
||||||
delete ctx->model.data;
|
delete ctx->model.data;
|
||||||
|
Reference in New Issue
Block a user