From b8f34d1ed786194d9787b1f1d086b89136e361f3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Sep 2023 17:05:05 +0300 Subject: [PATCH] whisper : fine-tuning grammar functionality --- examples/command/command.cpp | 18 ++++-- whisper.cpp | 114 ++++++++++++++++++++++++----------- 2 files changed, 92 insertions(+), 40 deletions(-) diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 85789d35..f33f8e15 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -31,8 +31,9 @@ struct whisper_params { int32_t max_tokens = 32; int32_t audio_ctx = 0; - float vad_thold = 0.6f; - float freq_thold = 100.0f; + float vad_thold = 0.6f; + float freq_thold = 100.0f; + float grammar_penalty = 100.0f; bool speed_up = false; @@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; + // disable fallback - seems not useful for command recognition + wparams.temperature_inc = 0.0f; + wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; @@ -508,7 +512,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // general-purpose mode // freely transcribe the voice into text -int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { +int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) { bool is_running = true; bool have_prompt = false; bool ask_prompt = true; @@ -519,7 +523,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud std::vector pcmf32_cur; std::vector pcmf32_prompt; - const std::string k_prompt = "Ok Whisper, start listening for commands."; + //const std::string k_prompt = "Ok Whisper, start listening for commands."; + //const std::string k_prompt = "Начало."; + const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди."; fprintf(stderr, "\n"); fprintf(stderr, "%s: general-purpose mode\n", __func__); @@ -578,6 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + // append 1 second of silence + pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); prob = 100.0f*(prob - prob0); @@ -604,6 +613,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud } } + fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str()); if (best_len == 0) { fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__); } else { diff --git a/whisper.cpp b/whisper.cpp index 078841b3..5e3b86a8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3865,7 +3865,7 @@ static struct whisper_grammar whisper_grammar_init( static void whisper_suppress_invalid_grammar( whisper_context & ctx, const whisper_full_params & params, - std::vector & logits, + std::vector & logprobs, const whisper_grammar & grammar) { if (grammar.rules.empty() || grammar.stacks.empty()) { @@ -3883,8 +3883,8 @@ static void whisper_suppress_invalid_grammar( std::vector, whisper_partial_utf8>> candidates_decoded; std::vector candidates_grammar; - size_t size = logits.size(); - for (whisper_token id = 0; id < size; ++id) { + size_t size = logprobs.size(); + for (whisper_token id = 0; id < (int) size; ++id) { const std::string & text = ctx.vocab.id_to_token[id]; if (!text.empty() && text.rfind("[_", 0) != 0) { candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); @@ -3893,14 +3893,18 @@ static void whisper_suppress_invalid_grammar( } const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + for (const auto & reject : rejects) { - if (logits[reject.id] > 0) { - logits[reject.id] /= params.grammar_penalty; - } else { - logits[reject.id] *= params.grammar_penalty; - } + logprobs[reject.id] -= params.grammar_penalty; } - // fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); + + // when the grammar does not allow any continuation, we don't want to penalize the EOT token + // TODO: is there are better way to do this? + printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2); + if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) { + logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty; + } + //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); } static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { @@ -3908,10 +3912,10 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar return; } - // fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str()); + fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); const std::string & text = ctx.vocab.id_to_token[token]; - + if (text.rfind("[_", 0) == 0) { // fprintf(stderr, " (skipped)\n"); return; @@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.grammar_rules =*/ nullptr, /*.n_grammar_rules =*/ 0, /*.i_start_rule =*/ 0, - /*.grammar_penalty =*/ 1000.0f, + /*.grammar_penalty =*/ 100.0f, }; switch (strategy) { @@ -4181,12 +4185,18 @@ static void whisper_process_logits( logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + // suppress lang tokens + for (size_t i = 0; i < g_lang.size(); ++i) { + logits[whisper_token_lang(&ctx, i)] = -INFINITY; + } + + // suppress prev token + logits[vocab.token_prev] = -INFINITY; + if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } - whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); - // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 if (params.suppress_non_speech_tokens) { @@ -4293,10 +4303,19 @@ static void whisper_process_logits( //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { + //printf("sampling timestamp\n"); for (int i = 0; i < vocab.token_beg; ++i) { logits[i] = -INFINITY; logprobs[i] = -INFINITY; } + } else { + //printf("sampling text\n"); + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + + whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar); } } } @@ -4312,34 +4331,57 @@ static void whisper_process_logits( } } -#if 0 +#if 1 // print first 100 logits - token string : logit - for (int i = 0; i < 100; i++) { - const auto token = vocab.id_to_token.at(i); - const auto prob = probs[i]; - const auto logit = logits[i]; - const auto logprob = logprobs[i]; - printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //for (int i = 0; i < 10; i++) { + // const auto token = vocab.id_to_token.at(i); + // const auto prob = probs[i]; + // const auto logit = logits[i]; + // const auto logprob = logprobs[i]; + // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //} + + // print sorted + { + std::vector> pairs; + + for (int i = 0; i < n_logits; ++i) { + pairs.push_back(std::make_pair(probs[i], i)); + } + + std::sort(pairs.begin(), pairs.end(), [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + + for (int i = 0; i < 10; i++) { + const auto token = vocab.id_to_token.at(pairs[i].second); + const auto prob = pairs[i].first; + const auto logit = logits[pairs[i].second]; + const auto logprob = logprobs[pairs[i].second]; + printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); + } + + printf("----------------\n"); } // "And", "and", " And", " and" - printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); - printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); - printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); - printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); - printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); - printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); - printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); - printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); - printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); - printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); - printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); - printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); - printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); - printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); - printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); + //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); #endif }