ref #52 : improve greedy sampling strategy

Force timestamp token to be sampled if the probability sum over all
timestamp tokens is above the probability of any other token
This commit is contained in:
Georgi Gerganov 2022-10-18 19:33:10 +03:00
parent 632660abb9
commit 7eeef0358a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 26 additions and 15 deletions

View File

@ -1784,7 +1784,7 @@ bool whisper_decode(
// the most basic sampling scheme - select the top token
whisper_vocab::id whisper_sample_best(
const whisper_vocab & vocab,
const float * probs, bool need_timestamp) {
const float * probs) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
probs_id.push_back(std::make_pair(probs[i], i));
}
const int top_k = 4;
double sum_ts = 0.0;
double max_tx = 0.0;
for (int i = 0; i < vocab.token_beg; i++) {
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab.token_beg; i < n_logits; i++) {
sum_ts += probs_id[i].first;
}
// if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
}
}
// find the top K tokens
const int top_k = 4;
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
if (need_timestamp) {
// at the end of the 30-second audio segment, we start giving preference to time tokens
for (int i = 0; i < top_k; i++) {
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
return probs_id[i].second;
}
}
}
int res = 0;
while ((probs_id[res].second == vocab.token_sot ||
probs_id[res].second == vocab.token_solm ||
@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return 0;
}
whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
whisper_token whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2437,7 +2448,7 @@ int whisper_full(
whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx);
id = whisper_sample_best(ctx, result_len == 0);
id = whisper_sample_best(ctx);
if (i > 0) {
tid = whisper_sample_timestamp(ctx);
}

View File

@ -120,7 +120,7 @@ extern "C" {
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found