mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-02 19:39:44 +01:00
Wrong implementation of carry_initial_prompt
This commit is contained in:
parent
53c9a3a984
commit
4c60e6b0c1
@ -482,13 +482,14 @@ extern "C" {
|
||||
int duration_ms; // audio duration to process in ms
|
||||
|
||||
bool translate;
|
||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
bool no_timestamps; // do not generate timestamps
|
||||
bool single_segment; // force single segment output (useful for streaming)
|
||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
bool print_progress; // print progress information
|
||||
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
|
||||
bool print_timestamps; // print timestamps for each text segment when printing realtime
|
||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
bool carry_initial_prompt; // carry the initial prompt to the next call
|
||||
bool no_timestamps; // do not generate timestamps
|
||||
bool single_segment; // force single segment output (useful for streaming)
|
||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
bool print_progress; // print progress information
|
||||
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
|
||||
bool print_timestamps; // print timestamps for each text segment when printing realtime
|
||||
|
||||
// [EXPERIMENTAL] token-level timestamps
|
||||
bool token_timestamps; // enable token-level timestamps
|
||||
|
@ -4646,14 +4646,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.offset_ms =*/ 0,
|
||||
/*.duration_ms =*/ 0,
|
||||
|
||||
/*.translate =*/ false,
|
||||
/*.no_context =*/ true,
|
||||
/*.no_timestamps =*/ false,
|
||||
/*.single_segment =*/ false,
|
||||
/*.print_special =*/ false,
|
||||
/*.print_progress =*/ true,
|
||||
/*.print_realtime =*/ false,
|
||||
/*.print_timestamps =*/ true,
|
||||
/*.translate =*/ false,
|
||||
/*.no_context =*/ true,
|
||||
/*.carry_initial_prompt =*/ false,
|
||||
/*.no_timestamps =*/ false,
|
||||
/*.single_segment =*/ false,
|
||||
/*.print_special =*/ false,
|
||||
/*.print_progress =*/ true,
|
||||
/*.print_realtime =*/ false,
|
||||
/*.print_timestamps =*/ true,
|
||||
|
||||
/*.token_timestamps =*/ false,
|
||||
/*.thold_pt =*/ 0.01f,
|
||||
@ -5454,6 +5455,7 @@ int whisper_full_with_state(
|
||||
prompt_tokens.resize(-n_needed);
|
||||
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
||||
}
|
||||
remaining_prompt_length = ctx->model.hparams.n_text_ctx / 2 - 1 - initial_prompt_tokens.size();
|
||||
prompt_tokens.resize(n_needed);
|
||||
params.prompt_tokens = prompt_tokens.data();
|
||||
params.prompt_n_tokens = prompt_tokens.size();
|
||||
@ -5610,9 +5612,21 @@ int whisper_full_with_state(
|
||||
// init prompt and kv cache for the current iteration
|
||||
// TODO: do not recompute the prompt if it is the same as previous time
|
||||
{
|
||||
prompt.clear();
|
||||
// LLMs think we should add this if block here
|
||||
if (params.carry_initial_prompt) {
|
||||
// Prepend initial_prompt_tokens to the prompt
|
||||
int nignored = std::max((int)initial_prompt_tokens.size(), prompt_past.size());
|
||||
std::vector<whisper_token> remaining_prompt(prompt_past.begin() + nignored, prompt_past.end());
|
||||
remaining_prompt.resize(std::min(remaining_prompt.size(), remaining_prompt_length));
|
||||
prompt.clear();
|
||||
prompt.insert(prompt.end(), initial_prompt_tokens.begin(), initial_prompt_tokens.end());
|
||||
prompt.insert(prompt.end(), remaining_prompt.begin(), remaining_prompt.end());
|
||||
} else {
|
||||
prompt.clear();
|
||||
}
|
||||
|
||||
// if we have already generated some text, use it as a prompt to condition the next generation
|
||||
// But maybe we can put it here?
|
||||
if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
|
||||
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user