diff --git a/whisper-mel-cuda.cu b/whisper-mel-cuda.cu index cc44556f..8d674142 100644 --- a/whisper-mel-cuda.cu +++ b/whisper-mel-cuda.cu @@ -2,6 +2,9 @@ #include "whisper-mel-cuda.hpp" #include "whisper.h" +#include +#include + #include #include #include @@ -16,16 +19,9 @@ #pragma warning(disable: 4324) // added padding #endif -#ifndef NDEBUG -# define DO_CHECKS 1 -#else -# define DO_CHECKS 0 -#endif - namespace { -#if DO_CHECKS -const char* cufftGetErrorString(cufftResult_t res) { +static const char* cufftGetErrorString(cufftResult_t res) { switch (res) { case CUFFT_SUCCESS: return "The cuFFT operation was successful"; case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle"; @@ -48,19 +44,6 @@ const char* cufftGetErrorString(cufftResult_t res) { } } -# define CUDA_CHECK_GEN(err, success, error_fn) \ - do { \ - auto err_ = (err); \ - if (err_ != (success)) { \ - fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \ - } \ - } while (0) -#else -# define CUDA_CHECK_GEN(err, success, error_fn) err -#endif - -#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) -#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString) #define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString) __global__ void k_fill_stft_input( @@ -81,7 +64,7 @@ __global__ void k_fill_stft_input( } __global__ void k_calc_magnitudes( - const cuComplex* stft_out, + const cuComplex * stft_out, const int n_frames, float * magnitudes ) { @@ -133,7 +116,7 @@ void fill_stft_input( } void calc_magnitudes( - const cuComplex* stft_out, + const cuComplex * stft_out, int n_frames, float * magnitudes, cudaStream_t stream @@ -169,6 +152,7 @@ class mel_calc_cuda : public whisper_mel_calc { const int m_n_mel; ggml_backend_t m_backend = nullptr; + int m_device = -1; cudaStream_t m_stream = nullptr; cublasHandle_t m_cublas_handle = nullptr; @@ -190,6 +174,18 @@ public: : m_n_mel(filters.n_mel) , m_backend(backend) { + ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*)m_backend->context; + m_device = cuda_ctx->device; + + if (ggml_cuda_info().devices[m_device].cc < 600) { + // we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0: + // https://github.com/ggerganov/whisper.cpp/issues/2230 + // to be safe forbid anything below 6.0 + throw std::runtime_error("CUDA compute capability 6.0 or higher is required"); + } + + ggml_cuda_set_device(m_device); + if (filters.n_fft != WHISPER_N_FFT_HALF) { throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF"); } @@ -219,6 +215,7 @@ public: } ~mel_calc_cuda() { + ggml_cuda_set_device(m_device); CUDA_CHECK(cudaStreamSynchronize(m_stream)); CUDA_CHECK(cudaStreamDestroy(m_stream)); CUDA_CHECK(cudaFree(m_hann_window)); @@ -268,6 +265,7 @@ public: } virtual whisper_mel calculate(whisper_span samples, int /*n_threads*/) override { + ggml_cuda_set_device(m_device); ensure_working_areas(samples.len); const size_t mirror_pad = WHISPER_N_FFT / 2; @@ -356,8 +354,11 @@ public: } whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) { - if (filters.n_fft != WHISPER_N_FFT_HALF) { + try { + return new mel_calc_cuda(backend, filters); + } + catch (...) { + // TODO: log error (but for this we would have to expose the log state to be accessible here) return nullptr; } - return new mel_calc_cuda(backend, filters); } diff --git a/whisper.cpp b/whisper.cpp index a08f15ff..71eb0673 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3170,13 +3170,18 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper #if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS) if (ggml_backend_is_cuda(backend)) { auto ret = whisper_mel_calc_create_cuda(backend, filters); - // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run) - const float warmup[256] = {0}; - ret->calculate({warmup, 256}, 1); - return ret; - } else + if (ret) { + // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run) + const float warmup[256] = { 0 }; + ret->calculate({ warmup, 256 }, 1); + return ret; + } + } #endif - return new mel_calc_cpu(backend, filters); + + // a specialized mel_calc could not be created + // fall back to CPU + return new mel_calc_cpu(backend, filters); } // split text into tokens