mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-24 14:58:43 +01:00
ref #52 : improve greedy sampling strategy
Force timestamp token to be sampled if the probability sum over all timestamp tokens is above the probability of any other token
This commit is contained in:
parent
632660abb9
commit
7eeef0358a
39
whisper.cpp
39
whisper.cpp
@ -1784,7 +1784,7 @@ bool whisper_decode(
|
||||
// the most basic sampling scheme - select the top token
|
||||
whisper_vocab::id whisper_sample_best(
|
||||
const whisper_vocab & vocab,
|
||||
const float * probs, bool need_timestamp) {
|
||||
const float * probs) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
||||
@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
|
||||
probs_id.push_back(std::make_pair(probs[i], i));
|
||||
}
|
||||
|
||||
const int top_k = 4;
|
||||
double sum_ts = 0.0;
|
||||
double max_tx = 0.0;
|
||||
|
||||
for (int i = 0; i < vocab.token_beg; i++) {
|
||||
max_tx = std::max(max_tx, probs_id[i].first);
|
||||
}
|
||||
|
||||
for (int i = vocab.token_beg; i < n_logits; i++) {
|
||||
sum_ts += probs_id[i].first;
|
||||
}
|
||||
|
||||
// if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
|
||||
// timestamp token
|
||||
if (sum_ts > max_tx) {
|
||||
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
|
||||
for (int i = 0; i < vocab.token_beg; i++) {
|
||||
probs_id[i].first = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
const int top_k = 4;
|
||||
|
||||
std::partial_sort(
|
||||
probs_id.begin(),
|
||||
probs_id.begin() + top_k, probs_id.end(),
|
||||
@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
|
||||
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
||||
//}
|
||||
|
||||
if (need_timestamp) {
|
||||
// at the end of the 30-second audio segment, we start giving preference to time tokens
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
|
||||
return probs_id[i].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int res = 0;
|
||||
while ((probs_id[res].second == vocab.token_sot ||
|
||||
probs_id[res].second == vocab.token_solm ||
|
||||
@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
||||
return 0;
|
||||
}
|
||||
|
||||
whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
|
||||
whisper_token whisper_sample_best(struct whisper_context * ctx) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// TODO: simplify
|
||||
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
|
||||
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
|
||||
@ -2437,7 +2448,7 @@ int whisper_full(
|
||||
whisper_token id = 0;
|
||||
whisper_token tid = whisper_token_beg(ctx);
|
||||
|
||||
id = whisper_sample_best(ctx, result_len == 0);
|
||||
id = whisper_sample_best(ctx);
|
||||
if (i > 0) {
|
||||
tid = whisper_sample_timestamp(ctx);
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ extern "C" {
|
||||
// You can also implement your own sampling method using the whisper_get_probs() function.
|
||||
// whisper_sample_best() returns the token with the highest probability
|
||||
// whisper_sample_timestamp() returns the most probable timestamp token
|
||||
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
|
||||
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
|
Loading…
Reference in New Issue
Block a user