From d77603578bb55681357df0c30481828acdac8ea9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Nov 2023 21:04:33 +0200 Subject: [PATCH] whisper : clear kv cache when using whisper_decode API --- whisper.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index a8184546..502ac64d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3516,6 +3516,8 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); + whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1); + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; @@ -3530,6 +3532,8 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return false; } + whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1); + whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0); if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {