From b8c93c5f3b30fd506898f2e94fc7e97c4758cc4f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 3 Nov 2023 13:24:05 +0200 Subject: [PATCH] whisper : add support for new distilled Whisper models --- whisper.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index ccac6aaf..602d43b8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3942,6 +3942,7 @@ static void whisper_process_logits( // suppress task tokens logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + logits[vocab.token_prev] = -INFINITY; if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); @@ -4560,6 +4561,7 @@ int whisper_full_with_state( // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; + if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); state->lang_id = lang_id; @@ -4571,6 +4573,16 @@ int whisper_full_with_state( } } + { + const bool is_distil = ctx->model.hparams.n_text_layer == 2; + + // distilled models require the "no_timestamps" token + // TODO: add input parameter (#1229) + if (is_distil) { + prompt_init.push_back(whisper_token_not(ctx)); + } + } + int seek = seek_start; std::vector prompt;