mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-26 00:29:21 +01:00
whisper : add whisper_n_audio_ctx and check for invalid audio_ctx
closes #344
This commit is contained in:
parent
3467230a77
commit
d97e6005e9
10
whisper.cpp
10
whisper.cpp
@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
|
||||
return ctx->model.hparams.n_text_ctx;
|
||||
}
|
||||
|
||||
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
||||
return ctx->model.hparams.n_audio_ctx;
|
||||
}
|
||||
|
||||
int whisper_is_multilingual(struct whisper_context * ctx) {
|
||||
return ctx->vocab.is_multilingual() ? 1 : 0;
|
||||
}
|
||||
@ -2822,7 +2826,11 @@ int whisper_full(
|
||||
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
||||
}
|
||||
|
||||
// overwrite audio_ctx
|
||||
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
||||
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
||||
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
||||
return -4;
|
||||
}
|
||||
ctx->exp_n_audio_ctx = params.audio_ctx;
|
||||
|
||||
// these tokens determine the task that will be performed
|
||||
|
@ -177,6 +177,7 @@ extern "C" {
|
||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
||||
|
||||
// The probabilities for the next token
|
||||
|
Loading…
Reference in New Issue
Block a user