From f00509d57cc8e208ad2153aff3fe0af924289abc Mon Sep 17 00:00:00 2001 From: Andy Maloney Date: Sat, 31 Dec 2022 07:08:57 -0500 Subject: [PATCH] command : refactor to split command list & general transcription modes (#331) This makes it easier to understand if you're looking for only one of the capabilities. --- examples/command/command.cpp | 623 +++++++++++++++++++---------------- 1 file changed, 331 insertions(+), 292 deletions(-) diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 0ee33067..3ea563ad 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -510,6 +510,333 @@ std::vector read_allowed_commands(const std::string & fname) { return allowed_commands; } +// command-list mode +// guide the transcription to match the most likely command from a provided list +int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { + fprintf(stderr, "\n"); + fprintf(stderr, "%s: guided mode\n", __func__); + + std::vector allowed_commands = read_allowed_commands(params.commands); + + if (allowed_commands.empty()) { + fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str()); + return 2; + } + + int max_len = 0; + + std::vector> allowed_tokens; + + for (const auto & cmd : allowed_commands) { + whisper_token tokens[1024]; + allowed_tokens.emplace_back(); + + for (int l = 0; l < (int) cmd.size(); ++l) { + // NOTE: very important to add the whitespace ! + // the reason is that the first decoded token starts with a whitespace too! + std::string ss = std::string(" ") + cmd.substr(0, l + 1); + + const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024); + if (n < 0) { + fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str()); + return 3; + } + + if (n == 1) { + allowed_tokens.back().push_back(tokens[0]); + } + } + + max_len = std::max(max_len, (int) cmd.size()); + } + + fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__); + fprintf(stderr, "\n"); + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str()); + for (const auto & token : allowed_tokens[i]) { + fprintf(stderr, " %5d", token); + } + fprintf(stderr, " ]\n"); + } + + std::string k_prompt = "select one from the available words: "; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + if (i > 0) { + k_prompt += ", "; + } + k_prompt += allowed_commands[i]; + } + k_prompt += ". selected word: "; + + // tokenize prompt + std::vector k_tokens; + { + k_tokens.resize(1024); + const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024); + if (n < 0) { + fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str()); + return 4; + } + k_tokens.resize(n); + } + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str()); + fprintf(stderr, "%s: tokens: [", __func__); + for (const auto & token : k_tokens) { + fprintf(stderr, " %d", token); + } + fprintf(stderr, " ]\n"); + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: listening for a command ...\n", __func__); + fprintf(stderr, "\n"); + + bool is_running = true; + + std::vector pcmf32_cur; + std::vector pcmf32_prompt; + + // main loop + while (is_running) { + // handle Ctrl + C + { + SDL_Event event; + while (SDL_PollEvent(&event)) { + switch (event.type) { + case SDL_QUIT: + { + is_running = false; + } break; + default: + break; + } + } + + if (!is_running) { + return 0; + } + } + + // delay + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + audio.get(2000, pcmf32_cur); + + if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + + const auto t_start = std::chrono::high_resolution_clock::now(); + + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + wparams.print_progress = false; + wparams.print_special = params.print_special; + wparams.print_realtime = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.translate = params.translate; + wparams.no_context = true; + wparams.single_segment = true; + wparams.max_tokens = 1; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + + wparams.audio_ctx = params.audio_ctx; + wparams.speed_up = params.speed_up; + + wparams.prompt_tokens = k_tokens.data(); + wparams.prompt_n_tokens = k_tokens.size(); + + // run the transformer and a single decoding pass + if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { + fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); + break; + } + + const auto * probs = whisper_get_probs(ctx); + std::vector> probs_id; + + double psum = 0.0; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + probs_id.emplace_back(probs[allowed_tokens[i][0]], i); + for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { + probs_id.back().first += probs[allowed_tokens[i][j]]; + } + probs_id.back().first /= allowed_tokens[i].size(); + psum += probs_id.back().first; + } + + // normalize + for (auto & p : probs_id) { + p.first /= psum; + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // print the commands and the respective probabilities + { + fprintf(stdout, "\n"); + for (const auto & cmd : probs_id) { + fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); + for (int token : allowed_tokens[cmd.second]) { + fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + } + fprintf(stdout, "\n"); + } + } + + // best command + { + const auto t_end = std::chrono::high_resolution_clock::now(); + + const float prob = probs_id[0].first; + const int index = probs_id[0].second; + + fprintf(stdout, "\n"); + fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, + "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, + (int) std::chrono::duration_cast(t_end - t_start).count()); + fprintf(stdout, "\n"); + } + + audio.clear(); + } + } + + return 0; +} + +// general-purpose mode +// freely transcribe the voice into text +int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { + bool is_running = true; + bool have_prompt = false; + bool ask_prompt = true; + + float prob0 = 0.0f; + float prob = 0.0f; + + std::vector pcmf32_cur; + std::vector pcmf32_prompt; + + const std::string k_prompt = "Ok Whisper, start listening for commands."; + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: general-purpose mode\n", __func__); + + // main loop + while (is_running) { + // handle Ctrl + C + { + SDL_Event event; + while (SDL_PollEvent(&event)) { + switch (event.type) { + case SDL_QUIT: + { + is_running = false; + } break; + default: + break; + } + } + + if (!is_running) { + return 0; + } + } + + // delay + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + if (ask_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); + fprintf(stdout, "\n"); + + ask_prompt = false; + } + + { + audio.get(2000, pcmf32_cur); + + if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + + int64_t t_ms = 0; + + if (!have_prompt) { + // wait for activation phrase + audio.get(params.prompt_ms, pcmf32_cur); + + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); + + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + + const float sim = similarity(txt, k_prompt); + + if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { + fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); + ask_prompt = true; + } else { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); + fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); + fprintf(stdout, "\n"); + + // save the audio for the prompt + pcmf32_prompt = pcmf32_cur; + have_prompt = true; + } + } else { + // we have heard the activation phrase, now detect the commands + audio.get(params.command_ms, pcmf32_cur); + + // prepend the prompt audio + pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + + prob = 100.0f*(prob - prob0); + + //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); + + // find the prompt in the text + float best_sim = 0.0f; + size_t best_len = 0; + for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + const auto prompt = txt.substr(0, n); + + const float sim = similarity(prompt, k_prompt); + + //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); + + if (sim > best_sim) { + best_sim = sim; + best_len = n; + } + } + + const std::string command = ::trim(txt.substr(best_len)); + + fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + fprintf(stdout, "\n"); + } + + audio.clear(); + } + } + } + + return 0; +} + int main(int argc, char ** argv) { whisper_params params; @@ -561,300 +888,12 @@ int main(int argc, char ** argv) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); audio.clear(); - int max_len = 0; - - bool is_running = true; - bool have_prompt = false; - bool ask_prompt = true; - - float prob0 = 0.0f; - float prob = 0.0f; - - std::vector pcmf32_cur; - std::vector pcmf32_prompt; - - std::vector allowed_commands; - std::vector> allowed_tokens; - - std::string k_prompt; - std::vector k_tokens; + int ret_val = 0; if (!params.commands.empty()) { - fprintf(stderr, "\n"); - fprintf(stderr, "%s: guided mode\n", __func__); - - allowed_commands = read_allowed_commands(params.commands); - - if (allowed_commands.empty()) { - fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str()); - return 2; - } - - for (const auto & cmd : allowed_commands) { - whisper_token tokens[1024]; - allowed_tokens.emplace_back(); - - for (int l = 0; l < (int) cmd.size(); ++l) { - // NOTE: very important to add the whitespace ! - // the reason is that the first decoded token starts with a whitespace too! - std::string ss = std::string(" ") + cmd.substr(0, l + 1); - - const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024); - if (n < 0) { - fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str()); - return 3; - } - - if (n == 1) { - allowed_tokens.back().push_back(tokens[0]); - } - } - - max_len = std::max(max_len, (int) cmd.size()); - } - - fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__); - fprintf(stderr, "\n"); - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str()); - for (const auto & token : allowed_tokens[i]) { - fprintf(stderr, " %5d", token); - } - fprintf(stderr, " ]\n"); - } - - k_prompt = "select one from the available words: "; - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - if (i > 0) { - k_prompt += ", "; - } - k_prompt += allowed_commands[i]; - } - k_prompt += ". selected word: "; - - // tokenize prompt - { - k_tokens.resize(1024); - const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024); - if (n < 0) { - fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str()); - return 4; - } - k_tokens.resize(n); - } - - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str()); - fprintf(stderr, "%s: tokens: [", __func__); - for (const auto & token : k_tokens) { - fprintf(stderr, " %d", token); - } - fprintf(stderr, " ]\n"); - - fprintf(stderr, "\n"); - fprintf(stderr, "%s: listening for a command ...\n", __func__); - fprintf(stderr, "\n"); - + ret_val = process_command_list(ctx, audio, params); } else { - fprintf(stderr, "\n"); - fprintf(stderr, "%s: general-purpose mode\n", __func__); - - k_prompt = "Ok Whisper, start listening for commands."; - } - - // main loop - while (is_running) { - // handle Ctrl + C - { - SDL_Event event; - while (SDL_PollEvent(&event)) { - switch (event.type) { - case SDL_QUIT: - { - is_running = false; - } break; - default: - break; - } - } - - if (!is_running) { - break; - } - } - - // delay - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - if (allowed_commands.empty()) { - // general-purpose mode - // freely transcribe the voice into text - - if (ask_prompt) { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); - fprintf(stdout, "\n"); - - ask_prompt = false; - } - - { - int64_t t_ms = 0; - - audio.get(2000, pcmf32_cur); - - if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); - - if (!have_prompt) { - // wait for activation phrase - audio.get(params.prompt_ms, pcmf32_cur); - - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); - - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); - - const float sim = similarity(txt, k_prompt); - - if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { - fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); - ask_prompt = true; - } else { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); - fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); - fprintf(stdout, "\n"); - - // save the audio for the prompt - pcmf32_prompt = pcmf32_cur; - have_prompt = true; - } - } else { - // we have heard the activation phrase, now detect the commands - audio.get(params.command_ms, pcmf32_cur); - - // prepend the prompt audio - pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); - - prob = 100.0f*(prob - prob0); - - //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); - - // find the prompt in the text - float best_sim = 0.0f; - size_t best_len = 0; - for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { - const auto prompt = txt.substr(0, n); - - const float sim = similarity(prompt, k_prompt); - - //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); - - if (sim > best_sim) { - best_sim = sim; - best_len = n; - } - } - - const std::string command = ::trim(txt.substr(best_len)); - - fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); - fprintf(stdout, "\n"); - } - - audio.clear(); - } - } - } else { - // command-list mode - // guide the transcription to match the most likely command from a provided list - - audio.get(2000, pcmf32_cur); - - if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); - - const auto t_start = std::chrono::high_resolution_clock::now(); - - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - - wparams.print_progress = false; - wparams.print_special = params.print_special; - wparams.print_realtime = false; - wparams.print_timestamps = !params.no_timestamps; - wparams.translate = params.translate; - wparams.no_context = true; - wparams.single_segment = true; - wparams.max_tokens = 1; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - - wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; - - wparams.prompt_tokens = k_tokens.data(); - wparams.prompt_n_tokens = k_tokens.size(); - - // run the transformer and a single decoding pass - if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) { - fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); - break; - } - - const auto * probs = whisper_get_probs(ctx); - std::vector> probs_id; - - double psum = 0.0; - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - probs_id.emplace_back(probs[allowed_tokens[i][0]], i); - for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { - probs_id.back().first += probs[allowed_tokens[i][j]]; - } - probs_id.back().first /= allowed_tokens[i].size(); - psum += probs_id.back().first; - } - - // normalize - for (auto & p : probs_id) { - p.first /= psum; - } - - // sort descending - { - using pair_type = decltype(probs_id)::value_type; - std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); - } - - // print the commands and the respective probabilities - { - fprintf(stdout, "\n"); - for (const auto & cmd : probs_id) { - fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); - for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) { - fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]); - } - fprintf(stdout, "\n"); - } - } - - // best command - { - const auto t_end = std::chrono::high_resolution_clock::now(); - - fprintf(stdout, "\n"); - fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, - "\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first, - (int) std::chrono::duration_cast(t_end - t_start).count()); - fprintf(stdout, "\n"); - } - - audio.clear(); - } - } + ret_val = process_general_transcription(ctx, audio, params); } audio.pause(); @@ -862,5 +901,5 @@ int main(int argc, char ** argv) { whisper_print_timings(ctx); whisper_free(ctx); - return 0; + return ret_val; }