From 3dead611bb5f5519467b47e2d0adee88034d7332 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 15 Apr 2023 14:18:46 +0300 Subject: [PATCH] whisper : slightly faster Log Mel computation + n-1 FFT threads (#568) --- whisper.cpp | 52 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 178766e8..8e9fa6cd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2306,10 +2306,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector std::vector fft_in(fft_size, 0.0); std::vector fft_out(2 * fft_size); int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2); - + for (int i = ith; i < mel.n_len; i += n_threads) { const int offset = i * fft_step; - + // apply Hanning window for (int j = 0; j < fft_size; j++) { if (offset + j < n_samples) { @@ -2318,37 +2318,49 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector fft_in[j] = 0.0; } } - + // FFT -> mag^2 fft(fft_in, fft_out); - + for (int j = 0; j < fft_size; j++) { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } for (int j = 1; j < fft_size / 2; j++) { fft_out[j] += fft_out[fft_size - j]; } - + if (speed_up) { // scale down in the frequency domain results in a speed up in the time domain for (int j = 0; j < n_fft; j++) { fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]); } } - + // mel spectrogram for (int j = 0; j < mel.n_mel; j++) { double sum = 0.0; - - for (int k = 0; k < n_fft; k++) { + + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fft - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j*n_fft + k + 0] + + fft_out[k + 1] * filters.data[j*n_fft + k + 1] + + fft_out[k + 2] * filters.data[j*n_fft + k + 2] + + fft_out[k + 3] * filters.data[j*n_fft + k + 3]; + } + + // handle n_fft remainder + for (; k < n_fft; k++) { sum += fft_out[k] * filters.data[j * n_fft + k]; } + if (sum < 1e-10) { sum = 1e-10; } - + sum = log10(sum); - + mel.data[j * mel.n_len + i] = sum; } } @@ -2383,17 +2395,19 @@ static bool log_mel_spectrogram( //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); - if (n_threads == 1) { - log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel); - } else { - std::vector workers(n_threads); - for (int iw = 0; iw < n_threads; ++iw) { - workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples, - n_samples, fft_size, fft_step, n_threads, - std::cref(filters), speed_up, std::ref(mel)); + { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw] = std::thread( + log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples, + n_samples, fft_size, fft_step, n_threads, + std::cref(filters), speed_up, std::ref(mel)); } - for (int iw = 0; iw < n_threads; ++iw) { + // main thread + log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel); + + for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } }