ggml : add CUDA support for ggml_conv

This commit is contained in:
Georgi Gerganov
2023-11-10 15:10:27 +02:00
parent c99e290a7f
commit 81506268ba
4 changed files with 172 additions and 982 deletions

View File

@@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
*dsti = __float2half(*xi); *dsti = __float2half(*xi);
} }
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4729,6 +4736,17 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
} }
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
__syncthreads();
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
}
}
template<int qk, int qr, dequantize_kernel_t dq> template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5618,6 +5636,16 @@ static void ggml_cpy_f32_f16_cuda(
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
} }
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5701,6 +5729,16 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x); soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
} }
static void im2col_f32_f16_cuda(const float* x, half* dst,
int OH, int IW, int IH,
int OW, int IC,
int KH, int KW, int N, int ofs0, int ofs1,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
dim3 block_nums(IC, OH, OW);
dim3 block_dims(N, KH, KW);
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda // buffer pool for cuda
#define MAX_CUDA_BUFFERS 256 #define MAX_CUDA_BUFFERS 256
@@ -6483,7 +6521,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
} }
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_f16_as = 0; size_t dst_f16_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
@@ -6659,6 +6697,45 @@ inline void ggml_cuda_op_alibi(
(void) src1_dd; (void) src1_dd;
} }
inline void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
OH, IW, IH, OW, IC, KH, KW, N,
src1->nb[is_2D ? 3 : 2] / 4, // nb is byte offset, src is type float32
src1->nb[is_2D ? 2 : 1] / 4, // nb is byte offset, src is type float32
s0, s1, p0, p1, d0, d1, main_stream);
(void) src0;
(void) src0_dd;
}
inline void ggml_cuda_op_diag_mask_inf( inline void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7549,6 +7626,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream); ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else { } else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7580,6 +7660,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
} }
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0; (void) src0;
(void) src1; (void) src1;
@@ -7943,6 +8027,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
func = ggml_cuda_alibi; func = ggml_cuda_alibi;
break; break;
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
default: default:
return false; return false;
} }

1025
ggml.c

File diff suppressed because it is too large Load Diff

19
ggml.h
View File

@@ -403,13 +403,8 @@ extern "C" {
GGML_OP_ROPE_BACK, GGML_OP_ROPE_BACK,
GGML_OP_ALIBI, GGML_OP_ALIBI,
GGML_OP_CLAMP, GGML_OP_CLAMP,
GGML_OP_CONV_1D,
GGML_OP_CONV_1D_STAGE_0, // internal
GGML_OP_CONV_1D_STAGE_1, // internal
GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_CONV_2D, GGML_OP_IM2COL,
GGML_OP_CONV_2D_STAGE_0, // internal
GGML_OP_CONV_2D_STAGE_1, // internal
GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D, GGML_OP_POOL_1D,
GGML_OP_POOL_2D, GGML_OP_POOL_2D,
@@ -1398,6 +1393,18 @@ extern "C" {
float min, float min,
float max); float max);
GGML_API struct ggml_tensor * ggml_im2col(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1,
bool is_2D);
GGML_API struct ggml_tensor * ggml_conv_1d( GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,

View File

@@ -588,7 +588,6 @@ struct whisper_kv_cache {
}; };
struct whisper_model_data { struct whisper_model_data {
ggml_backend_buffer_t buffer_conv; // TODO: tmp until GPU support for conv
ggml_backend_buffer_t buffer_main; ggml_backend_buffer_t buffer_main;
}; };
@@ -827,9 +826,8 @@ struct whisper_context {
return backend_gpu ? backend_gpu : backend_cpu; return backend_gpu ? backend_gpu : backend_cpu;
} }
// TODO: always on CPU until we have a GPU support for conv
ggml_backend_t backend_conv() const { ggml_backend_t backend_conv() const {
return backend_cpu; return backend_gpu ? backend_gpu : backend_cpu;
} }
ggml_backend_t backend_main() const { ggml_backend_t backend_main() const {
@@ -1408,31 +1406,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
size_t size_main = 0; size_t size_main = 0;
for (const auto & t : model.tensors) { for (const auto & t : model.tensors) {
if (t.first.find("conv") != std::string::npos) {
size_conv += ggml_nbytes(t.second) + ggml_tensor_overhead();
} else {
size_main += ggml_nbytes(t.second) + ggml_tensor_overhead(); size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
}
} }
model.data->buffer_conv = ggml_backend_alloc_buffer(wctx.backend_conv(), size_conv);
model.data->buffer_main = ggml_backend_alloc_buffer(wctx.backend_main(), size_main); model.data->buffer_main = ggml_backend_alloc_buffer(wctx.backend_main(), size_main);
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_conv()), size_conv / 1024.0 / 1024.0);
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_main()), size_main / 1024.0 / 1024.0); WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_main()), size_main / 1024.0 / 1024.0);
} }
ggml_allocr * alloc_conv = ggml_allocr_new_from_buffer(model.data->buffer_conv);
ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main); ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main);
// allocate tensors in the backend buffers // allocate tensors in the backend buffers
{ {
for (const auto & t : model.tensors) { for (const auto & t : model.tensors) {
if (t.first.find("conv") != std::string::npos) {
ggml_allocr_alloc(alloc_conv, t.second);
} else {
ggml_allocr_alloc(alloc_main, t.second); ggml_allocr_alloc(alloc_main, t.second);
}
} }
} }
@@ -1496,9 +1483,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return false; return false;
} }
const bool is_conv = name.find("conv") != std::string::npos; ggml_backend * backend = wctx.backend_main();
ggml_backend * backend = is_conv ? wctx.backend_conv() : wctx.backend_main();
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str()); //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
@@ -1532,7 +1517,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
} }
} }
ggml_allocr_free(alloc_conv);
ggml_allocr_free(alloc_main); ggml_allocr_free(alloc_main);
wctx.t_load_us = ggml_time_us() - t_start_us; wctx.t_load_us = ggml_time_us() - t_start_us;
@@ -3273,7 +3257,6 @@ void whisper_free(struct whisper_context * ctx) {
ggml_free(ctx->model.ctx); ggml_free(ctx->model.ctx);
} }
if (ctx->model.data) { if (ctx->model.data) {
ggml_backend_buffer_free(ctx->model.data->buffer_conv);
ggml_backend_buffer_free(ctx->model.data->buffer_main); ggml_backend_buffer_free(ctx->model.data->buffer_main);
delete ctx->model.data; delete ctx->model.data;