whisper : optimize fft() function (#2242)

Co-authored-by: Mike Fan <60965742+mike-fzy@users.noreply.github.com>
This commit is contained in:
mky_coder 2024-06-18 23:10:33 +08:00 committed by GitHub
parent e293f17d34
commit bf4cb4abad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2974,10 +2974,7 @@ whisper_span<const float> whisper_mel_calc::hann_window() {
// naive Discrete Fourier Transform // naive Discrete Fourier Transform
// input is real-valued // input is real-valued
// output is complex-valued // output is complex-valued
static void dft(const std::vector<float> & in, std::vector<float> & out) { static void dft(const float* in, int N, float* out) {
int N = in.size();
out.resize(N*2);
const int sin_cos_step = SIN_COS_N_COUNT / N; const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N; k++) { for (int k = 0; k < N; k++) {
@ -2999,44 +2996,35 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better // poor man's implementation - use something better
// input is real-valued // input is real-valued
// output is complex-valued // output is complex-valued
static void fft(const std::vector<float> & in, std::vector<float> & out) { static void fft(float* in, int N, float* out) {
out.resize(in.size()*2);
int N = in.size();
if (N == 1) { if (N == 1) {
out[0] = in[0]; out[0] = in[0];
out[1] = 0; out[1] = 0;
return; return;
} }
if (N%2 == 1) { const int half_N = N / 2;
dft(in, out); if (N - half_N*2 == 1) {
dft(in, N, out);
return; return;
} }
std::vector<float> even; float* even = in + N;
std::vector<float> odd; for (int i = 0; i < half_N; ++i) {
even[i]= in[2*i];
even.reserve(N/2);
odd.reserve(N/2);
for (int i = 0; i < N; i++) {
if (i % 2 == 0) {
even.push_back(in[i]);
} else {
odd.push_back(in[i]);
}
} }
float* even_fft = out + 2 * N;
fft(even, half_N, even_fft);
std::vector<float> even_fft; float* odd = even;
std::vector<float> odd_fft; for (int i = 0; i < half_N; ++i) {
odd[i] = in[2*i + 1];
fft(even, even_fft); }
fft(odd, odd_fft); float* odd_fft = even_fft + N;
fft(odd, half_N, odd_fft);
const int sin_cos_step = SIN_COS_N_COUNT / N; const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N/2; k++) { for (int k = 0; k < half_N; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = global_cache.cos_vals[idx]; // cos(t) float re = global_cache.cos_vals[idx]; // cos(t)
float im = -global_cache.sin_vals[idx]; // sin(t) float im = -global_cache.sin_vals[idx]; // sin(t)
@ -3047,8 +3035,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
} }
} }
@ -3066,8 +3054,8 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
const whisper_filters & filters, whisper_mel_data & mel) { const whisper_filters & filters, whisper_mel_data & mel) {
const auto frame_size = WHISPER_N_FFT; const auto frame_size = WHISPER_N_FFT;
const auto frame_step = WHISPER_HOP_LENGTH; const auto frame_step = WHISPER_HOP_LENGTH;
std::vector<float> fft_in(frame_size, 0.0); std::vector<float> fft_in(frame_size * 2, 0.0);
std::vector<float> fft_out(2 * frame_size); std::vector<float> fft_out(frame_size * 2 * 2 * 2);
int n_fft = filters.n_fft; int n_fft = filters.n_fft;
int i = ith; int i = ith;
@ -3088,7 +3076,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
} }
// FFT // FFT
fft(fft_in, fft_out); fft(fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers // Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.