mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-17 17:02:00 +02:00
command : enable beam-search, add "no_timestamps", add "context", add p
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -11,6 +11,7 @@ build/
|
|||||||
build-em/
|
build-em/
|
||||||
build-debug/
|
build-debug/
|
||||||
build-release/
|
build-release/
|
||||||
|
build-rwdi/
|
||||||
build-static/
|
build-static/
|
||||||
build-cublas/
|
build-cublas/
|
||||||
build-no-accel/
|
build-no-accel/
|
||||||
|
@ -54,6 +54,7 @@ struct whisper_params {
|
|||||||
std::string fname_out;
|
std::string fname_out;
|
||||||
std::string commands;
|
std::string commands;
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
std::string context;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -84,6 +85,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||||
|
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||||
else {
|
else {
|
||||||
@ -119,6 +121,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||||
|
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
@ -129,15 +132,19 @@ std::string transcribe(
|
|||||||
const whisper_params & params,
|
const whisper_params & params,
|
||||||
const std::vector<float> & pcmf32,
|
const std::vector<float> & pcmf32,
|
||||||
const std::string & grammar_rule,
|
const std::string & grammar_rule,
|
||||||
float & prob,
|
float & logprob_min,
|
||||||
|
float & logprob_sum,
|
||||||
|
int & n_tokens,
|
||||||
int64_t & t_ms) {
|
int64_t & t_ms) {
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
prob = 0.0f;
|
logprob_min = 0.0f;
|
||||||
|
logprob_sum = 0.0f;
|
||||||
|
n_tokens = 0;
|
||||||
t_ms = 0;
|
t_ms = 0;
|
||||||
|
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||||
|
|
||||||
wparams.print_progress = false;
|
wparams.print_progress = false;
|
||||||
wparams.print_special = params.print_special;
|
wparams.print_special = params.print_special;
|
||||||
@ -145,6 +152,7 @@ std::string transcribe(
|
|||||||
wparams.print_timestamps = !params.no_timestamps;
|
wparams.print_timestamps = !params.no_timestamps;
|
||||||
wparams.translate = params.translate;
|
wparams.translate = params.translate;
|
||||||
wparams.no_context = true;
|
wparams.no_context = true;
|
||||||
|
wparams.no_timestamps = params.no_timestamps;
|
||||||
wparams.single_segment = true;
|
wparams.single_segment = true;
|
||||||
wparams.max_tokens = params.max_tokens;
|
wparams.max_tokens = params.max_tokens;
|
||||||
wparams.language = params.language.c_str();
|
wparams.language = params.language.c_str();
|
||||||
@ -153,12 +161,18 @@ std::string transcribe(
|
|||||||
wparams.audio_ctx = params.audio_ctx;
|
wparams.audio_ctx = params.audio_ctx;
|
||||||
wparams.speed_up = params.speed_up;
|
wparams.speed_up = params.speed_up;
|
||||||
|
|
||||||
//wparams.initial_prompt = params.prompt.data();
|
wparams.temperature = 0.4f;
|
||||||
|
wparams.temperature_inc = 1.0f;
|
||||||
|
wparams.greedy.best_of = 5;
|
||||||
|
|
||||||
|
wparams.beam_search.beam_size = 5;
|
||||||
|
|
||||||
|
wparams.initial_prompt = params.context.data();
|
||||||
|
|
||||||
const auto & grammar_parsed = params.grammar_parsed;
|
const auto & grammar_parsed = params.grammar_parsed;
|
||||||
auto grammar_rules = grammar_parsed.c_rules();
|
auto grammar_rules = grammar_parsed.c_rules();
|
||||||
|
|
||||||
if (!params.grammar_parsed.rules.empty()) {
|
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
|
||||||
wparams.grammar_rules = grammar_rules.data();
|
wparams.grammar_rules = grammar_rules.data();
|
||||||
wparams.n_grammar_rules = grammar_rules.size();
|
wparams.n_grammar_rules = grammar_rules.size();
|
||||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
||||||
@ -169,7 +183,6 @@ std::string transcribe(
|
|||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
int prob_n = 0;
|
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
@ -178,19 +191,17 @@ std::string transcribe(
|
|||||||
|
|
||||||
result += text;
|
result += text;
|
||||||
|
|
||||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
const int n = whisper_full_n_tokens(ctx, i);
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n; ++j) {
|
||||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||||
|
|
||||||
prob += token.p;
|
if(token.plog > 0.0f) exit(0);
|
||||||
++prob_n;
|
logprob_min = std::min(logprob_min, token.plog);
|
||||||
|
logprob_sum += token.plog;
|
||||||
|
++n_tokens;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prob_n > 0) {
|
|
||||||
prob /= prob_n;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||||
|
|
||||||
@ -449,7 +460,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|||||||
bool is_running = true;
|
bool is_running = true;
|
||||||
bool ask_prompt = true;
|
bool ask_prompt = true;
|
||||||
|
|
||||||
float prob = 0.0f;
|
float logprob_min = 0.0f;
|
||||||
|
float logprob_sum = 0.0f;
|
||||||
|
int n_tokens = 0;
|
||||||
|
|
||||||
std::vector<float> pcmf32_cur;
|
std::vector<float> pcmf32_cur;
|
||||||
|
|
||||||
@ -487,7 +500,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|||||||
// detect the commands
|
// detect the commands
|
||||||
audio.get(params.command_ms, pcmf32_cur);
|
audio.get(params.command_ms, pcmf32_cur);
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", prob, t_ms));
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||||
|
|
||||||
const auto words = get_words(txt);
|
const auto words = get_words(txt);
|
||||||
|
|
||||||
@ -528,8 +541,14 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
bool have_prompt = false;
|
bool have_prompt = false;
|
||||||
bool ask_prompt = true;
|
bool ask_prompt = true;
|
||||||
|
|
||||||
float prob0 = 0.0f;
|
float logprob_min0 = 0.0f;
|
||||||
float prob = 0.0f;
|
float logprob_min = 0.0f;
|
||||||
|
|
||||||
|
float logprob_sum0 = 0.0f;
|
||||||
|
float logprob_sum = 0.0f;
|
||||||
|
|
||||||
|
int n_tokens0 = 0;
|
||||||
|
int n_tokens = 0;
|
||||||
|
|
||||||
std::vector<float> pcmf32_cur;
|
std::vector<float> pcmf32_cur;
|
||||||
std::vector<float> pcmf32_prompt;
|
std::vector<float> pcmf32_prompt;
|
||||||
@ -570,9 +589,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
// wait for activation phrase
|
// wait for activation phrase
|
||||||
audio.get(params.prompt_ms, pcmf32_cur);
|
audio.get(params.prompt_ms, pcmf32_cur);
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob0, t_ms));
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
|
||||||
|
|
||||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
const float p = 100.0f * std::exp(logprob_min0);
|
||||||
|
|
||||||
|
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
|
||||||
|
|
||||||
const float sim = similarity(txt, k_prompt);
|
const float sim = similarity(txt, k_prompt);
|
||||||
|
|
||||||
@ -602,9 +623,10 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
// prepend the prompt audio
|
// prepend the prompt audio
|
||||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob, t_ms));
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||||
|
|
||||||
prob = 100.0f*(prob - prob0);
|
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
|
||||||
|
const float p = 100.0f * std::exp(logprob_min);
|
||||||
|
|
||||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||||
|
|
||||||
@ -628,7 +650,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str());
|
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||||
if (best_len == 0) {
|
if (best_len == 0) {
|
||||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||||
} else {
|
} else {
|
||||||
|
@ -23,14 +23,14 @@
|
|||||||
#
|
#
|
||||||
# example:
|
# example:
|
||||||
#
|
#
|
||||||
# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "Ok Whisper, start listening for commands"
|
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
|
||||||
#
|
#
|
||||||
|
|
||||||
root ::= init " " (command | question) "."
|
root ::= init " " (command | question) "."
|
||||||
prompt ::= init "."
|
prompt ::= init
|
||||||
|
|
||||||
# leading space is very important!
|
# leading space is very important!
|
||||||
init ::= " Ok Whisper, start listening for commands"
|
init ::= " Ok Whisper, start listening for commands."
|
||||||
|
|
||||||
command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
|
command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
|
||||||
"Increase " device " by " value | "Decrease " device " by " value |
|
"Increase " device " by " value | "Decrease " device " by " value |
|
||||||
@ -54,4 +54,4 @@ media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
|
|||||||
|
|
||||||
task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
|
task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
|
||||||
|
|
||||||
time ::= [0-9] [0-9]? ":" [0-9] [0-9] ("am" | "pm")?
|
time ::= [0-9] [0-9]? ("am" | "pm")?
|
||||||
|
@ -5,18 +5,24 @@
|
|||||||
# - c3 queen to d4 king b1
|
# - c3 queen to d4 king b1
|
||||||
# - pawn to a1 bishop to b2 knight to c3
|
# - pawn to a1 bishop to b2 knight to c3
|
||||||
#
|
#
|
||||||
|
# The prompt (--prompt) is the initial phrase that the user has to say.
|
||||||
|
# This is used to prime Whisper with how the user is expected to speak.
|
||||||
|
#
|
||||||
|
# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
|
||||||
|
# Longer context is better, but it slightly increases the processing time.
|
||||||
|
#
|
||||||
# example:
|
# example:
|
||||||
#
|
#
|
||||||
# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6"
|
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
|
||||||
#
|
#
|
||||||
|
|
||||||
root ::= init (move? move? move? ".")
|
root ::= init move move? move? "."
|
||||||
prompt ::= init "."
|
prompt ::= init "."
|
||||||
|
|
||||||
# leading space is very important!
|
# leading space is very important!
|
||||||
init ::= " pawn knight king a1 f5 h6"
|
init ::= " rook to b4, f3"
|
||||||
|
|
||||||
move ::= " " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
|
move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
|
||||||
|
|
||||||
piece ::= "bishop" | "rook" | "knight" | "queen"
|
piece ::= "bishop" | "rook" | "knight" | "queen"
|
||||||
king ::= "king"
|
king ::= "king"
|
||||||
|
@ -1,20 +1,16 @@
|
|||||||
# - red
|
# - red
|
||||||
# - green
|
# - green
|
||||||
# - blue
|
# - blue
|
||||||
# - red green
|
|
||||||
# - red blue
|
|
||||||
# - green red
|
|
||||||
# - green blue green
|
|
||||||
#
|
#
|
||||||
# example:
|
# example:
|
||||||
#
|
#
|
||||||
# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue"
|
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
|
||||||
#
|
#
|
||||||
|
|
||||||
root ::= init color (color)? (color)? "."
|
root ::= init color "."
|
||||||
prompt ::= init "."
|
prompt ::= init "."
|
||||||
|
|
||||||
# leading space is very important!
|
# leading space is very important!
|
||||||
init ::= " red green blue"
|
init ::= " red, green, blue"
|
||||||
|
|
||||||
color ::= " " ("red" | "green" | "blue")
|
color ::= ", " ("red" | "green" | "blue")
|
||||||
|
45
whisper.cpp
45
whisper.cpp
@ -3872,13 +3872,13 @@ static void whisper_suppress_invalid_grammar(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool allow_eot = false;
|
//bool allow_eot = false;
|
||||||
for (const auto & stack : grammar.stacks) {
|
//for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
// if (stack.empty()) {
|
||||||
allow_eot = true;
|
// allow_eot = true;
|
||||||
break;
|
// break;
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
const whisper_token eot = whisper_token_eot(&ctx);
|
const whisper_token eot = whisper_token_eot(&ctx);
|
||||||
|
|
||||||
@ -3900,9 +3900,9 @@ static void whisper_suppress_invalid_grammar(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// when the grammar allows a continuation, we penalize the end-of-text token
|
// when the grammar allows a continuation, we penalize the end-of-text token
|
||||||
if (!allow_eot) {
|
//if (!allow_eot) {
|
||||||
logits[eot] -= params.grammar_penalty;
|
// logits[eot] -= params.grammar_penalty;
|
||||||
}
|
//}
|
||||||
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3955,6 +3955,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
|
|
||||||
/*.translate =*/ false,
|
/*.translate =*/ false,
|
||||||
/*.no_context =*/ true,
|
/*.no_context =*/ true,
|
||||||
|
/*.no_timestamps =*/ false,
|
||||||
/*.single_segment =*/ false,
|
/*.single_segment =*/ false,
|
||||||
/*.print_special =*/ false,
|
/*.print_special =*/ false,
|
||||||
/*.print_progress =*/ true,
|
/*.print_progress =*/ true,
|
||||||
@ -4170,6 +4171,11 @@ static void whisper_process_logits(
|
|||||||
// suppress <|notimestamps|> token
|
// suppress <|notimestamps|> token
|
||||||
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
|
||||||
logits[vocab.token_not] = -INFINITY;
|
logits[vocab.token_not] = -INFINITY;
|
||||||
|
if (params.no_timestamps) {
|
||||||
|
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
||||||
|
logits[i] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// suppress sot and nosp tokens
|
// suppress sot and nosp tokens
|
||||||
logits[vocab.token_sot] = -INFINITY;
|
logits[vocab.token_sot] = -INFINITY;
|
||||||
@ -4515,8 +4521,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|||||||
ptsum = sum_ts;
|
ptsum = sum_ts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
|
|
||||||
for (int i = 0; i < k; ++i) {
|
for (int i = 0; i < k; ++i) {
|
||||||
const auto id = logits_id[i].second;
|
const auto id = dist(state.rng);
|
||||||
|
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
||||||
|
|
||||||
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
||||||
|
|
||||||
@ -4726,7 +4735,7 @@ int whisper_full_with_state(
|
|||||||
state->exp_n_audio_ctx = params.audio_ctx;
|
state->exp_n_audio_ctx = params.audio_ctx;
|
||||||
|
|
||||||
// these tokens determine the task that will be performed
|
// these tokens determine the task that will be performed
|
||||||
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
|
||||||
if (whisper_is_multilingual(ctx)) {
|
if (whisper_is_multilingual(ctx)) {
|
||||||
const int lang_id = whisper_lang_id(params.language);
|
const int lang_id = whisper_lang_id(params.language);
|
||||||
state->lang_id = lang_id;
|
state->lang_id = lang_id;
|
||||||
@ -4737,6 +4746,9 @@ int whisper_full_with_state(
|
|||||||
prompt_init.push_back(whisper_token_transcribe(ctx));
|
prompt_init.push_back(whisper_token_transcribe(ctx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (params.no_timestamps) {
|
||||||
|
prompt_init.push_back(whisper_token_not(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
int seek = seek_start;
|
int seek = seek_start;
|
||||||
|
|
||||||
@ -4821,7 +4833,7 @@ int whisper_full_with_state(
|
|||||||
|
|
||||||
n_decoders_cur = std::max(1, n_decoders_cur);
|
n_decoders_cur = std::max(1, n_decoders_cur);
|
||||||
|
|
||||||
WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
|
WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
|
||||||
|
|
||||||
// TAGS: WHISPER_DECODER_INIT
|
// TAGS: WHISPER_DECODER_INIT
|
||||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||||
@ -4978,8 +4990,15 @@ int whisper_full_with_state(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cur_c >= beam_candidates.size()) {
|
||||||
|
cur_c = 0;
|
||||||
|
}
|
||||||
|
|
||||||
auto & cur = beam_candidates[cur_c++];
|
auto & cur = beam_candidates[cur_c++];
|
||||||
|
|
||||||
|
// TODO: test if this is better:
|
||||||
|
//while (beam_candidates.size() > cur_c && i > 0) {
|
||||||
|
|
||||||
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
|
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
|
||||||
++cur_c;
|
++cur_c;
|
||||||
}
|
}
|
||||||
|
@ -389,6 +389,7 @@ extern "C" {
|
|||||||
|
|
||||||
bool translate;
|
bool translate;
|
||||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||||
|
bool no_timestamps; // do not generate timestamps
|
||||||
bool single_segment; // force single segment output (useful for streaming)
|
bool single_segment; // force single segment output (useful for streaming)
|
||||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||||
bool print_progress; // print progress information
|
bool print_progress; // print progress information
|
||||||
|
Reference in New Issue
Block a user