whisper : add grammar-based sampling

This commit is contained in:
Evan Jones
2023-08-29 22:56:29 -04:00
parent fb0a24fba2
commit 31476ccc0e
6 changed files with 983 additions and 8 deletions

View File

@@ -9,6 +9,7 @@
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "grammar-parser.h"
#include <sstream>
#include <cassert>
@@ -32,6 +33,7 @@ struct whisper_params {
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float grammar_penalty = 100.0f;
bool speed_up = false;
bool translate = false;
@@ -44,6 +46,7 @@ struct whisper_params {
std::string fname_out;
std::string commands;
std::string prompt;
std::string grammar;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -73,6 +76,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@@ -106,6 +111,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
@@ -115,6 +122,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
prob = 0.0f;
t_ms = 0;
grammar_parser::parse_state parsed_grammar;
std::vector<const whisper_grammar_element *> grammar_rules;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
@@ -131,6 +141,15 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
grammar_rules = parsed_grammar.c_rules();
wparams.grammar_rules = grammar_rules.data();
wparams.n_grammar_rules = grammar_rules.size();
wparams.i_start_rule = parsed_grammar.symbol_ids.at("root");
wparams.grammar_penalty = params.grammar_penalty;
}
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
@@ -648,12 +667,26 @@ int main(int argc, char ** argv) {
int ret_val = 0;
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
if (!params.grammar.empty()) {
auto parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
ret_val = 1;
} else {
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
fprintf(stderr, "\n");
}
}
if (ret_val == 0) {
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
}
}
audio.pause();