whisper : auto-grow working areas for mel_calc_cuda (#2227)

* whisper : auto-grow working areas for mel_calc_cuda, fixes #2226

* whisper : only calculate mel spectrogram on GPU if audio is <= 5 min
This commit is contained in:
Borislav Stanimirov 2024-06-10 21:51:32 +03:00 committed by GitHub
parent c2bdb960cd
commit 20c542c713
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 30 deletions

View File

@ -145,17 +145,6 @@ void calc_magnitudes(
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
size_t get_log_mel_temp_storage_size() {
constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT;
constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
constexpr auto maxMels = 160;
size_t nbytes = 0;
float * temp = nullptr;
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, maxFrames * maxMels);
return nbytes + LOG_MEL_PREFIX_SIZE;
}
void calc_log_mel(
const float * mel_data,
int n_mel,
@ -186,11 +175,14 @@ class mel_calc_cuda : public whisper_mel_calc {
float * m_hann_window = nullptr;
float * m_filters = nullptr;
// max samples for which we have allocated memory for the temp working areas below (cufft, log_mel)
int m_n_max_samples = 0;
size_t m_cufft_workspace_size = 0;
void * m_cufft_workspace = nullptr;
float * m_filters = nullptr;
size_t m_log_mel_temp_storage_size = 0;
void * m_log_mel_temp_storage = nullptr;
public:
@ -215,14 +207,6 @@ public:
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// create working area
{
constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT;
constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, maxFrames, &m_cufft_workspace_size));
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
}
// fill filters
{
auto& f = filters.data;
@ -230,10 +214,8 @@ public:
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
{
m_log_mel_temp_storage_size = get_log_mel_temp_storage_size();
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
}
// preallocate working areas enough for the most common cases (<= 30s)
ensure_working_areas(WHISPER_N_SAMPLES);
}
~mel_calc_cuda() {
@ -245,7 +227,49 @@ public:
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
}
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) const override {
void ensure_working_areas(int n_samples) {
if (n_samples <= m_n_max_samples) {
return;
}
const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT;
const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
// cufft workspace
{
if (m_cufft_workspace) {
CUDA_CHECK(cudaFree(m_cufft_workspace));
m_cufft_workspace_size = 0;
m_cufft_workspace = nullptr;
}
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size));
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
}
// device reduce working area
{
if (m_log_mel_temp_storage) {
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
m_log_mel_temp_storage_size = 0;
m_log_mel_temp_storage = nullptr;
}
const auto max_mels = 160;
size_t nbytes = 0;
float* temp = nullptr;
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels);
m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE;
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
}
m_n_max_samples = n_samples;
}
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
ensure_working_areas(samples.len);
const size_t mirror_pad = WHISPER_N_FFT / 2;
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;

View File

@ -29,6 +29,6 @@ struct whisper_span {
struct whisper_mel_calc {
virtual ~whisper_mel_calc();
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) const = 0;
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) = 0;
static whisper_span<const float> hann_window();
};

View File

@ -802,6 +802,7 @@ struct whisper_state {
whisper_mel mel;
whisper_mel_calc * mel_calc = nullptr;
whisper_mel_calc * mel_calc_fallback = nullptr;
whisper_batch batch;
@ -3079,7 +3080,7 @@ struct mel_calc_cpu : public whisper_mel_calc {
mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) const override {
whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) override {
// Hann window
const float * hann = global_cache.hann_window;
@ -3721,6 +3722,8 @@ void whisper_free_state(struct whisper_state * state) {
delete state->mel_calc;
state->mel_calc = nullptr;
delete state->mel_calc_fallback;
state->mel_calc_fallback = nullptr;
#ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) {
@ -3778,11 +3781,24 @@ void whisper_free_params(struct whisper_full_params * params) {
}
}
int whisper_pcm_to_mel_with_state(struct whisper_context * /*ctx*/, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
whisper_mel_free(state->mel);
if (n_samples <= 5 * 60 * WHISPER_SAMPLE_RATE) {
// calculate mel spectrogram for lengths up to 5 minutes on the most optimal mel calculator
state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads);
} else {
// calcuate mel spectrogram for longer audios on the CPU
// 1. gpu calculations may use hundreds of megabytes of memory for longer audios so we're being conservative
// with our gpu demands
// 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation
// taking longer is not a major concern
if (!state->mel_calc_fallback) {
state->mel_calc_fallback = new mel_calc_cpu(state->backend, ctx->model.filters);
}
state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads);
}
state->t_mel_us += ggml_time_us() - t_start_us;