diff --git a/whisper.cpp b/whisper.cpp index 846d3a93..0d677153 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2284,6 +2284,60 @@ static void fft(const std::vector & in, std::vector & out) { } } +static void log_mel_spectrogram_worker_thread(int ith, const std::vector &hann, const float *samples, + int n_samples, int fft_size, int fft_step, int n_threads, + const whisper_filters &filters, bool speed_up, whisper_mel &mel) { + std::vector fft_in(fft_size, 0.0); + std::vector fft_out(2 * fft_size); + int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2); + + for (int i = ith; i < mel.n_len; i += n_threads) { + const int offset = i * fft_step; + + // apply Hanning window + for (int j = 0; j < fft_size; j++) { + if (offset + j < n_samples) { + fft_in[j] = hann[j] * samples[offset + j]; + } else { + fft_in[j] = 0.0; + } + } + + // FFT -> mag^2 + fft(fft_in, fft_out); + + for (int j = 0; j < fft_size; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + for (int j = 1; j < fft_size / 2; j++) { + fft_out[j] += fft_out[fft_size - j]; + } + + if (speed_up) { + // scale down in the frequency domain results in a speed up in the time domain + for (int j = 0; j < n_fft; j++) { + fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]); + } + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + + for (int k = 0; k < n_fft; k++) { + sum += fft_out[k] * filters.data[j * n_fft + k]; + } + if (sum < 1e-10) { + sum = 1e-10; + } + + sum = log10(sum); + + mel.data[j * mel.n_len + i] = sum; + } + } +} + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 static bool log_mel_spectrogram( whisper_state & wstate, @@ -2310,81 +2364,22 @@ static bool log_mel_spectrogram( mel.n_len = (n_samples)/fft_step; mel.data.resize(mel.n_mel*mel.n_len); - const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); - //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); - std::vector workers(n_threads); - for (int iw = 0; iw < n_threads; ++iw) { - workers[iw] = std::thread([&](int ith) { - std::vector fft_in; - fft_in.resize(fft_size); - for (int i = 0; i < fft_size; i++) { - fft_in[i] = 0.0; - } + if (n_threads == 1) { + log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel); + } else { + std::vector workers(n_threads); + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples, + n_samples, fft_size, fft_step, n_threads, + std::cref(filters), speed_up, std::ref(mel)); + } - std::vector fft_out; - fft_out.resize(2*fft_size); - - for (int i = ith; i < mel.n_len; i += n_threads) { - const int offset = i*fft_step; - - // apply Hanning window - for (int j = 0; j < fft_size; j++) { - if (offset + j < n_samples) { - fft_in[j] = hann[j]*samples[offset + j]; - } else { - fft_in[j] = 0.0; - } - } - - // FFT -> mag^2 - fft(fft_in, fft_out); - - for (int j = 0; j < fft_size; j++) { - fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); - } - for (int j = 1; j < fft_size/2; j++) { - //if (i == 0) { - // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); - //} - fft_out[j] += fft_out[fft_size - j]; - } - if (i == 0) { - //for (int j = 0; j < fft_size; j++) { - // printf("%d: %e\n", j, fft_out[j]); - //} - } - - if (speed_up) { - // scale down in the frequency domain results in a speed up in the time domain - for (int j = 0; j < n_fft; j++) { - fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); - } - } - - // mel spectrogram - for (int j = 0; j < mel.n_mel; j++) { - double sum = 0.0; - - for (int k = 0; k < n_fft; k++) { - sum += fft_out[k]*filters.data[j*n_fft + k]; - } - if (sum < 1e-10) { - sum = 1e-10; - } - - sum = log10(sum); - - mel.data[j*mel.n_len + i] = sum; - } - } - }, iw); - } - - for (int iw = 0; iw < n_threads; ++iw) { - workers[iw].join(); + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw].join(); + } } // clamping and normalization