From 270b1e48dbdcb68679b86ccf073455c506907809 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 Nov 2023 15:52:06 +0200 Subject: [PATCH] cuda : sync llama.cpp fixes --- ggml-cuda.cu | 306 +++++++++++++++++++++++++++++++-------------------- ggml-cuda.h | 5 + whisper.cpp | 2 +- 3 files changed, 194 insertions(+), 119 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 058011a4..c0c9edd5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -39,7 +39,6 @@ #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess -#define cudaDeviceGetMemPool hipDeviceGetMemPool #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -49,7 +48,6 @@ #define cudaEvent_t hipEvent_t #define cudaEventDestroy hipEventDestroy #define cudaFree hipFree -#define cudaFreeAsync hipFreeAsync #define cudaFreeHost hipHostFree #define cudaGetDevice hipGetDevice #define cudaGetDeviceCount hipGetDeviceCount @@ -57,7 +55,6 @@ #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError #define cudaMalloc hipMalloc -#define cudaMallocFromPoolAsync hipMallocFromPoolAsync #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMemcpy hipMemcpy #define cudaMemcpy2DAsync hipMemcpy2DAsync @@ -66,9 +63,6 @@ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost #define cudaMemcpyHostToDevice hipMemcpyHostToDevice #define cudaMemcpyKind hipMemcpyKind -#define cudaMemPool_t hipMemPool_t -#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold -#define cudaMemPoolSetAttribute hipMemPoolSetAttribute #define cudaMemset hipMemset #define cudaMemsetAsync hipMemsetAsync #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize @@ -94,6 +88,8 @@ #define CC_OFFSET_AMD 1000000 #define CC_RDNA2 (CC_OFFSET_AMD + 1030) +#define GGML_CUDA_MAX_NODES 8192 + // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant // for large computational tasks. the drawback is that this requires some extra amount of VRAM: @@ -188,11 +184,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cudaError_t err_ = (err); \ if (err_ != cudaSuccess) { \ - int dev_id; \ - cudaGetDevice(&dev_id); \ + int id; \ + cudaGetDevice(&id); \ fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ cudaGetErrorString(err_)); \ - fprintf(stderr, "current device: %d\n", dev_id); \ + fprintf(stderr, "current device: %d\n", id); \ exit(1); \ } \ } while (0) @@ -202,11 +198,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cublasStatus_t err_ = (err); \ if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int dev_id; \ - cudaGetDevice(&dev_id); \ + int id; \ + cudaGetDevice(&id); \ fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ - fprintf(stderr, "current device: %d\n", dev_id); \ + fprintf(stderr, "current device: %d\n", id); \ exit(1); \ } \ } while (0) @@ -440,6 +436,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_RELU_BLOCK_SIZE 256 +#define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_CLAMP_BLOCK_SIZE 256 @@ -472,7 +470,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MAX_STREAMS 8 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; -static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr }; struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors @@ -561,6 +558,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void relu_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0); +} + +static __global__ void sqr_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * x[i]; +} + static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -990,7 +1005,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -1094,7 +1109,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -1198,7 +1213,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; @@ -1452,7 +1467,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -4262,7 +4277,7 @@ template static __global__ void template static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { return; @@ -4302,7 +4317,7 @@ template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { return; @@ -4741,7 +4756,7 @@ static __global__ void im2col_f32_f16( int ofs0, int ofs1, int IW, int IH, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0; - const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; + const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; const int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW + @@ -4793,6 +4808,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + relu_f32<<>>(x, dst, k); +} + +static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; + sqr_f32<<>>(x, dst, k); +} + static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { @@ -4901,7 +4926,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4910,7 +4936,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4919,7 +4945,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4928,7 +4954,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4937,7 +4963,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4947,7 +4973,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } @@ -4956,7 +4982,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); } @@ -4965,7 +4991,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); } @@ -4980,7 +5006,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } @@ -4988,7 +5014,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -4997,7 +5023,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5006,7 +5032,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK5_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5015,7 +5041,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK5_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5024,7 +5050,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK8_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5033,7 +5059,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5042,7 +5068,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5051,7 +5077,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5060,7 +5086,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5069,7 +5095,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5088,7 +5114,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec<1, 1, convert_f16> <<>>(vx, y, dst, ncols, nrows); @@ -5825,16 +5851,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) { - if (g_cudaMemPools[id] == nullptr) { - return ggml_cuda_pool_malloc(size, actual_size); - } - void *ptr; - CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream)); - *actual_size = size; - return ptr; -} - static void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; @@ -5852,12 +5868,10 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } +static bool g_cublas_loaded = false; -static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) { - if (g_cudaMemPools[id] == nullptr) { - return ggml_cuda_pool_free(ptr, actual_size); - } - CUDA_CHECK(cudaFreeAsync(ptr, stream)); +bool ggml_cublas_loaded(void) { + return g_cublas_loaded; } void ggml_init_cublas() { @@ -5872,7 +5886,12 @@ void ggml_init_cublas() { CUDA_CHECK(cudaDeviceSynchronize()); #endif - CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); + if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) { + initialized = true; + g_cublas_loaded = false; + return; + } + GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; #if defined(GGML_CUDA_FORCE_MMQ) @@ -5914,19 +5933,13 @@ void ggml_init_cublas() { // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); - - // configure memory pool - cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id); - if (err == cudaSuccess) { - size_t treshold = UINT64_MAX; - CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold)); - } } // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); initialized = true; + g_cublas_loaded = true; } } @@ -6193,6 +6206,34 @@ inline void ggml_cuda_op_silu( (void) src1_dd; } +inline void ggml_cuda_op_relu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_sqr( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_cuda_op_norm( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6514,7 +6555,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = row_diff*ne00; - src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream); + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); } const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; @@ -6525,12 +6566,12 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = src1_ncols*ne10; - src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16; - size_t dst_f16_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; @@ -6548,15 +6589,14 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); - if (dst_f16_as != 0) { - ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream); - } + ggml_cuda_pool_free(dst_f16, dst_as); if (src0_as != 0) { - ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream); + ggml_cuda_pool_free(src0_as_f16, src0_as); } + if (src1_as != 0) { - ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream); + ggml_cuda_pool_free(src1_as_f16, src1_as); } } else { @@ -6566,7 +6606,7 @@ inline void ggml_cuda_op_mul_mat_cublas( if (src0->type != GGML_TYPE_F32) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); GGML_ASSERT(to_fp32_cuda != nullptr); - src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT + src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); } const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; @@ -6583,7 +6623,7 @@ inline void ggml_cuda_op_mul_mat_cublas( &beta, dst_dd_i, ldc)); if (src0_as != 0) { - ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream); + ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); } } @@ -7008,6 +7048,8 @@ static void ggml_cuda_op_mul_mat( int64_t row_low[GGML_CUDA_MAX_DEVICES]; int64_t row_high[GGML_CUDA_MAX_DEVICES]; + int used_devices = 0; + for (int64_t id = 0; id < g_device_count; ++id) { // by default, use all rows row_low[id] = 0; @@ -7035,6 +7077,8 @@ static void ggml_cuda_op_mul_mat( continue; } + used_devices++; + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; @@ -7045,22 +7089,21 @@ static void ggml_cuda_op_mul_mat( src0_dd[id] = (char *) src0_extra->data_device[id]; } else { const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); - src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream); + src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); } if (src1_on_device && src1_is_contiguous) { src1_ddf[id] = (float *) src1_extra->data_device[id]; } else { - src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream); + src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); } if (convert_src1_to_q8_1) { - const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; - src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream); + src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); if (src1_on_device && src1_is_contiguous) { quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); - // CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaGetLastError()); } } @@ -7068,18 +7111,18 @@ static void ggml_cuda_op_mul_mat( dst_dd[id] = (float *) dst_extra->data_device[id]; } else { const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); - dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream); + dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); } } // if multiple devices are used they need to wait for the main device // here an event is recorded that signals that the main device has finished calculating the input data - if (split && g_device_count > 1) { + if (split && used_devices > 1) { CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); } - const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; @@ -7194,6 +7237,27 @@ static void ggml_cuda_op_mul_mat( } } + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } + CUDA_CHECK(ggml_cuda_set_device(id)); + + // free buffers again when done + if (src0_as[id] > 0) { + ggml_cuda_pool_free(src0_dd[id], src0_as[id]); + } + if (src1_asf[id] > 0) { + ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); + } + if (src1_asq[id] > 0) { + ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); + } + if (dst_as[id] > 0) { + ggml_cuda_pool_free(dst_dd[id], dst_as[id]); + } + } + // main device waits for all other devices to be finished if (split && g_device_count > 1) { int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; @@ -7201,6 +7265,9 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); for (int64_t id = 0; id < g_device_count; ++id) { + if (row_low[id] == row_high[id]) { + continue; + } for (int64_t is = 0; is < is_max; ++is) { CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0)); } @@ -7211,21 +7278,6 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaDeviceSynchronize()); } - - for (int64_t id = 0; id < g_device_count; ++id) { - if (src0_as[id] > 0) { - ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]); - } - if (src1_asf[id] > 0) { - ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]); - } - if (src1_asq[id] > 0) { - ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]); - } - if (dst_as[id] > 0) { - ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]); - } - } } static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7252,6 +7304,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); } +static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); +} + +static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); +} + static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); } @@ -7261,6 +7321,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src } bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + if (!g_cublas_loaded) return false; + const int64_t ne10 = src1->ne[0]; const int64_t ne0 = dst->ne[0]; @@ -7412,11 +7474,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const GGML_ASSERT(to_fp16_cuda != nullptr); size_t src1_as = 0; - half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream); + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream); + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -7470,8 +7532,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const size_t ptrs_src_s = 0; size_t ptrs_dst_s = 0; - ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream); - ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream); + ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); + ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); dim3 block_dims(ne13, ne12); k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( @@ -7484,6 +7546,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const dst->nb[2], dst->nb[3], r2, r3); CUDA_CHECK(cudaGetLastError()); + CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, @@ -7495,30 +7558,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_GEMM_DEFAULT_TENSOR_OP)); if (ptrs_src_s != 0) { - ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream); + ggml_cuda_pool_free(ptrs_src, ptrs_src_s); } if (ptrs_dst_s != 0) { - ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream); + ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); } } #endif const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); - if (src1_as != 0) { - ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream); - } - if (dst_as != 0) { - ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream); - } + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = - (src0->backend == GGML_BACKEND_GPU) && + (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && (src1->backend == GGML_BACKEND_GPU) && ( dst->backend == GGML_BACKEND_GPU); + const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + int64_t min_compute_capability = INT_MAX; for (int64_t id = 0; id < g_device_count; ++id) { if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { @@ -7540,13 +7602,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch ggml_cuda_mul_mat_vec_p021(src0, src1, dst); - } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_cuda_mul_mat_vec_nc(src0, src1, dst); - } else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { + } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { @@ -7667,7 +7729,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } -void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); } @@ -7782,11 +7844,11 @@ static size_t g_temp_tensor_extra_index = 0; static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { if (g_temp_tensor_extras == nullptr) { - g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE]; + g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES]; } size_t alloc_index = g_temp_tensor_extra_index; - g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE; + g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES; ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; memset(extra, 0, sizeof(*extra)); @@ -7953,6 +8015,8 @@ void ggml_cuda_free_scratch() { } bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + if (!g_cublas_loaded) return false; + ggml_cuda_func_t func; const bool any_on_device = tensor->backend == GGML_BACKEND_GPU || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) @@ -7995,6 +8059,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_UNARY_OP_SILU: func = ggml_cuda_silu; break; + case GGML_UNARY_OP_RELU: + func = ggml_cuda_relu; + break; default: return false; } break; @@ -8013,6 +8080,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_SCALE: func = ggml_cuda_scale; break; + case GGML_OP_SQR: + func = ggml_cuda_sqr; + break; case GGML_OP_CLAMP: if (!any_on_device) { return false; @@ -8105,11 +8175,11 @@ struct ggml_backend_buffer_context_cuda { ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { if (temp_tensor_extras == nullptr) { - temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE]; + temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES]; } size_t alloc_index = temp_tensor_extra_index; - temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE; + temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES; ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; memset(extra, 0, sizeof(*extra)); diff --git a/ggml-cuda.h b/ggml-cuda.h index 57adc9cf..528e66c3 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -17,7 +17,12 @@ extern "C" { #define GGML_CUDA_MAX_DEVICES 16 +// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`. GGML_API void ggml_init_cublas(void); + +// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`. +GGML_API bool ggml_cublas_loaded(void); + GGML_API void * ggml_cuda_host_malloc(size_t size); GGML_API void ggml_cuda_host_free(void * ptr); diff --git a/whisper.cpp b/whisper.cpp index e2bfa41e..a3e0fbd0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1061,7 +1061,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params // initialize the backends #ifdef GGML_USE_CUBLAS - if (params.use_gpu) { + if (params.use_gpu && ggml_cublas_loaded()) { WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); backend_gpu = ggml_backend_cuda_init(); if (!backend_gpu) {