diff --git a/include/whisper.h b/include/whisper.h index 1e137503..eaf48d67 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -482,13 +482,14 @@ extern "C" { int duration_ms; // audio duration to process in ms bool translate; - bool no_context; // do not use past transcription (if any) as initial prompt for the decoder - bool no_timestamps; // do not generate timestamps - bool single_segment; // force single segment output (useful for streaming) - bool print_special; // print special tokens (e.g. , , , etc.) - bool print_progress; // print progress information - bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) - bool print_timestamps; // print timestamps for each text segment when printing realtime + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool carry_initial_prompt; // carry the initial prompt to the next call + bool no_timestamps; // do not generate timestamps + bool single_segment; // force single segment output (useful for streaming) + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime // [EXPERIMENTAL] token-level timestamps bool token_timestamps; // enable token-level timestamps diff --git a/src/whisper.cpp b/src/whisper.cpp index c633765e..7b6d2ad7 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -4793,14 +4793,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.offset_ms =*/ 0, /*.duration_ms =*/ 0, - /*.translate =*/ false, - /*.no_context =*/ true, - /*.no_timestamps =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.translate =*/ false, + /*.no_context =*/ true, + /*.carry_initial_prompt =*/ false, + /*.no_timestamps =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, /*.token_timestamps =*/ false, /*.thold_pt =*/ 0.01f, @@ -5601,6 +5602,7 @@ int whisper_full_with_state( prompt_tokens.resize(-n_needed); n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); } + remaining_prompt_length = ctx->model.hparams.n_text_ctx / 2 - 1 - initial_prompt_tokens.size(); prompt_tokens.resize(n_needed); params.prompt_tokens = prompt_tokens.data(); params.prompt_n_tokens = prompt_tokens.size(); @@ -5757,9 +5759,21 @@ int whisper_full_with_state( // init prompt and kv cache for the current iteration // TODO: do not recompute the prompt if it is the same as previous time { - prompt.clear(); + // LLMs think we should add this if block here + if (params.carry_initial_prompt) { + // Prepend initial_prompt_tokens to the prompt + int nignored = std::max((int)initial_prompt_tokens.size(), prompt_past.size()); + std::vector remaining_prompt(prompt_past.begin() + nignored, prompt_past.end()); + remaining_prompt.resize(std::min(remaining_prompt.size(), remaining_prompt_length)); + prompt.clear(); + prompt.insert(prompt.end(), initial_prompt_tokens.begin(), initial_prompt_tokens.end()); + prompt.insert(prompt.end(), remaining_prompt.begin(), remaining_prompt.end()); + } else { + prompt.clear(); + } // if we have already generated some text, use it as a prompt to condition the next generation + // But maybe we can put it here? if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));