mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-13 17:38:36 +01:00
whisper : add whisper_n_audio_ctx and check for invalid audio_ctx
closes #344
This commit is contained in:
parent
3467230a77
commit
d97e6005e9
10
whisper.cpp
10
whisper.cpp
@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
|
|||||||
return ctx->model.hparams.n_text_ctx;
|
return ctx->model.hparams.n_text_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
||||||
|
return ctx->model.hparams.n_audio_ctx;
|
||||||
|
}
|
||||||
|
|
||||||
int whisper_is_multilingual(struct whisper_context * ctx) {
|
int whisper_is_multilingual(struct whisper_context * ctx) {
|
||||||
return ctx->vocab.is_multilingual() ? 1 : 0;
|
return ctx->vocab.is_multilingual() ? 1 : 0;
|
||||||
}
|
}
|
||||||
@ -2822,7 +2826,11 @@ int whisper_full(
|
|||||||
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// overwrite audio_ctx
|
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
||||||
|
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
||||||
|
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
||||||
|
return -4;
|
||||||
|
}
|
||||||
ctx->exp_n_audio_ctx = params.audio_ctx;
|
ctx->exp_n_audio_ctx = params.audio_ctx;
|
||||||
|
|
||||||
// these tokens determine the task that will be performed
|
// these tokens determine the task that will be performed
|
||||||
|
@ -177,6 +177,7 @@ extern "C" {
|
|||||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||||
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
||||||
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
||||||
|
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
||||||
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
||||||
|
|
||||||
// The probabilities for the next token
|
// The probabilities for the next token
|
||||||
|
Loading…
Reference in New Issue
Block a user