From 78d13257be8094a71b65af401d4753281af2205a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Dec 2022 21:51:50 +0200 Subject: [PATCH] Try to improve the token sampling strategy (#193) * whisper : try to improve the token sampling strategy - Add the "max_initial_timestaamp" token logic from OpenAI - Disallow sampling timestamps that are in the past * whisper : fix the max initial timestamp logic + fallback decoding --- whisper.cpp | 97 +++++++++++++++++++++++++---------------------------- whisper.h | 2 +- 2 files changed, 46 insertions(+), 53 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index fbcb5d1..42467ef 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1846,7 +1846,9 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( const whisper_vocab & vocab, - const float * probs) { + const float * probs, + bool force_timestamp, + bool is_initial) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; @@ -1869,7 +1871,18 @@ static whisper_token_data whisper_sample_best( max_tx = std::max(max_tx, probs_id[i].first); } - for (int i = vocab.token_beg; i < n_logits; i++) { + const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; + const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits; + + // the initial timestamp cannot be larger than 100 + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial) { + for (int i = i0; i < n_logits; ++ i) { + probs_id[i].first = -INFINITY; + } + } + + for (int i = vocab.token_beg; i < i1; i++) { sum_ts += probs_id[i].first; if (probs_id[i].first > max_ts) { max_ts = probs_id[i].first; @@ -1879,7 +1892,7 @@ static whisper_token_data whisper_sample_best( // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a // timestamp token - if (sum_ts > max_tx) { + if (sum_ts > max_tx || force_timestamp) { // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 for (int i = 0; i < vocab.token_beg; i++) { probs_id[i].first = -INFINITY; @@ -1921,39 +1934,6 @@ static whisper_token_data whisper_sample_best( return result; } -// samples only from the timestamps tokens -static whisper_vocab::id whisper_sample_timestamp( - const whisper_vocab & vocab, - const float * probs) { - int n_logits = vocab.id_to_token.size(); - - std::vector> probs_id; - probs_id.reserve(n_logits); - - for (int i = vocab.token_beg + 1; i < n_logits; i++) { - probs_id.push_back(std::make_pair(probs[i], i)); - } - - const int top_k = 10; - - // find the top K tokens - std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - probs_id.resize(top_k); - - //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); - //} - - return probs_id[0].second; -} - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2284,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); - // TODO: simplify - auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; return res; } -whisper_token whisper_sample_timestamp(struct whisper_context * ctx) { +struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); - // TODO: simplify - auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)); + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2694,7 +2672,6 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - bool done = false; int seek_delta = 100*WHISPER_CHUNK_SIZE; // print the prompt @@ -2708,7 +2685,9 @@ int whisper_full( int result_len = 0; tokens_cur.clear(); - for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) { + bool failed = false; + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); return 8; @@ -2725,15 +2704,19 @@ int whisper_full( // feel free to experiment! // { - auto token = whisper_sample_best(ctx); - - if (i == 0) { - token.tid = whisper_token_beg(ctx); - } + const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { - seek_delta = 2*(token.id - whisper_token_beg(ctx)); + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (seek_delta != 100*WHISPER_CHUNK_SIZE && + seek_delta > seek_delta_new && result_len < i) { + break; + } + + seek_delta = seek_delta_new; result_len = i + 1; } @@ -2752,8 +2735,8 @@ int whisper_full( if (seek + seek_delta + 100 >= seek_end) { result_len = i + 1; } else { - // TODO: figure out how to resolve this - fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__); + failed = true; + break; } } @@ -2772,11 +2755,21 @@ int whisper_full( } } - if (done) { + // sometimes, the decoding can get stuck in a repetition loop + // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance + // the sliding window by 1 second + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; break; } } + if (failed) { + fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__); + seek += 100; + continue; + } + // shrink down to result_len tokens_cur.resize(result_len); diff --git a/whisper.h b/whisper.h index 156edbb..def77d4 100644 --- a/whisper.h +++ b/whisper.h @@ -137,7 +137,7 @@ extern "C" { // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); + WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); // Return the id of the specified language, returns -1 if not found WHISPER_API int whisper_lang_id(const char * lang);