From 700898e6edc14cb29dff3334ad6823cbb93358c8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 5 Oct 2022 23:44:10 +0300 Subject: [PATCH] ref #22 : add option to provide multiple input .wav files --- README.md | 31 ++++----- main.cpp | 198 +++++++++++++++++++++++++++++------------------------- 2 files changed, 121 insertions(+), 108 deletions(-) diff --git a/README.md b/README.md index 9d5685d8..a73dfb8a 100644 --- a/README.md +++ b/README.md @@ -31,13 +31,12 @@ For a quick demo, simply run `make base.en`: ```java $ make base.en - -gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c -g++ -pthread -O3 -std=c++11 -c main.cpp -g++ -pthread -o main ggml.o main.o +cc -O3 -std=c11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread -c ggml.c +c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread -c whisper.cpp +c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread main.cpp whisper.o ggml.o -o main ./main -h -usage: ./main [options] +usage: ./main [options] file0.wav file1.wav ... options: -h, --help show this help message and exit @@ -49,11 +48,11 @@ options: -nt, --no_timestamps do not print timestamps -l LANG, --language LANG spoken language (default: en) -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin) - -f FNAME, --file FNAME input WAV file path (default: samples/jfk.wav) + -f FNAME, --file FNAME input WAV file path bash ./download-ggml-model.sh base.en Downloading ggml model base.en ... -models/ggml-base.en.bin 100%[=====================================>] 141.11M 8.58MB/s in 22s +models/ggml-base.en.bin 100%[===================================>] 141.11M 6.49MB/s in 23s Done! Model 'base.en' saved in 'models/ggml-base.en.bin' You can now use it like this: @@ -86,20 +85,18 @@ whisper_model_load: adding 1607 extra tokens whisper_model_load: ggml ctx size = 163.43 MB whisper_model_load: memory size = 22.83 MB whisper_model_load: model size = 140.54 MB -log_mel_spectrogram: n_sample = 176000, n_len = 1100 -log_mel_spectrogram: recording length: 11.000000 s -main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe, timestamps = 1 ... +main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, lang = en, task = transcribe, timestamps = 1 ... -[00:00.000 --> 00:11.000] And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country. +[00:00.000 --> 00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. -main: load time = 82.05 ms -main: mel time = 44.15 ms -main: sample time = 1.98 ms -main: encode time = 674.77 ms / 112.46 ms per layer -main: decode time = 82.91 ms -main: total time = 886.29 ms +whisper_print_timings: load time = 77.48 ms +whisper_print_timings: mel time = 26.10 ms +whisper_print_timings: sample time = 2.19 ms +whisper_print_timings: encode time = 632.95 ms / 105.49 ms per layer +whisper_print_timings: decode time = 85.11 ms / 14.18 ms per layer +whisper_print_timings: total time = 824.14 ms ``` The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`. diff --git a/main.cpp b/main.cpp index 6d1c55da..d363ab7e 100644 --- a/main.cpp +++ b/main.cpp @@ -36,7 +36,8 @@ struct whisper_params { std::string language = "en"; std::string model = "models/ggml-base.en.bin"; - std::string fname_inp = "samples/jfk.wav"; + + std::vector fname_inp = {}; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -45,6 +46,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + if (arg == "-s" || arg == "--seed") { params.seed = std::stoi(argv[++i]); } else if (arg == "-t" || arg == "--threads") { @@ -67,7 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { - params.fname_inp = argv[++i]; + params.fname_inp.push_back(argv[++i]); } else if (arg == "-h" || arg == "--help") { whisper_print_usage(argc, argv, params); exit(0); @@ -83,7 +89,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { void whisper_print_usage(int argc, char ** argv, const whisper_params & params) { fprintf(stderr, "\n"); - fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); @@ -95,7 +101,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n"); fprintf(stderr, "\n"); } @@ -110,106 +116,116 @@ int main(int argc, char ** argv) { params.seed = time(NULL); } + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + whisper_print_usage(argc, argv, params); + return 1; + } + // whisper init struct whisper_context * ctx = whisper_init(params.model.c_str()); - // WAV input - std::vector pcmf32; - { - drwav wav; - if (!drwav_init_file(&wav, params.fname_inp.c_str(), NULL)) { - fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], params.fname_inp.c_str()); - whisper_print_usage(argc, argv, {}); - return 2; - } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { + const auto fname_inp = params.fname_inp[f]; - if (wav.channels != 1 && wav.channels != 2) { - fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str()); - return 3; - } - - if (wav.sampleRate != WHISPER_SAMPLE_RATE) { - fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str()); - return 4; - } - - if (wav.bitsPerSample != 16) { - fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], params.fname_inp.c_str()); - return 5; - } - - int n = wav.totalPCMFrameCount; - - std::vector pcm16; - pcm16.resize(n*wav.channels); - drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); - drwav_uninit(&wav); - - // convert to mono, float - pcmf32.resize(n); - if (wav.channels == 1) { - for (int i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[i])/32768.0f; + // WAV input + std::vector pcmf32; + { + drwav wav; + if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) { + fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str()); + whisper_print_usage(argc, argv, {}); + return 2; } - } else { - for (int i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); + return 3; + } + + if (wav.sampleRate != WHISPER_SAMPLE_RATE) { + fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); + return 4; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); + return 5; + } + + int n = wav.totalPCMFrameCount; + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (int i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (int i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } } } - } - // print some info about the processing - { - printf("\n"); - if (!whisper_is_multilingual(ctx)) { - if (params.language != "en" || params.translate) { - params.language = "en"; - params.translate = false; - printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); - } - } - printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", - __func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, - params.language.c_str(), - params.translate ? "translate" : "transcribe", - params.no_timestamps ? 0 : 1); - printf("\n"); - } - - // run the inference - { - whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); - - wparams.print_realtime = true; - wparams.print_progress = false; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special_tokens = params.print_special_tokens; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - - if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 6; - } - - // print result; - if (!wparams.print_realtime) { + // print some info about the processing + { printf("\n"); + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + printf("%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", + __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1); + printf("\n"); + } - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); + // run the inference + { + whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); - if (params.no_timestamps) { - printf ("%s", text); - fflush(stdout); - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + wparams.print_realtime = true; + wparams.print_progress = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special_tokens = params.print_special_tokens; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; - printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 6; + } + + // print result; + if (!wparams.print_realtime) { + printf("\n"); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + + if (params.no_timestamps) { + printf ("%s", text); + fflush(stdout); + } else { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + } } } }