Compare commits

..

10 Commits

Author SHA1 Message Date
15c4fdce45 chess : tuning performance 2023-11-30 10:50:47 +02:00
70741ba794 wchess: c++17 -> c++11 2023-11-30 08:37:54 +02:00
bb723282cc wchess: off/on prompt 2023-11-30 01:17:29 +02:00
dc5513a709 wchess: prompt 2023-11-29 19:30:57 +02:00
ffc244845b wchess : dynamic grammar 2023-11-29 18:53:28 +02:00
8962a6bd67 wchess: preparing dyn grammar 2023-11-29 15:29:16 +02:00
d313034b9c wchess grammar tweaks 2023-11-29 09:25:45 +02:00
8b0b0acff3 wchess : remove vad 2023-11-28 19:03:17 +02:00
02ade14f67 wchess minor 2023-11-28 16:21:46 +02:00
8dba8204eb Merge pull request #1 from ggerganov/gg/wchess
wchess : add clear_audio callback
2023-11-28 15:45:17 +02:00
7 changed files with 642 additions and 278 deletions

View File

@ -1,9 +1,12 @@
#include "Chessboard.h"
#include <vector>
#include <algorithm>
#include <cassert>
#include <cstring>
#include <set>
static constexpr std::array<const char*, 64> positions = {
namespace {
// remove std::string_view, c++17 -> c++11
constexpr std::array<const char*, 64> positions = {
"a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
"a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2",
"a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3",
@ -13,47 +16,96 @@ static constexpr std::array<const char*, 64> positions = {
"a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7",
"a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8",
};
constexpr int INVALID_POS = positions.size();
constexpr int R = 0; // rank index
constexpr int F = 1; // file index
#define POS ((c[F] - '1') * 8 + (c[R] - 'a'))
constexpr int operator ""_P(const char * c, size_t size) {
return size < 2 || POS < 0 || POS > INVALID_POS ? INVALID_POS : POS;
}
#undef POS
static constexpr std::array<const char*, 6> pieceNames = {
struct sview {
const char * ptr = nullptr;
size_t size = 0;
sview() = default;
sview(const char * p, size_t s) : ptr(p), size(s) {}
sview(const std::string& s) : ptr(s.data()), size(s.size()) {}
size_t find(char del, size_t pos) {
while (pos < size && ptr[pos] != del) ++pos;
return pos < size ? pos : std::string::npos;
}
};
std::vector<sview> split(sview str, char del) {
std::vector<sview> res;
size_t cur = 0;
size_t last = 0;
while (cur != std::string::npos) {
if (str.ptr[last] == ' ') {
++last;
continue;
}
cur = str.find(del, last);
size_t len = cur == std::string::npos ? str.size - last : cur - last;
res.emplace_back(str.ptr + last, len);
last = cur + 1;
}
return res;
}
size_t strToPos(sview str) {
return operator ""_P(str.ptr, str.size);
}
constexpr std::array<const char*, 6> pieceNames = {
"pawn", "knight", "bishop", "rook", "queen", "king",
};
int strToType(sview str) {
auto it = std::find_if(pieceNames.begin(), pieceNames.end(), [str] (const char* name) { return strncmp(name, str.ptr, str.size) == 0; });
return it != pieceNames.end() ? int(it - pieceNames.begin()) : pieceNames.size();
}
}
Chessboard::Chessboard()
: blackPieces {{
{Piece::Pawn, Piece::Black, 48 },
{Piece::Pawn, Piece::Black, 49 },
{Piece::Pawn, Piece::Black, 50 },
{Piece::Pawn, Piece::Black, 51 },
{Piece::Pawn, Piece::Black, 52 },
{Piece::Pawn, Piece::Black, 53 },
{Piece::Pawn, Piece::Black, 54 },
{Piece::Pawn, Piece::Black, 55 },
{Piece::Rook, Piece::Black, 56 },
{Piece::Knight, Piece::Black, 57 },
{Piece::Bishop, Piece::Black, 58 },
{Piece::Queen, Piece::Black, 59 },
{Piece::King, Piece::Black, 60 },
{Piece::Bishop, Piece::Black, 61 },
{Piece::Knight, Piece::Black, 62 },
{Piece::Rook, Piece::Black, 63 },
{Piece::Pawn, Piece::Black, "a7"_P },
{Piece::Pawn, Piece::Black, "b7"_P },
{Piece::Pawn, Piece::Black, "c7"_P },
{Piece::Pawn, Piece::Black, "d7"_P },
{Piece::Pawn, Piece::Black, "e7"_P },
{Piece::Pawn, Piece::Black, "f7"_P },
{Piece::Pawn, Piece::Black, "g7"_P },
{Piece::Pawn, Piece::Black, "h7"_P },
{Piece::Rook, Piece::Black, "a8"_P },
{Piece::Knight, Piece::Black, "b8"_P },
{Piece::Bishop, Piece::Black, "c8"_P },
{Piece::Queen, Piece::Black, "d8"_P },
{Piece::King, Piece::Black, "e8"_P },
{Piece::Bishop, Piece::Black, "f8"_P },
{Piece::Knight, Piece::Black, "g8"_P },
{Piece::Rook, Piece::Black, "h8"_P },
}}
, whitePieces {{
{Piece::Pawn, Piece::White, 8 },
{Piece::Pawn, Piece::White, 9 },
{Piece::Pawn, Piece::White, 10 },
{Piece::Pawn, Piece::White, 11 },
{Piece::Pawn, Piece::White, 12 },
{Piece::Pawn, Piece::White, 13 },
{Piece::Pawn, Piece::White, 14 },
{Piece::Pawn, Piece::White, 15 },
{Piece::Rook, Piece::White, 0 },
{Piece::Knight, Piece::White, 1 },
{Piece::Bishop, Piece::White, 2 },
{Piece::Queen, Piece::White, 3 },
{Piece::King, Piece::White, 4 },
{Piece::Bishop, Piece::White, 5 },
{Piece::Knight, Piece::White, 6 },
{Piece::Rook, Piece::White, 7 },
{Piece::Pawn, Piece::White, "a2"_P },
{Piece::Pawn, Piece::White, "b2"_P },
{Piece::Pawn, Piece::White, "c2"_P },
{Piece::Pawn, Piece::White, "d2"_P },
{Piece::Pawn, Piece::White, "e2"_P },
{Piece::Pawn, Piece::White, "f2"_P },
{Piece::Pawn, Piece::White, "g2"_P },
{Piece::Pawn, Piece::White, "h2"_P },
{Piece::Rook, Piece::White, "a1"_P },
{Piece::Knight, Piece::White, "b1"_P },
{Piece::Bishop, Piece::White, "c1"_P },
{Piece::Queen, Piece::White, "d1"_P },
{Piece::King, Piece::White, "e1"_P },
{Piece::Bishop, Piece::White, "f1"_P },
{Piece::Knight, Piece::White, "g1"_P },
{Piece::Rook, Piece::White, "h1"_P },
}}
, board {{
&whitePieces[ 8], &whitePieces[ 9], &whitePieces[10], &whitePieces[11], &whitePieces[12], &whitePieces[13], &whitePieces[14], &whitePieces[15],
@ -65,8 +117,83 @@ Chessboard::Chessboard()
&blackPieces[ 0], &blackPieces[ 1], &blackPieces[ 2], &blackPieces[ 3], &blackPieces[ 4], &blackPieces[ 5], &blackPieces[ 6], &blackPieces[ 7],
&blackPieces[ 8], &blackPieces[ 9], &blackPieces[10], &blackPieces[11], &blackPieces[12], &blackPieces[13], &blackPieces[14], &blackPieces[15],
}}
, whiteMoves {
{"b1"_P, "a3"_P}, {"b1"_P, "c3"_P},
{"g1"_P, "f3"_P}, {"g1"_P, "h3"_P},
{"a2"_P, "a3"_P}, {"a2"_P, "a4"_P},
{"b2"_P, "b3"_P}, {"b2"_P, "b4"_P},
{"c2"_P, "c3"_P}, {"c2"_P, "c4"_P},
{"d2"_P, "d3"_P}, {"d2"_P, "d4"_P},
{"e2"_P, "e3"_P}, {"e2"_P, "e4"_P},
{"f2"_P, "f3"_P}, {"f2"_P, "f4"_P},
{"g2"_P, "g3"_P}, {"g2"_P, "g4"_P},
{"h2"_P, "h3"_P}, {"h2"_P, "h4"_P},
}
, blackMoves {
{"a7"_P, "a5"_P}, {"a7"_P, "a6"_P},
{"b7"_P, "b5"_P}, {"b7"_P, "b6"_P},
{"c7"_P, "c5"_P}, {"c7"_P, "c6"_P},
{"d7"_P, "d5"_P}, {"d7"_P, "d6"_P},
{"e7"_P, "e5"_P}, {"e7"_P, "e6"_P},
{"f7"_P, "f5"_P}, {"f7"_P, "f6"_P},
{"g7"_P, "g5"_P}, {"g7"_P, "g6"_P},
{"h7"_P, "h5"_P}, {"h7"_P, "h6"_P},
{"b8"_P, "a6"_P}, {"b8"_P, "c6"_P},
{"g8"_P, "f6"_P}, {"g8"_P, "h6"_P},
}
{
static_assert(pieceNames.size() == Chessboard::Piece::Taken, "Mismatch between piece names and types");
std::sort(whiteMoves.begin(), whiteMoves.end());
std::sort(blackMoves.begin(), blackMoves.end());
}
std::string Chessboard::getRules(const std::string& prompt) const {
// leading space is very important!
std::string result =
"\n"
"# leading space is very important!\n"
"\n";
if (prompt.empty()) {
result += "move ::= \" \" ((piece | frompos) \" \" \"to \"?)? topos\n";
//result += "move ::= \" \" frompos \" \" \"to \"? topos\n";
}
else {
// result += "move ::= prompt \" \" ((piece | frompos) \" \" \"to \"?)? topos\n"
result += "move ::= prompt \" \" frompos \" \" \"to \"? topos\n"
"\n"
"prompt ::= \" " + prompt + "\"\n";
}
std::set<std::string> pieces;
std::set<std::string> from_pos;
std::set<std::string> to_pos;
auto& allowed_moves = m_moveCounter % 2 ? blackMoves : whiteMoves;
for (auto& m : allowed_moves) {
if (board[m.first]->type != Piece::Taken) pieces.insert(pieceNames[board[m.first]->type]);
from_pos.insert(positions[m.first]);
to_pos.insert(positions[m.second]);
}
if (!pieces.empty()) {
result += "piece ::= (";
for (auto& p : pieces) result += " \"" + p + "\" |";
result.pop_back();
result += ")\n\n";
}
if (!from_pos.empty()) {
result += "frompos ::= (";
for (auto& p : from_pos) result += " \"" + p + "\" |";
result.pop_back();
result += ")\n";
}
if (!to_pos.empty()) {
result += "topos ::= (";
for (auto& p : to_pos) result += " \"" + p + "\" |";
result.pop_back();
result += ")\n";
}
return result;
}
std::string Chessboard::stringifyBoard() {
@ -86,98 +213,402 @@ std::string Chessboard::stringifyBoard() {
result.back() = '\n';
for (int i = 7; i >= 0; --i) {
for (int j = 0; j < 8; ++j) {
if (auto p = board[i * 8 + j]; p) result.push_back(p->color == Piece::White ? whiteShort[p->type] : blackShort[p->type]);
auto p = board[i * 8 + j];
if (p) result.push_back(p->color == Piece::White ? whiteShort[p->type] : blackShort[p->type]);
else result.push_back((i + j) % 2 ? '.' : '*');
result.push_back(' ');
}
result.push_back('0' + i + 1);
result.push_back('\n');
}
return result;
return result;
}
std::vector<std::string_view> split(std::string_view str, char del) {
std::vector<std::string_view> res;
size_t cur = 0;
size_t last = 0;
while (cur != std::string::npos) {
if (str[last] == ' ') { // trim white
++last;
continue;
}
cur = str.find(del, last);
size_t len = cur == std::string::npos ? str.size() - last : cur - last;
res.emplace_back(str.data() + last, len);
last = cur + 1;
std::string Chessboard::process(const std::string& command) {
auto color = Piece::Colors(m_moveCounter % 2);
fprintf(stdout, "%s: Command to %s: '%s%.*s%s'\n", __func__, (color ? "Black" : "White"), "\033[1m", int(command.size()), command.data(), "\033[0m");
if (command.empty()) return "";
auto tokens = split(command, ' ');
for (auto& t : tokens) fprintf(stdout, "%s: Token %.*s\n", __func__, int(t.size), t.ptr);
auto pos_from = INVALID_POS;
auto type = Piece::Types::Taken;
auto pos_to = INVALID_POS;
if (tokens.size() == 1) {
type = Piece::Types::Pawn;
pos_to = strToPos(tokens.front());
}
return res;
}
Chessboard::Piece::Types Chessboard::tokenToType(std::string_view token) {
auto it = std::find(pieceNames.begin(), pieceNames.end(), token);
return it != pieceNames.end() ? Piece::Types(it - pieceNames.begin()) : Piece::Taken;
}
size_t Chessboard::tokenToPos(std::string_view token) {
if (token.size() < 2) return positions.size();
int file = token[0] - 'a';
int rank = token[1] - '1';
int pos = rank * 8 + file;
if (pos < 0 || pos >= int(positions.size())) return positions.size();
return pos;
}
std::string Chessboard::process(const std::string& transcription) {
auto commands = split(transcription, ',');
// fixme: lookup depends on grammar
int count = m_moveCounter;
std::vector<Move> moves;
std::string result;
result.reserve(commands.size() * 6);
for (auto& command : commands) {
fprintf(stdout, "%s: Command '%s%.*s%s'\n", __func__, "\033[1m", int(command.size()), command.data(), "\033[0m");
if (command.empty()) continue;
auto tokens = split(command, ' ');
Piece::Types type = Piece::Types::Taken;
size_t pos = positions.size();
if (tokens.size() == 1) {
type = Piece::Types::Pawn;
pos = tokenToPos(tokens[0]);
}
else if (tokens.size() == 3) {
type = tokenToType(tokens[0]);
assert(tokens[1] == "to");
pos = tokenToPos(tokens[2]);
}
if (type == Piece::Types::Taken || pos == positions.size()) continue;
auto& pieces = count % 2 ? blackPieces : whitePieces;
else {
pos_from = strToPos(tokens.front());
if (pos_from == INVALID_POS) type = Piece::Types(strToType(tokens.front()));
pos_to = strToPos(tokens.back());
}
if (pos_to == INVALID_POS) return "";
if (pos_from == INVALID_POS) {
if (type == Piece::Types::Taken) return "";
auto& pieces = color ? blackPieces : whitePieces;
auto pieceIndex = 0u;
for (; pieceIndex < pieces.size(); ++pieceIndex) {
if (pieces[pieceIndex].type == type && validateMove(pieces[pieceIndex], pos)) break;
if (pieces[pieceIndex].type == type && validateMove(pieces[pieceIndex], pos_to)) break;
}
Move m = {pieces[pieceIndex].pos, pos};
if (pieceIndex < pieces.size() && move({m})) {
result.append(positions[m.first]);
result.push_back('-');
result.append(positions[m.second]);
result.push_back(' ');
++count;
if (pieceIndex == pieces.size()) return "";
pos_from = pieces[pieceIndex].pos;
}
if (board[pos_from] == nullptr) return "";
if (board[pos_from]->color != color) return "";
Move m = {pos_from, pos_to};
auto& allowed_moves = color ? blackMoves : whiteMoves;
fprintf(stdout, "%s:allowed size %d :\n", __func__, int(allowed_moves.size()));
for (auto& m : allowed_moves) fprintf(stdout, " %s %s; ", positions[m.first], positions[m.second]);
fprintf(stdout, "\n");
if (!std::binary_search(allowed_moves.begin(), allowed_moves.end(), m)) return "";
move(m);
{
auto it = std::remove_if(allowed_moves.begin(), allowed_moves.end(), [m] (const Move& move) { return move.first == m.first; });
allowed_moves.erase(it, allowed_moves.end());
}
std::vector<Piece*> affected = { board[m.second] };
for (auto& p : whitePieces) {
if (&p == board[m.second]
|| validateMove(p, m.first)
|| validateMove(p, m.second)
|| std::binary_search(whiteMoves.begin(), whiteMoves.end(), Move(p.pos, m.second))
) {
auto it = std::remove_if(whiteMoves.begin(), whiteMoves.end(), [&p] (const Move& m) { return m.first == p.pos; });
whiteMoves.erase(it, whiteMoves.end());
affected.push_back(&p);
}
}
if (!result.empty()) result.pop_back();
m_moveCounter = count;
fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", result.data(), "\033[0m");
for (auto& p : blackPieces) {
if (&p == board[m.second]
|| validateMove(p, m.first)
|| validateMove(p, m.second)
|| std::binary_search(blackMoves.begin(), blackMoves.end(), Move(p.pos, m.second))
) {
auto it = std::remove_if(blackMoves.begin(), blackMoves.end(), [&p] (const Move& m) { return m.first == p.pos; });
blackMoves.erase(it, blackMoves.end());
affected.push_back(&p);
}
}
for (auto& p : affected) getValidMoves(*p, p->color ? blackMoves : whiteMoves);
std::sort(blackMoves.begin(), blackMoves.end());
std::sort(whiteMoves.begin(), whiteMoves.end());
std::string result = positions[m.first];
result += "-";
result += positions[m.second];
++m_moveCounter;
fprintf(stdout, "%s: Move '%s%s%s'\n", __func__, "\033[1m", result.data(), "\033[0m");
return result;
}
void Chessboard::getValidMoves(const Piece& piece, std::vector<Move>& result) {
std::string cur = positions[piece.pos];
switch (piece.type) {
case Piece::Pawn: {
std::string next = cur;
piece.color ? --next[F] : ++next[F]; // one down / up
std::string left = { char(next[R] - 1), next[F]};
auto pos = strToPos(left);
if (pos != INVALID_POS && board[pos] && board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
std::string right = { char(next[R] + 1), next[F]};
pos = strToPos(right);
if (pos != INVALID_POS && board[pos] && board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
pos = strToPos(next);
if (pos != INVALID_POS && !board[pos]) result.emplace_back(piece.pos, pos);
else break;
if (piece.color ? cur[F] != '7' : cur[F] != '2') break;
piece.color ? --next[F] : ++next[F]; // one down / up
pos = strToPos(next);
if (pos != INVALID_POS && !board[pos]) result.emplace_back(piece.pos, pos);
break;
}
case Piece::Knight: {
std::string next = cur;
--next[F]; --next[F]; --next[R];
auto pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F]; --next[F]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; ++next[F]; --next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; ++next[F]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F]; --next[R]; --next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; --next[R]; --next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F]; ++next[R]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; ++next[R]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
break;
}
case Piece::Bishop: {
std::string next = cur;
while (true) {
--next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
break;
}
case Piece::Rook: {
std::string next = cur;
while (true) {
--next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
break;
}
case Piece::Queen: {
std::string next = cur;
while (true) {
--next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
break;
}
case Piece::King: {
std::string next = cur;
--next[R]; --next[F];
auto pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[R]; ++next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[R]; --next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[R]; ++next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
break;
}
case Piece::Taken: break;
default: break;
}
}
bool Chessboard::validatePawnMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file) {
int direction = color == Piece::White ? 1 : -1;
bool two_ranks = color == Piece::White ? from_rank == 1 : from_rank == 6;
if (from_file == to_file) {
if (from_rank == to_rank - direction) return board[to_rank * 8 + to_file] == nullptr;
if (from_rank == to_rank - direction * 2) return board[(to_rank - direction) * 8 + to_file] == nullptr && board[to_rank * 8 + to_file] == nullptr;
if (two_ranks && from_rank == to_rank - direction * 2) return board[(to_rank - direction) * 8 + to_file] == nullptr && board[to_rank * 8 + to_file] == nullptr;
}
else if (from_file + 1 == to_file || from_file - 1 == to_file) {
if (from_rank == to_rank - direction) return board[to_rank * 8 + to_file] != nullptr && board[to_rank * 8 + to_file]->color != color;
@ -280,4 +711,4 @@ bool Chessboard::move(const Move& m) {
board[m.first] = nullptr;
board[m.second]->pos = m.second;
return true;
}
}

View File

@ -8,8 +8,9 @@ public:
Chessboard();
std::string process(const std::string& t);
std::string stringifyBoard();
private:
std::string getRules(const std::string & prompt) const;
using Move = std::pair<int, int>;
private:
bool move(const Move& move);
struct Piece {
@ -24,8 +25,8 @@ private:
};
enum Colors {
White,
Black,
White
};
Types type;
@ -33,8 +34,6 @@ private:
int pos;
};
Piece::Types tokenToType(std::string_view token);
size_t tokenToPos(std::string_view token);
using PieceSet = std::array<Piece, 16>;
PieceSet blackPieces;
@ -44,7 +43,11 @@ private:
using Board = std::array<Piece*, 64>;
Board board;
std::vector<Move> whiteMoves;
std::vector<Move> blackMoves;
bool validateMove(const Piece& piece, int pos);
void getValidMoves(const Piece& piece, std::vector<Move>& moves);
// just basic validation
// fixme: missing en passant, castling, promotion, etc.
bool validatePawnMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);

View File

@ -4,24 +4,6 @@
#include "common.h"
#include <thread>
static constexpr auto RULES =
"\n"
"root ::= init move move? move? \".\"\n"
"prompt ::= init \".\"\n"
"\n"
"# leading space is very important!\n"
"init ::= \" rook to b4, f3\"\n"
"\n"
"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n"
"\n"
"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n"
"king ::= \"king\"\n"
"pawn ::= \"pawn\"\n"
"\n";
static constexpr auto PROMPT = "rook to b4, f3,";
static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,";
WChess::WChess(whisper_context * ctx,
const whisper_full_params & wparams,
callbacks cb,
@ -48,9 +30,8 @@ bool WChess::check_running() const {
return false;
}
bool WChess::clear_audio() const {
if (m_cb.clear_audio) return (*m_cb.clear_audio)();
return false;
void WChess::clear_audio() const {
if (m_cb.clear_audio) (*m_cb.clear_audio)();
}
void WChess::get_audio(int ms, std::vector<float>& pcmf32) const {
@ -64,8 +45,8 @@ std::string WChess::stringify_board() const {
void WChess::run() {
set_status("loading data ...");
bool have_prompt = false;
bool ask_prompt = true;
bool have_prompt = true;
bool ask_prompt = !have_prompt;
float logprob_min0 = 0.0f;
float logprob_min = 0.0f;
@ -79,19 +60,7 @@ void WChess::run() {
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = PROMPT;
m_wparams.initial_prompt = CONTEXT;
auto grammar_parsed = grammar_parser::parse(RULES);
auto grammar_rules = grammar_parsed.c_rules();
if (grammar_parsed.rules.empty()) {
fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__);
}
else {
m_wparams.grammar_rules = grammar_rules.data();
m_wparams.n_grammar_rules = grammar_rules.size();
}
const std::string k_prompt = have_prompt ? "" : "checkmate";
while (check_running()) {
// delay
@ -116,14 +85,11 @@ void WChess::run() {
{
get_audio(m_settings.vad_ms, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, m_settings.vad_thold, m_settings.freq_thold, m_settings.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
set_status("Speech detected! Processing ...");
if (!pcmf32_cur.empty()) {
fprintf(stdout, "%s: Processing ...\n", __func__);
set_status("Processing ...");
if (!have_prompt) {
get_audio(m_settings.prompt_ms, pcmf32_cur);
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt");
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
@ -150,16 +116,21 @@ void WChess::run() {
have_prompt = true;
}
} else {
get_audio(m_settings.command_ms, pcmf32_cur);
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
static const size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
// prepend 3 second of silence
pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f);
std::string rules = m_board->getRules(k_prompt);
fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, rules.c_str());
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
auto grammar_parsed = grammar_parser::parse(rules.c_str());
auto grammar_rules = grammar_parsed.c_rules();
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root");
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
m_wparams.grammar_rules = grammar_rules.data();
m_wparams.n_grammar_rules = grammar_rules.size();
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
const float p = 100.0f * std::exp(logprob_min);
@ -169,10 +140,6 @@ void WChess::run() {
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
if (n >= int(txt.size())) {
break;
}
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
@ -197,7 +164,10 @@ void WChess::run() {
set_status(txt);
}
if (!command.empty()) {
set_moves(m_board->process(command));
auto move = m_board->process(command);
if (!move.empty()) {
set_moves(std::move(move));
}
}
}

View File

@ -12,14 +12,14 @@ public:
using CheckRunningCb = bool (*)();
using GetAudioCb = void (*)(int, std::vector<float> &);
using SetMovesCb = void (*)(const std::string &);
using CleartAudioCb = bool (*)();
using ClearAudioCb = void (*)();
struct callbacks {
SetStatusCb set_status = nullptr;
CheckRunningCb check_running = nullptr;
GetAudioCb get_audio = nullptr;
SetMovesCb set_moves = nullptr;
CleartAudioCb clear_audio = nullptr;
ClearAudioCb clear_audio = nullptr;
};
struct settings {
@ -46,7 +46,7 @@ private:
void set_status(const std::string& msg) const;
void set_moves(const std::string& moves) const;
bool check_running() const;
bool clear_audio() const;
void clear_audio() const;
std::string transcribe(
const std::vector<float> & pcmf32,
float & logprob_min,

View File

@ -118,7 +118,7 @@ void get_audio(int ms, std::vector<float> & pcmf32_cur) {
g_audio.get(ms, pcmf32_cur);
}
bool clear_audio() {
void clear_audio() {
g_audio.clear();
}

View File

@ -50,7 +50,7 @@
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<span id="fetch-whisper-progress"></span>
<button id="clear" onclick="clearCache()">Clear Cache</button>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
@ -67,9 +67,7 @@
<br>
<div id="input">
<button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
<button id="toggler" disabled>Hold</button>
</div>
<br>
@ -115,10 +113,6 @@
// web audio context
var context = null;
// audio data
var audio = null;
var audio0 = null;
// the command instance
var instance = null;
@ -137,7 +131,15 @@
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Initialized successfully!');
printTextarea('js: Module initialized successfully!');
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
else {
printTextarea("js: failed to initialize whisper");
}
}
};
@ -165,8 +167,7 @@
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
if (model_whisper != null) {
document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = true;
document.getElementById('toggler').disabled = false;
}
}
@ -187,7 +188,7 @@
// 'base-en-q5_1': 57,
// };
let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin';
let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin';
let dst = 'whisper.bin';
let size_mb = 75;
@ -225,10 +226,7 @@
function stopRecording() {
Module.set_status("paused");
doRecording = false;
audio0 = null;
audio = null;
context = null;
mediaRecorder.stop();
}
function startRecording() {
@ -242,12 +240,6 @@
});
}
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
doRecording = true;
startTime = Date.now();
var chunks = [];
@ -277,22 +269,16 @@
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
audio = renderedBuffer.getChannelData(0);
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
let audio = renderedBuffer.getChannelData(0);
if (instance) {
Module.set_audio(instance, audioAll);
printTextarea('js: number of samples: ' + audio.length);
Module.set_audio(instance, audio);
}
});
}, function(e) {
audio = null;
mediaRecorder = null;
context = null;
});
}
@ -300,48 +286,16 @@
};
mediaRecorder.onstop = function(e) {
if (doRecording) {
setTimeout(function() {
startRecording();
});
}
stream.getTracks().forEach(function(track) {
track.stop();
});
};
mediaRecorder.start(kIntervalAudio_ms);
mediaRecorder.start();
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
}
//
@ -352,26 +306,47 @@
var intervalUpdate = null;
var movesAll = '';
document.body.addEventListener('keydown', function(event) {
if (event.keyCode === 32) {
document.getElementById('toggler').innerText = "Release";
onStart();
}
}, true);
document.body.addEventListener('keyup', function(event) {
if (event.keyCode === 32) {
document.getElementById('toggler').innerText = "Hold";
onStop();
}
}, true);
document.getElementById('toggler').addEventListener('mousedown', function(event) {
this.innerText = "Release";
onStart();
}, true);
document.getElementById('toggler').addEventListener('mouseup', function(event) {
this.innerText = "Hold";
onStop();
}, true);
function onStart() {
if (!instance) {
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording();
}
intervalUpdate = setInterval(function() {
function onStop() {
printTextarea('js: stopping recording ...');
stopRecording();
var interval = setInterval(function() {
var moves = Module.get_moves();
if (moves != null && moves.length > 1) {
clearInterval(interval);
for (move of moves.split(' ')) {
board.move(move);
@ -388,17 +363,13 @@
nLines--;
}
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-moves').innerHTML = movesAll;
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-moves').innerHTML = movesAll;
}, 100);
}
function onStop() {
stopRecording();
}
</script>
<script type="text/javascript" src="js/chess.js"></script>
</body>

View File

@ -29,28 +29,18 @@ void set_moves(const std::string & moves) {
g_moves = moves;
}
void get_audio(int ms, std::vector<float> & audio) {
const int64_t n_samples = (ms * WHISPER_SAMPLE_RATE) / 1000;
int64_t n_take = 0;
if (n_samples > (int) g_pcmf32.size()) {
n_take = g_pcmf32.size();
} else {
n_take = n_samples;
}
audio.resize(n_take);
std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin());
void get_audio(int /* ms */, std::vector<float> & audio) {
std::lock_guard<std::mutex> lock(g_mutex);
audio = g_pcmf32;
}
bool check_running() {
//g_pcmf32.clear();
return g_running;
}
bool clear_audio() {
void clear_audio() {
std::lock_guard<std::mutex> lock(g_mutex);
g_pcmf32.clear();
return true;
}
void wchess_main(size_t i) {
@ -65,19 +55,21 @@ void wchess_main(size_t i) {
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.no_timestamps = true;
wparams.max_tokens = 32;
// wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f;
wparams.temperature = 0.0f;
wparams.temperature_inc = 2.0f;
wparams.greedy.best_of = 1;
wparams.beam_search.beam_size = 5;
wparams.beam_search.beam_size = 1;
wparams.language = "en";
wparams.grammar_penalty = 100.0;
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
printf("command: using %d threads\n", wparams.n_threads);
@ -160,9 +152,6 @@ EMSCRIPTEN_BINDINGS(command) {
moves = std::move(g_moves);
}
if (!moves.empty()) fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", moves.c_str(), "\033[0m");
return moves;
}));