From d97e6005e95f31ff812f72cd2cad3347080d1520 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 31 Dec 2022 09:55:33 +0200 Subject: [PATCH] whisper : add whisper_n_audio_ctx and check for invalid audio_ctx closes #344 --- whisper.cpp | 10 +++++++++- whisper.h | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/whisper.cpp b/whisper.cpp index d23e97fe..84c24900 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) { return ctx->model.hparams.n_text_ctx; } +int whisper_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + int whisper_is_multilingual(struct whisper_context * ctx) { return ctx->vocab.is_multilingual() ? 1 : 0; } @@ -2822,7 +2826,11 @@ int whisper_full( std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); } - // overwrite audio_ctx + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -4; + } ctx->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed diff --git a/whisper.h b/whisper.h index 92c14da0..e36b761f 100644 --- a/whisper.h +++ b/whisper.h @@ -177,6 +177,7 @@ extern "C" { WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); // The probabilities for the next token