Wrong implementation of carry_initial_prompt

This commit is contained in:
Andreas Lubbe 2024-12-30 23:13:32 +01:00
parent 53c9a3a984
commit 4c60e6b0c1
2 changed files with 31 additions and 16 deletions

View File

@ -482,13 +482,14 @@ extern "C" {
int duration_ms; // audio duration to process in ms
bool translate;
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
bool no_timestamps; // do not generate timestamps
bool single_segment; // force single segment output (useful for streaming)
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
bool carry_initial_prompt; // carry the initial prompt to the next call
bool no_timestamps; // do not generate timestamps
bool single_segment; // force single segment output (useful for streaming)
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps

View File

@ -4646,14 +4646,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ true,
/*.no_timestamps =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.translate =*/ false,
/*.no_context =*/ true,
/*.carry_initial_prompt =*/ false,
/*.no_timestamps =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
@ -5454,6 +5455,7 @@ int whisper_full_with_state(
prompt_tokens.resize(-n_needed);
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
}
remaining_prompt_length = ctx->model.hparams.n_text_ctx / 2 - 1 - initial_prompt_tokens.size();
prompt_tokens.resize(n_needed);
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
@ -5610,9 +5612,21 @@ int whisper_full_with_state(
// init prompt and kv cache for the current iteration
// TODO: do not recompute the prompt if it is the same as previous time
{
prompt.clear();
// LLMs think we should add this if block here
if (params.carry_initial_prompt) {
// Prepend initial_prompt_tokens to the prompt
int nignored = std::max((int)initial_prompt_tokens.size(), prompt_past.size());
std::vector<whisper_token> remaining_prompt(prompt_past.begin() + nignored, prompt_past.end());
remaining_prompt.resize(std::min(remaining_prompt.size(), remaining_prompt_length));
prompt.clear();
prompt.insert(prompt.end(), initial_prompt_tokens.begin(), initial_prompt_tokens.end());
prompt.insert(prompt.end(), remaining_prompt.begin(), remaining_prompt.end());
} else {
prompt.clear();
}
// if we have already generated some text, use it as a prompt to condition the next generation
// But maybe we can put it here?
if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));