Compare commits

...

2 Commits

Author SHA1 Message Date
Georgi Gerganov
673c55c683 whisper : print log when using distilled models 2023-11-05 19:43:04 +02:00
Georgi Gerganov
b8c93c5f3b whisper : add support for new distilled Whisper models 2023-11-03 13:27:08 +02:00

View File

@@ -3942,6 +3942,7 @@ static void whisper_process_logits(
// suppress task tokens // suppress task tokens
logits[vocab.token_translate] = -INFINITY; logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY;
logits[vocab.token_prev] = -INFINITY;
if (params.logits_filter_callback) { if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@@ -4560,6 +4561,7 @@ int whisper_full_with_state(
// these tokens determine the task that will be performed // these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) }; std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) { if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language); const int lang_id = whisper_lang_id(params.language);
state->lang_id = lang_id; state->lang_id = lang_id;
@@ -4571,6 +4573,17 @@ int whisper_full_with_state(
} }
} }
{
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
// distilled models require the "no_timestamps" token
// TODO: add input parameter (#1229)
if (is_distil) {
log("%s: using distilled model - forcing no_timestamps\n", __func__);
prompt_init.push_back(whisper_token_not(ctx));
}
}
int seek = seek_start; int seek = seek_start;
std::vector<whisper_token> prompt; std::vector<whisper_token> prompt;