mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-10 07:20:33 +01:00
whisper : use ggml-cuda in mel calc, set appropriate device (#2236)
* whisper : use ggml-cuda in mel calc, set appropriate device * whisper : forbid cuda mel calc on devices with compute < 600, workaround for #2230
This commit is contained in:
parent
420b6abc54
commit
b29b3b2924
@ -2,6 +2,9 @@
|
|||||||
#include "whisper-mel-cuda.hpp"
|
#include "whisper-mel-cuda.hpp"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
|
||||||
|
#include <ggml-cuda/common.cuh>
|
||||||
|
#include <ggml-backend-impl.h>
|
||||||
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cufft.h>
|
#include <cufft.h>
|
||||||
@ -16,16 +19,9 @@
|
|||||||
#pragma warning(disable: 4324) // added padding
|
#pragma warning(disable: 4324) // added padding
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
# define DO_CHECKS 1
|
|
||||||
#else
|
|
||||||
# define DO_CHECKS 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if DO_CHECKS
|
static const char* cufftGetErrorString(cufftResult_t res) {
|
||||||
const char* cufftGetErrorString(cufftResult_t res) {
|
|
||||||
switch (res) {
|
switch (res) {
|
||||||
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
|
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
|
||||||
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
|
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
|
||||||
@ -48,19 +44,6 @@ const char* cufftGetErrorString(cufftResult_t res) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# define CUDA_CHECK_GEN(err, success, error_fn) \
|
|
||||||
do { \
|
|
||||||
auto err_ = (err); \
|
|
||||||
if (err_ != (success)) { \
|
|
||||||
fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
#else
|
|
||||||
# define CUDA_CHECK_GEN(err, success, error_fn) err
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
|
||||||
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
|
|
||||||
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
|
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
|
||||||
|
|
||||||
__global__ void k_fill_stft_input(
|
__global__ void k_fill_stft_input(
|
||||||
@ -81,7 +64,7 @@ __global__ void k_fill_stft_input(
|
|||||||
}
|
}
|
||||||
|
|
||||||
__global__ void k_calc_magnitudes(
|
__global__ void k_calc_magnitudes(
|
||||||
const cuComplex* stft_out,
|
const cuComplex * stft_out,
|
||||||
const int n_frames,
|
const int n_frames,
|
||||||
float * magnitudes
|
float * magnitudes
|
||||||
) {
|
) {
|
||||||
@ -133,7 +116,7 @@ void fill_stft_input(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void calc_magnitudes(
|
void calc_magnitudes(
|
||||||
const cuComplex* stft_out,
|
const cuComplex * stft_out,
|
||||||
int n_frames,
|
int n_frames,
|
||||||
float * magnitudes,
|
float * magnitudes,
|
||||||
cudaStream_t stream
|
cudaStream_t stream
|
||||||
@ -169,6 +152,7 @@ class mel_calc_cuda : public whisper_mel_calc {
|
|||||||
const int m_n_mel;
|
const int m_n_mel;
|
||||||
|
|
||||||
ggml_backend_t m_backend = nullptr;
|
ggml_backend_t m_backend = nullptr;
|
||||||
|
int m_device = -1;
|
||||||
|
|
||||||
cudaStream_t m_stream = nullptr;
|
cudaStream_t m_stream = nullptr;
|
||||||
cublasHandle_t m_cublas_handle = nullptr;
|
cublasHandle_t m_cublas_handle = nullptr;
|
||||||
@ -190,6 +174,18 @@ public:
|
|||||||
: m_n_mel(filters.n_mel)
|
: m_n_mel(filters.n_mel)
|
||||||
, m_backend(backend)
|
, m_backend(backend)
|
||||||
{
|
{
|
||||||
|
ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*)m_backend->context;
|
||||||
|
m_device = cuda_ctx->device;
|
||||||
|
|
||||||
|
if (ggml_cuda_info().devices[m_device].cc < 600) {
|
||||||
|
// we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0:
|
||||||
|
// https://github.com/ggerganov/whisper.cpp/issues/2230
|
||||||
|
// to be safe forbid anything below 6.0
|
||||||
|
throw std::runtime_error("CUDA compute capability 6.0 or higher is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cuda_set_device(m_device);
|
||||||
|
|
||||||
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
||||||
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
|
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
|
||||||
}
|
}
|
||||||
@ -219,6 +215,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
~mel_calc_cuda() {
|
~mel_calc_cuda() {
|
||||||
|
ggml_cuda_set_device(m_device);
|
||||||
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
||||||
CUDA_CHECK(cudaStreamDestroy(m_stream));
|
CUDA_CHECK(cudaStreamDestroy(m_stream));
|
||||||
CUDA_CHECK(cudaFree(m_hann_window));
|
CUDA_CHECK(cudaFree(m_hann_window));
|
||||||
@ -268,6 +265,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
|
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
|
||||||
|
ggml_cuda_set_device(m_device);
|
||||||
ensure_working_areas(samples.len);
|
ensure_working_areas(samples.len);
|
||||||
|
|
||||||
const size_t mirror_pad = WHISPER_N_FFT / 2;
|
const size_t mirror_pad = WHISPER_N_FFT / 2;
|
||||||
@ -356,8 +354,11 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
|
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
|
||||||
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
try {
|
||||||
|
return new mel_calc_cuda(backend, filters);
|
||||||
|
}
|
||||||
|
catch (...) {
|
||||||
|
// TODO: log error (but for this we would have to expose the log state to be accessible here)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return new mel_calc_cuda(backend, filters);
|
|
||||||
}
|
}
|
||||||
|
17
whisper.cpp
17
whisper.cpp
@ -3170,13 +3170,18 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
|
|||||||
#if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
|
||||||
if (ggml_backend_is_cuda(backend)) {
|
if (ggml_backend_is_cuda(backend)) {
|
||||||
auto ret = whisper_mel_calc_create_cuda(backend, filters);
|
auto ret = whisper_mel_calc_create_cuda(backend, filters);
|
||||||
// run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
|
if (ret) {
|
||||||
const float warmup[256] = {0};
|
// run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
|
||||||
ret->calculate({warmup, 256}, 1);
|
const float warmup[256] = { 0 };
|
||||||
return ret;
|
ret->calculate({ warmup, 256 }, 1);
|
||||||
} else
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
return new mel_calc_cpu(backend, filters);
|
|
||||||
|
// a specialized mel_calc could not be created
|
||||||
|
// fall back to CPU
|
||||||
|
return new mel_calc_cpu(backend, filters);
|
||||||
}
|
}
|
||||||
|
|
||||||
// split text into tokens
|
// split text into tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user