wchess grammar tweaks

This commit is contained in:
Fraxy V 2023-11-29 09:25:45 +02:00
parent 8b0b0acff3
commit d313034b9c

View File

@ -6,20 +6,20 @@
static constexpr auto RULES = static constexpr auto RULES =
"\n" "\n"
"root ::= init move move? move? \".\"\n" "root ::= move \".\"\n"
"prompt ::= init \".\"\n" "prompt ::= init \".\"\n"
"\n" "\n"
"# leading space is very important!\n" "# leading space is very important!\n"
"init ::= \" rook to b4, f3\"\n" "init ::= \" rook to b4, f3\"\n"
"\n" "\n"
"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" "move ::= \" \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n"
"\n" "\n"
"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n" "piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n"
"king ::= \"king\"\n" "king ::= \"king\"\n"
"pawn ::= \"pawn\"\n" "pawn ::= \"pawn\"\n"
"\n"; "\n";
static constexpr auto PROMPT = "rook to b4, f3,"; static constexpr auto PROMPT = "rook to b4, f3";
static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,"; static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,";
WChess::WChess(whisper_context * ctx, WChess::WChess(whisper_context * ctx,
@ -63,8 +63,8 @@ std::string WChess::stringify_board() const {
void WChess::run() { void WChess::run() {
set_status("loading data ..."); set_status("loading data ...");
bool have_prompt = false; bool have_prompt = true;
bool ask_prompt = true; bool ask_prompt = false;
float logprob_min0 = 0.0f; float logprob_min0 = 0.0f;
float logprob_min = 0.0f; float logprob_min = 0.0f;
@ -79,7 +79,6 @@ void WChess::run() {
std::vector<float> pcmf32_prompt; std::vector<float> pcmf32_prompt;
const std::string k_prompt = PROMPT; const std::string k_prompt = PROMPT;
m_wparams.initial_prompt = CONTEXT;
auto grammar_parsed = grammar_parser::parse(RULES); auto grammar_parsed = grammar_parser::parse(RULES);
auto grammar_rules = grammar_parsed.c_rules(); auto grammar_rules = grammar_parsed.c_rules();
@ -149,13 +148,16 @@ void WChess::run() {
} }
} else { } else {
// prepend 3 second of silence // prepend 3 second of silence
pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); // pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f);
// prepend the prompt audio // prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); // pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
if (WHISPER_SAMPLE_RATE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), WHISPER_SAMPLE_RATE - pcmf32_cur.size(), 0.0f);
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root"); m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root");
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
txt = PROMPT + txt;
const float p = 100.0f * std::exp(logprob_min); const float p = 100.0f * std::exp(logprob_min);