diff --git a/examples/wchess/libwchess/Chessboard.cpp b/examples/wchess/libwchess/Chessboard.cpp index ed6532b1..49f8fc87 100644 --- a/examples/wchess/libwchess/Chessboard.cpp +++ b/examples/wchess/libwchess/Chessboard.cpp @@ -2,6 +2,7 @@ #include #include #include +#include static constexpr std::array positions = { "a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1", @@ -14,46 +15,57 @@ static constexpr std::array positions = { "a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8", }; +constexpr auto INVALID_POS = positions.size(); + +constexpr int operator ""_P(const char * c, size_t size) { + if (size < 2) return INVALID_POS; + int file = c[0] - 'a'; + int rank = c[1] - '1'; + int pos = rank * 8 + file; + if (pos < 0 || pos >= int(INVALID_POS)) return INVALID_POS; + return pos; +} + static constexpr std::array pieceNames = { "pawn", "knight", "bishop", "rook", "queen", "king", }; Chessboard::Chessboard() : blackPieces {{ - {Piece::Pawn, Piece::Black, 48 }, - {Piece::Pawn, Piece::Black, 49 }, - {Piece::Pawn, Piece::Black, 50 }, - {Piece::Pawn, Piece::Black, 51 }, - {Piece::Pawn, Piece::Black, 52 }, - {Piece::Pawn, Piece::Black, 53 }, - {Piece::Pawn, Piece::Black, 54 }, - {Piece::Pawn, Piece::Black, 55 }, - {Piece::Rook, Piece::Black, 56 }, - {Piece::Knight, Piece::Black, 57 }, - {Piece::Bishop, Piece::Black, 58 }, - {Piece::Queen, Piece::Black, 59 }, - {Piece::King, Piece::Black, 60 }, - {Piece::Bishop, Piece::Black, 61 }, - {Piece::Knight, Piece::Black, 62 }, - {Piece::Rook, Piece::Black, 63 }, + {Piece::Pawn, Piece::Black, "a7"_P }, + {Piece::Pawn, Piece::Black, "b7"_P }, + {Piece::Pawn, Piece::Black, "c7"_P }, + {Piece::Pawn, Piece::Black, "d7"_P }, + {Piece::Pawn, Piece::Black, "e7"_P }, + {Piece::Pawn, Piece::Black, "f7"_P }, + {Piece::Pawn, Piece::Black, "g7"_P }, + {Piece::Pawn, Piece::Black, "h7"_P }, + {Piece::Rook, Piece::Black, "a8"_P }, + {Piece::Knight, Piece::Black, "b8"_P }, + {Piece::Bishop, Piece::Black, "c8"_P }, + {Piece::Queen, Piece::Black, "d8"_P }, + {Piece::King, Piece::Black, "e8"_P }, + {Piece::Bishop, Piece::Black, "f8"_P }, + {Piece::Knight, Piece::Black, "g8"_P }, + {Piece::Rook, Piece::Black, "h8"_P }, }} , whitePieces {{ - {Piece::Pawn, Piece::White, 8 }, - {Piece::Pawn, Piece::White, 9 }, - {Piece::Pawn, Piece::White, 10 }, - {Piece::Pawn, Piece::White, 11 }, - {Piece::Pawn, Piece::White, 12 }, - {Piece::Pawn, Piece::White, 13 }, - {Piece::Pawn, Piece::White, 14 }, - {Piece::Pawn, Piece::White, 15 }, - {Piece::Rook, Piece::White, 0 }, - {Piece::Knight, Piece::White, 1 }, - {Piece::Bishop, Piece::White, 2 }, - {Piece::Queen, Piece::White, 3 }, - {Piece::King, Piece::White, 4 }, - {Piece::Bishop, Piece::White, 5 }, - {Piece::Knight, Piece::White, 6 }, - {Piece::Rook, Piece::White, 7 }, + {Piece::Pawn, Piece::White, "a2"_P }, + {Piece::Pawn, Piece::White, "b2"_P }, + {Piece::Pawn, Piece::White, "c2"_P }, + {Piece::Pawn, Piece::White, "d2"_P }, + {Piece::Pawn, Piece::White, "e2"_P }, + {Piece::Pawn, Piece::White, "f2"_P }, + {Piece::Pawn, Piece::White, "g2"_P }, + {Piece::Pawn, Piece::White, "h2"_P }, + {Piece::Rook, Piece::White, "a1"_P }, + {Piece::Knight, Piece::White, "b1"_P }, + {Piece::Bishop, Piece::White, "c1"_P }, + {Piece::Queen, Piece::White, "d1"_P }, + {Piece::King, Piece::White, "e1"_P }, + {Piece::Bishop, Piece::White, "f1"_P }, + {Piece::Knight, Piece::White, "g1"_P }, + {Piece::Rook, Piece::White, "h1"_P }, }} , board {{ &whitePieces[ 8], &whitePieces[ 9], &whitePieces[10], &whitePieces[11], &whitePieces[12], &whitePieces[13], &whitePieces[14], &whitePieces[15], @@ -65,8 +77,75 @@ Chessboard::Chessboard() &blackPieces[ 0], &blackPieces[ 1], &blackPieces[ 2], &blackPieces[ 3], &blackPieces[ 4], &blackPieces[ 5], &blackPieces[ 6], &blackPieces[ 7], &blackPieces[ 8], &blackPieces[ 9], &blackPieces[10], &blackPieces[11], &blackPieces[12], &blackPieces[13], &blackPieces[14], &blackPieces[15], }} + , whiteMoves { + {"b1"_P, "a3"_P}, {"b1"_P, "c3"_P}, + {"g1"_P, "f3"_P}, {"g1"_P, "h3"_P}, + {"a2"_P, "a3"_P}, {"a2"_P, "a4"_P}, + {"b2"_P, "b3"_P}, {"b2"_P, "b4"_P}, + {"c2"_P, "c3"_P}, {"c2"_P, "c4"_P}, + {"d2"_P, "d3"_P}, {"d2"_P, "d4"_P}, + {"e2"_P, "e3"_P}, {"e2"_P, "e4"_P}, + {"f2"_P, "f3"_P}, {"f2"_P, "f4"_P}, + {"g2"_P, "g3"_P}, {"g2"_P, "g4"_P}, + {"h2"_P, "h3"_P}, {"h2"_P, "h4"_P}, + } + , blackMoves { + {"a7"_P, "a5"_P}, {"a7"_P, "a6"_P}, + {"b7"_P, "b5"_P}, {"b7"_P, "b6"_P}, + {"c7"_P, "c5"_P}, {"c7"_P, "c6"_P}, + {"d7"_P, "d5"_P}, {"d7"_P, "d6"_P}, + {"e7"_P, "e5"_P}, {"e7"_P, "e6"_P}, + {"f7"_P, "f5"_P}, {"f7"_P, "f6"_P}, + {"g7"_P, "g5"_P}, {"g7"_P, "g6"_P}, + {"h7"_P, "h5"_P}, {"h7"_P, "h6"_P}, + {"b8"_P, "a6"_P}, {"b8"_P, "c6"_P}, + {"g8"_P, "f6"_P}, {"g8"_P, "h6"_P}, + } + { static_assert(pieceNames.size() == Chessboard::Piece::Taken, "Mismatch between piece names and types"); + std::sort(whiteMoves.begin(), whiteMoves.end()); + std::sort(blackMoves.begin(), blackMoves.end()); +} + +std::string Chessboard::getRules() const { + // leading space is very important! + std::string result = + "\n" + "# leading space is very important!\n" + "\n" + "move ::= \" \" ((piece | frompos) \" \" \"to \"?)? topos\n" + "\n"; + + std::set pieces; + std::set from_pos; + std::set to_pos; + auto& allowed_moves = m_moveCounter % 2 ? blackMoves : whiteMoves; + for (auto& m : allowed_moves) { + if (board[m.first]->type != Piece::Taken) pieces.insert(pieceNames[board[m.first]->type]); + from_pos.insert(positions[m.first]); + to_pos.insert(positions[m.second]); + } + if (!pieces.empty()) { + result += "piece ::= ("; + for (auto& p : pieces) result += " \"" + p + "\" |"; + result.pop_back(); + result += ")\n\n"; + } + if (!from_pos.empty()) { + result += "frompos ::= ("; + for (auto& p : from_pos) result += " \"" + p + "\" |"; + result.pop_back(); + result += ")\n"; + } + if (!to_pos.empty()) { + result += "topos ::= ("; + for (auto& p : to_pos) result += " \"" + p + "\" |"; + result.pop_back(); + result += ")\n"; + } + + return result; } std::string Chessboard::stringifyBoard() { @@ -93,7 +172,7 @@ std::string Chessboard::stringifyBoard() { result.push_back('0' + i + 1); result.push_back('\n'); } -return result; + return result; } std::vector split(std::string_view str, char del) { @@ -119,57 +198,59 @@ Chessboard::Piece::Types Chessboard::tokenToType(std::string_view token) { } size_t Chessboard::tokenToPos(std::string_view token) { - if (token.size() < 2) return positions.size(); - int file = token[0] - 'a'; - int rank = token[1] - '1'; - int pos = rank * 8 + file; - if (pos < 0 || pos >= int(positions.size())) return positions.size(); - return pos; + return operator ""_P(token.data(), token.size()); } -std::string Chessboard::process(const std::string& transcription) { - auto commands = split(transcription, ','); +void Chessboard::updateMoves(const Move& m) { + // todo +} - // fixme: lookup depends on grammar - int count = m_moveCounter; - std::vector moves; - std::string result; - result.reserve(commands.size() * 6); - for (auto& command : commands) { - - fprintf(stdout, "%s: Command '%s%.*s%s'\n", __func__, "\033[1m", int(command.size()), command.data(), "\033[0m"); - if (command.empty()) continue; - auto tokens = split(command, ' '); - Piece::Types type = Piece::Types::Taken; - size_t pos = positions.size(); - if (tokens.size() == 1) { - type = Piece::Types::Pawn; - pos = tokenToPos(tokens[0]); - } - else if (tokens.size() == 3) { - type = tokenToType(tokens[0]); - assert(tokens[1] == "to"); - pos = tokenToPos(tokens[2]); - } - if (type == Piece::Types::Taken || pos == positions.size()) continue; - - auto& pieces = count % 2 ? blackPieces : whitePieces; +std::string Chessboard::process(const std::string& command) { + fprintf(stdout, "%s: Command '%s%.*s%s'\n", __func__, "\033[1m", int(command.size()), command.data(), "\033[0m"); + if (command.empty()) return ""; + auto tokens = split(command, ' '); + for (auto& t : tokens) fprintf(stdout, "%s: Token %.*s\n", __func__, int(t.size()), t.data()); + auto pos_from = INVALID_POS; + auto type = Piece::Types::Taken; + auto pos_to = INVALID_POS; + if (tokens.size() == 1) { + type = Piece::Types::Pawn; + pos_to = tokenToPos(tokens[0]); + } + else { + pos_from = tokenToPos(tokens.front()); + if (pos_from == INVALID_POS) type = tokenToType(tokens.front()); + pos_to = tokenToPos(tokens.back()); + } + if (pos_to == INVALID_POS) return ""; + auto color = Piece::Colors(m_moveCounter % 2); + if (pos_from == INVALID_POS) { + if (type == Piece::Types::Taken) return ""; + auto& pieces = color ? blackPieces : whitePieces; auto pieceIndex = 0u; for (; pieceIndex < pieces.size(); ++pieceIndex) { - if (pieces[pieceIndex].type == type && validateMove(pieces[pieceIndex], pos)) break; - } - Move m = {pieces[pieceIndex].pos, pos}; - if (pieceIndex < pieces.size() && move({m})) { - result.append(positions[m.first]); - result.push_back('-'); - result.append(positions[m.second]); - result.push_back(' '); - ++count; + if (pieces[pieceIndex].type == type && validateMove(pieces[pieceIndex], pos_to)) break; } + if (pieceIndex == pieces.size()) return ""; + pos_from = pieces[pieceIndex].pos; } - if (!result.empty()) result.pop_back(); - m_moveCounter = count; - fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", result.data(), "\033[0m"); + if (board[pos_from] == nullptr) return ""; + if (board[pos_from]->color != color) return ""; + + Move m = {pos_from, pos_to}; + auto& allowed_moves = color ? blackMoves : whiteMoves; + auto it = std::lower_bound(allowed_moves.begin(), allowed_moves.end(), m); + if (it == allowed_moves.end() || *it != m) return ""; + allowed_moves.erase(it); + + move(m); + updateMoves(m); + + std::string result = positions[m.first]; + result += "-"; + result += positions[m.second]; + ++m_moveCounter; + fprintf(stdout, "%s: Move '%s%s%s'\n", __func__, "\033[1m", result.data(), "\033[0m"); return result; } diff --git a/examples/wchess/libwchess/Chessboard.h b/examples/wchess/libwchess/Chessboard.h index 25fe8ef7..9c6f4be4 100644 --- a/examples/wchess/libwchess/Chessboard.h +++ b/examples/wchess/libwchess/Chessboard.h @@ -8,9 +8,11 @@ public: Chessboard(); std::string process(const std::string& t); std::string stringifyBoard(); -private: + std::string getRules() const; using Move = std::pair; +private: bool move(const Move& move); + void updateMoves(const Move& move); struct Piece { enum Types { @@ -24,8 +26,8 @@ private: }; enum Colors { + White, Black, - White }; Types type; @@ -44,6 +46,9 @@ private: using Board = std::array; Board board; + std::vector whiteMoves; + std::vector blackMoves; + bool validateMove(const Piece& piece, int pos); // just basic validation // fixme: missing en passant, castling, promotion, etc. diff --git a/examples/wchess/libwchess/WChess.cpp b/examples/wchess/libwchess/WChess.cpp index 9293180c..f85c21db 100644 --- a/examples/wchess/libwchess/WChess.cpp +++ b/examples/wchess/libwchess/WChess.cpp @@ -4,24 +4,6 @@ #include "common.h" #include -static constexpr auto RULES = -"\n" -"root ::= move \".\"\n" -"prompt ::= init \".\"\n" -"\n" -"# leading space is very important!\n" -"init ::= \" rook to b4, f3\"\n" -"\n" -"move ::= \" \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" -"\n" -"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n" -"king ::= \"king\"\n" -"pawn ::= \"pawn\"\n" -"\n"; - -static constexpr auto PROMPT = "rook to b4, f3"; -static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,"; - WChess::WChess(whisper_context * ctx, const whisper_full_params & wparams, callbacks cb, @@ -78,86 +60,38 @@ void WChess::run() { std::vector pcmf32_cur; std::vector pcmf32_prompt; - const std::string k_prompt = PROMPT; - - auto grammar_parsed = grammar_parser::parse(RULES); - auto grammar_rules = grammar_parsed.c_rules(); - - if (grammar_parsed.rules.empty()) { - fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__); - } - else { - m_wparams.grammar_rules = grammar_rules.data(); - m_wparams.n_grammar_rules = grammar_rules.size(); - } + std::string prompt = ""; + float prompt_prop = 0.0f; while (check_running()) { // delay std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (ask_prompt) { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); - fprintf(stdout, "\n"); - - { - char txt[1024]; - snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str()); - set_status(txt); - } - - ask_prompt = false; - } - int64_t t_ms = 0; { get_audio(m_settings.vad_ms, pcmf32_cur); if (!pcmf32_cur.empty()) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); - set_status("Speech detected! Processing ..."); + fprintf(stdout, "%s: Processing ...\n", __func__); + set_status("Processing ..."); - if (!have_prompt) { - - m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt"); - const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); - - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); - - const float sim = similarity(txt, k_prompt); - - if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { - fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); - ask_prompt = true; - } else { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); - fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); - fprintf(stdout, "\n"); - - { - char txt[1024]; - snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ..."); - set_status(txt); - } - - // save the audio for the prompt - pcmf32_prompt = pcmf32_cur; - have_prompt = true; - } - } else { - // prepend 3 second of silence - // pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); - - // prepend the prompt audio - // pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + { + if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); if (WHISPER_SAMPLE_RATE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), WHISPER_SAMPLE_RATE - pcmf32_cur.size(), 0.0f); - m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root"); + std::string rules = m_board->getRules(); + fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, rules.c_str()); + + auto grammar_parsed = grammar_parser::parse(rules.c_str()); + auto grammar_rules = grammar_parsed.c_rules(); + + m_wparams.grammar_rules = grammar_rules.data(); + m_wparams.n_grammar_rules = grammar_rules.size(); + + m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move"); auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); - txt = PROMPT + txt; const float p = 100.0f * std::exp(logprob_min); @@ -166,20 +100,18 @@ void WChess::run() { // find the prompt in the text float best_sim = 0.0f; size_t best_len = 0; - for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { - if (n >= int(txt.size())) { - break; - } + if (!prompt.empty()) { + auto pos = txt.find_first_of('.'); - const auto prompt = txt.substr(0, n); + const auto header = txt.substr(0, pos); - const float sim = similarity(prompt, k_prompt); + const float sim = similarity(prompt, header); //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); if (sim > best_sim) { best_sim = sim; - best_len = n; + best_len = pos + 1; } } @@ -195,7 +127,10 @@ void WChess::run() { set_status(txt); } if (!command.empty()) { - set_moves(m_board->process(command)); + auto move = m_board->process(command); + if (!move.empty()) { + set_moves(std::move(move)); + } } }