Compare commits

..

11 Commits

Author SHA1 Message Date
4260d4fc70 wchess : minor 2023-11-28 15:10:18 +02:00
ee65df7982 wchess : add clear_audio callback 2023-11-28 13:37:26 +02:00
03f254193b wchess: hardcoded rules 2023-11-27 10:51:20 +02:00
8f2d8eae10 wchess: basic chess rules 2023-11-27 10:41:04 +02:00
a44b21bce0 wchess: tidy up entry files 2023-11-25 11:34:06 +02:00
f07ff2aa6a chess -> wchess 2023-11-25 10:16:48 +02:00
280e631bcf chess.wasm: poc of chess rules 2023-11-23 16:09:00 +02:00
2f86da0d09 chess.wasm: add chessboard 2023-11-23 08:49:47 +02:00
a787f7f85c chess.wasm: encoder context value resulting in echoing 2023-11-21 20:42:20 +02:00
c83a38e89d chess.wasm: go back to greedy 2023-11-21 16:56:22 +02:00
758c951729 chess.wasm: grammar in emscripten 2023-11-21 16:30:44 +02:00
49 changed files with 2490 additions and 8721 deletions

View File

@ -25,7 +25,6 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential libsdl2-dev
make
@ -87,7 +86,6 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
@ -115,10 +113,8 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y clang
apt install -y clang build-essential cmake libsdl2-dev
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
make
ctest -L gh --output-on-failure'
@ -144,7 +140,6 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential cmake
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
@ -222,10 +217,10 @@ jobs:
sdl2: [ON]
include:
- arch: Win32
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
s2arc: x86
- arch: x64
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
s2arc: x64
- sdl2: ON
s2ver: 2.26.0
@ -290,7 +285,6 @@ jobs:
arch: [x64]
cublas: [ON]
sdl2: [ON]
cuda-toolkit: [12.2.0, 11.8.0]
include:
- arch: x64
s2arc: x64
@ -306,9 +300,7 @@ jobs:
- name: Install CUDA Toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.11
with:
cuda: '${{ matrix.cuda-toolkit }}'
uses: Jimver/cuda-toolkit@v0.2.10
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
@ -323,10 +315,10 @@ jobs:
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_CUBLAS=1
- name: Build ${{ matrix.cuda-toolkit }}
- name: Build
run: |
cd ./build
cmake --build . --config ${{ matrix.build }}
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy CUDA DLLs
run: >
@ -343,7 +335,7 @@ jobs:
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
with:
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
name: whisper-cublas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.5)
project(whisper.cpp VERSION 1.5.2)
project(whisper.cpp VERSION 1.5.0)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@ -533,7 +533,7 @@ target_compile_definitions(${TARGET} PUBLIC
${WHISPER_EXTRA_FLAGS}
)
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "ggml.h;whisper.h")
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
include(GNUInstallDirs)

View File

@ -2,14 +2,33 @@
import PackageDescription
#if arch(arm) || arch(arm64)
let platforms: [SupportedPlatform]? = [
.macOS(.v12),
.iOS(.v14),
.watchOS(.v4),
.tvOS(.v14)
]
let exclude: [String] = []
let resources: [Resource] = [
.process("ggml-metal.metal")
]
let additionalSources: [String] = ["ggml-metal.m"]
let additionalSettings: [CSetting] = [
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
]
#else
let platforms: [SupportedPlatform]? = nil
let exclude: [String] = ["ggml-metal.metal"]
let resources: [Resource] = []
let additionalSources: [String] = []
let additionalSettings: [CSetting] = []
#endif
let package = Package(
name: "whisper",
platforms: [
.macOS(.v12),
.iOS(.v14),
.watchOS(.v4),
.tvOS(.v14)
],
platforms: platforms,
products: [
.library(name: "whisper", targets: ["whisper"]),
],
@ -17,7 +36,7 @@ let package = Package(
.target(
name: "whisper",
path: ".",
exclude: [
exclude: exclude + [
"bindings",
"cmake",
"coreml",
@ -36,22 +55,19 @@ let package = Package(
"whisper.cpp",
"ggml-alloc.c",
"ggml-backend.c",
"ggml-quants.c",
"ggml-metal.m"
],
resources: [.process("ggml-metal.metal")],
"ggml-quants.c"
] + additionalSources,
resources: resources,
publicHeadersPath: "spm-headers",
cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
.define("GGML_USE_ACCELERATE"),
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
.define("GGML_USE_ACCELERATE")
// NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64")
],
] + additionalSettings,
linkerSettings: [
.linkedFramework("Accelerate")
]

View File

@ -6,7 +6,7 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.5.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.2) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Stable: [v1.5.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -110,8 +110,8 @@ options:
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search
-bo N, --best-of N [2 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
@ -128,7 +128,6 @@ options:
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
@ -140,8 +139,7 @@ options:
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
-ls, --log-score [false ] log best decoder scores of token
bash ./models/download-ggml-model.sh base.en
@ -770,7 +768,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
@ -780,7 +777,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [server](examples/server) | | HTTP transcription server with OAI-like API |
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)

View File

@ -1,26 +1,9 @@
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
ifndef UNAME_P
UNAME_P := $(shell uname -p)
endif
ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
BUILD_DIR := build
MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..)
ifeq ($(UNAME_S),Darwin)
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
endif
all: clean whisper examples
whisper: mkdir
@ -28,13 +11,8 @@ whisper: mkdir
@${MAKE} -C ../.. libwhisper.a
test: model-small whisper modtidy
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
else
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
endif
examples: $(EXAMPLES_DIR)
@ -43,11 +21,7 @@ model-small: mkdir examples/go-model-download
$(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@)
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go build ${BUILD_FLAGS} -ldflags "-extldflags '$(EXT_LDFLAGS)'" -o ${BUILD_DIR}/$(notdir $@) ./$@
else
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
endif
mkdir:
@echo Mkdir ${BUILD_DIR}

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.5.2",
"version": "1.5.0",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

File diff suppressed because one or more lines are too long

View File

@ -22,7 +22,6 @@ var printTextarea = (function() {
async function clearCache() {
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
indexedDB.deleteDatabase(dbName);
location.reload();
}
}

View File

@ -17,37 +17,28 @@ options:
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-olrc, --output-lrc [false ] output result in a lrc file
-owts, --output-words [false ] output script for generating karaoke video
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
```

View File

@ -4,9 +4,3 @@ add_executable(${TARGET} server.cpp httplib.h json.hpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
# Check if the compiler is MinGW
if(MINGW)
# Link the necessary libraries for SSL and Winsock
target_link_libraries(${TARGET} PRIVATE -lcrypt32 -lssl -lcrypto -lws2_32)
endif()

View File

@ -2,10 +2,6 @@
Simple http server. WAV Files are passed to the inference model via http requests.
https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-afe5e4594b8f
## Usage
```
./server -h
@ -33,7 +29,6 @@ options:
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pr, --print-realtime [false ] print output in realtime
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
@ -43,12 +38,8 @@ options:
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
--port PORT, [8080 ] Port number for the server
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
```
> [!WARNING]
> **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.**
## request examples
**/inference**

View File

@ -11,7 +11,6 @@
#include <thread>
#include <vector>
#include <cstring>
#include <sstream>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@ -44,8 +43,6 @@ struct server_params
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
bool ffmpeg_converter = false;
};
struct whisper_params {
@ -75,7 +72,6 @@ struct whisper_params {
bool no_fallback = false;
bool print_special = false;
bool print_colors = false;
bool print_realtime = false;
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;
@ -148,7 +144,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
@ -160,7 +155,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, "\n");
}
@ -194,7 +188,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@ -207,7 +200,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
else if ( arg == "--public") { sparams.public_path = argv[++i]; }
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params, sparams);
@ -225,45 +217,6 @@ struct whisper_print_user_data {
int progress_prev;
};
void check_ffmpeg_availibility() {
int result = system("ffmpeg -version");
if (result == 0) {
std::cout << "ffmpeg is available." << std::endl;
} else {
// ffmpeg is not available
std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed ";
std::cout << "and that its executable is included in your system's PATH. ";
exit(0);
}
}
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) {
std::ostringstream cmd_stream;
std::string converted_filename_temp = temp_filename + "_temp.wav";
cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1";
std::string cmd = cmd_stream.str();
int status = std::system(cmd.c_str());
if (status != 0) {
error_resp = "{\"error\":\"FFmpeg conversion failed.\"}";
return false;
}
// Remove the original file
if (remove(temp_filename.c_str()) != 0) {
error_resp = "{\"error\":\"Failed to remove the original file.\"}";
return false;
}
// Rename the temporary file to match the original filename
if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) {
error_resp = "{\"error\":\"Failed to rename the temporary file.\"}";
return false;
}
return true;
}
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size();
@ -420,7 +373,7 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.response_format = req.get_file_value("response-format").content;
}
if (req.has_file("temperature"))
if (req.has_file("temerature"))
{
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
}
@ -451,9 +404,6 @@ int main(int argc, char ** argv) {
exit(0);
}
if (sparams.ffmpeg_converter) {
check_ffmpeg_availibility();
}
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
@ -469,9 +419,6 @@ int main(int argc, char ** argv) {
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
Server svr;
svr.set_default_headers({{"Server", "whisper.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
std::string const default_content = "<html>hello</html>";
@ -482,7 +429,7 @@ int main(int argc, char ** argv) {
});
svr.Post("/inference", [&](const Request &req, Response &res){
// acquire whisper model mutex lock
// aquire whisper model mutex lock
whisper_mutex.lock();
// first check user requested fields of the request
@ -506,35 +453,20 @@ int main(int argc, char ** argv) {
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// write to temporary file
const std::string temp_filename = "whisper_server_temp_file.wav";
std::ofstream temp_file{temp_filename, std::ios::binary};
// write file to temporary file
std::ofstream temp_file{filename, std::ios::binary};
temp_file << audio_file.content;
temp_file.close();
// if file is not wav, convert to wav
if (sparams.ffmpeg_converter) {
std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
const bool is_converted = convert_to_wav(temp_filename, error_resp);
if (!is_converted) {
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
}
// read wav content into pcmf32
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
if (!::read_wav(filename, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", filename.c_str());
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
res.set_content(error_resp, "application/json");
std::remove(temp_filename.c_str());
whisper_mutex.unlock();
return;
}
// remove temp file
std::remove(temp_filename.c_str());
std::remove(filename.c_str());
printf("Successfully loaded %s\n", filename.c_str());
@ -571,6 +503,7 @@ int main(int argc, char ** argv) {
// run the inference
{
printf("Running whisper.cpp inference on %s\n", filename.c_str());
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
@ -589,7 +522,6 @@ int main(int argc, char ** argv) {
wparams.duration_ms = params.duration_ms;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up;
@ -609,7 +541,7 @@ int main(int argc, char ** argv) {
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment
if (params.print_realtime) {
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
@ -659,50 +591,6 @@ int main(int argc, char ** argv) {
std::string results = output_str(ctx, params, pcmf32s);
res.set_content(results.c_str(), "text/html");
}
else if (params.response_format == srt_format)
{
std::stringstream ss;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
ss << i + 1 + params.offset_n << "\n";
ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
ss << speaker << text << "\n\n";
}
res.set_content(ss.str(), "application/x-subrip");
} else if (params.response_format == vtt_format) {
std::stringstream ss;
ss << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
speaker.insert(0, "<v Speaker");
speaker.append(">");
}
ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
ss << speaker << text << "\n\n";
}
res.set_content(ss.str(), "text/vtt");
}
// TODO add more output formats
else
{

View File

@ -18,11 +18,6 @@ if (WHISPER_SDL2)
../../ggml-quants.c
../../whisper.cpp)
if(WIN32)
# It requires Windows 8.1 or later for PrefetchVirtualMemory
target_compile_definitions(${TARGET} PRIVATE -D_WIN32_WINNT=0x0602)
endif()
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,40 +0,0 @@
# wchess
Voice-controlled chess using Whisper
Online demo: https://whisper.ggerganov.com/wchess/
https://github.com/ggerganov/whisper.cpp/assets/1991296/c2b2f03c-9684-49f3-8106-357d2d4e67fa
## Command-line tool
```bash
mkdir build && cd build
cmake -DWHISPER_SDL2=1 ..
make -j
./bin/wchess -m ../models/ggml-base.en.bin
Move: start
a b c d e f g h
r n b q k b n r 8
p p p p p p p p 7
. * . * . * . * 6
* . * . * . * . 5
. * . * . * . * 4
* . * . * . * . 3
P P P P P P P P 2
R N B Q K B N R 1
White's turn
[(l)isten/(p)ause/(q)uit]:
```
## TODO
- Improve web-browser audio capture - sometimes it does not record the voice properly
- Add support for more languages by making the generated grammar string multi-lingual
- Fix bugs in the chess moves logic
PRs welcome!

View File

@ -1,19 +1,19 @@
add_library(wchess-core STATIC
add_library(libwchess
WChess.cpp
WChess.h
Chessboard.cpp
Chessboard.h
)
target_link_libraries(wchess-core
target_link_libraries(libwchess
PUBLIC
whisper
common
)
target_include_directories(wchess-core
target_include_directories(libwchess
PUBLIC
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
)
# add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)
add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)

File diff suppressed because it is too large Load Diff

View File

@ -1,33 +1,56 @@
#pragma once
#include <string>
#include <set>
#include <memory>
#include <array>
#include <vector>
// just basic validation
// fixme: missing en passant, castling, promotion, etc.
struct State;
class Piece;
class Chessboard {
public:
Chessboard();
~Chessboard();
std::string process(const std::string& command);
std::string process(const std::string& t);
std::string stringifyBoard();
const std::string& grammar() { return m_grammar; }
const std::string& prompt() { return m_prompt; }
void setPrompt(const std::string& prompt);
private:
bool parseCommand(const std::string& command, Piece*& piece, char& pos_to);
bool move(Piece& piece, char pos);
void flagUpdates(char pos_from, char pos_to);
void updatePins(Piece& piece);
void detectChecks();
void setGrammar();
using Move = std::pair<int, int>;
bool move(const Move& move);
std::unique_ptr<State> m_state;
std::set<char> m_allowedInCheck;
bool m_inCheck = false;
struct Piece {
enum Types {
Pawn,
Knight,
Bishop,
Rook,
Queen,
King,
Taken,
};
enum Colors {
Black,
White
};
Types type;
Colors color;
int pos;
};
Piece::Types tokenToType(std::string_view token);
size_t tokenToPos(std::string_view token);
using PieceSet = std::array<Piece, 16>;
PieceSet blackPieces;
PieceSet whitePieces;
int m_moveCounter = 0;
std::string m_grammar;
std::string m_prompt;
using Board = std::array<Piece*, 64>;
Board board;
bool validateMove(const Piece& piece, int pos);
// just basic validation
// fixme: missing en passant, castling, promotion, etc.
bool validatePawnMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateKnightMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateBishopMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateRookMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateQueenMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateKingMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
};

View File

@ -4,6 +4,24 @@
#include "common.h"
#include <thread>
static constexpr auto RULES =
"\n"
"root ::= init move move? move? \".\"\n"
"prompt ::= init \".\"\n"
"\n"
"# leading space is very important!\n"
"init ::= \" rook to b4, f3\"\n"
"\n"
"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n"
"\n"
"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n"
"king ::= \"king\"\n"
"pawn ::= \"pawn\"\n"
"\n";
static constexpr auto PROMPT = "rook to b4, f3,";
static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,";
WChess::WChess(whisper_context * ctx,
const whisper_full_params & wparams,
callbacks cb,
@ -17,136 +35,175 @@ WChess::WChess(whisper_context * ctx,
WChess::~WChess() = default;
void WChess::set_move(const std::string& moves, float prob) const {
if (m_cb.set_move) (*m_cb.set_move)(moves, prob);
void WChess::set_status(const std::string& msg) const {
if (m_cb.set_status) (*m_cb.set_status)(msg);
}
void WChess::set_grammar(const std::string& grammar) const {
if (m_cb.set_grammar) (*m_cb.set_grammar)(grammar);
void WChess::set_moves(const std::string& moves) const {
if (m_cb.set_moves) (*m_cb.set_moves)(moves);
}
bool WChess::get_audio(std::vector<float>& pcmf32) const {
if (m_cb.get_audio) return (*m_cb.get_audio)(pcmf32);
bool WChess::check_running() const {
if (m_cb.check_running) return (*m_cb.check_running)();
return false;
}
bool WChess::clear_audio() const {
if (m_cb.clear_audio) return (*m_cb.clear_audio)();
return false;
}
void WChess::get_audio(int ms, std::vector<float>& pcmf32) const {
if (m_cb.get_audio) (*m_cb.get_audio)(ms, pcmf32);
}
std::string WChess::stringify_board() const {
return m_board->stringifyBoard();
}
std::string WChess::get_grammar() const {
return m_board->grammar();
}
void WChess::run() {
bool have_prompt = true;
bool ask_prompt = !have_prompt;
set_status("loading data ...");
bool have_prompt = false;
bool ask_prompt = true;
float logprob_min0 = 0.0f;
float logprob_min = 0.0f;
float logprob_sum0 = 0.0f;
float logprob_sum = 0.0f;
int n_tokens0 = 0;
int n_tokens = 0;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = have_prompt ? "" : "rook to d4, f3";
int64_t t_ms = 0;
const std::string k_prompt = PROMPT;
m_wparams.initial_prompt = CONTEXT;
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
auto grammar_parsed = grammar_parser::parse(RULES);
auto grammar_rules = grammar_parsed.c_rules();
ask_prompt = false;
if (grammar_parsed.rules.empty()) {
fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__);
}
else {
m_wparams.grammar_rules = grammar_rules.data();
m_wparams.n_grammar_rules = grammar_rules.size();
}
while (get_audio(pcmf32_cur)) {
if (!pcmf32_cur.empty()) {
// fprintf(stdout, "%s: Processing ...\n", __func__);
if (!have_prompt) {
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
m_board->setPrompt(k_prompt);
}
} else {
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
constexpr size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
// fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, m_board->grammar().c_str());
auto grammar_parsed = grammar_parser::parse(m_board->grammar().c_str());
auto grammar_rules = grammar_parsed.c_rules();
m_wparams.grammar_rules = grammar_rules.data();
m_wparams.n_grammar_rules = grammar_rules.size();
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
const float p = 100.0f * std::exp(logprob_min);
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
if (!command.empty()) {
set_move(m_board->process(command), p);
set_grammar(m_board->grammar());
}
if (m_board->grammar().empty()) {
fprintf(stdout, "%s: No more moves possible\n", __func__);
break;
}
}
}
while (check_running()) {
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str());
set_status(txt);
}
ask_prompt = false;
}
int64_t t_ms = 0;
{
get_audio(m_settings.vad_ms, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, m_settings.vad_thold, m_settings.freq_thold, m_settings.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
set_status("Speech detected! Processing ...");
if (!have_prompt) {
get_audio(m_settings.prompt_ms, pcmf32_cur);
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt");
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ...");
set_status(txt);
}
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
get_audio(m_settings.command_ms, pcmf32_cur);
// prepend 3 second of silence
pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root");
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
const float p = 100.0f * std::exp(logprob_min);
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
if (n >= int(txt.size())) {
break;
}
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms);
set_status(txt);
}
if (!command.empty()) {
set_moves(m_board->process(command));
}
}
clear_audio();
}
}
}
}

View File

@ -8,16 +8,18 @@ class Chessboard;
class WChess {
public:
using SetStatusCb = void (*)(const std::string &);
using CheckRunningCb = bool (*)();
using GetAudioCb = bool (*)(std::vector<float> &);
using SetMovesCb = void (*)(const std::string &, float);
using SetGrammarCb = void (*)(const std::string &);
using ClearAudioCb = void (*)();
using GetAudioCb = void (*)(int, std::vector<float> &);
using SetMovesCb = void (*)(const std::string &);
using CleartAudioCb = bool (*)();
struct callbacks {
SetStatusCb set_status = nullptr;
CheckRunningCb check_running = nullptr;
GetAudioCb get_audio = nullptr;
SetMovesCb set_move = nullptr;
SetGrammarCb set_grammar = nullptr;
SetMovesCb set_moves = nullptr;
CleartAudioCb clear_audio = nullptr;
};
struct settings {
@ -38,16 +40,13 @@ public:
~WChess();
void run();
std::string stringify_board() const;
std::string get_grammar() const;
private:
bool get_audio(std::vector<float>& pcmf32) const;
void set_move(const std::string& moves, float prob) const;
void set_grammar(const std::string& grammar) const;
void get_audio(int ms, std::vector<float>& pcmf32) const;
void set_status(const std::string& msg) const;
void set_moves(const std::string& moves) const;
bool check_running() const;
bool clear_audio() const;
std::string transcribe(
const std::vector<float> & pcmf32,
float & logprob_min,

View File

@ -11,107 +11,78 @@
int main() {
{
Chessboard chess;
ASSERT(chess.process("pawn to d4") == "d2-d4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("c1 h6") == "c1-h6");
ASSERT(chess.process("queen h4") == "d8-h4");
ASSERT(chess.process("bishop to g5") == "h6-g5");
ASSERT(chess.process("bishop to b4") == "f8-b4");
ASSERT(chess.process("c4") == "");
ASSERT(chess.process("knight c3") == "b1-c3");
ASSERT(chess.process("knight c6") == "b8-c6");
ASSERT(chess.process("f3") == "");
}
{
// pawns
Chessboard chess;
ASSERT(chess.process("d4") == "d2-d4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("queen h4") == "d8-h4");
ASSERT(chess.process("queen h5") == "d1-h5");
ASSERT(chess.process("f5") == "");
ASSERT(chess.process("g6") == "g7-g6");
ASSERT(chess.process("knight e2") == "g1-e2");
ASSERT(chess.process("f5") == "f7-f5");
ASSERT(chess.process("knight g3") == "e2-g3");
ASSERT(chess.process("g5") == "");
ASSERT(chess.process("king e7") == "e8-e7");
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("g5") == "g6-g5");
}
{
Chessboard chess;
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("c5") == "c7-c5");
ASSERT(chess.process("e5") == "e4-e5");
ASSERT(chess.process("c4") == "c5-c4");
ASSERT(chess.process("e6") == "e5-e6");
ASSERT(chess.process("c3") == "c4-c3");
ASSERT(chess.process("e7") == "");
ASSERT(chess.process("f7") == "e6-f7");
ASSERT(chess.process("d2") == "");
ASSERT(chess.process("king to f7") == "e8-f7");
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("d2") == "c3-d2");
ASSERT(chess.process("f5") == "");
ASSERT(chess.process("king to e2") == "e1-e2");
ASSERT(chess.process("king to g6") == "f7-g6");
ASSERT(chess.process("f5") == "f4-f5");
ASSERT(chess.process("e6") == "");
ASSERT(chess.process("king to h5") == "g6-h5");
ASSERT(chess.process("g4") == "g2-g4");
ASSERT(chess.process("king to g5") == "h5-g5");
ASSERT(chess.process("pawn to d4, e5, e3, pawn to d5") == "d2-d4 e7-e5 e2-e3 d7-d5");
ASSERT(chess.process("pawn to d4") == ""); // wrong
ASSERT(chess.process("pawn to c5") == ""); // wrong
ASSERT(chess.process("pawn to d5") == ""); // wrong
ASSERT(chess.process("pawn to d3") == ""); // wrong
ASSERT(chess.process("pawn to f5") == ""); // wrong, white's turn
ASSERT(chess.process("h4") == "h2-h4");
ASSERT(chess.process("king to h5") == "");
ASSERT(chess.process("king to g6") == "");
ASSERT(chess.process("king to h6") == "g5-h6");
ASSERT(chess.process("bishop to d2") == "c1-d2");
ASSERT(chess.process("king to g5") == "");
ASSERT(chess.process("g5") == "g7-g5");
ASSERT(chess.process("d4") == "e5-d4");
ASSERT(chess.process("e4") == "e3-e4");
ASSERT(chess.process("d4") == ""); // wrong
ASSERT(chess.process("e4") == "d5-e4");
}
{
// rook
Chessboard chess;
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("g4") == "g2-g4");
ASSERT(chess.process("queen to h4") == "d8-h4#");
ASSERT(chess.process("knight f3") == "");
ASSERT(chess.grammar().empty());
ASSERT(chess.process("rook to a3") == ""); // wrong
ASSERT(chess.process("a4, h5, rook to a3, rook to h6") == "a2-a4 h7-h5 a1-a3 h8-h6");
ASSERT(chess.process("rook to d3, rook to e6") == "a3-d3 h6-e6");
ASSERT(chess.process("rook to d4, rook to e5") == "d3-d4 e6-e5");
ASSERT(chess.process("rook to a4") == ""); // wrong
ASSERT(chess.process("rook to d8") == ""); // wrong
ASSERT(chess.process("rook to d3") == "d4-d3");
ASSERT(chess.process("rook to e2") == "e5-e2");
}
{
// knight
Chessboard chess;
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("g4") == "g2-g4");
ASSERT(chess.process("d5") == "d7-d5");
ASSERT(chess.process("g1 f3") == "g1-f3");
ASSERT(chess.process("queen to h4") == "d8-h4");
ASSERT(!chess.grammar().empty());
ASSERT(chess.process("knight to c3, knight to c6") == "b1-c3 b8-c6");
ASSERT(chess.process("knight to c3") == ""); // wrong
ASSERT(chess.process("knight to a2") == ""); // wrong
ASSERT(chess.process("knight to b4") == ""); // wrong, white's turn
ASSERT(chess.process("knight to b5") == "c3-b5");
ASSERT(chess.process("knight to a5") == "c6-a5");
ASSERT(chess.process("knight to c7") == "b5-c7");
}
{
// bishop
Chessboard chess;
ASSERT(chess.process("knight c3") == "b1-c3");
ASSERT(chess.process("knight c6") == "b8-c6");
ASSERT(chess.process("knight b5") == "c3-b5");
ASSERT(chess.process("knight f6") == "g8-f6");
ASSERT(chess.process("knight d6") == "b5-d6");
ASSERT(chess.process("knight d4") == "");
ASSERT(chess.process("d6") == "c7-d6");
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("knight d4") == "c6-d4");
ASSERT(chess.process("d3") == "d2-d3");
ASSERT(chess.process("knight e4") == "f6-e4");
ASSERT(chess.process("king to e2") == "");
ASSERT(chess.process("king to d2") == "");
ASSERT(chess.process("b3, b6, bishop to b2, bishop to b7") == "b2-b3 b7-b6 c1-b2 c8-b7");
ASSERT(chess.process("bishop to a1") == ""); // wrong
ASSERT(chess.process("bishop to h8") == ""); // wrong
ASSERT(chess.process("bishop to a6") == ""); // wrong, white's turn
ASSERT(chess.process("bishop to g7") == "b2-g7");
}
{
// queen
Chessboard chess;
ASSERT(chess.process("queen to d8") == ""); // wrong
ASSERT(chess.process("queen to f1") == ""); // wrong
ASSERT(chess.process("queen to h5") == ""); // wrong
ASSERT(chess.process("e3, d5, queen to h5, queen to d6") == "e2-e3 d7-d5 d1-h5 d8-d6");
ASSERT(chess.process("queen to c5") == ""); // wrong, white's turn
ASSERT(chess.process("queen to f7") == "h5-f7");
}
{
// king
Chessboard chess;
ASSERT(chess.process("d3, d6, king to d2, king to d7, king to c3, king to c6, king to c4") == "d2-d3 d7-d6 e1-d2 e8-d7 d2-c3 d7-c6 c3-c4");
ASSERT(chess.process("bishop to e6") == "c8-e6");
ASSERT(chess.process("king to b3") == "c4-b3"); // !! check check not implemented
}
}

View File

@ -4,5 +4,5 @@ if (WHISPER_SDL2)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE wchess-core common-sdl ${CMAKE_THREAD_LIBS_INIT})
endif ()
target_link_libraries(${TARGET} PRIVATE libwchess common-sdl ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -7,7 +7,6 @@
#include "WChess.h"
#include "common-sdl.h"
#include <iostream>
#include <memory>
#include <thread>
@ -110,61 +109,17 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
}
std::unique_ptr<WChess> g_wchess;
int g_moveCount = 0;
void set_move(const std::string & move, float) {
if (!move.empty()) {
g_moveCount++;
fprintf(stdout, "Move: %s\n\n", move.c_str());
}
else fprintf(stdout, "Move rejected\n\n");
fprintf(stdout, "%s\n", g_wchess->stringify_board().c_str());
fprintf(stdout, "%s\n", g_moveCount ? "White's turn" : "Black's turn");
void set_moves(const std::string & moves) {
if (!moves.empty()) fprintf(stdout, "%s", g_wchess->stringify_board().c_str());
}
audio_async g_audio(30*1000);
bool g_listening = false;
std::vector<float> g_pcmf32;
bool read_input() {
std::string input;
while (true) {
fprintf(stdout, "[(l)isten/(p)ause/(q)uit]: ");
std::cin >> input;
fprintf(stdout, "\n");
if (input[0] == 'q') {
fprintf(stdout, "Quitting\n");
return false;
}
if (input[0] == 'l') {
if (!g_listening) {
fprintf(stdout, "Listening\n");
g_listening = true;
g_pcmf32.clear();
g_audio.resume();
g_audio.clear();
}
else fprintf(stdout, "Still listening\n");
return true;
}
else {
if (g_listening) {
g_listening = false;
g_audio.get(0, g_pcmf32);
g_audio.pause();
fprintf(stdout, "Processing\n");
}
else fprintf(stdout, "Not listening\n");
return true;
}
}
return true;
void get_audio(int ms, std::vector<float> & pcmf32_cur) {
g_audio.get(ms, pcmf32_cur);
}
bool get_audio(std::vector<float> & pcmf32_cur) {
if (!read_input()) return false;
if (!g_pcmf32.empty()) pcmf32_cur = std::move(g_pcmf32);
else pcmf32_cur.clear();
return true;
bool clear_audio() {
g_audio.clear();
}
int main(int argc, char ** argv) {
@ -186,10 +141,6 @@ int main(int argc, char ** argv) {
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (!ctx) {
fprintf(stderr, "%s: whisper_init_from_file_with_params() failed!\n", __func__);
return 1;
}
// init audio
@ -198,35 +149,42 @@ int main(int argc, char ** argv) {
return 1;
}
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.no_timestamps = true;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.no_timestamps = params.no_timestamps;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature = 0.0f;
wparams.temperature_inc = 2.0f;
wparams.greedy.best_of = 1;
wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 5;
wparams.beam_search.beam_size = 1;
wparams.language = "en";
wparams.grammar_penalty = 100.0;
wparams.beam_search.beam_size = 5;
wparams.initial_prompt = params.context.data();
g_audio.resume();
// wait for 1 second to avoid any buffered noise
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
g_audio.clear();
WChess::callbacks cb;
cb.check_running = sdl_poll_events;
cb.get_audio = get_audio;
cb.set_move = set_move;
cb.set_moves = set_moves;
cb.clear_audio = clear_audio;
WChess::settings s;
s.vad_ms = 2000;
@ -237,9 +195,11 @@ int main(int argc, char ** argv) {
s.print_energy = params.print_energy;
g_wchess.reset(new WChess(ctx, wparams, cb, s));
set_move("start", 0);
set_moves("start");
g_wchess->run();
g_audio.pause();
whisper_print_timings(ctx);
whisper_free(ctx);

View File

@ -8,7 +8,7 @@ include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
common
wchess-core
libwchess
)
unset(EXTRA_FLAGS)

View File

@ -1,11 +1,7 @@
<!doctype html>
<html lang="en-us">
<head>
<title>wchess : voice-controlled chess using Whisper + WebAssembly</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
<meta name="viewport" content="width=device-width, initial-scale=0.7, maximum-scale=1, minimum-scale=0.7, user-scalable=no"/>
<meta name="apple-mobile-web-app-capable" content="yes" />
<title>wchess : Voice assistant example using Whisper + WebAssembly</title>
<style>
#output {
@ -27,145 +23,61 @@
overflow-wrap: normal;
overflow-x: scroll;
}
.button {
background-color: #000000;
color: #FFFFFF;
padding: 20px;
border-radius: 10px;
-moz-border-radius: 10px;
-webkit-border-radius: 10px;
margin:10px;
width: 100px;
height: 50px;
-webkit-touch-callout: none; /* Safari */
-webkit-user-select: none; /* Chrome */
-moz-user-select: none; /* Firefox */
-ms-user-select: none; /* Internet Explorer/Edge */
user-select: none;
}
button[disabled]{
background-color: #cccccc;
color: #666666;
padding: 20px;
border-radius: 10px;
-moz-border-radius: 10px;
-webkit-border-radius: 10px;
margin:10px;
width: 100px;
}
.center {
display: flex;
justify-content: center;
align-items: center;
width: 500px;
}
#description {
width: 500px;
}
</style>
<link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous">
</head>
<body>
<body onload="loadWhisper()">
<div id="main-container">
<div id="description">
<b>wchess : voice-controlled chess using Whisper + WebAssembly</b>
<b>wchess : Voice assistant example using Whisper + WebAssembly</b>
<br><br>
<br><br>
This is a demonstration of using Whisper to recognize voice commands in the browser.
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">GitHub</a>.
<br><br>
<br><br>
Usage:<br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<ul>
<li>Select a Whisper model</li>
<li>Accept the microphone permission request if prompted</li>
<li>Hold the button and say a chess move (e.g. "Knight to c3")</li>
<li>Release the button and wait for the move to be recognized</li>
<li>Repeat</li>
</ul>
Examples:<br>
<ul>
<li><b>"d4"</b></li>
<li><b>"e2 e4"</b></li>
<li><b>"Knight f3"</b></li>
<li><b>"Bishop to b5"</b></li>
</ul>
Features:<br>
<ul>
<li>Model quantization for reduced memory footprint (~42MB)</li>
<li><a href="https://github.com/ggerganov/whisper.cpp/pull/1229">Grammar-based sampling</a> for improved recognition accuracy</li>
</ul>
<b>
Note that not all chess moves are supported. For example, castling and pawn promotion
currently do not work, but can be easily implemented. There could also be some bugs in
the move handling logic in general. The main reason for that is to keep the implementation
simple. The assumption is that a real application would already have a proper move
validation logic in place.<br><br>
The main purpose of this example is to demonstrate the capabilities of whisper.cpp and
its application in the browser for voice recognition locally on your device.
</b>
<br><br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/wchess">GitHub</a>.
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
</div>
<br><br>
<hr>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper()">tiny.en (Q8_0, 42 MB)</button>
<span id="fetch-whisper-progress"></span>
<br><br>
<button id="clear" onclick="clearCache()">Clear browser cache</button>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<div id="game">
<br>
<div id="chessboard" style="width: 500px"></div>
<script src="js/jquery-3.7.1.min.js"></script>
<script src="js/chessboard-1.0.0.min.js"></script>
<script>
var board = Chessboard('chessboard', 'start')
var move_count = 0;
</script>
<br>
<div id="myBoard" style="width: 400px"></div>
<script src="js/jquery-3.7.1.min.js"></script>
<script src="js/chessboard-1.0.0.min.js"></script>
<script>
var board = Chessboard('myBoard', 'start')
</script>
<br>
<br>
<div id="state">
Status: <b><span id="state-status">select model</span></b>
<div id="input">
<button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
</div>
<div id="input" class="center">
<button id="toggler" class="button" onselectstart="return false" style="display: none">Hold</button>
</div>
<br>
<pre id="state-grammar">[The grammar will be displayed here]</pre>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-moves">[The moves will be displayed here]</pre>
</div>
<pre id="state-moves">[The moves will be displayed here]</pre>
</div>
<hr>
@ -183,6 +95,7 @@
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
@ -202,14 +115,15 @@
// web audio context
var context = null;
// audio data
var audio = null;
var audio0 = null;
// the command instance
var instance = null;
// model name
var model_whisper = null;
var model_file = null;
var module_ready = null;
var Module = {
print: printTextarea,
@ -223,30 +137,10 @@
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Module initialized successfully!');
module_ready = true;
initInstance();
printTextarea('js: Initialized successfully!');
}
};
function initInstance() {
if (!module_ready || !model_file || instance) return
instance = Module.init(model_file);
if (instance) {
setStatus('Ready');
printTextarea("js: whisper initialized, instance: " + instance);
}
else {
printTextarea("js: failed to initialize whisper");
}
}
function setStatus(text) {
document.getElementById('state-status').innerHTML = text;
}
//
// fetch models
//
@ -270,21 +164,36 @@
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
model_file = fname;
initInstance();
if (model_whisper != null) {
document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = true;
}
}
function loadWhisper() {
setStatus('Loading')
//let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin';
let url = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin';
let dst = 'whisper.bin';
let size_mb = 42;
// let urls = {
// 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
// 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
model_whisper = 'tiny.en-q8_0';
// 'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
// 'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
// };
// let sizes = {
// 'tiny.en': 75,
// 'base.en': 142,
// 'tiny-en-q5_1': 31,
// 'base-en-q5_1': 57,
// };
let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin';
let dst = 'whisper.bin';
let size_mb = 75;
model_whisper = 'tiny.en';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... ';
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
@ -297,30 +206,6 @@
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
// init audio capture so that the user receives a permission request
{
let context = new AudioContext({
sampleRate: 16000,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
stream.getTracks().forEach(function(track) {
track.stop();
});
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
context.close();
}
document.getElementById('toggler').style.display = 'block';
}
//
@ -339,9 +224,11 @@
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
if (mediaRecorder) {
mediaRecorder.stop();
}
Module.set_status("paused");
doRecording = false;
audio0 = null;
audio = null;
context = null;
}
function startRecording() {
@ -355,6 +242,12 @@
});
}
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
doRecording = true;
startTime = Date.now();
var chunks = [];
@ -372,6 +265,10 @@
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
@ -380,13 +277,22 @@
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
let audio = renderedBuffer.getChannelData(0);
printTextarea('js: number of samples: ' + audio.length);
Module.set_audio(instance, audio);
});
audio = renderedBuffer.getChannelData(0);
mediaRecorder = null;
context = null;
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
if (instance) {
Module.set_audio(instance, audioAll);
}
});
}, function(e) {
audio = null;
});
}
@ -394,16 +300,48 @@
};
mediaRecorder.onstop = function(e) {
stream.getTracks().forEach(function(track) {
track.stop();
});
if (doRecording) {
setTimeout(function() {
startRecording();
});
}
};
mediaRecorder.start();
mediaRecorder.start(kIntervalAudio_ms);
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
}
//
@ -411,88 +349,56 @@
//
var nLines = 0;
var intervalUpdate = null;
var movesAll = '';
// document.body.addEventListener('keydown', function(event) {
// if (event.keyCode === 32) {
// document.getElementById('toggler').innerText = "";
// onStart();
// }
// }, true);
// document.body.addEventListener('keyup', function(event) {
// if (event.keyCode === 32) {
// document.getElementById('toggler').innerText = "Hold";
// onStop();
// }
// }, true);
document.getElementById('toggler').addEventListener("touchstart", function(event){
this.innerText = "";
onStart();
}, true);
document.getElementById('toggler').addEventListener("touchend", function(event){
this.innerText = "Hold";
onStop();
}, true)
document.getElementById('toggler').addEventListener('mousedown', function(event) {
this.innerText = "";
onStart();
}, true);
document.getElementById('toggler').addEventListener('mouseup', function(event) {
this.innerText = "Hold";
onStop();
}, true);
function onStart() {
if (!instance) return;
setStatus('Listening');
if (!instance) {
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording();
intervalUpdate = setInterval(function() {
var moves = Module.get_moves();
if (moves != null && moves.length > 1) {
for (move of moves.split(' ')) {
board.move(move);
}
movesAll += moves + '<br>';
nLines++;
// if more than 10 lines, remove the first line
if (nLines > 10) {
var i = movesAll.indexOf('<br>');
if (i > 0) {
movesAll = movesAll.substring(i + 4);
nLines--;
}
}
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-moves').innerHTML = movesAll;
}, 100);
}
function onStop() {
setStatus('Processing');
printTextarea('js: stopping recording ...');
stopRecording();
}
function setMove(move, prob) {
if (move != null && move.length > 1) {
let gameOver = move[move.length - 1] === '#';
if (gameOver) {
move = move.substring(0, move.length - 1);
document.getElementById('toggler').disabled = true;
}
board.move(move);
movesAll += move + ', prob = ' + prob.toFixed(2) + '% <br>';
nLines++;
// if more than 10 lines, remove the first line
if (nLines > 10) {
var i = movesAll.indexOf('<br>');
if (i > 0) {
movesAll = movesAll.substring(i + 4);
nLines--;
}
}
++move_count;
setStatus(gameOver ? 'Done' : move_count % 2 ? 'Black\'s turn' : 'White\'s turn');
document.getElementById('state-moves').innerHTML = movesAll;
}
else {
setStatus('Failed. ' + (move_count % 2 ? 'Black\'s turn' : 'White\'s turn'));
}
}
function setGrammar(grammar) {
document.getElementById('state-grammar').innerHTML = grammar;
}
</script>
<script type="text/javascript" src="js/chess.js"></script>
</body>

View File

@ -1,7 +1,7 @@
#include <WChess.h>
#include <emscripten.h>
#include <emscripten/bind.h>
#include <atomic>
#include <thread>
constexpr int N_THREAD = 8;
@ -11,28 +11,45 @@ std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::condition_variable g_cv;
std::atomic<bool> g_running(false);
std::string g_status = "";
std::string g_status_forced = "";
std::string g_moves = "";
bool g_running(false);
std::vector<float> g_pcmf32;
void set_move(const std::string & move, float prob) {
MAIN_THREAD_EM_ASM({
setMove(UTF8ToString($0), $1)
}, move.c_str(), prob);
void set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;
}
void set_grammar(const std::string & grammar) {
MAIN_THREAD_EM_ASM({
setGrammar(UTF8ToString($0))
}, grammar.c_str());
void set_moves(const std::string & moves) {
std::lock_guard<std::mutex> lock(g_mutex);
g_moves = moves;
}
bool get_audio(std::vector<float> & audio) {
std::unique_lock<std::mutex> lock(g_mutex);
g_cv.wait(lock, [] { return !g_running || !g_pcmf32.empty(); });
if (!g_running) return false;
audio = std::move(g_pcmf32);
void get_audio(int ms, std::vector<float> & audio) {
const int64_t n_samples = (ms * WHISPER_SAMPLE_RATE) / 1000;
int64_t n_take = 0;
if (n_samples > (int) g_pcmf32.size()) {
n_take = g_pcmf32.size();
} else {
n_take = n_samples;
}
audio.resize(n_take);
std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin());
}
bool check_running() {
//g_pcmf32.clear();
return g_running;
}
bool clear_audio() {
g_pcmf32.clear();
return true;
}
@ -48,31 +65,30 @@ void wchess_main(size_t i) {
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.no_timestamps = true;
wparams.max_tokens = 32;
wparams.audio_ctx = 1280; // partial encoder context for better performance
// wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.temperature = 0.0f;
wparams.temperature_inc = 2.0f;
wparams.greedy.best_of = 1;
wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 1;
wparams.beam_search.beam_size = 1;
wparams.beam_search.beam_size = 5;
wparams.language = "en";
wparams.grammar_penalty = 100.0;
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
printf("command: using %d threads\n", wparams.n_threads);
WChess::callbacks cb;
cb.set_status = set_status;
cb.check_running = check_running;
cb.get_audio = get_audio;
cb.set_move = set_move;
cb.set_grammar = set_grammar;
cb.set_moves = set_moves;
cb.clear_audio = clear_audio;
WChess(g_contexts[i], wparams, cb, {}).run();
if (i < g_contexts.size()) {
whisper_free(g_contexts[i]);
g_contexts[i] = nullptr;
@ -104,11 +120,9 @@ EMSCRIPTEN_BINDINGS(command) {
}));
emscripten::function("free", emscripten::optional_override([](size_t /* index */) {
{
std::unique_lock<std::mutex> lock(g_mutex);
if (g_running) {
g_running = false;
}
g_cv.notify_one();
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
@ -134,8 +148,37 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
g_cv.notify_one();
return 0;
}));
emscripten::function("get_moves", emscripten::optional_override([]() {
std::string moves;
{
std::lock_guard<std::mutex> lock(g_mutex);
moves = std::move(g_moves);
}
if (!moves.empty()) fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", moves.c_str(), "\033[0m");
return moves;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}));
}

View File

@ -206,7 +206,6 @@ void AudioInputCallback(void * inUserData,
params.offset_ms = 0;
params.no_context = true;
params.single_segment = self->stateInp.isRealtime;
params.no_timestamps = params.single_segment;
CFTimeInterval startTime = CACurrentMediaTime();

View File

@ -8,15 +8,15 @@ enum WhisperError: Error {
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
actor WhisperContext {
private var context: OpaquePointer
init(context: OpaquePointer) {
self.context = context
}
deinit {
whisper_free(context)
}
func fullTranscribe(samples: [Float]) {
// Leave 2 processors free (i.e. the high-efficiency cores).
let maxThreads = max(1, min(8, cpuCount() - 2))
@ -24,17 +24,17 @@ actor WhisperContext {
var params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY)
"en".withCString { en in
// Adapted from whisper.objc
params.print_realtime = true
params.print_progress = false
params.print_realtime = true
params.print_progress = false
params.print_timestamps = true
params.print_special = false
params.translate = false
params.language = en
params.n_threads = Int32(maxThreads)
params.offset_ms = 0
params.no_context = true
params.single_segment = false
params.print_special = false
params.translate = false
params.language = en
params.n_threads = Int32(maxThreads)
params.offset_ms = 0
params.no_context = true
params.single_segment = false
whisper_reset_timings(context)
print("About to run whisper_full")
samples.withUnsafeBufferPointer { samples in
@ -46,7 +46,7 @@ actor WhisperContext {
}
}
}
func getTranscription() -> String {
var transcription = ""
for i in 0..<whisper_full_n_segments(context) {
@ -54,7 +54,7 @@ actor WhisperContext {
}
return transcription
}
static func createContext(path: String) throws -> WhisperContext {
var params = whisper_context_default_params()
#if targetEnvironment(simulator)

View File

@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
size_t cur_max = (char*)addr - (char*)alloc->base + size;
size_t cur_max = (char*)addr - (char*)alloc->data + size;
if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
@ -168,6 +168,10 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
if (!alloc->measure) {
ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
}
#ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
#endif
@ -233,7 +237,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
}
ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
@ -445,6 +449,7 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
ggml_tallocr_t alloc = node_tallocr(galloc, view);
//printf("init_view: %s from src %s\n", view->name, view->view_src->name);
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
if (update_backend) {
view->backend = view->view_src->backend;
@ -454,7 +459,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
if (!alloc->measure) {
ggml_backend_buffer_init_tensor(alloc->buffer, view);
@ -760,43 +765,3 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
}
// utils
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
size_t alignment = ggml_backend_buft_get_alignment(buft);
size_t nbytes = 0;
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->data == NULL && t->view_src == NULL) {
nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
}
}
if (nbytes == 0) {
fprintf(stderr, "%s: no tensors to allocate\n", __func__);
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->data == NULL) {
if (t->view_src == NULL) {
ggml_tallocr_alloc(tallocr, t);
} else {
ggml_backend_view_init(buffer, t);
}
}
}
ggml_tallocr_free(tallocr);
return buffer;
}
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
}

View File

@ -8,7 +8,6 @@ extern "C" {
struct ggml_backend;
struct ggml_backend_buffer;
struct ggml_backend_buffer_type;
//
// Legacy API
@ -43,7 +42,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
// ggml-backend v2 API
//
// Separate tensor and graph allocator objects
// Seperate tensor and graph allocator objects
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
// The original API is kept as a wrapper around the new API
@ -81,12 +80,6 @@ GGML_API void ggml_gallocr_alloc_graph_n(
struct ggml_hash_set hash_set,
ggml_tallocr_t * hash_node_talloc);
// Utils
// Create a buffer and allocate all the tensors in a ggml_context
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
#ifdef __cplusplus
}
#endif

View File

@ -12,50 +12,31 @@ extern "C" {
// Backend buffer
//
// buffer type
typedef void * ggml_backend_buffer_type_context_t;
struct ggml_backend_buffer_type_i {
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
};
struct ggml_backend_buffer_type {
struct ggml_backend_buffer_type_i iface;
ggml_backend_buffer_type_context_t context;
};
// buffer
typedef void * ggml_backend_buffer_context_t;
struct ggml_backend_buffer_i {
void (*free_buffer)(ggml_backend_buffer_t buffer);
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
void * (*get_base) (ggml_backend_buffer_t buffer);
void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*free_buffer) (ggml_backend_buffer_t buffer);
void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
};
struct ggml_backend_buffer {
struct ggml_backend_buffer_i iface;
ggml_backend_buffer_type_t buft;
struct ggml_backend_buffer_i iface;
ggml_backend_t backend;
ggml_backend_buffer_context_t context;
size_t size;
};
ggml_backend_buffer_t ggml_backend_buffer_init(
ggml_backend_buffer_type_t buft,
GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
struct ggml_backend * backend,
struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context,
size_t size);
//
// Backend
//
@ -68,18 +49,21 @@ extern "C" {
void (*free)(ggml_backend_t backend);
// buffer allocation
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
// (optional) asynchroneous tensor data access
// get buffer alignment
size_t (*get_alignment)(ggml_backend_t backend);
// tensor data access
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) asynchroneous tensor copy
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*synchronize) (ggml_backend_t backend);
// (optional) copy tensor between different backends, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
// compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
@ -98,15 +82,6 @@ extern "C" {
ggml_backend_context_t context;
};
//
// Backend registry
//
typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -7,44 +7,41 @@
extern "C" {
#endif
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
//
// Backend buffer
//
// buffer type
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
struct ggml_backend_buffer;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
// buffer
// backend buffer functions
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
//
// Backend
//
struct ggml_backend;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
@ -60,7 +57,6 @@ extern "C" {
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
//
// CPU backend
@ -72,23 +68,8 @@ extern "C" {
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
// Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
//
// Backend registry
//
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
GGML_API size_t ggml_backend_reg_get_count(void);
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
GGML_API const char * ggml_backend_reg_get_name(size_t i);
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
//
// Backend scheduler
@ -150,32 +131,6 @@ extern "C" {
ggml_backend_sched_t sched,
struct ggml_cgraph * graph);
//
// Utils
//
struct ggml_backend_graph_copy {
ggml_backend_buffer_t buffer;
struct ggml_context * ctx_allocated;
struct ggml_context * ctx_unallocated;
struct ggml_cgraph * graph;
};
// Copy a graph to a different backend
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
// Tensor initialization
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -49,15 +49,7 @@ GGML_API int ggml_cuda_get_device_count(void);
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
// backend API
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
#ifdef __cplusplus
}

View File

@ -232,7 +232,7 @@ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
// return index, asserts if table is full

View File

@ -99,12 +99,6 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// helper to check if the device supports a specific family
// ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
#ifdef __cplusplus
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +1,20 @@
#include "ggml.h"
#include "ggml-opencl.h"
#include <array>
#include <atomic>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <sstream>
#include <vector>
#include <limits>
#define CL_TARGET_OPENCL_VERSION 110
#include <clblast.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "ggml.h"
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

View File

@ -19,7 +19,7 @@
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#else
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
#ifdef __POWER9_VECTOR__
#include <altivec.h>
#undef bool
#define bool _Bool
@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
size_t vl = __riscv_vsetvl_e8m1(qk/2);
// These temporary registers are for masking and shift operations
// These tempory registers are for masking and shift operations
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vl = 16;
// retrieve lane to multiply with scale
// retreive lane to multiply with scale
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);

939
ggml.c

File diff suppressed because it is too large Load Diff

91
ggml.h
View File

@ -215,9 +215,9 @@
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
#define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 2048
#define GGML_MAX_PARAMS 1024
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 10
#define GGML_MAX_SRC 6
#define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4
@ -244,10 +244,11 @@
#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
fflush(stdout); \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
fflush(stderr); \
fflush(stdout); \
ggml_print_backtrace(); \
abort(); \
exit(1); \
} \
} while (0)
@ -283,20 +284,6 @@
const type prefix##3 = (pointer)->array[3]; \
GGML_UNUSED(prefix##3);
#define GGML_TENSOR_UNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#ifdef __cplusplus
extern "C" {
#endif
@ -395,7 +382,6 @@ extern "C" {
GGML_OP_GROUP_NORM,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
GGML_OP_OUT_PROD,
GGML_OP_SCALE,
@ -422,10 +408,8 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
@ -465,8 +449,7 @@ extern "C" {
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_COUNT,
GGML_UNARY_OP_LEAKY
};
enum ggml_object_type {
@ -649,9 +632,6 @@ extern "C" {
GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op);
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_quantized(enum ggml_type type);
@ -794,9 +774,6 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// dst = a
// view(dst, nb1, nb2, nb3, offset) += b
// return dst
GGML_API struct ggml_tensor * ggml_acc(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -961,14 +938,15 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_leaky_relu(
GGML_API struct ggml_tensor * ggml_leaky(
struct ggml_context * ctx,
struct ggml_tensor * a, float negative_slope, bool inplace);
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// TODO: double-check this computation is correct
GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx,
struct ggml_tensor * a);
@ -1050,16 +1028,6 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// indirect matrix multiplication
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * const as[],
int n_as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b);
// A: m columns, n rows,
// B: p columns, n rows,
// result is m columns, p rows
@ -1267,7 +1235,6 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1316,14 +1283,6 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// fused soft_max(a*scale + mask)
// mask is optional
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale);
GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1554,32 +1513,6 @@ extern "C" {
struct ggml_tensor * a,
int scale_factor);
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
GGML_API struct ggml_tensor * ggml_pad(
struct ggml_context * ctx,
struct ggml_tensor * a,
int p0,
int p1,
int p2,
int p3);
// sort rows
enum ggml_sort_order {
GGML_SORT_ASC,
GGML_SORT_DESC,
};
GGML_API struct ggml_tensor * ggml_argsort(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_sort_order order);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
@ -1641,6 +1574,7 @@ extern "C" {
int kh);
// used in sam
GGML_API struct ggml_tensor * ggml_add_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1815,7 +1749,7 @@ extern "C" {
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
@ -2111,7 +2045,6 @@ extern "C" {
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);

View File

@ -1063,7 +1063,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
#ifdef GGML_USE_CUBLAS
if (params.use_gpu && ggml_cublas_loaded()) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init(0);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
@ -1077,10 +1077,6 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
} else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
ggml_backend_free(backend_gpu);
backend_gpu = NULL;
}
}
#endif
@ -1345,10 +1341,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@ -1578,25 +1574,29 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
}
const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
if (!is_conv_bias) {
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
}
ggml_backend_t backend = wctx.backend;
@ -1607,7 +1607,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
#ifdef GGML_USE_METAL
|| ggml_backend_is_metal(backend)
#endif
)) {
) && !is_conv_bias) {
// for the CPU and Metal backend, we can read directly into the tensor
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
BYTESWAP_TENSOR(tensor);
@ -1615,7 +1615,24 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor));
loader->read(loader->context, read_buf.data(), read_buf.size());
// we repeat the 2 bias tensors along dim 0:
// [1, 512] -> [3000, 512] (conv1.bias)
// [1, 512] -> [1500, 512] (conv2.bias)
if (is_conv_bias) {
loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
float * data_f32 = (float *) read_buf.data();
for (int64_t y = 0; y < tensor->ne[1]; ++y) {
const int64_t yy = tensor->ne[1] - y - 1;
const float val = data_f32[yy];
for (int64_t x = 0; x < tensor->ne[0]; ++x) {
data_f32[yy*tensor->ne[0] + x] = val;
}
}
} else {
loader->read(loader->context, read_buf.data(), read_buf.size());
}
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
}
@ -1715,12 +1732,20 @@ static struct ggml_cgraph * whisper_build_graph_conv(
// convolution + gelu
{
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
if (n_ctx == hparams.n_audio_ctx) {
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
} else {
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_1_b, cur->ne[0], cur->ne[1], model.e_conv_1_b->nb[1], 0)));
}
cur = ggml_gelu(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
if (n_ctx == hparams.n_audio_ctx) {
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
} else {
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_2_b, cur->ne[0], cur->ne[1], model.e_conv_2_b->nb[1], 0)));
}
cur = ggml_gelu(ctx0, cur);
}
@ -3568,17 +3593,6 @@ const char * whisper_lang_str(int id) {
return nullptr;
}
const char * whisper_lang_str_full(int id) {
for (const auto & kv : g_lang) {
if (kv.second.first == id) {
return kv.second.second.c_str();
}
}
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
return nullptr;
}
int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
@ -5028,7 +5042,6 @@ int whisper_full_with_state(
// basically don't process anything that is less than 1.0s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
return 0;
}
@ -5167,7 +5180,7 @@ int whisper_full_with_state(
ctx, state, progress_cur, params.progress_callback_user_data);
}
// if only 1 second left, then stop
// of only 1 second left, then stop
if (seek + 100 >= seek_end) {
break;
}
@ -5456,7 +5469,6 @@ int whisper_full_with_state(
// do not allow to go back in time
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
failed = true; // TODO: maybe this is not a failure ?
continue;
}
@ -5485,7 +5497,6 @@ int whisper_full_with_state(
if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
} else {
WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
failed = true;
continue;
}
@ -5496,7 +5507,6 @@ int whisper_full_with_state(
seek_delta = 100*WHISPER_CHUNK_SIZE;
}
WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
completed = true;
continue;
}
@ -5512,7 +5522,6 @@ int whisper_full_with_state(
// sometimes, the decoding can get stuck in a repetition loop
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
failed = true;
continue;
}
@ -5656,27 +5665,28 @@ int whisper_full_with_state(
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
}
bool success = true;
// was the decoding successful for the current temperature?
// do fallback only if:
// - we are not at the last temperature
if (it != (int) temperatures.size() - 1) {
// - we are not at the end of the audio (3 sec)
if (it != (int) temperatures.size() - 1 &&
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
bool success = true;
const auto & decoder = state->decoders[best_decoder_id];
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
success = false;
state->n_fail_p++;
}
}
if (success) {
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
if (success) {
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
break;
break;
}
}
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
@ -6054,43 +6064,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
// 1GB array
const size_t size = arr*1e6;
double sum = 0.0;
// heat-up
{
char * src = (char *) malloc(size);
char * dst = (char *) malloc(size);
for (size_t i = 0; i < size; i++) src[i] = i;
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
for (size_t i = 0; i < n; i++) {
const int64_t t0 = ggml_time_us();
memcpy(dst, src, size);
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
src[rand() % size] = rand() % 256;
}
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
// single-thread
{
char * src = (char *) malloc(size);
@ -6101,6 +6074,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
double sum = 0.0;
for (size_t i = 0; i < n; i++) {
const int64_t t0 = ggml_time_us();
@ -6114,73 +6088,21 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
src[rand() % size] = rand() % 256;
}
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1e9));
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
s += strbuf;
}
free(src);
free(dst);
}
// multi-thread
for (uint32_t k = 1; k <= n_threads; k++) {
char * src = (char *) malloc(size);
char * dst = (char *) malloc(size);
for (size_t i = 0; i < size; i++) src[i] = i;
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
auto helper = [&](int th) {
const int64_t i0 = (th + 0)*size/k;
const int64_t i1 = (th + 1)*size/k;
for (size_t i = 0; i < n; i++) {
memcpy(dst + i0, src + i0, i1 - i0);
src[i0 + rand() % (i1 - i0)] = rand() % 256;
};
};
const int64_t t0 = ggml_time_us();
std::vector<std::thread> threads(k - 1);
for (uint32_t th = 0; th < k - 1; ++th) {
threads[th] = std::thread(helper, th);
}
helper(k - 1);
for (uint32_t th = 0; th < k - 1; ++th) {
threads[th].join();
}
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
s += strbuf;
return s.c_str();
}

View File

@ -51,7 +51,7 @@ extern "C" {
// ...
//
// whisper_context_params cparams = whisper_context_default_params();
//
//
// struct whisper_context * ctx = whisper_init_from_file_with_params("/path/to/ggml-base.en.bin", cparams);
//
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
@ -315,9 +315,6 @@ extern "C" {
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id);
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str_full(int id);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure