mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-07-01 23:10:47 +02:00
Compare commits
11 Commits
Author | SHA1 | Date | |
---|---|---|---|
4260d4fc70 | |||
ee65df7982 | |||
03f254193b | |||
8f2d8eae10 | |||
a44b21bce0 | |||
f07ff2aa6a | |||
280e631bcf | |||
2f86da0d09 | |||
a787f7f85c | |||
c83a38e89d | |||
758c951729 |
22
.github/workflows/build.yml
vendored
22
.github/workflows/build.yml
vendored
@ -25,7 +25,6 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev
|
||||
make
|
||||
@ -87,7 +86,6 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
@ -115,10 +113,8 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y clang
|
||||
apt install -y clang build-essential cmake libsdl2-dev
|
||||
apt install -y build-essential cmake libsdl2-dev
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
@ -144,7 +140,6 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake
|
||||
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
@ -222,10 +217,10 @@ jobs:
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
@ -290,7 +285,6 @@ jobs:
|
||||
arch: [x64]
|
||||
cublas: [ON]
|
||||
sdl2: [ON]
|
||||
cuda-toolkit: [12.2.0, 11.8.0]
|
||||
include:
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
@ -306,9 +300,7 @@ jobs:
|
||||
|
||||
- name: Install CUDA Toolkit
|
||||
id: cuda-toolkit
|
||||
uses: Jimver/cuda-toolkit@v0.2.11
|
||||
with:
|
||||
cuda: '${{ matrix.cuda-toolkit }}'
|
||||
uses: Jimver/cuda-toolkit@v0.2.10
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
@ -323,10 +315,10 @@ jobs:
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_CUBLAS=1
|
||||
|
||||
- name: Build ${{ matrix.cuda-toolkit }}
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
cmake --build . --config ${{ matrix.build }}
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy CUDA DLLs
|
||||
run: >
|
||||
@ -343,7 +335,7 @@ jobs:
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
|
||||
name: whisper-cublas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
|
@ -1,6 +1,6 @@
|
||||
cmake_minimum_required (VERSION 3.5)
|
||||
|
||||
project(whisper.cpp VERSION 1.5.2)
|
||||
project(whisper.cpp VERSION 1.5.0)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
@ -533,7 +533,7 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
${WHISPER_EXTRA_FLAGS}
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "ggml.h;whisper.h")
|
||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
|
@ -2,14 +2,33 @@
|
||||
|
||||
import PackageDescription
|
||||
|
||||
#if arch(arm) || arch(arm64)
|
||||
let platforms: [SupportedPlatform]? = [
|
||||
.macOS(.v12),
|
||||
.iOS(.v14),
|
||||
.watchOS(.v4),
|
||||
.tvOS(.v14)
|
||||
]
|
||||
let exclude: [String] = []
|
||||
let resources: [Resource] = [
|
||||
.process("ggml-metal.metal")
|
||||
]
|
||||
let additionalSources: [String] = ["ggml-metal.m"]
|
||||
let additionalSettings: [CSetting] = [
|
||||
.unsafeFlags(["-fno-objc-arc"]),
|
||||
.define("GGML_USE_METAL")
|
||||
]
|
||||
#else
|
||||
let platforms: [SupportedPlatform]? = nil
|
||||
let exclude: [String] = ["ggml-metal.metal"]
|
||||
let resources: [Resource] = []
|
||||
let additionalSources: [String] = []
|
||||
let additionalSettings: [CSetting] = []
|
||||
#endif
|
||||
|
||||
let package = Package(
|
||||
name: "whisper",
|
||||
platforms: [
|
||||
.macOS(.v12),
|
||||
.iOS(.v14),
|
||||
.watchOS(.v4),
|
||||
.tvOS(.v14)
|
||||
],
|
||||
platforms: platforms,
|
||||
products: [
|
||||
.library(name: "whisper", targets: ["whisper"]),
|
||||
],
|
||||
@ -17,7 +36,7 @@ let package = Package(
|
||||
.target(
|
||||
name: "whisper",
|
||||
path: ".",
|
||||
exclude: [
|
||||
exclude: exclude + [
|
||||
"bindings",
|
||||
"cmake",
|
||||
"coreml",
|
||||
@ -36,22 +55,19 @@ let package = Package(
|
||||
"whisper.cpp",
|
||||
"ggml-alloc.c",
|
||||
"ggml-backend.c",
|
||||
"ggml-quants.c",
|
||||
"ggml-metal.m"
|
||||
],
|
||||
resources: [.process("ggml-metal.metal")],
|
||||
"ggml-quants.c"
|
||||
] + additionalSources,
|
||||
resources: resources,
|
||||
publicHeadersPath: "spm-headers",
|
||||
cSettings: [
|
||||
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
|
||||
.define("GGML_USE_ACCELERATE"),
|
||||
.unsafeFlags(["-fno-objc-arc"]),
|
||||
.define("GGML_USE_METAL")
|
||||
.define("GGML_USE_ACCELERATE")
|
||||
// NOTE: NEW_LAPACK will required iOS version 16.4+
|
||||
// We should consider add this in the future when we drop support for iOS 14
|
||||
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
|
||||
// .define("ACCELERATE_NEW_LAPACK"),
|
||||
// .define("ACCELERATE_LAPACK_ILP64")
|
||||
],
|
||||
] + additionalSettings,
|
||||
linkerSettings: [
|
||||
.linkedFramework("Accelerate")
|
||||
]
|
||||
|
12
README.md
12
README.md
@ -6,7 +6,7 @@
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.5.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.2) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
Stable: [v1.5.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
@ -110,8 +110,8 @@ options:
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [5 ] beam size for beam search
|
||||
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
@ -128,7 +128,6 @@ options:
|
||||
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-oj, --output-json [false ] output result in a JSON file
|
||||
-ojf, --output-json-full [false ] include more information in the JSON file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
@ -140,8 +139,7 @@ options:
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
-ls, --log-score [false ] log best decoder scores of tokens
|
||||
-ng, --no-gpu [false ] disable GPU
|
||||
-ls, --log-score [false ] log best decoder scores of token
|
||||
|
||||
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
@ -770,7 +768,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
||||
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
|
||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
@ -780,7 +777,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
||||
| [server](examples/server) | | HTTP transcription server with OAI-like API |
|
||||
|
||||
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
|
||||
|
||||
|
@ -1,26 +1,9 @@
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
|
||||
ifndef UNAME_P
|
||||
UNAME_P := $(shell uname -p)
|
||||
endif
|
||||
|
||||
ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
|
||||
endif
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@ -28,13 +11,8 @@ whisper: mkdir
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
|
||||
test: model-small whisper modtidy
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
|
||||
else
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
endif
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
@ -43,11 +21,7 @@ model-small: mkdir examples/go-model-download
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go build ${BUILD_FLAGS} -ldflags "-extldflags '$(EXT_LDFLAGS)'" -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
else
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
endif
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
|
Submodule bindings/ios updated: 88c28eb833...f5e5cf24ca
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.5.2",
|
||||
"version": "1.5.0",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
File diff suppressed because one or more lines are too long
@ -22,7 +22,6 @@ var printTextarea = (function() {
|
||||
async function clearCache() {
|
||||
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
|
||||
indexedDB.deleteDatabase(dbName);
|
||||
location.reload();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -17,37 +17,28 @@ options:
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [5 ] beam size for beam search
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-olrc, --output-lrc [false ] output result in a lrc file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-oj, --output-json [false ] output result in a JSON file
|
||||
-ojf, --output-json-full [false ] include more information in the JSON file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [false ] do not print timestamps
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
-dl, --detect-language [false ] exit after automatically detecting language
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
-ls, --log-score [false ] log best decoder scores of tokens
|
||||
-ng, --no-gpu [false ] disable GPU
|
||||
```
|
||||
|
@ -4,9 +4,3 @@ add_executable(${TARGET} server.cpp httplib.h json.hpp)
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
# Check if the compiler is MinGW
|
||||
if(MINGW)
|
||||
# Link the necessary libraries for SSL and Winsock
|
||||
target_link_libraries(${TARGET} PRIVATE -lcrypt32 -lssl -lcrypto -lws2_32)
|
||||
endif()
|
||||
|
@ -2,10 +2,6 @@
|
||||
|
||||
Simple http server. WAV Files are passed to the inference model via http requests.
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-afe5e4594b8f
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
./server -h
|
||||
|
||||
@ -33,7 +29,6 @@ options:
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pr, --print-realtime [false ] print output in realtime
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [false ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
@ -43,12 +38,8 @@ options:
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
|
||||
--port PORT, [8080 ] Port number for the server
|
||||
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.**
|
||||
|
||||
## request examples
|
||||
|
||||
**/inference**
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@ -44,8 +43,6 @@ struct server_params
|
||||
int32_t port = 8080;
|
||||
int32_t read_timeout = 600;
|
||||
int32_t write_timeout = 600;
|
||||
|
||||
bool ffmpeg_converter = false;
|
||||
};
|
||||
|
||||
struct whisper_params {
|
||||
@ -75,7 +72,6 @@ struct whisper_params {
|
||||
bool no_fallback = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_realtime = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool use_gpu = true;
|
||||
@ -148,7 +144,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
@ -160,7 +155,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
|
||||
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
|
||||
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
|
||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -194,7 +188,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
@ -207,7 +200,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
||||
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
||||
else if ( arg == "--public") { sparams.public_path = argv[++i]; }
|
||||
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params, sparams);
|
||||
@ -225,45 +217,6 @@ struct whisper_print_user_data {
|
||||
int progress_prev;
|
||||
};
|
||||
|
||||
void check_ffmpeg_availibility() {
|
||||
int result = system("ffmpeg -version");
|
||||
|
||||
if (result == 0) {
|
||||
std::cout << "ffmpeg is available." << std::endl;
|
||||
} else {
|
||||
// ffmpeg is not available
|
||||
std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed ";
|
||||
std::cout << "and that its executable is included in your system's PATH. ";
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) {
|
||||
std::ostringstream cmd_stream;
|
||||
std::string converted_filename_temp = temp_filename + "_temp.wav";
|
||||
cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1";
|
||||
std::string cmd = cmd_stream.str();
|
||||
|
||||
int status = std::system(cmd.c_str());
|
||||
if (status != 0) {
|
||||
error_resp = "{\"error\":\"FFmpeg conversion failed.\"}";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the original file
|
||||
if (remove(temp_filename.c_str()) != 0) {
|
||||
error_resp = "{\"error\":\"Failed to remove the original file.\"}";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rename the temporary file to match the original filename
|
||||
if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) {
|
||||
error_resp = "{\"error\":\"Failed to rename the temporary file.\"}";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
|
||||
std::string speaker = "";
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
@ -420,7 +373,7 @@ void get_req_parameters(const Request & req, whisper_params & params)
|
||||
{
|
||||
params.response_format = req.get_file_value("response-format").content;
|
||||
}
|
||||
if (req.has_file("temperature"))
|
||||
if (req.has_file("temerature"))
|
||||
{
|
||||
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
|
||||
}
|
||||
@ -451,9 +404,6 @@ int main(int argc, char ** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (sparams.ffmpeg_converter) {
|
||||
check_ffmpeg_availibility();
|
||||
}
|
||||
// whisper init
|
||||
struct whisper_context_params cparams;
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
@ -469,9 +419,6 @@ int main(int argc, char ** argv) {
|
||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||
|
||||
Server svr;
|
||||
svr.set_default_headers({{"Server", "whisper.cpp"},
|
||||
{"Access-Control-Allow-Origin", "*"},
|
||||
{"Access-Control-Allow-Headers", "content-type"}});
|
||||
|
||||
std::string const default_content = "<html>hello</html>";
|
||||
|
||||
@ -482,7 +429,7 @@ int main(int argc, char ** argv) {
|
||||
});
|
||||
|
||||
svr.Post("/inference", [&](const Request &req, Response &res){
|
||||
// acquire whisper model mutex lock
|
||||
// aquire whisper model mutex lock
|
||||
whisper_mutex.lock();
|
||||
|
||||
// first check user requested fields of the request
|
||||
@ -506,35 +453,20 @@ int main(int argc, char ** argv) {
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// write to temporary file
|
||||
const std::string temp_filename = "whisper_server_temp_file.wav";
|
||||
std::ofstream temp_file{temp_filename, std::ios::binary};
|
||||
// write file to temporary file
|
||||
std::ofstream temp_file{filename, std::ios::binary};
|
||||
temp_file << audio_file.content;
|
||||
temp_file.close();
|
||||
|
||||
// if file is not wav, convert to wav
|
||||
|
||||
if (sparams.ffmpeg_converter) {
|
||||
std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
|
||||
const bool is_converted = convert_to_wav(temp_filename, error_resp);
|
||||
if (!is_converted) {
|
||||
res.set_content(error_resp, "application/json");
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// read wav content into pcmf32
|
||||
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) {
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
|
||||
if (!::read_wav(filename, pcmf32, pcmf32s, params.diarize)) {
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", filename.c_str());
|
||||
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
|
||||
res.set_content(error_resp, "application/json");
|
||||
std::remove(temp_filename.c_str());
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
// remove temp file
|
||||
std::remove(temp_filename.c_str());
|
||||
std::remove(filename.c_str());
|
||||
|
||||
printf("Successfully loaded %s\n", filename.c_str());
|
||||
|
||||
@ -571,6 +503,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// run the inference
|
||||
{
|
||||
|
||||
printf("Running whisper.cpp inference on %s\n", filename.c_str());
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
@ -589,7 +522,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
@ -609,7 +541,7 @@ int main(int argc, char ** argv) {
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
// this callback is called on each new segment
|
||||
if (params.print_realtime) {
|
||||
if (!wparams.print_realtime) {
|
||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||
wparams.new_segment_callback_user_data = &user_data;
|
||||
}
|
||||
@ -659,50 +591,6 @@ int main(int argc, char ** argv) {
|
||||
std::string results = output_str(ctx, params, pcmf32s);
|
||||
res.set_content(results.c_str(), "text/html");
|
||||
}
|
||||
else if (params.response_format == srt_format)
|
||||
{
|
||||
std::stringstream ss;
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
ss << i + 1 + params.offset_n << "\n";
|
||||
ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
|
||||
ss << speaker << text << "\n\n";
|
||||
}
|
||||
res.set_content(ss.str(), "application/x-subrip");
|
||||
} else if (params.response_format == vtt_format) {
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "WEBVTT\n\n";
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
|
||||
speaker.insert(0, "<v Speaker");
|
||||
speaker.append(">");
|
||||
}
|
||||
|
||||
ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
|
||||
ss << speaker << text << "\n\n";
|
||||
}
|
||||
res.set_content(ss.str(), "text/vtt");
|
||||
}
|
||||
// TODO add more output formats
|
||||
else
|
||||
{
|
||||
|
@ -18,11 +18,6 @@ if (WHISPER_SDL2)
|
||||
../../ggml-quants.c
|
||||
../../whisper.cpp)
|
||||
|
||||
if(WIN32)
|
||||
# It requires Windows 8.1 or later for PrefetchVirtualMemory
|
||||
target_compile_definitions(${TARGET} PRIVATE -D_WIN32_WINNT=0x0602)
|
||||
endif()
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
# wchess
|
||||
|
||||
Voice-controlled chess using Whisper
|
||||
|
||||
Online demo: https://whisper.ggerganov.com/wchess/
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/c2b2f03c-9684-49f3-8106-357d2d4e67fa
|
||||
|
||||
## Command-line tool
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake -DWHISPER_SDL2=1 ..
|
||||
make -j
|
||||
|
||||
./bin/wchess -m ../models/ggml-base.en.bin
|
||||
|
||||
Move: start
|
||||
|
||||
a b c d e f g h
|
||||
r n b q k b n r 8
|
||||
p p p p p p p p 7
|
||||
. * . * . * . * 6
|
||||
* . * . * . * . 5
|
||||
. * . * . * . * 4
|
||||
* . * . * . * . 3
|
||||
P P P P P P P P 2
|
||||
R N B Q K B N R 1
|
||||
|
||||
White's turn
|
||||
[(l)isten/(p)ause/(q)uit]:
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- Improve web-browser audio capture - sometimes it does not record the voice properly
|
||||
- Add support for more languages by making the generated grammar string multi-lingual
|
||||
- Fix bugs in the chess moves logic
|
||||
|
||||
PRs welcome!
|
@ -1,19 +1,19 @@
|
||||
add_library(wchess-core STATIC
|
||||
add_library(libwchess
|
||||
WChess.cpp
|
||||
WChess.h
|
||||
Chessboard.cpp
|
||||
Chessboard.h
|
||||
)
|
||||
|
||||
target_link_libraries(wchess-core
|
||||
target_link_libraries(libwchess
|
||||
PUBLIC
|
||||
whisper
|
||||
common
|
||||
)
|
||||
|
||||
target_include_directories(wchess-core
|
||||
target_include_directories(libwchess
|
||||
PUBLIC
|
||||
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
|
||||
)
|
||||
|
||||
# add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)
|
||||
add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,33 +1,56 @@
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
// just basic validation
|
||||
// fixme: missing en passant, castling, promotion, etc.
|
||||
struct State;
|
||||
class Piece;
|
||||
class Chessboard {
|
||||
public:
|
||||
Chessboard();
|
||||
~Chessboard();
|
||||
std::string process(const std::string& command);
|
||||
std::string process(const std::string& t);
|
||||
std::string stringifyBoard();
|
||||
const std::string& grammar() { return m_grammar; }
|
||||
const std::string& prompt() { return m_prompt; }
|
||||
void setPrompt(const std::string& prompt);
|
||||
private:
|
||||
bool parseCommand(const std::string& command, Piece*& piece, char& pos_to);
|
||||
bool move(Piece& piece, char pos);
|
||||
void flagUpdates(char pos_from, char pos_to);
|
||||
void updatePins(Piece& piece);
|
||||
void detectChecks();
|
||||
void setGrammar();
|
||||
using Move = std::pair<int, int>;
|
||||
bool move(const Move& move);
|
||||
|
||||
std::unique_ptr<State> m_state;
|
||||
std::set<char> m_allowedInCheck;
|
||||
bool m_inCheck = false;
|
||||
struct Piece {
|
||||
enum Types {
|
||||
Pawn,
|
||||
Knight,
|
||||
Bishop,
|
||||
Rook,
|
||||
Queen,
|
||||
King,
|
||||
Taken,
|
||||
};
|
||||
|
||||
enum Colors {
|
||||
Black,
|
||||
White
|
||||
};
|
||||
|
||||
Types type;
|
||||
Colors color;
|
||||
int pos;
|
||||
};
|
||||
|
||||
Piece::Types tokenToType(std::string_view token);
|
||||
size_t tokenToPos(std::string_view token);
|
||||
using PieceSet = std::array<Piece, 16>;
|
||||
|
||||
PieceSet blackPieces;
|
||||
PieceSet whitePieces;
|
||||
int m_moveCounter = 0;
|
||||
std::string m_grammar;
|
||||
std::string m_prompt;
|
||||
|
||||
using Board = std::array<Piece*, 64>;
|
||||
Board board;
|
||||
|
||||
bool validateMove(const Piece& piece, int pos);
|
||||
// just basic validation
|
||||
// fixme: missing en passant, castling, promotion, etc.
|
||||
bool validatePawnMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
|
||||
bool validateKnightMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
|
||||
bool validateBishopMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
|
||||
bool validateRookMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
|
||||
bool validateQueenMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
|
||||
bool validateKingMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
|
||||
};
|
||||
|
@ -4,6 +4,24 @@
|
||||
#include "common.h"
|
||||
#include <thread>
|
||||
|
||||
static constexpr auto RULES =
|
||||
"\n"
|
||||
"root ::= init move move? move? \".\"\n"
|
||||
"prompt ::= init \".\"\n"
|
||||
"\n"
|
||||
"# leading space is very important!\n"
|
||||
"init ::= \" rook to b4, f3\"\n"
|
||||
"\n"
|
||||
"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n"
|
||||
"\n"
|
||||
"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n"
|
||||
"king ::= \"king\"\n"
|
||||
"pawn ::= \"pawn\"\n"
|
||||
"\n";
|
||||
|
||||
static constexpr auto PROMPT = "rook to b4, f3,";
|
||||
static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,";
|
||||
|
||||
WChess::WChess(whisper_context * ctx,
|
||||
const whisper_full_params & wparams,
|
||||
callbacks cb,
|
||||
@ -17,136 +35,175 @@ WChess::WChess(whisper_context * ctx,
|
||||
|
||||
WChess::~WChess() = default;
|
||||
|
||||
void WChess::set_move(const std::string& moves, float prob) const {
|
||||
if (m_cb.set_move) (*m_cb.set_move)(moves, prob);
|
||||
void WChess::set_status(const std::string& msg) const {
|
||||
if (m_cb.set_status) (*m_cb.set_status)(msg);
|
||||
}
|
||||
|
||||
void WChess::set_grammar(const std::string& grammar) const {
|
||||
if (m_cb.set_grammar) (*m_cb.set_grammar)(grammar);
|
||||
void WChess::set_moves(const std::string& moves) const {
|
||||
if (m_cb.set_moves) (*m_cb.set_moves)(moves);
|
||||
}
|
||||
|
||||
bool WChess::get_audio(std::vector<float>& pcmf32) const {
|
||||
if (m_cb.get_audio) return (*m_cb.get_audio)(pcmf32);
|
||||
bool WChess::check_running() const {
|
||||
if (m_cb.check_running) return (*m_cb.check_running)();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool WChess::clear_audio() const {
|
||||
if (m_cb.clear_audio) return (*m_cb.clear_audio)();
|
||||
return false;
|
||||
}
|
||||
|
||||
void WChess::get_audio(int ms, std::vector<float>& pcmf32) const {
|
||||
if (m_cb.get_audio) (*m_cb.get_audio)(ms, pcmf32);
|
||||
}
|
||||
|
||||
std::string WChess::stringify_board() const {
|
||||
return m_board->stringifyBoard();
|
||||
}
|
||||
|
||||
std::string WChess::get_grammar() const {
|
||||
return m_board->grammar();
|
||||
}
|
||||
|
||||
void WChess::run() {
|
||||
bool have_prompt = true;
|
||||
bool ask_prompt = !have_prompt;
|
||||
set_status("loading data ...");
|
||||
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float logprob_min0 = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
|
||||
float logprob_sum0 = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
|
||||
int n_tokens0 = 0;
|
||||
int n_tokens = 0;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = have_prompt ? "" : "rook to d4, f3";
|
||||
int64_t t_ms = 0;
|
||||
const std::string k_prompt = PROMPT;
|
||||
m_wparams.initial_prompt = CONTEXT;
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
auto grammar_parsed = grammar_parser::parse(RULES);
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
ask_prompt = false;
|
||||
if (grammar_parsed.rules.empty()) {
|
||||
fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__);
|
||||
}
|
||||
else {
|
||||
m_wparams.grammar_rules = grammar_rules.data();
|
||||
m_wparams.n_grammar_rules = grammar_rules.size();
|
||||
}
|
||||
|
||||
while (get_audio(pcmf32_cur)) {
|
||||
if (!pcmf32_cur.empty()) {
|
||||
// fprintf(stdout, "%s: Processing ...\n", __func__);
|
||||
|
||||
if (!have_prompt) {
|
||||
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
m_board->setPrompt(k_prompt);
|
||||
}
|
||||
} else {
|
||||
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
constexpr size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
|
||||
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
|
||||
|
||||
// fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, m_board->grammar().c_str());
|
||||
|
||||
auto grammar_parsed = grammar_parser::parse(m_board->grammar().c_str());
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
m_wparams.grammar_rules = grammar_rules.data();
|
||||
m_wparams.n_grammar_rules = grammar_rules.size();
|
||||
|
||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
|
||||
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
|
||||
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
if (!command.empty()) {
|
||||
set_move(m_board->process(command), p);
|
||||
set_grammar(m_board->grammar());
|
||||
}
|
||||
if (m_board->grammar().empty()) {
|
||||
fprintf(stdout, "%s: No more moves possible\n", __func__);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
while (check_running()) {
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str());
|
||||
set_status(txt);
|
||||
}
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
get_audio(m_settings.vad_ms, pcmf32_cur);
|
||||
|
||||
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, m_settings.vad_thold, m_settings.freq_thold, m_settings.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
set_status("Speech detected! Processing ...");
|
||||
|
||||
if (!have_prompt) {
|
||||
get_audio(m_settings.prompt_ms, pcmf32_cur);
|
||||
|
||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt");
|
||||
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ...");
|
||||
set_status(txt);
|
||||
}
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
get_audio(m_settings.command_ms, pcmf32_cur);
|
||||
|
||||
// prepend 3 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root");
|
||||
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
|
||||
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
if (n >= int(txt.size())) {
|
||||
break;
|
||||
}
|
||||
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms);
|
||||
set_status(txt);
|
||||
}
|
||||
if (!command.empty()) {
|
||||
set_moves(m_board->process(command));
|
||||
}
|
||||
}
|
||||
|
||||
clear_audio();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,16 +8,18 @@ class Chessboard;
|
||||
|
||||
class WChess {
|
||||
public:
|
||||
using SetStatusCb = void (*)(const std::string &);
|
||||
using CheckRunningCb = bool (*)();
|
||||
using GetAudioCb = bool (*)(std::vector<float> &);
|
||||
using SetMovesCb = void (*)(const std::string &, float);
|
||||
using SetGrammarCb = void (*)(const std::string &);
|
||||
using ClearAudioCb = void (*)();
|
||||
using GetAudioCb = void (*)(int, std::vector<float> &);
|
||||
using SetMovesCb = void (*)(const std::string &);
|
||||
using CleartAudioCb = bool (*)();
|
||||
|
||||
struct callbacks {
|
||||
SetStatusCb set_status = nullptr;
|
||||
CheckRunningCb check_running = nullptr;
|
||||
GetAudioCb get_audio = nullptr;
|
||||
SetMovesCb set_move = nullptr;
|
||||
SetGrammarCb set_grammar = nullptr;
|
||||
SetMovesCb set_moves = nullptr;
|
||||
CleartAudioCb clear_audio = nullptr;
|
||||
};
|
||||
|
||||
struct settings {
|
||||
@ -38,16 +40,13 @@ public:
|
||||
~WChess();
|
||||
|
||||
void run();
|
||||
|
||||
std::string stringify_board() const;
|
||||
|
||||
std::string get_grammar() const;
|
||||
|
||||
private:
|
||||
bool get_audio(std::vector<float>& pcmf32) const;
|
||||
void set_move(const std::string& moves, float prob) const;
|
||||
void set_grammar(const std::string& grammar) const;
|
||||
|
||||
void get_audio(int ms, std::vector<float>& pcmf32) const;
|
||||
void set_status(const std::string& msg) const;
|
||||
void set_moves(const std::string& moves) const;
|
||||
bool check_running() const;
|
||||
bool clear_audio() const;
|
||||
std::string transcribe(
|
||||
const std::vector<float> & pcmf32,
|
||||
float & logprob_min,
|
||||
|
@ -11,107 +11,78 @@
|
||||
|
||||
|
||||
int main() {
|
||||
{
|
||||
Chessboard chess;
|
||||
|
||||
ASSERT(chess.process("pawn to d4") == "d2-d4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("c1 h6") == "c1-h6");
|
||||
ASSERT(chess.process("queen h4") == "d8-h4");
|
||||
ASSERT(chess.process("bishop to g5") == "h6-g5");
|
||||
ASSERT(chess.process("bishop to b4") == "f8-b4");
|
||||
ASSERT(chess.process("c4") == "");
|
||||
ASSERT(chess.process("knight c3") == "b1-c3");
|
||||
ASSERT(chess.process("knight c6") == "b8-c6");
|
||||
ASSERT(chess.process("f3") == "");
|
||||
}
|
||||
|
||||
{
|
||||
// pawns
|
||||
Chessboard chess;
|
||||
|
||||
ASSERT(chess.process("d4") == "d2-d4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("e4") == "e2-e4");
|
||||
ASSERT(chess.process("queen h4") == "d8-h4");
|
||||
ASSERT(chess.process("queen h5") == "d1-h5");
|
||||
ASSERT(chess.process("f5") == "");
|
||||
ASSERT(chess.process("g6") == "g7-g6");
|
||||
ASSERT(chess.process("knight e2") == "g1-e2");
|
||||
ASSERT(chess.process("f5") == "f7-f5");
|
||||
ASSERT(chess.process("knight g3") == "e2-g3");
|
||||
ASSERT(chess.process("g5") == "");
|
||||
ASSERT(chess.process("king e7") == "e8-e7");
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("g5") == "g6-g5");
|
||||
}
|
||||
|
||||
{
|
||||
Chessboard chess;
|
||||
|
||||
ASSERT(chess.process("e4") == "e2-e4");
|
||||
ASSERT(chess.process("c5") == "c7-c5");
|
||||
ASSERT(chess.process("e5") == "e4-e5");
|
||||
ASSERT(chess.process("c4") == "c5-c4");
|
||||
ASSERT(chess.process("e6") == "e5-e6");
|
||||
ASSERT(chess.process("c3") == "c4-c3");
|
||||
ASSERT(chess.process("e7") == "");
|
||||
ASSERT(chess.process("f7") == "e6-f7");
|
||||
ASSERT(chess.process("d2") == "");
|
||||
ASSERT(chess.process("king to f7") == "e8-f7");
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("d2") == "c3-d2");
|
||||
ASSERT(chess.process("f5") == "");
|
||||
ASSERT(chess.process("king to e2") == "e1-e2");
|
||||
ASSERT(chess.process("king to g6") == "f7-g6");
|
||||
ASSERT(chess.process("f5") == "f4-f5");
|
||||
ASSERT(chess.process("e6") == "");
|
||||
ASSERT(chess.process("king to h5") == "g6-h5");
|
||||
ASSERT(chess.process("g4") == "g2-g4");
|
||||
ASSERT(chess.process("king to g5") == "h5-g5");
|
||||
ASSERT(chess.process("pawn to d4, e5, e3, pawn to d5") == "d2-d4 e7-e5 e2-e3 d7-d5");
|
||||
ASSERT(chess.process("pawn to d4") == ""); // wrong
|
||||
ASSERT(chess.process("pawn to c5") == ""); // wrong
|
||||
ASSERT(chess.process("pawn to d5") == ""); // wrong
|
||||
ASSERT(chess.process("pawn to d3") == ""); // wrong
|
||||
ASSERT(chess.process("pawn to f5") == ""); // wrong, white's turn
|
||||
ASSERT(chess.process("h4") == "h2-h4");
|
||||
ASSERT(chess.process("king to h5") == "");
|
||||
ASSERT(chess.process("king to g6") == "");
|
||||
ASSERT(chess.process("king to h6") == "g5-h6");
|
||||
ASSERT(chess.process("bishop to d2") == "c1-d2");
|
||||
ASSERT(chess.process("king to g5") == "");
|
||||
ASSERT(chess.process("g5") == "g7-g5");
|
||||
ASSERT(chess.process("d4") == "e5-d4");
|
||||
ASSERT(chess.process("e4") == "e3-e4");
|
||||
ASSERT(chess.process("d4") == ""); // wrong
|
||||
ASSERT(chess.process("e4") == "d5-e4");
|
||||
}
|
||||
|
||||
{
|
||||
// rook
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("g4") == "g2-g4");
|
||||
ASSERT(chess.process("queen to h4") == "d8-h4#");
|
||||
ASSERT(chess.process("knight f3") == "");
|
||||
ASSERT(chess.grammar().empty());
|
||||
|
||||
ASSERT(chess.process("rook to a3") == ""); // wrong
|
||||
ASSERT(chess.process("a4, h5, rook to a3, rook to h6") == "a2-a4 h7-h5 a1-a3 h8-h6");
|
||||
ASSERT(chess.process("rook to d3, rook to e6") == "a3-d3 h6-e6");
|
||||
ASSERT(chess.process("rook to d4, rook to e5") == "d3-d4 e6-e5");
|
||||
ASSERT(chess.process("rook to a4") == ""); // wrong
|
||||
ASSERT(chess.process("rook to d8") == ""); // wrong
|
||||
ASSERT(chess.process("rook to d3") == "d4-d3");
|
||||
ASSERT(chess.process("rook to e2") == "e5-e2");
|
||||
}
|
||||
|
||||
{
|
||||
// knight
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("f4") == "f2-f4");
|
||||
ASSERT(chess.process("e5") == "e7-e5");
|
||||
ASSERT(chess.process("g4") == "g2-g4");
|
||||
ASSERT(chess.process("d5") == "d7-d5");
|
||||
ASSERT(chess.process("g1 f3") == "g1-f3");
|
||||
ASSERT(chess.process("queen to h4") == "d8-h4");
|
||||
ASSERT(!chess.grammar().empty());
|
||||
|
||||
ASSERT(chess.process("knight to c3, knight to c6") == "b1-c3 b8-c6");
|
||||
ASSERT(chess.process("knight to c3") == ""); // wrong
|
||||
ASSERT(chess.process("knight to a2") == ""); // wrong
|
||||
ASSERT(chess.process("knight to b4") == ""); // wrong, white's turn
|
||||
ASSERT(chess.process("knight to b5") == "c3-b5");
|
||||
ASSERT(chess.process("knight to a5") == "c6-a5");
|
||||
ASSERT(chess.process("knight to c7") == "b5-c7");
|
||||
}
|
||||
|
||||
{
|
||||
// bishop
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("knight c3") == "b1-c3");
|
||||
ASSERT(chess.process("knight c6") == "b8-c6");
|
||||
ASSERT(chess.process("knight b5") == "c3-b5");
|
||||
ASSERT(chess.process("knight f6") == "g8-f6");
|
||||
ASSERT(chess.process("knight d6") == "b5-d6");
|
||||
ASSERT(chess.process("knight d4") == "");
|
||||
ASSERT(chess.process("d6") == "c7-d6");
|
||||
ASSERT(chess.process("e4") == "e2-e4");
|
||||
ASSERT(chess.process("knight d4") == "c6-d4");
|
||||
ASSERT(chess.process("d3") == "d2-d3");
|
||||
ASSERT(chess.process("knight e4") == "f6-e4");
|
||||
ASSERT(chess.process("king to e2") == "");
|
||||
ASSERT(chess.process("king to d2") == "");
|
||||
|
||||
ASSERT(chess.process("b3, b6, bishop to b2, bishop to b7") == "b2-b3 b7-b6 c1-b2 c8-b7");
|
||||
ASSERT(chess.process("bishop to a1") == ""); // wrong
|
||||
ASSERT(chess.process("bishop to h8") == ""); // wrong
|
||||
ASSERT(chess.process("bishop to a6") == ""); // wrong, white's turn
|
||||
ASSERT(chess.process("bishop to g7") == "b2-g7");
|
||||
}
|
||||
|
||||
{
|
||||
// queen
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("queen to d8") == ""); // wrong
|
||||
ASSERT(chess.process("queen to f1") == ""); // wrong
|
||||
ASSERT(chess.process("queen to h5") == ""); // wrong
|
||||
ASSERT(chess.process("e3, d5, queen to h5, queen to d6") == "e2-e3 d7-d5 d1-h5 d8-d6");
|
||||
ASSERT(chess.process("queen to c5") == ""); // wrong, white's turn
|
||||
ASSERT(chess.process("queen to f7") == "h5-f7");
|
||||
}
|
||||
|
||||
{
|
||||
// king
|
||||
Chessboard chess;
|
||||
ASSERT(chess.process("d3, d6, king to d2, king to d7, king to c3, king to c6, king to c4") == "d2-d3 d7-d6 e1-d2 e8-d7 d2-c3 d7-c6 c3-c4");
|
||||
ASSERT(chess.process("bishop to e6") == "c8-e6");
|
||||
ASSERT(chess.process("king to b3") == "c4-b3"); // !! check check not implemented
|
||||
}
|
||||
}
|
@ -4,5 +4,5 @@ if (WHISPER_SDL2)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE wchess-core common-sdl ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
target_link_libraries(${TARGET} PRIVATE libwchess common-sdl ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
@ -7,7 +7,6 @@
|
||||
|
||||
#include "WChess.h"
|
||||
#include "common-sdl.h"
|
||||
#include <iostream>
|
||||
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
@ -110,61 +109,17 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
}
|
||||
|
||||
std::unique_ptr<WChess> g_wchess;
|
||||
int g_moveCount = 0;
|
||||
void set_move(const std::string & move, float) {
|
||||
if (!move.empty()) {
|
||||
g_moveCount++;
|
||||
fprintf(stdout, "Move: %s\n\n", move.c_str());
|
||||
}
|
||||
else fprintf(stdout, "Move rejected\n\n");
|
||||
fprintf(stdout, "%s\n", g_wchess->stringify_board().c_str());
|
||||
fprintf(stdout, "%s\n", g_moveCount ? "White's turn" : "Black's turn");
|
||||
void set_moves(const std::string & moves) {
|
||||
if (!moves.empty()) fprintf(stdout, "%s", g_wchess->stringify_board().c_str());
|
||||
}
|
||||
|
||||
audio_async g_audio(30*1000);
|
||||
bool g_listening = false;
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
bool read_input() {
|
||||
std::string input;
|
||||
while (true) {
|
||||
fprintf(stdout, "[(l)isten/(p)ause/(q)uit]: ");
|
||||
std::cin >> input;
|
||||
fprintf(stdout, "\n");
|
||||
if (input[0] == 'q') {
|
||||
fprintf(stdout, "Quitting\n");
|
||||
return false;
|
||||
}
|
||||
if (input[0] == 'l') {
|
||||
if (!g_listening) {
|
||||
fprintf(stdout, "Listening\n");
|
||||
g_listening = true;
|
||||
g_pcmf32.clear();
|
||||
g_audio.resume();
|
||||
g_audio.clear();
|
||||
}
|
||||
else fprintf(stdout, "Still listening\n");
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
if (g_listening) {
|
||||
g_listening = false;
|
||||
g_audio.get(0, g_pcmf32);
|
||||
g_audio.pause();
|
||||
fprintf(stdout, "Processing\n");
|
||||
}
|
||||
else fprintf(stdout, "Not listening\n");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
void get_audio(int ms, std::vector<float> & pcmf32_cur) {
|
||||
g_audio.get(ms, pcmf32_cur);
|
||||
}
|
||||
|
||||
bool get_audio(std::vector<float> & pcmf32_cur) {
|
||||
if (!read_input()) return false;
|
||||
if (!g_pcmf32.empty()) pcmf32_cur = std::move(g_pcmf32);
|
||||
else pcmf32_cur.clear();
|
||||
return true;
|
||||
bool clear_audio() {
|
||||
g_audio.clear();
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
@ -186,10 +141,6 @@ int main(int argc, char ** argv) {
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "%s: whisper_init_from_file_with_params() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// init audio
|
||||
|
||||
@ -198,35 +149,42 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||
wparams.offset_ms = 0;
|
||||
wparams.translate = false;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.print_realtime = false;
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = true;
|
||||
wparams.print_special = false;
|
||||
wparams.no_timestamps = true;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.temperature = 0.0f;
|
||||
wparams.temperature_inc = 2.0f;
|
||||
wparams.greedy.best_of = 1;
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
wparams.greedy.best_of = 5;
|
||||
|
||||
wparams.beam_search.beam_size = 1;
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
wparams.grammar_penalty = 100.0;
|
||||
wparams.beam_search.beam_size = 5;
|
||||
|
||||
wparams.initial_prompt = params.context.data();
|
||||
|
||||
g_audio.resume();
|
||||
|
||||
// wait for 1 second to avoid any buffered noise
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
g_audio.clear();
|
||||
|
||||
WChess::callbacks cb;
|
||||
cb.check_running = sdl_poll_events;
|
||||
cb.get_audio = get_audio;
|
||||
cb.set_move = set_move;
|
||||
cb.set_moves = set_moves;
|
||||
cb.clear_audio = clear_audio;
|
||||
|
||||
WChess::settings s;
|
||||
s.vad_ms = 2000;
|
||||
@ -237,9 +195,11 @@ int main(int argc, char ** argv) {
|
||||
s.print_energy = params.print_energy;
|
||||
|
||||
g_wchess.reset(new WChess(ctx, wparams, cb, s));
|
||||
set_move("start", 0);
|
||||
set_moves("start");
|
||||
g_wchess->run();
|
||||
|
||||
g_audio.pause();
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
|
@ -8,7 +8,7 @@ include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
common
|
||||
wchess-core
|
||||
libwchess
|
||||
)
|
||||
|
||||
unset(EXTRA_FLAGS)
|
||||
|
@ -1,11 +1,7 @@
|
||||
<!doctype html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<title>wchess : voice-controlled chess using Whisper + WebAssembly</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=0.7, maximum-scale=1, minimum-scale=0.7, user-scalable=no"/>
|
||||
<meta name="apple-mobile-web-app-capable" content="yes" />
|
||||
<title>wchess : Voice assistant example using Whisper + WebAssembly</title>
|
||||
|
||||
<style>
|
||||
#output {
|
||||
@ -27,145 +23,61 @@
|
||||
overflow-wrap: normal;
|
||||
overflow-x: scroll;
|
||||
}
|
||||
.button {
|
||||
background-color: #000000;
|
||||
color: #FFFFFF;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
-moz-border-radius: 10px;
|
||||
-webkit-border-radius: 10px;
|
||||
margin:10px;
|
||||
width: 100px;
|
||||
height: 50px;
|
||||
-webkit-touch-callout: none; /* Safari */
|
||||
-webkit-user-select: none; /* Chrome */
|
||||
-moz-user-select: none; /* Firefox */
|
||||
-ms-user-select: none; /* Internet Explorer/Edge */
|
||||
user-select: none;
|
||||
}
|
||||
button[disabled]{
|
||||
background-color: #cccccc;
|
||||
color: #666666;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
-moz-border-radius: 10px;
|
||||
-webkit-border-radius: 10px;
|
||||
margin:10px;
|
||||
width: 100px;
|
||||
}
|
||||
.center {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
width: 500px;
|
||||
}
|
||||
#description {
|
||||
width: 500px;
|
||||
}
|
||||
</style>
|
||||
<link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous">
|
||||
</head>
|
||||
<body>
|
||||
<body onload="loadWhisper()">
|
||||
<div id="main-container">
|
||||
<div id="description">
|
||||
<b>wchess : voice-controlled chess using Whisper + WebAssembly</b>
|
||||
<b>wchess : Voice assistant example using Whisper + WebAssembly</b>
|
||||
|
||||
<br><br>
|
||||
<br><br>
|
||||
|
||||
This is a demonstration of using Whisper to recognize voice commands in the browser.
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
<br><br>
|
||||
|
||||
Usage:<br>
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<ul>
|
||||
<li>Select a Whisper model</li>
|
||||
<li>Accept the microphone permission request if prompted</li>
|
||||
<li>Hold the button and say a chess move (e.g. "Knight to c3")</li>
|
||||
<li>Release the button and wait for the move to be recognized</li>
|
||||
<li>Repeat</li>
|
||||
</ul>
|
||||
|
||||
Examples:<br>
|
||||
|
||||
<ul>
|
||||
<li><b>"d4"</b></li>
|
||||
<li><b>"e2 e4"</b></li>
|
||||
<li><b>"Knight f3"</b></li>
|
||||
<li><b>"Bishop to b5"</b></li>
|
||||
</ul>
|
||||
|
||||
Features:<br>
|
||||
|
||||
<ul>
|
||||
<li>Model quantization for reduced memory footprint (~42MB)</li>
|
||||
<li><a href="https://github.com/ggerganov/whisper.cpp/pull/1229">Grammar-based sampling</a> for improved recognition accuracy</li>
|
||||
</ul>
|
||||
|
||||
<b>
|
||||
Note that not all chess moves are supported. For example, castling and pawn promotion
|
||||
currently do not work, but can be easily implemented. There could also be some bugs in
|
||||
the move handling logic in general. The main reason for that is to keep the implementation
|
||||
simple. The assumption is that a real application would already have a proper move
|
||||
validation logic in place.<br><br>
|
||||
|
||||
The main purpose of this example is to demonstrate the capabilities of whisper.cpp and
|
||||
its application in the browser for voice recognition locally on your device.
|
||||
</b>
|
||||
|
||||
<br><br>
|
||||
|
||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/wchess">GitHub</a>.
|
||||
|
||||
<br><br>
|
||||
|
||||
<b>More examples:</b>
|
||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
||||
|
||||
<br><br>
|
||||
|
||||
</div>
|
||||
<br><br>
|
||||
|
||||
<hr>
|
||||
|
||||
<div id="model-whisper">
|
||||
Whisper model: <span id="model-whisper-status"></span>
|
||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper()">tiny.en (Q8_0, 42 MB)</button>
|
||||
<span id="fetch-whisper-progress"></span>
|
||||
<br><br>
|
||||
<button id="clear" onclick="clearCache()">Clear browser cache</button>
|
||||
|
||||
<!--
|
||||
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
||||
-->
|
||||
</div>
|
||||
|
||||
<div id="game">
|
||||
<br>
|
||||
<div id="chessboard" style="width: 500px"></div>
|
||||
<script src="js/jquery-3.7.1.min.js"></script>
|
||||
<script src="js/chessboard-1.0.0.min.js"></script>
|
||||
<script>
|
||||
var board = Chessboard('chessboard', 'start')
|
||||
var move_count = 0;
|
||||
</script>
|
||||
<br>
|
||||
<div id="myBoard" style="width: 400px"></div>
|
||||
<script src="js/jquery-3.7.1.min.js"></script>
|
||||
<script src="js/chessboard-1.0.0.min.js"></script>
|
||||
<script>
|
||||
var board = Chessboard('myBoard', 'start')
|
||||
</script>
|
||||
|
||||
<br>
|
||||
<br>
|
||||
|
||||
<div id="state">
|
||||
Status: <b><span id="state-status">select model</span></b>
|
||||
<div id="input">
|
||||
<button id="start" onclick="onStart()" disabled>Start</button>
|
||||
<button id="stop" onclick="onStop()" disabled>Stop</button>
|
||||
<button id="clear" onclick="clearCache()">Clear Cache</button>
|
||||
</div>
|
||||
|
||||
<div id="input" class="center">
|
||||
<button id="toggler" class="button" onselectstart="return false" style="display: none">Hold</button>
|
||||
</div>
|
||||
<br>
|
||||
|
||||
<pre id="state-grammar">[The grammar will be displayed here]</pre>
|
||||
<div id="state">
|
||||
Status: <b><span id="state-status">not started</span></b>
|
||||
|
||||
<pre id="state-moves">[The moves will be displayed here]</pre>
|
||||
</div>
|
||||
<pre id="state-moves">[The moves will be displayed here]</pre>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
@ -183,6 +95,7 @@
|
||||
|
||||
<ul>
|
||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
||||
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
|
||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
||||
</ul>
|
||||
|
||||
@ -202,14 +115,15 @@
|
||||
// web audio context
|
||||
var context = null;
|
||||
|
||||
// audio data
|
||||
var audio = null;
|
||||
var audio0 = null;
|
||||
|
||||
// the command instance
|
||||
var instance = null;
|
||||
|
||||
// model name
|
||||
var model_whisper = null;
|
||||
var model_file = null;
|
||||
|
||||
var module_ready = null;
|
||||
|
||||
var Module = {
|
||||
print: printTextarea,
|
||||
@ -223,30 +137,10 @@
|
||||
printTextarea('js: Preparing ...');
|
||||
},
|
||||
postRun: function() {
|
||||
printTextarea('js: Module initialized successfully!');
|
||||
module_ready = true;
|
||||
initInstance();
|
||||
printTextarea('js: Initialized successfully!');
|
||||
}
|
||||
};
|
||||
|
||||
function initInstance() {
|
||||
if (!module_ready || !model_file || instance) return
|
||||
|
||||
instance = Module.init(model_file);
|
||||
|
||||
if (instance) {
|
||||
setStatus('Ready');
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
else {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
}
|
||||
}
|
||||
|
||||
function setStatus(text) {
|
||||
document.getElementById('state-status').innerHTML = text;
|
||||
}
|
||||
|
||||
//
|
||||
// fetch models
|
||||
//
|
||||
@ -270,21 +164,36 @@
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
||||
|
||||
model_file = fname;
|
||||
initInstance();
|
||||
if (model_whisper != null) {
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop' ).disabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
function loadWhisper() {
|
||||
setStatus('Loading')
|
||||
//let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin';
|
||||
let url = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin';
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = 42;
|
||||
// let urls = {
|
||||
// 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
|
||||
// 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
|
||||
|
||||
model_whisper = 'tiny.en-q8_0';
|
||||
// 'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
|
||||
// 'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
|
||||
// };
|
||||
|
||||
// let sizes = {
|
||||
// 'tiny.en': 75,
|
||||
// 'base.en': 142,
|
||||
|
||||
// 'tiny-en-q5_1': 31,
|
||||
// 'base-en-q5_1': 57,
|
||||
// };
|
||||
|
||||
let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin';
|
||||
let dst = 'whisper.bin';
|
||||
let size_mb = 75;
|
||||
|
||||
model_whisper = 'tiny.en';
|
||||
|
||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... ';
|
||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
||||
|
||||
cbProgress = function(p) {
|
||||
let el = document.getElementById('fetch-whisper-progress');
|
||||
@ -297,30 +206,6 @@
|
||||
};
|
||||
|
||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
||||
|
||||
// init audio capture so that the user receives a permission request
|
||||
{
|
||||
let context = new AudioContext({
|
||||
sampleRate: 16000,
|
||||
channelCount: 1,
|
||||
echoCancellation: false,
|
||||
autoGainControl: true,
|
||||
noiseSuppression: true,
|
||||
});
|
||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
||||
.then(function(s) {
|
||||
stream = s;
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
})
|
||||
.catch(function(err) {
|
||||
printTextarea('js: error getting audio stream: ' + err);
|
||||
});
|
||||
context.close();
|
||||
}
|
||||
|
||||
document.getElementById('toggler').style.display = 'block';
|
||||
}
|
||||
|
||||
//
|
||||
@ -339,9 +224,11 @@
|
||||
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
||||
|
||||
function stopRecording() {
|
||||
if (mediaRecorder) {
|
||||
mediaRecorder.stop();
|
||||
}
|
||||
Module.set_status("paused");
|
||||
doRecording = false;
|
||||
audio0 = null;
|
||||
audio = null;
|
||||
context = null;
|
||||
}
|
||||
|
||||
function startRecording() {
|
||||
@ -355,6 +242,12 @@
|
||||
});
|
||||
}
|
||||
|
||||
Module.set_status("");
|
||||
|
||||
document.getElementById('start').disabled = true;
|
||||
document.getElementById('stop').disabled = false;
|
||||
|
||||
doRecording = true;
|
||||
startTime = Date.now();
|
||||
|
||||
var chunks = [];
|
||||
@ -372,6 +265,10 @@
|
||||
|
||||
reader.onload = function(event) {
|
||||
var buf = new Uint8Array(reader.result);
|
||||
|
||||
if (!context) {
|
||||
return;
|
||||
}
|
||||
context.decodeAudioData(buf.buffer, function(audioBuffer) {
|
||||
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
|
||||
var source = offlineContext.createBufferSource();
|
||||
@ -380,13 +277,22 @@
|
||||
source.start(0);
|
||||
|
||||
offlineContext.startRendering().then(function(renderedBuffer) {
|
||||
let audio = renderedBuffer.getChannelData(0);
|
||||
printTextarea('js: number of samples: ' + audio.length);
|
||||
Module.set_audio(instance, audio);
|
||||
});
|
||||
audio = renderedBuffer.getChannelData(0);
|
||||
|
||||
mediaRecorder = null;
|
||||
context = null;
|
||||
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
|
||||
|
||||
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
|
||||
if (audio0 != null) {
|
||||
audioAll.set(audio0, 0);
|
||||
}
|
||||
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
|
||||
|
||||
if (instance) {
|
||||
Module.set_audio(instance, audioAll);
|
||||
}
|
||||
});
|
||||
}, function(e) {
|
||||
audio = null;
|
||||
});
|
||||
}
|
||||
|
||||
@ -394,16 +300,48 @@
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = function(e) {
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
if (doRecording) {
|
||||
setTimeout(function() {
|
||||
startRecording();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.start();
|
||||
mediaRecorder.start(kIntervalAudio_ms);
|
||||
})
|
||||
.catch(function(err) {
|
||||
printTextarea('js: error getting audio stream: ' + err);
|
||||
});
|
||||
|
||||
var interval = setInterval(function() {
|
||||
if (!doRecording) {
|
||||
clearInterval(interval);
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
|
||||
document.getElementById('start').disabled = false;
|
||||
document.getElementById('stop').disabled = true;
|
||||
|
||||
mediaRecorder = null;
|
||||
}
|
||||
|
||||
// if audio length is more than kRestartRecording_s seconds, restart recording
|
||||
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
|
||||
if (doRecording) {
|
||||
//printTextarea('js: restarting recording');
|
||||
|
||||
clearInterval(interval);
|
||||
audio0 = audio;
|
||||
audio = null;
|
||||
mediaRecorder.stop();
|
||||
stream.getTracks().forEach(function(track) {
|
||||
track.stop();
|
||||
});
|
||||
}
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
//
|
||||
@ -411,88 +349,56 @@
|
||||
//
|
||||
|
||||
var nLines = 0;
|
||||
var intervalUpdate = null;
|
||||
var movesAll = '';
|
||||
|
||||
// document.body.addEventListener('keydown', function(event) {
|
||||
// if (event.keyCode === 32) {
|
||||
// document.getElementById('toggler').innerText = "";
|
||||
// onStart();
|
||||
// }
|
||||
// }, true);
|
||||
|
||||
// document.body.addEventListener('keyup', function(event) {
|
||||
// if (event.keyCode === 32) {
|
||||
// document.getElementById('toggler').innerText = "Hold";
|
||||
// onStop();
|
||||
// }
|
||||
// }, true);
|
||||
|
||||
document.getElementById('toggler').addEventListener("touchstart", function(event){
|
||||
this.innerText = "";
|
||||
onStart();
|
||||
}, true);
|
||||
|
||||
document.getElementById('toggler').addEventListener("touchend", function(event){
|
||||
this.innerText = "Hold";
|
||||
onStop();
|
||||
}, true)
|
||||
|
||||
document.getElementById('toggler').addEventListener('mousedown', function(event) {
|
||||
this.innerText = "";
|
||||
onStart();
|
||||
}, true);
|
||||
|
||||
document.getElementById('toggler').addEventListener('mouseup', function(event) {
|
||||
this.innerText = "Hold";
|
||||
onStop();
|
||||
}, true);
|
||||
|
||||
function onStart() {
|
||||
if (!instance) return;
|
||||
setStatus('Listening');
|
||||
if (!instance) {
|
||||
instance = Module.init('whisper.bin');
|
||||
|
||||
if (instance) {
|
||||
printTextarea("js: whisper initialized, instance: " + instance);
|
||||
}
|
||||
}
|
||||
|
||||
if (!instance) {
|
||||
printTextarea("js: failed to initialize whisper");
|
||||
return;
|
||||
}
|
||||
|
||||
startRecording();
|
||||
|
||||
intervalUpdate = setInterval(function() {
|
||||
var moves = Module.get_moves();
|
||||
|
||||
if (moves != null && moves.length > 1) {
|
||||
|
||||
for (move of moves.split(' ')) {
|
||||
board.move(move);
|
||||
}
|
||||
|
||||
movesAll += moves + '<br>';
|
||||
nLines++;
|
||||
|
||||
// if more than 10 lines, remove the first line
|
||||
if (nLines > 10) {
|
||||
var i = movesAll.indexOf('<br>');
|
||||
if (i > 0) {
|
||||
movesAll = movesAll.substring(i + 4);
|
||||
nLines--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById('state-status').innerHTML = Module.get_status();
|
||||
document.getElementById('state-moves').innerHTML = movesAll;
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function onStop() {
|
||||
setStatus('Processing');
|
||||
printTextarea('js: stopping recording ...');
|
||||
stopRecording();
|
||||
}
|
||||
|
||||
function setMove(move, prob) {
|
||||
if (move != null && move.length > 1) {
|
||||
let gameOver = move[move.length - 1] === '#';
|
||||
if (gameOver) {
|
||||
move = move.substring(0, move.length - 1);
|
||||
document.getElementById('toggler').disabled = true;
|
||||
}
|
||||
board.move(move);
|
||||
|
||||
movesAll += move + ', prob = ' + prob.toFixed(2) + '% <br>';
|
||||
nLines++;
|
||||
|
||||
// if more than 10 lines, remove the first line
|
||||
if (nLines > 10) {
|
||||
var i = movesAll.indexOf('<br>');
|
||||
if (i > 0) {
|
||||
movesAll = movesAll.substring(i + 4);
|
||||
nLines--;
|
||||
}
|
||||
}
|
||||
++move_count;
|
||||
setStatus(gameOver ? 'Done' : move_count % 2 ? 'Black\'s turn' : 'White\'s turn');
|
||||
document.getElementById('state-moves').innerHTML = movesAll;
|
||||
}
|
||||
else {
|
||||
setStatus('Failed. ' + (move_count % 2 ? 'Black\'s turn' : 'White\'s turn'));
|
||||
}
|
||||
}
|
||||
|
||||
function setGrammar(grammar) {
|
||||
document.getElementById('state-grammar').innerHTML = grammar;
|
||||
}
|
||||
|
||||
</script>
|
||||
<script type="text/javascript" src="js/chess.js"></script>
|
||||
</body>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <WChess.h>
|
||||
#include <emscripten.h>
|
||||
#include <emscripten/bind.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
constexpr int N_THREAD = 8;
|
||||
@ -11,28 +11,45 @@ std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
||||
std::mutex g_mutex;
|
||||
std::thread g_worker;
|
||||
|
||||
std::condition_variable g_cv;
|
||||
std::atomic<bool> g_running(false);
|
||||
|
||||
std::string g_status = "";
|
||||
std::string g_status_forced = "";
|
||||
std::string g_moves = "";
|
||||
|
||||
bool g_running(false);
|
||||
std::vector<float> g_pcmf32;
|
||||
|
||||
void set_move(const std::string & move, float prob) {
|
||||
MAIN_THREAD_EM_ASM({
|
||||
setMove(UTF8ToString($0), $1)
|
||||
}, move.c_str(), prob);
|
||||
void set_status(const std::string & status) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status = status;
|
||||
}
|
||||
|
||||
void set_grammar(const std::string & grammar) {
|
||||
MAIN_THREAD_EM_ASM({
|
||||
setGrammar(UTF8ToString($0))
|
||||
}, grammar.c_str());
|
||||
void set_moves(const std::string & moves) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_moves = moves;
|
||||
}
|
||||
|
||||
bool get_audio(std::vector<float> & audio) {
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
g_cv.wait(lock, [] { return !g_running || !g_pcmf32.empty(); });
|
||||
if (!g_running) return false;
|
||||
audio = std::move(g_pcmf32);
|
||||
void get_audio(int ms, std::vector<float> & audio) {
|
||||
const int64_t n_samples = (ms * WHISPER_SAMPLE_RATE) / 1000;
|
||||
|
||||
int64_t n_take = 0;
|
||||
if (n_samples > (int) g_pcmf32.size()) {
|
||||
n_take = g_pcmf32.size();
|
||||
} else {
|
||||
n_take = n_samples;
|
||||
}
|
||||
|
||||
audio.resize(n_take);
|
||||
std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin());
|
||||
}
|
||||
|
||||
bool check_running() {
|
||||
//g_pcmf32.clear();
|
||||
return g_running;
|
||||
}
|
||||
|
||||
bool clear_audio() {
|
||||
g_pcmf32.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -48,31 +65,30 @@ void wchess_main(size_t i) {
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = true;
|
||||
wparams.print_special = false;
|
||||
wparams.no_timestamps = true;
|
||||
|
||||
wparams.max_tokens = 32;
|
||||
wparams.audio_ctx = 1280; // partial encoder context for better performance
|
||||
// wparams.audio_ctx = 768; // partial encoder context for better performance
|
||||
|
||||
wparams.temperature = 0.0f;
|
||||
wparams.temperature_inc = 2.0f;
|
||||
wparams.greedy.best_of = 1;
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
wparams.greedy.best_of = 1;
|
||||
|
||||
wparams.beam_search.beam_size = 1;
|
||||
wparams.beam_search.beam_size = 5;
|
||||
|
||||
wparams.language = "en";
|
||||
|
||||
wparams.grammar_penalty = 100.0;
|
||||
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
|
||||
|
||||
printf("command: using %d threads\n", wparams.n_threads);
|
||||
|
||||
WChess::callbacks cb;
|
||||
cb.set_status = set_status;
|
||||
cb.check_running = check_running;
|
||||
cb.get_audio = get_audio;
|
||||
cb.set_move = set_move;
|
||||
cb.set_grammar = set_grammar;
|
||||
cb.set_moves = set_moves;
|
||||
cb.clear_audio = clear_audio;
|
||||
|
||||
WChess(g_contexts[i], wparams, cb, {}).run();
|
||||
|
||||
if (i < g_contexts.size()) {
|
||||
whisper_free(g_contexts[i]);
|
||||
g_contexts[i] = nullptr;
|
||||
@ -104,11 +120,9 @@ EMSCRIPTEN_BINDINGS(command) {
|
||||
}));
|
||||
|
||||
emscripten::function("free", emscripten::optional_override([](size_t /* index */) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
if (g_running) {
|
||||
g_running = false;
|
||||
}
|
||||
g_cv.notify_one();
|
||||
}));
|
||||
|
||||
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
|
||||
@ -134,8 +148,37 @@ EMSCRIPTEN_BINDINGS(command) {
|
||||
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
|
||||
memoryView.call<void>("set", audio);
|
||||
}
|
||||
g_cv.notify_one();
|
||||
|
||||
return 0;
|
||||
}));
|
||||
|
||||
emscripten::function("get_moves", emscripten::optional_override([]() {
|
||||
std::string moves;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
moves = std::move(g_moves);
|
||||
}
|
||||
|
||||
|
||||
if (!moves.empty()) fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", moves.c_str(), "\033[0m");
|
||||
|
||||
return moves;
|
||||
}));
|
||||
|
||||
emscripten::function("get_status", emscripten::optional_override([]() {
|
||||
std::string status;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
status = g_status_forced.empty() ? g_status : g_status_forced;
|
||||
}
|
||||
|
||||
return status;
|
||||
}));
|
||||
|
||||
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
|
||||
std::lock_guard<std::mutex> lock(g_mutex);
|
||||
g_status_forced = status;
|
||||
}));
|
||||
}
|
||||
|
@ -206,7 +206,6 @@ void AudioInputCallback(void * inUserData,
|
||||
params.offset_ms = 0;
|
||||
params.no_context = true;
|
||||
params.single_segment = self->stateInp.isRealtime;
|
||||
params.no_timestamps = params.single_segment;
|
||||
|
||||
CFTimeInterval startTime = CACurrentMediaTime();
|
||||
|
||||
|
@ -8,15 +8,15 @@ enum WhisperError: Error {
|
||||
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
|
||||
actor WhisperContext {
|
||||
private var context: OpaquePointer
|
||||
|
||||
|
||||
init(context: OpaquePointer) {
|
||||
self.context = context
|
||||
}
|
||||
|
||||
|
||||
deinit {
|
||||
whisper_free(context)
|
||||
}
|
||||
|
||||
|
||||
func fullTranscribe(samples: [Float]) {
|
||||
// Leave 2 processors free (i.e. the high-efficiency cores).
|
||||
let maxThreads = max(1, min(8, cpuCount() - 2))
|
||||
@ -24,17 +24,17 @@ actor WhisperContext {
|
||||
var params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY)
|
||||
"en".withCString { en in
|
||||
// Adapted from whisper.objc
|
||||
params.print_realtime = true
|
||||
params.print_progress = false
|
||||
params.print_realtime = true
|
||||
params.print_progress = false
|
||||
params.print_timestamps = true
|
||||
params.print_special = false
|
||||
params.translate = false
|
||||
params.language = en
|
||||
params.n_threads = Int32(maxThreads)
|
||||
params.offset_ms = 0
|
||||
params.no_context = true
|
||||
params.single_segment = false
|
||||
|
||||
params.print_special = false
|
||||
params.translate = false
|
||||
params.language = en
|
||||
params.n_threads = Int32(maxThreads)
|
||||
params.offset_ms = 0
|
||||
params.no_context = true
|
||||
params.single_segment = false
|
||||
|
||||
whisper_reset_timings(context)
|
||||
print("About to run whisper_full")
|
||||
samples.withUnsafeBufferPointer { samples in
|
||||
@ -46,7 +46,7 @@ actor WhisperContext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func getTranscription() -> String {
|
||||
var transcription = ""
|
||||
for i in 0..<whisper_full_n_segments(context) {
|
||||
@ -54,7 +54,7 @@ actor WhisperContext {
|
||||
}
|
||||
return transcription
|
||||
}
|
||||
|
||||
|
||||
static func createContext(path: String) throws -> WhisperContext {
|
||||
var params = whisper_context_default_params()
|
||||
#if targetEnvironment(simulator)
|
||||
|
51
ggml-alloc.c
51
ggml-alloc.c
@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
|
||||
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
add_allocated_tensor(alloc, tensor);
|
||||
size_t cur_max = (char*)addr - (char*)alloc->base + size;
|
||||
size_t cur_max = (char*)addr - (char*)alloc->data + size;
|
||||
if (cur_max > alloc->max_size) {
|
||||
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
@ -168,6 +168,10 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
|
||||
size = aligned_offset(NULL, size, alloc->alignment);
|
||||
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
|
||||
|
||||
if (!alloc->measure) {
|
||||
ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
|
||||
}
|
||||
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
remove_allocated_tensor(alloc, tensor);
|
||||
#endif
|
||||
@ -233,7 +237,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
|
||||
}
|
||||
|
||||
ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
|
||||
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
|
||||
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
|
||||
|
||||
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
|
||||
|
||||
@ -445,6 +449,7 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
|
||||
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
|
||||
ggml_tallocr_t alloc = node_tallocr(galloc, view);
|
||||
|
||||
//printf("init_view: %s from src %s\n", view->name, view->view_src->name);
|
||||
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
|
||||
if (update_backend) {
|
||||
view->backend = view->view_src->backend;
|
||||
@ -454,7 +459,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
|
||||
|
||||
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
|
||||
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
|
||||
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
|
||||
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
|
||||
|
||||
if (!alloc->measure) {
|
||||
ggml_backend_buffer_init_tensor(alloc->buffer, view);
|
||||
@ -760,43 +765,3 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
|
||||
size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
|
||||
return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
|
||||
}
|
||||
|
||||
// utils
|
||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
|
||||
|
||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||
|
||||
size_t nbytes = 0;
|
||||
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->data == NULL && t->view_src == NULL) {
|
||||
nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
|
||||
}
|
||||
}
|
||||
|
||||
if (nbytes == 0) {
|
||||
fprintf(stderr, "%s: no tensors to allocate\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
|
||||
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
|
||||
|
||||
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->data == NULL) {
|
||||
if (t->view_src == NULL) {
|
||||
ggml_tallocr_alloc(tallocr, t);
|
||||
} else {
|
||||
ggml_backend_view_init(buffer, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tallocr_free(tallocr);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
|
||||
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ extern "C" {
|
||||
|
||||
struct ggml_backend;
|
||||
struct ggml_backend_buffer;
|
||||
struct ggml_backend_buffer_type;
|
||||
|
||||
//
|
||||
// Legacy API
|
||||
@ -43,7 +42,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
|
||||
// ggml-backend v2 API
|
||||
//
|
||||
|
||||
// Separate tensor and graph allocator objects
|
||||
// Seperate tensor and graph allocator objects
|
||||
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
|
||||
// The original API is kept as a wrapper around the new API
|
||||
|
||||
@ -81,12 +80,6 @@ GGML_API void ggml_gallocr_alloc_graph_n(
|
||||
struct ggml_hash_set hash_set,
|
||||
ggml_tallocr_t * hash_node_talloc);
|
||||
|
||||
|
||||
// Utils
|
||||
// Create a buffer and allocate all the tensors in a ggml_context
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -12,50 +12,31 @@ extern "C" {
|
||||
// Backend buffer
|
||||
//
|
||||
|
||||
// buffer type
|
||||
typedef void * ggml_backend_buffer_type_context_t;
|
||||
|
||||
struct ggml_backend_buffer_type_i {
|
||||
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
|
||||
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
|
||||
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
|
||||
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
|
||||
};
|
||||
|
||||
struct ggml_backend_buffer_type {
|
||||
struct ggml_backend_buffer_type_i iface;
|
||||
ggml_backend_buffer_type_context_t context;
|
||||
};
|
||||
|
||||
// buffer
|
||||
typedef void * ggml_backend_buffer_context_t;
|
||||
|
||||
struct ggml_backend_buffer_i {
|
||||
void (*free_buffer)(ggml_backend_buffer_t buffer);
|
||||
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
|
||||
void * (*get_base) (ggml_backend_buffer_t buffer);
|
||||
void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
|
||||
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
void (*free_buffer) (ggml_backend_buffer_t buffer);
|
||||
void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
|
||||
size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
|
||||
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
|
||||
void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
|
||||
};
|
||||
|
||||
struct ggml_backend_buffer {
|
||||
struct ggml_backend_buffer_i iface;
|
||||
ggml_backend_buffer_type_t buft;
|
||||
struct ggml_backend_buffer_i iface;
|
||||
|
||||
ggml_backend_t backend;
|
||||
ggml_backend_buffer_context_t context;
|
||||
|
||||
size_t size;
|
||||
};
|
||||
|
||||
ggml_backend_buffer_t ggml_backend_buffer_init(
|
||||
ggml_backend_buffer_type_t buft,
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
|
||||
struct ggml_backend * backend,
|
||||
struct ggml_backend_buffer_i iface,
|
||||
ggml_backend_buffer_context_t context,
|
||||
size_t size);
|
||||
|
||||
|
||||
//
|
||||
// Backend
|
||||
//
|
||||
@ -68,18 +49,21 @@ extern "C" {
|
||||
void (*free)(ggml_backend_t backend);
|
||||
|
||||
// buffer allocation
|
||||
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
|
||||
ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
|
||||
|
||||
// (optional) asynchroneous tensor data access
|
||||
// get buffer alignment
|
||||
size_t (*get_alignment)(ggml_backend_t backend);
|
||||
|
||||
// tensor data access
|
||||
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
|
||||
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
// (optional) asynchroneous tensor copy
|
||||
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
|
||||
void (*synchronize) (ggml_backend_t backend);
|
||||
|
||||
// (optional) copy tensor between different backends, allow for single-copy tranfers
|
||||
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
|
||||
// compute graph with a plan
|
||||
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||
@ -98,15 +82,6 @@ extern "C" {
|
||||
ggml_backend_context_t context;
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// Backend registry
|
||||
//
|
||||
|
||||
typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
|
||||
|
||||
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
775
ggml-backend.c
775
ggml-backend.c
File diff suppressed because it is too large
Load Diff
@ -7,44 +7,41 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
|
||||
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
||||
typedef struct ggml_backend * ggml_backend_t;
|
||||
typedef void * ggml_backend_graph_plan_t;
|
||||
|
||||
//
|
||||
// Backend buffer
|
||||
//
|
||||
|
||||
// buffer type
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
|
||||
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
|
||||
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
|
||||
struct ggml_backend_buffer;
|
||||
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
||||
|
||||
// buffer
|
||||
// backend buffer functions
|
||||
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
|
||||
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
||||
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
|
||||
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
|
||||
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
||||
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
|
||||
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||
|
||||
//
|
||||
// Backend
|
||||
//
|
||||
|
||||
struct ggml_backend;
|
||||
typedef struct ggml_backend * ggml_backend_t;
|
||||
typedef void * ggml_backend_graph_plan_t;
|
||||
|
||||
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
|
||||
GGML_API void ggml_backend_free(ggml_backend_t backend);
|
||||
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
|
||||
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
|
||||
|
||||
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
|
||||
|
||||
GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
@ -60,7 +57,6 @@ extern "C" {
|
||||
|
||||
// tensor copy between different backends
|
||||
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
|
||||
|
||||
//
|
||||
// CPU backend
|
||||
@ -72,23 +68,8 @@ extern "C" {
|
||||
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
|
||||
|
||||
// Create a backend buffer from an existing pointer
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
|
||||
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
|
||||
|
||||
//
|
||||
// Backend registry
|
||||
//
|
||||
|
||||
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
|
||||
|
||||
GGML_API size_t ggml_backend_reg_get_count(void);
|
||||
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
|
||||
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
|
||||
GGML_API const char * ggml_backend_reg_get_name(size_t i);
|
||||
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
|
||||
|
||||
//
|
||||
// Backend scheduler
|
||||
@ -150,32 +131,6 @@ extern "C" {
|
||||
ggml_backend_sched_t sched,
|
||||
struct ggml_cgraph * graph);
|
||||
|
||||
|
||||
//
|
||||
// Utils
|
||||
//
|
||||
|
||||
struct ggml_backend_graph_copy {
|
||||
ggml_backend_buffer_t buffer;
|
||||
struct ggml_context * ctx_allocated;
|
||||
struct ggml_context * ctx_unallocated;
|
||||
struct ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
// Copy a graph to a different backend
|
||||
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
|
||||
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
|
||||
|
||||
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
||||
|
||||
// Compare the output of two backends
|
||||
GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
|
||||
|
||||
// Tensor initialization
|
||||
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
||||
GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
2421
ggml-cuda.cu
2421
ggml-cuda.cu
File diff suppressed because it is too large
Load Diff
10
ggml-cuda.h
10
ggml-cuda.h
@ -49,15 +49,7 @@ GGML_API int ggml_cuda_get_device_count(void);
|
||||
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
||||
|
||||
// backend API
|
||||
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
|
||||
|
||||
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
||||
GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
|
||||
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
||||
|
||||
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
||||
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml
|
||||
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
|
||||
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
||||
|
||||
// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
|
||||
// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
|
||||
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
||||
|
||||
// return index, asserts if table is full
|
||||
|
@ -99,12 +99,6 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
|
||||
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
||||
|
||||
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
||||
|
||||
// helper to check if the device supports a specific family
|
||||
// ideally, the user code should be doing these checks
|
||||
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
1387
ggml-metal.m
1387
ggml-metal.m
File diff suppressed because it is too large
Load Diff
2649
ggml-metal.metal
2649
ggml-metal.metal
File diff suppressed because it is too large
Load Diff
@ -1,18 +1,20 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-opencl.h"
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#define CL_TARGET_OPENCL_VERSION 110
|
||||
#include <clblast.h>
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
@ -19,7 +19,7 @@
|
||||
#ifdef __wasm_simd128__
|
||||
#include <wasm_simd128.h>
|
||||
#else
|
||||
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
|
||||
#ifdef __POWER9_VECTOR__
|
||||
#include <altivec.h>
|
||||
#undef bool
|
||||
#define bool _Bool
|
||||
@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
||||
|
||||
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
||||
|
||||
// These temporary registers are for masking and shift operations
|
||||
// These tempory registers are for masking and shift operations
|
||||
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
||||
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
||||
|
||||
@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
vl = 16;
|
||||
|
||||
// retrieve lane to multiply with scale
|
||||
// retreive lane to multiply with scale
|
||||
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
||||
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
||||
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
||||
|
91
ggml.h
91
ggml.h
@ -215,9 +215,9 @@
|
||||
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
|
||||
|
||||
#define GGML_MAX_DIMS 4
|
||||
#define GGML_MAX_PARAMS 2048
|
||||
#define GGML_MAX_PARAMS 1024
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_SRC 10
|
||||
#define GGML_MAX_SRC 6
|
||||
#define GGML_MAX_NAME 64
|
||||
#define GGML_MAX_OP_PARAMS 64
|
||||
#define GGML_DEFAULT_N_THREADS 4
|
||||
@ -244,10 +244,11 @@
|
||||
#define GGML_ASSERT(x) \
|
||||
do { \
|
||||
if (!(x)) { \
|
||||
fflush(stdout); \
|
||||
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
||||
fflush(stderr); \
|
||||
fflush(stdout); \
|
||||
ggml_print_backtrace(); \
|
||||
abort(); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@ -283,20 +284,6 @@
|
||||
const type prefix##3 = (pointer)->array[3]; \
|
||||
GGML_UNUSED(prefix##3);
|
||||
|
||||
#define GGML_TENSOR_UNARY_OP_LOCALS \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
#define GGML_TENSOR_BINARY_OP_LOCALS \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@ -395,7 +382,6 @@ extern "C" {
|
||||
GGML_OP_GROUP_NORM,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
GGML_OP_MUL_MAT_ID,
|
||||
GGML_OP_OUT_PROD,
|
||||
|
||||
GGML_OP_SCALE,
|
||||
@ -422,10 +408,8 @@ extern "C" {
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
|
||||
GGML_OP_FLASH_ATTN,
|
||||
GGML_OP_FLASH_FF,
|
||||
@ -465,8 +449,7 @@ extern "C" {
|
||||
GGML_UNARY_OP_GELU,
|
||||
GGML_UNARY_OP_GELU_QUICK,
|
||||
GGML_UNARY_OP_SILU,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
GGML_UNARY_OP_LEAKY
|
||||
};
|
||||
|
||||
enum ggml_object_type {
|
||||
@ -649,9 +632,6 @@ extern "C" {
|
||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||
|
||||
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
|
||||
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
|
||||
|
||||
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API bool ggml_is_quantized(enum ggml_type type);
|
||||
@ -794,9 +774,6 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// dst = a
|
||||
// view(dst, nb1, nb2, nb3, offset) += b
|
||||
// return dst
|
||||
GGML_API struct ggml_tensor * ggml_acc(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -961,14 +938,15 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_leaky_relu(
|
||||
GGML_API struct ggml_tensor * ggml_leaky(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, float negative_slope, bool inplace);
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_relu_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// TODO: double-check this computation is correct
|
||||
GGML_API struct ggml_tensor * ggml_gelu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
@ -1050,16 +1028,6 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// indirect matrix multiplication
|
||||
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * const as[],
|
||||
int n_as,
|
||||
struct ggml_tensor * ids,
|
||||
int id,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// A: m columns, n rows,
|
||||
// B: p columns, n rows,
|
||||
// result is m columns, p rows
|
||||
@ -1267,7 +1235,6 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// supports 3D: a->ne[2] == b->ne[1]
|
||||
GGML_API struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -1316,14 +1283,6 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// fused soft_max(a*scale + mask)
|
||||
// mask is optional
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -1554,32 +1513,6 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
int scale_factor);
|
||||
|
||||
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
||||
GGML_API struct ggml_tensor * ggml_pad(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
// sort rows
|
||||
enum ggml_sort_order {
|
||||
GGML_SORT_ASC,
|
||||
GGML_SORT_DESC,
|
||||
};
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_argsort(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_sort_order order);
|
||||
|
||||
// top k elements per row
|
||||
GGML_API struct ggml_tensor * ggml_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
@ -1641,6 +1574,7 @@ extern "C" {
|
||||
int kh);
|
||||
|
||||
// used in sam
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add_rel_pos(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -1815,7 +1749,7 @@ extern "C" {
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
|
||||
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
||||
GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
|
||||
GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
|
||||
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
||||
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
|
||||
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
||||
@ -2111,7 +2045,6 @@ extern "C" {
|
||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
|
||||
|
216
whisper.cpp
216
whisper.cpp
@ -1063,7 +1063,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (params.use_gpu && ggml_cublas_loaded()) {
|
||||
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
||||
backend_gpu = ggml_backend_cuda_init(0);
|
||||
backend_gpu = ggml_backend_cuda_init();
|
||||
if (!backend_gpu) {
|
||||
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
}
|
||||
@ -1077,10 +1077,6 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
|
||||
backend_gpu = ggml_backend_metal_init();
|
||||
if (!backend_gpu) {
|
||||
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
|
||||
} else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
|
||||
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
|
||||
ggml_backend_free(backend_gpu);
|
||||
backend_gpu = NULL;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@ -1345,10 +1341,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
||||
|
||||
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
||||
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
||||
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
|
||||
|
||||
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
||||
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
||||
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
|
||||
|
||||
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
||||
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
||||
@ -1578,25 +1574,29 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
||||
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
||||
return false;
|
||||
}
|
||||
const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
||||
return false;
|
||||
}
|
||||
if (!is_conv_bias) {
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
||||
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_t backend = wctx.backend;
|
||||
@ -1607,7 +1607,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
#ifdef GGML_USE_METAL
|
||||
|| ggml_backend_is_metal(backend)
|
||||
#endif
|
||||
)) {
|
||||
) && !is_conv_bias) {
|
||||
// for the CPU and Metal backend, we can read directly into the tensor
|
||||
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
||||
BYTESWAP_TENSOR(tensor);
|
||||
@ -1615,7 +1615,24 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
// read into a temporary buffer first, then copy to device memory
|
||||
read_buf.resize(ggml_nbytes(tensor));
|
||||
|
||||
loader->read(loader->context, read_buf.data(), read_buf.size());
|
||||
// we repeat the 2 bias tensors along dim 0:
|
||||
// [1, 512] -> [3000, 512] (conv1.bias)
|
||||
// [1, 512] -> [1500, 512] (conv2.bias)
|
||||
if (is_conv_bias) {
|
||||
loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
|
||||
|
||||
float * data_f32 = (float *) read_buf.data();
|
||||
for (int64_t y = 0; y < tensor->ne[1]; ++y) {
|
||||
const int64_t yy = tensor->ne[1] - y - 1;
|
||||
const float val = data_f32[yy];
|
||||
|
||||
for (int64_t x = 0; x < tensor->ne[0]; ++x) {
|
||||
data_f32[yy*tensor->ne[0] + x] = val;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
loader->read(loader->context, read_buf.data(), read_buf.size());
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
||||
}
|
||||
@ -1715,12 +1732,20 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
// convolution + gelu
|
||||
{
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
||||
if (n_ctx == hparams.n_audio_ctx) {
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
||||
} else {
|
||||
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_1_b, cur->ne[0], cur->ne[1], model.e_conv_1_b->nb[1], 0)));
|
||||
}
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
||||
if (n_ctx == hparams.n_audio_ctx) {
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
||||
} else {
|
||||
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_2_b, cur->ne[0], cur->ne[1], model.e_conv_2_b->nb[1], 0)));
|
||||
}
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
}
|
||||
@ -3568,17 +3593,6 @@ const char * whisper_lang_str(int id) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const char * whisper_lang_str_full(int id) {
|
||||
for (const auto & kv : g_lang) {
|
||||
if (kv.second.first == id) {
|
||||
return kv.second.second.c_str();
|
||||
}
|
||||
}
|
||||
|
||||
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int whisper_lang_auto_detect_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
@ -5028,7 +5042,6 @@ int whisper_full_with_state(
|
||||
// basically don't process anything that is less than 1.0s
|
||||
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
||||
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
||||
WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -5167,7 +5180,7 @@ int whisper_full_with_state(
|
||||
ctx, state, progress_cur, params.progress_callback_user_data);
|
||||
}
|
||||
|
||||
// if only 1 second left, then stop
|
||||
// of only 1 second left, then stop
|
||||
if (seek + 100 >= seek_end) {
|
||||
break;
|
||||
}
|
||||
@ -5456,7 +5469,6 @@ int whisper_full_with_state(
|
||||
|
||||
// do not allow to go back in time
|
||||
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
||||
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
|
||||
failed = true; // TODO: maybe this is not a failure ?
|
||||
continue;
|
||||
}
|
||||
@ -5485,7 +5497,6 @@ int whisper_full_with_state(
|
||||
if (seek + seek_delta + 100 >= seek_end) {
|
||||
result_len = i + 1;
|
||||
} else {
|
||||
WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
||||
failed = true;
|
||||
continue;
|
||||
}
|
||||
@ -5496,7 +5507,6 @@ int whisper_full_with_state(
|
||||
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
||||
}
|
||||
|
||||
WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
|
||||
completed = true;
|
||||
continue;
|
||||
}
|
||||
@ -5512,7 +5522,6 @@ int whisper_full_with_state(
|
||||
// sometimes, the decoding can get stuck in a repetition loop
|
||||
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
||||
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
||||
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
|
||||
failed = true;
|
||||
continue;
|
||||
}
|
||||
@ -5656,27 +5665,28 @@ int whisper_full_with_state(
|
||||
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
||||
}
|
||||
|
||||
bool success = true;
|
||||
|
||||
// was the decoding successful for the current temperature?
|
||||
// do fallback only if:
|
||||
// - we are not at the last temperature
|
||||
if (it != (int) temperatures.size() - 1) {
|
||||
// - we are not at the end of the audio (3 sec)
|
||||
if (it != (int) temperatures.size() - 1 &&
|
||||
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
|
||||
bool success = true;
|
||||
|
||||
const auto & decoder = state->decoders[best_decoder_id];
|
||||
|
||||
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
||||
WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
|
||||
success = false;
|
||||
state->n_fail_p++;
|
||||
}
|
||||
}
|
||||
|
||||
if (success) {
|
||||
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
||||
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
||||
//}
|
||||
if (success) {
|
||||
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
||||
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
||||
//}
|
||||
|
||||
break;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
||||
@ -6054,43 +6064,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
||||
// 1GB array
|
||||
const size_t size = arr*1e6;
|
||||
|
||||
double sum = 0.0;
|
||||
|
||||
// heat-up
|
||||
{
|
||||
char * src = (char *) malloc(size);
|
||||
char * dst = (char *) malloc(size);
|
||||
|
||||
for (size_t i = 0; i < size; i++) src[i] = i;
|
||||
|
||||
memcpy(dst, src, size); // heat-up
|
||||
|
||||
double tsum = 0.0;
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
memcpy(dst, src, size);
|
||||
|
||||
const int64_t t1 = ggml_time_us();
|
||||
|
||||
tsum += (t1 - t0)*1e-6;
|
||||
|
||||
src[rand() % size] = rand() % 256;
|
||||
}
|
||||
|
||||
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
|
||||
s += strbuf;
|
||||
|
||||
// needed to prevent the compiler from optimizing the memcpy away
|
||||
{
|
||||
for (size_t i = 0; i < size; i++) sum += dst[i];
|
||||
}
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
}
|
||||
|
||||
// single-thread
|
||||
{
|
||||
char * src = (char *) malloc(size);
|
||||
@ -6101,6 +6074,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
||||
memcpy(dst, src, size); // heat-up
|
||||
|
||||
double tsum = 0.0;
|
||||
double sum = 0.0;
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
const int64_t t0 = ggml_time_us();
|
||||
@ -6114,73 +6088,21 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
||||
src[rand() % size] = rand() % 256;
|
||||
}
|
||||
|
||||
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
|
||||
snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1e9));
|
||||
s += strbuf;
|
||||
|
||||
// needed to prevent the compiler from optimizing the memcpy away
|
||||
{
|
||||
for (size_t i = 0; i < size; i++) sum += dst[i];
|
||||
|
||||
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
|
||||
s += strbuf;
|
||||
}
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
}
|
||||
|
||||
// multi-thread
|
||||
|
||||
for (uint32_t k = 1; k <= n_threads; k++) {
|
||||
char * src = (char *) malloc(size);
|
||||
char * dst = (char *) malloc(size);
|
||||
|
||||
for (size_t i = 0; i < size; i++) src[i] = i;
|
||||
|
||||
memcpy(dst, src, size); // heat-up
|
||||
|
||||
double tsum = 0.0;
|
||||
|
||||
auto helper = [&](int th) {
|
||||
const int64_t i0 = (th + 0)*size/k;
|
||||
const int64_t i1 = (th + 1)*size/k;
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
memcpy(dst + i0, src + i0, i1 - i0);
|
||||
|
||||
src[i0 + rand() % (i1 - i0)] = rand() % 256;
|
||||
};
|
||||
};
|
||||
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
std::vector<std::thread> threads(k - 1);
|
||||
for (uint32_t th = 0; th < k - 1; ++th) {
|
||||
threads[th] = std::thread(helper, th);
|
||||
}
|
||||
|
||||
helper(k - 1);
|
||||
|
||||
for (uint32_t th = 0; th < k - 1; ++th) {
|
||||
threads[th].join();
|
||||
}
|
||||
|
||||
const int64_t t1 = ggml_time_us();
|
||||
|
||||
tsum += (t1 - t0)*1e-6;
|
||||
|
||||
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
|
||||
s += strbuf;
|
||||
|
||||
// needed to prevent the compiler from optimizing the memcpy away
|
||||
{
|
||||
for (size_t i = 0; i < size; i++) sum += dst[i];
|
||||
}
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
}
|
||||
|
||||
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
|
||||
s += strbuf;
|
||||
|
||||
return s.c_str();
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ extern "C" {
|
||||
// ...
|
||||
//
|
||||
// whisper_context_params cparams = whisper_context_default_params();
|
||||
//
|
||||
//
|
||||
// struct whisper_context * ctx = whisper_init_from_file_with_params("/path/to/ggml-base.en.bin", cparams);
|
||||
//
|
||||
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
@ -315,9 +315,6 @@ extern "C" {
|
||||
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
|
||||
WHISPER_API const char * whisper_lang_str(int id);
|
||||
|
||||
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
|
||||
WHISPER_API const char * whisper_lang_str_full(int id);
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
|
||||
// Returns the top language id or negative on failure
|
||||
|
Reference in New Issue
Block a user