mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-20 10:38:49 +02:00
whisper : add grammar-based sampling
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "whisper.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
@@ -32,6 +33,7 @@ struct whisper_params {
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
@@ -44,6 +46,7 @@ struct whisper_params {
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@@ -73,6 +76,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@@ -106,6 +111,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@@ -115,6 +122,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
std::vector<const whisper_grammar_element *> grammar_rules;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
@@ -131,6 +141,15 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
grammar_rules = parsed_grammar.c_rules();
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = parsed_grammar.symbol_ids.at("root");
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
@@ -648,12 +667,26 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
if (!params.grammar.empty()) {
|
||||
auto parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parsed_grammar.rules.empty()) {
|
||||
ret_val = 1;
|
||||
} else {
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, parsed_grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
Reference in New Issue
Block a user