Compare commits

..

11 Commits

Author SHA1 Message Date
4260d4fc70 wchess : minor 2023-11-28 15:10:18 +02:00
ee65df7982 wchess : add clear_audio callback 2023-11-28 13:37:26 +02:00
03f254193b wchess: hardcoded rules 2023-11-27 10:51:20 +02:00
8f2d8eae10 wchess: basic chess rules 2023-11-27 10:41:04 +02:00
a44b21bce0 wchess: tidy up entry files 2023-11-25 11:34:06 +02:00
f07ff2aa6a chess -> wchess 2023-11-25 10:16:48 +02:00
280e631bcf chess.wasm: poc of chess rules 2023-11-23 16:09:00 +02:00
2f86da0d09 chess.wasm: add chessboard 2023-11-23 08:49:47 +02:00
a787f7f85c chess.wasm: encoder context value resulting in echoing 2023-11-21 20:42:20 +02:00
c83a38e89d chess.wasm: go back to greedy 2023-11-21 16:56:22 +02:00
758c951729 chess.wasm: grammar in emscripten 2023-11-21 16:30:44 +02:00
49 changed files with 2490 additions and 8721 deletions

View File

@ -25,7 +25,6 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \ docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update apt update
apt install -y build-essential libsdl2-dev apt install -y build-essential libsdl2-dev
make make
@ -87,7 +86,6 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \ docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update apt update
apt install -y build-essential cmake libsdl2-dev apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
@ -115,10 +113,8 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \ docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update apt update
apt install -y clang apt install -y build-essential cmake libsdl2-dev
apt install -y clang build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
make make
ctest -L gh --output-on-failure' ctest -L gh --output-on-failure'
@ -144,7 +140,6 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \ docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update apt update
apt install -y build-essential cmake apt install -y build-essential cmake
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
@ -222,10 +217,10 @@ jobs:
sdl2: [ON] sdl2: [ON]
include: include:
- arch: Win32 - arch: Win32
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
s2arc: x86 s2arc: x86
- arch: x64 - arch: x64
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
s2arc: x64 s2arc: x64
- sdl2: ON - sdl2: ON
s2ver: 2.26.0 s2ver: 2.26.0
@ -290,7 +285,6 @@ jobs:
arch: [x64] arch: [x64]
cublas: [ON] cublas: [ON]
sdl2: [ON] sdl2: [ON]
cuda-toolkit: [12.2.0, 11.8.0]
include: include:
- arch: x64 - arch: x64
s2arc: x64 s2arc: x64
@ -306,9 +300,7 @@ jobs:
- name: Install CUDA Toolkit - name: Install CUDA Toolkit
id: cuda-toolkit id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.11 uses: Jimver/cuda-toolkit@v0.2.10
with:
cuda: '${{ matrix.cuda-toolkit }}'
- name: Fetch SDL2 and set SDL2_DIR - name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
@ -323,10 +315,10 @@ jobs:
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_CUBLAS=1 -DWHISPER_CUBLAS=1
- name: Build ${{ matrix.cuda-toolkit }} - name: Build
run: | run: |
cd ./build cd ./build
cmake --build . --config ${{ matrix.build }} msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy CUDA DLLs - name: Copy CUDA DLLs
run: > run: >
@ -343,7 +335,7 @@ jobs:
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1 uses: actions/upload-artifact@v1
with: with:
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }} name: whisper-cublas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
emscripten: emscripten:

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.5) cmake_minimum_required (VERSION 3.5)
project(whisper.cpp VERSION 1.5.2) project(whisper.cpp VERSION 1.5.0)
# Add path to modules # Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@ -533,7 +533,7 @@ target_compile_definitions(${TARGET} PUBLIC
${WHISPER_EXTRA_FLAGS} ${WHISPER_EXTRA_FLAGS}
) )
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "ggml.h;whisper.h") set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
include(GNUInstallDirs) include(GNUInstallDirs)

View File

@ -2,14 +2,33 @@
import PackageDescription import PackageDescription
let package = Package( #if arch(arm) || arch(arm64)
name: "whisper", let platforms: [SupportedPlatform]? = [
platforms: [
.macOS(.v12), .macOS(.v12),
.iOS(.v14), .iOS(.v14),
.watchOS(.v4), .watchOS(.v4),
.tvOS(.v14) .tvOS(.v14)
], ]
let exclude: [String] = []
let resources: [Resource] = [
.process("ggml-metal.metal")
]
let additionalSources: [String] = ["ggml-metal.m"]
let additionalSettings: [CSetting] = [
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
]
#else
let platforms: [SupportedPlatform]? = nil
let exclude: [String] = ["ggml-metal.metal"]
let resources: [Resource] = []
let additionalSources: [String] = []
let additionalSettings: [CSetting] = []
#endif
let package = Package(
name: "whisper",
platforms: platforms,
products: [ products: [
.library(name: "whisper", targets: ["whisper"]), .library(name: "whisper", targets: ["whisper"]),
], ],
@ -17,7 +36,7 @@ let package = Package(
.target( .target(
name: "whisper", name: "whisper",
path: ".", path: ".",
exclude: [ exclude: exclude + [
"bindings", "bindings",
"cmake", "cmake",
"coreml", "coreml",
@ -36,22 +55,19 @@ let package = Package(
"whisper.cpp", "whisper.cpp",
"ggml-alloc.c", "ggml-alloc.c",
"ggml-backend.c", "ggml-backend.c",
"ggml-quants.c", "ggml-quants.c"
"ggml-metal.m" ] + additionalSources,
], resources: resources,
resources: [.process("ggml-metal.metal")],
publicHeadersPath: "spm-headers", publicHeadersPath: "spm-headers",
cSettings: [ cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]), .unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
.define("GGML_USE_ACCELERATE"), .define("GGML_USE_ACCELERATE")
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
// NOTE: NEW_LAPACK will required iOS version 16.4+ // NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14 // We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc) // (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"), // .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64") // .define("ACCELERATE_LAPACK_ILP64")
], ] + additionalSettings,
linkerSettings: [ linkerSettings: [
.linkedFramework("Accelerate") .linkedFramework("Accelerate")
] ]

View File

@ -6,7 +6,7 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.5.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.2) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) Stable: [v1.5.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -110,8 +110,8 @@ options:
-mc N, --max-context N [-1 ] maximum number of text context tokens to store -mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters -ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token -sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep -bo N, --best-of N [2 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search -bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold -wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
@ -128,7 +128,6 @@ options:
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video -fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file -ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file -oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension) -of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens -ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors -pc, --print-colors [false ] print colors
@ -140,8 +139,7 @@ options:
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path -f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of tokens -ls, --log-score [false ] log best decoder scores of token
-ng, --no-gpu [false ] disable GPU
bash ./models/download-ggml-model.sh base.en bash ./models/download-ggml-model.sh base.en
@ -770,7 +768,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine | | [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture | | [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic | | [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot | | [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot | | [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp | | [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
@ -780,7 +777,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture | | [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) | | [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) | | [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [server](examples/server) | | HTTP transcription server with OAI-like API |
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions) ## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)

View File

@ -1,26 +1,9 @@
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
ifndef UNAME_P
UNAME_P := $(shell uname -p)
endif
ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
BUILD_DIR := build BUILD_DIR := build
MODELS_DIR := models MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*) EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../..) INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..) LIBRARY_PATH := $(abspath ../..)
ifeq ($(UNAME_S),Darwin)
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
endif
all: clean whisper examples all: clean whisper examples
whisper: mkdir whisper: mkdir
@ -28,13 +11,8 @@ whisper: mkdir
@${MAKE} -C ../.. libwhisper.a @${MAKE} -C ../.. libwhisper.a
test: model-small whisper modtidy test: model-small whisper modtidy
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
else
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v . @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/... @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
endif
examples: $(EXAMPLES_DIR) examples: $(EXAMPLES_DIR)
@ -43,11 +21,7 @@ model-small: mkdir examples/go-model-download
$(EXAMPLES_DIR): mkdir whisper modtidy $(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@) @echo Build example $(notdir $@)
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go build ${BUILD_FLAGS} -ldflags "-extldflags '$(EXT_LDFLAGS)'" -o ${BUILD_DIR}/$(notdir $@) ./$@
else
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@ @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
endif
mkdir: mkdir:
@echo Mkdir ${BUILD_DIR} @echo Mkdir ${BUILD_DIR}

View File

@ -1,6 +1,6 @@
{ {
"name": "whisper.cpp", "name": "whisper.cpp",
"version": "1.5.2", "version": "1.5.0",
"description": "Whisper speech recognition", "description": "Whisper speech recognition",
"main": "whisper.js", "main": "whisper.js",
"scripts": { "scripts": {

File diff suppressed because one or more lines are too long

View File

@ -22,7 +22,6 @@ var printTextarea = (function() {
async function clearCache() { async function clearCache() {
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) { if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
indexedDB.deleteDatabase(dbName); indexedDB.deleteDatabase(dbName);
location.reload();
} }
} }

View File

@ -17,37 +17,28 @@ options:
-d N, --duration N [0 ] duration of audio to process in milliseconds -d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store -mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters -ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep -bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search -bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold -wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel) -su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english -tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization -di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-nf, --no-fallback [false ] do not use temperature fallback while decoding -nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file -otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file -ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file -osrt, --output-srt [false ] output result in a srt file
-olrc, --output-lrc [false ] output result in a lrc file
-owts, --output-words [false ] output script for generating karaoke video -owts, --output-words [false ] output script for generating karaoke video
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file -ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file -oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension) -of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens -ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors -pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress -pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps -nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect) -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt --prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path -f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
``` ```

View File

@ -4,9 +4,3 @@ add_executable(${TARGET} server.cpp httplib.h json.hpp)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
# Check if the compiler is MinGW
if(MINGW)
# Link the necessary libraries for SSL and Winsock
target_link_libraries(${TARGET} PRIVATE -lcrypt32 -lssl -lcrypto -lws2_32)
endif()

View File

@ -2,10 +2,6 @@
Simple http server. WAV Files are passed to the inference model via http requests. Simple http server. WAV Files are passed to the inference model via http requests.
https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-afe5e4594b8f
## Usage
``` ```
./server -h ./server -h
@ -33,7 +29,6 @@ options:
-nf, --no-fallback [false ] do not use temperature fallback while decoding -nf, --no-fallback [false ] do not use temperature fallback while decoding
-ps, --print-special [false ] print special tokens -ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors -pc, --print-colors [false ] print colors
-pr, --print-realtime [false ] print output in realtime
-pp, --print-progress [false ] print progress -pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps -nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect) -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
@ -43,12 +38,8 @@ options:
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
--host HOST, [127.0.0.1] Hostname/ip-adress for the server --host HOST, [127.0.0.1] Hostname/ip-adress for the server
--port PORT, [8080 ] Port number for the server --port PORT, [8080 ] Port number for the server
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
``` ```
> [!WARNING]
> **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.**
## request examples ## request examples
**/inference** **/inference**

View File

@ -11,7 +11,6 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <cstring> #include <cstring>
#include <sstream>
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
@ -44,8 +43,6 @@ struct server_params
int32_t port = 8080; int32_t port = 8080;
int32_t read_timeout = 600; int32_t read_timeout = 600;
int32_t write_timeout = 600; int32_t write_timeout = 600;
bool ffmpeg_converter = false;
}; };
struct whisper_params { struct whisper_params {
@ -75,7 +72,6 @@ struct whisper_params {
bool no_fallback = false; bool no_fallback = false;
bool print_special = false; bool print_special = false;
bool print_colors = false; bool print_colors = false;
bool print_realtime = false;
bool print_progress = false; bool print_progress = false;
bool no_timestamps = false; bool no_timestamps = false;
bool use_gpu = true; bool use_gpu = true;
@ -148,7 +144,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
@ -160,7 +155,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -194,7 +188,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@ -207,7 +200,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; } else if ( arg == "--host") { sparams.hostname = argv[++i]; }
else if ( arg == "--public") { sparams.public_path = argv[++i]; } else if ( arg == "--public") { sparams.public_path = argv[++i]; }
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
else { else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params, sparams); whisper_print_usage(argc, argv, params, sparams);
@ -225,45 +217,6 @@ struct whisper_print_user_data {
int progress_prev; int progress_prev;
}; };
void check_ffmpeg_availibility() {
int result = system("ffmpeg -version");
if (result == 0) {
std::cout << "ffmpeg is available." << std::endl;
} else {
// ffmpeg is not available
std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed ";
std::cout << "and that its executable is included in your system's PATH. ";
exit(0);
}
}
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) {
std::ostringstream cmd_stream;
std::string converted_filename_temp = temp_filename + "_temp.wav";
cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1";
std::string cmd = cmd_stream.str();
int status = std::system(cmd.c_str());
if (status != 0) {
error_resp = "{\"error\":\"FFmpeg conversion failed.\"}";
return false;
}
// Remove the original file
if (remove(temp_filename.c_str()) != 0) {
error_resp = "{\"error\":\"Failed to remove the original file.\"}";
return false;
}
// Rename the temporary file to match the original filename
if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) {
error_resp = "{\"error\":\"Failed to rename the temporary file.\"}";
return false;
}
return true;
}
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
std::string speaker = ""; std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size(); const int64_t n_samples = pcmf32s[0].size();
@ -420,7 +373,7 @@ void get_req_parameters(const Request & req, whisper_params & params)
{ {
params.response_format = req.get_file_value("response-format").content; params.response_format = req.get_file_value("response-format").content;
} }
if (req.has_file("temperature")) if (req.has_file("temerature"))
{ {
params.userdef_temp = std::stof(req.get_file_value("temperature").content); params.userdef_temp = std::stof(req.get_file_value("temperature").content);
} }
@ -451,9 +404,6 @@ int main(int argc, char ** argv) {
exit(0); exit(0);
} }
if (sparams.ffmpeg_converter) {
check_ffmpeg_availibility();
}
// whisper init // whisper init
struct whisper_context_params cparams; struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu; cparams.use_gpu = params.use_gpu;
@ -469,9 +419,6 @@ int main(int argc, char ** argv) {
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
Server svr; Server svr;
svr.set_default_headers({{"Server", "whisper.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
std::string const default_content = "<html>hello</html>"; std::string const default_content = "<html>hello</html>";
@ -482,7 +429,7 @@ int main(int argc, char ** argv) {
}); });
svr.Post("/inference", [&](const Request &req, Response &res){ svr.Post("/inference", [&](const Request &req, Response &res){
// acquire whisper model mutex lock // aquire whisper model mutex lock
whisper_mutex.lock(); whisper_mutex.lock();
// first check user requested fields of the request // first check user requested fields of the request
@ -506,35 +453,20 @@ int main(int argc, char ** argv) {
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// write to temporary file // write file to temporary file
const std::string temp_filename = "whisper_server_temp_file.wav"; std::ofstream temp_file{filename, std::ios::binary};
std::ofstream temp_file{temp_filename, std::ios::binary};
temp_file << audio_file.content; temp_file << audio_file.content;
temp_file.close();
// if file is not wav, convert to wav
if (sparams.ffmpeg_converter) {
std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
const bool is_converted = convert_to_wav(temp_filename, error_resp);
if (!is_converted) {
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
}
// read wav content into pcmf32 // read wav content into pcmf32
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) { if (!::read_wav(filename, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); fprintf(stderr, "error: failed to read WAV file '%s'\n", filename.c_str());
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
res.set_content(error_resp, "application/json"); res.set_content(error_resp, "application/json");
std::remove(temp_filename.c_str());
whisper_mutex.unlock(); whisper_mutex.unlock();
return; return;
} }
// remove temp file // remove temp file
std::remove(temp_filename.c_str()); std::remove(filename.c_str());
printf("Successfully loaded %s\n", filename.c_str()); printf("Successfully loaded %s\n", filename.c_str());
@ -571,6 +503,7 @@ int main(int argc, char ** argv) {
// run the inference // run the inference
{ {
printf("Running whisper.cpp inference on %s\n", filename.c_str()); printf("Running whisper.cpp inference on %s\n", filename.c_str());
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
@ -589,7 +522,6 @@ int main(int argc, char ** argv) {
wparams.duration_ms = params.duration_ms; wparams.duration_ms = params.duration_ms;
wparams.thold_pt = params.word_thold; wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word; wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
@ -609,7 +541,7 @@ int main(int argc, char ** argv) {
whisper_print_user_data user_data = { &params, &pcmf32s, 0 }; whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment // this callback is called on each new segment
if (params.print_realtime) { if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback; wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data; wparams.new_segment_callback_user_data = &user_data;
} }
@ -659,50 +591,6 @@ int main(int argc, char ** argv) {
std::string results = output_str(ctx, params, pcmf32s); std::string results = output_str(ctx, params, pcmf32s);
res.set_content(results.c_str(), "text/html"); res.set_content(results.c_str(), "text/html");
} }
else if (params.response_format == srt_format)
{
std::stringstream ss;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
ss << i + 1 + params.offset_n << "\n";
ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
ss << speaker << text << "\n\n";
}
res.set_content(ss.str(), "application/x-subrip");
} else if (params.response_format == vtt_format) {
std::stringstream ss;
ss << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
speaker.insert(0, "<v Speaker");
speaker.append(">");
}
ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
ss << speaker << text << "\n\n";
}
res.set_content(ss.str(), "text/vtt");
}
// TODO add more output formats // TODO add more output formats
else else
{ {

View File

@ -18,11 +18,6 @@ if (WHISPER_SDL2)
../../ggml-quants.c ../../ggml-quants.c
../../whisper.cpp) ../../whisper.cpp)
if(WIN32)
# It requires Windows 8.1 or later for PrefetchVirtualMemory
target_compile_definitions(${TARGET} PRIVATE -D_WIN32_WINNT=0x0602)
endif()
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../) target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,40 +0,0 @@
# wchess
Voice-controlled chess using Whisper
Online demo: https://whisper.ggerganov.com/wchess/
https://github.com/ggerganov/whisper.cpp/assets/1991296/c2b2f03c-9684-49f3-8106-357d2d4e67fa
## Command-line tool
```bash
mkdir build && cd build
cmake -DWHISPER_SDL2=1 ..
make -j
./bin/wchess -m ../models/ggml-base.en.bin
Move: start
a b c d e f g h
r n b q k b n r 8
p p p p p p p p 7
. * . * . * . * 6
* . * . * . * . 5
. * . * . * . * 4
* . * . * . * . 3
P P P P P P P P 2
R N B Q K B N R 1
White's turn
[(l)isten/(p)ause/(q)uit]:
```
## TODO
- Improve web-browser audio capture - sometimes it does not record the voice properly
- Add support for more languages by making the generated grammar string multi-lingual
- Fix bugs in the chess moves logic
PRs welcome!

View File

@ -1,19 +1,19 @@
add_library(wchess-core STATIC add_library(libwchess
WChess.cpp WChess.cpp
WChess.h WChess.h
Chessboard.cpp Chessboard.cpp
Chessboard.h Chessboard.h
) )
target_link_libraries(wchess-core target_link_libraries(libwchess
PUBLIC PUBLIC
whisper whisper
common common
) )
target_include_directories(wchess-core target_include_directories(libwchess
PUBLIC PUBLIC
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>" "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
) )
# add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp) add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)

File diff suppressed because it is too large Load Diff

View File

@ -1,33 +1,56 @@
#pragma once #pragma once
#include <string> #include <string>
#include <set> #include <array>
#include <memory> #include <vector>
// just basic validation
// fixme: missing en passant, castling, promotion, etc.
struct State;
class Piece;
class Chessboard { class Chessboard {
public: public:
Chessboard(); Chessboard();
~Chessboard(); std::string process(const std::string& t);
std::string process(const std::string& command);
std::string stringifyBoard(); std::string stringifyBoard();
const std::string& grammar() { return m_grammar; }
const std::string& prompt() { return m_prompt; }
void setPrompt(const std::string& prompt);
private: private:
bool parseCommand(const std::string& command, Piece*& piece, char& pos_to); using Move = std::pair<int, int>;
bool move(Piece& piece, char pos); bool move(const Move& move);
void flagUpdates(char pos_from, char pos_to);
void updatePins(Piece& piece);
void detectChecks();
void setGrammar();
std::unique_ptr<State> m_state; struct Piece {
std::set<char> m_allowedInCheck; enum Types {
bool m_inCheck = false; Pawn,
int m_moveCounter = 0; Knight,
std::string m_grammar; Bishop,
std::string m_prompt; Rook,
Queen,
King,
Taken,
};
enum Colors {
Black,
White
};
Types type;
Colors color;
int pos;
};
Piece::Types tokenToType(std::string_view token);
size_t tokenToPos(std::string_view token);
using PieceSet = std::array<Piece, 16>;
PieceSet blackPieces;
PieceSet whitePieces;
int m_moveCounter = 0;
using Board = std::array<Piece*, 64>;
Board board;
bool validateMove(const Piece& piece, int pos);
// just basic validation
// fixme: missing en passant, castling, promotion, etc.
bool validatePawnMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateKnightMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateBishopMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateRookMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateQueenMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateKingMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
}; };

View File

@ -4,6 +4,24 @@
#include "common.h" #include "common.h"
#include <thread> #include <thread>
static constexpr auto RULES =
"\n"
"root ::= init move move? move? \".\"\n"
"prompt ::= init \".\"\n"
"\n"
"# leading space is very important!\n"
"init ::= \" rook to b4, f3\"\n"
"\n"
"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n"
"\n"
"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n"
"king ::= \"king\"\n"
"pawn ::= \"pawn\"\n"
"\n";
static constexpr auto PROMPT = "rook to b4, f3,";
static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,";
WChess::WChess(whisper_context * ctx, WChess::WChess(whisper_context * ctx,
const whisper_full_params & wparams, const whisper_full_params & wparams,
callbacks cb, callbacks cb,
@ -17,56 +35,95 @@ WChess::WChess(whisper_context * ctx,
WChess::~WChess() = default; WChess::~WChess() = default;
void WChess::set_move(const std::string& moves, float prob) const { void WChess::set_status(const std::string& msg) const {
if (m_cb.set_move) (*m_cb.set_move)(moves, prob); if (m_cb.set_status) (*m_cb.set_status)(msg);
} }
void WChess::set_grammar(const std::string& grammar) const { void WChess::set_moves(const std::string& moves) const {
if (m_cb.set_grammar) (*m_cb.set_grammar)(grammar); if (m_cb.set_moves) (*m_cb.set_moves)(moves);
} }
bool WChess::get_audio(std::vector<float>& pcmf32) const { bool WChess::check_running() const {
if (m_cb.get_audio) return (*m_cb.get_audio)(pcmf32); if (m_cb.check_running) return (*m_cb.check_running)();
return false; return false;
} }
bool WChess::clear_audio() const {
if (m_cb.clear_audio) return (*m_cb.clear_audio)();
return false;
}
void WChess::get_audio(int ms, std::vector<float>& pcmf32) const {
if (m_cb.get_audio) (*m_cb.get_audio)(ms, pcmf32);
}
std::string WChess::stringify_board() const { std::string WChess::stringify_board() const {
return m_board->stringifyBoard(); return m_board->stringifyBoard();
} }
std::string WChess::get_grammar() const {
return m_board->grammar();
}
void WChess::run() { void WChess::run() {
bool have_prompt = true; set_status("loading data ...");
bool ask_prompt = !have_prompt;
bool have_prompt = false;
bool ask_prompt = true;
float logprob_min0 = 0.0f;
float logprob_min = 0.0f; float logprob_min = 0.0f;
float logprob_sum0 = 0.0f;
float logprob_sum = 0.0f; float logprob_sum = 0.0f;
int n_tokens0 = 0;
int n_tokens = 0; int n_tokens = 0;
std::vector<float> pcmf32_cur; std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt; std::vector<float> pcmf32_prompt;
const std::string k_prompt = have_prompt ? "" : "rook to d4, f3"; const std::string k_prompt = PROMPT;
int64_t t_ms = 0; m_wparams.initial_prompt = CONTEXT;
auto grammar_parsed = grammar_parser::parse(RULES);
auto grammar_rules = grammar_parsed.c_rules();
if (grammar_parsed.rules.empty()) {
fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__);
}
else {
m_wparams.grammar_rules = grammar_rules.data();
m_wparams.n_grammar_rules = grammar_rules.size();
}
while (check_running()) {
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) { if (ask_prompt) {
fprintf(stdout, "\n"); fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n"); fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str());
set_status(txt);
}
ask_prompt = false; ask_prompt = false;
} }
while (get_audio(pcmf32_cur)) { int64_t t_ms = 0;
if (!pcmf32_cur.empty()) {
// fprintf(stdout, "%s: Processing ...\n", __func__); {
get_audio(m_settings.vad_ms, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, m_settings.vad_thold, m_settings.freq_thold, m_settings.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
set_status("Speech detected! Processing ...");
if (!have_prompt) { if (!have_prompt) {
get_audio(m_settings.prompt_ms, pcmf32_cur);
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt");
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
@ -82,26 +139,27 @@ void WChess::run() {
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n"); fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ...");
set_status(txt);
}
// save the audio for the prompt // save the audio for the prompt
pcmf32_prompt = pcmf32_cur; pcmf32_prompt = pcmf32_cur;
have_prompt = true; have_prompt = true;
m_board->setPrompt(k_prompt);
} }
} else { } else {
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); get_audio(m_settings.command_ms, pcmf32_cur);
constexpr size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
// fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, m_board->grammar().c_str()); // prepend 3 second of silence
pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f);
auto grammar_parsed = grammar_parser::parse(m_board->grammar().c_str()); // prepend the prompt audio
auto grammar_rules = grammar_parsed.c_rules(); pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
m_wparams.grammar_rules = grammar_rules.data(); m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root");
m_wparams.n_grammar_rules = grammar_rules.size(); const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
const float p = 100.0f * std::exp(logprob_min); const float p = 100.0f * std::exp(logprob_min);
@ -111,6 +169,10 @@ void WChess::run() {
float best_sim = 0.0f; float best_sim = 0.0f;
size_t best_len = 0; size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
if (n >= int(txt.size())) {
break;
}
const auto prompt = txt.substr(0, n); const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt); const float sim = similarity(prompt, k_prompt);
@ -129,23 +191,18 @@ void WChess::run() {
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n"); fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms);
set_status(txt);
}
if (!command.empty()) { if (!command.empty()) {
set_move(m_board->process(command), p); set_moves(m_board->process(command));
set_grammar(m_board->grammar());
}
if (m_board->grammar().empty()) {
fprintf(stdout, "%s: No more moves possible\n", __func__);
break;
}
} }
} }
if (ask_prompt) { clear_audio();
fprintf(stdout, "\n"); }
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
} }
} }
} }

View File

@ -8,16 +8,18 @@ class Chessboard;
class WChess { class WChess {
public: public:
using SetStatusCb = void (*)(const std::string &);
using CheckRunningCb = bool (*)(); using CheckRunningCb = bool (*)();
using GetAudioCb = bool (*)(std::vector<float> &); using GetAudioCb = void (*)(int, std::vector<float> &);
using SetMovesCb = void (*)(const std::string &, float); using SetMovesCb = void (*)(const std::string &);
using SetGrammarCb = void (*)(const std::string &); using CleartAudioCb = bool (*)();
using ClearAudioCb = void (*)();
struct callbacks { struct callbacks {
SetStatusCb set_status = nullptr;
CheckRunningCb check_running = nullptr;
GetAudioCb get_audio = nullptr; GetAudioCb get_audio = nullptr;
SetMovesCb set_move = nullptr; SetMovesCb set_moves = nullptr;
SetGrammarCb set_grammar = nullptr; CleartAudioCb clear_audio = nullptr;
}; };
struct settings { struct settings {
@ -38,16 +40,13 @@ public:
~WChess(); ~WChess();
void run(); void run();
std::string stringify_board() const; std::string stringify_board() const;
std::string get_grammar() const;
private: private:
bool get_audio(std::vector<float>& pcmf32) const; void get_audio(int ms, std::vector<float>& pcmf32) const;
void set_move(const std::string& moves, float prob) const; void set_status(const std::string& msg) const;
void set_grammar(const std::string& grammar) const; void set_moves(const std::string& moves) const;
bool check_running() const;
bool clear_audio() const;
std::string transcribe( std::string transcribe(
const std::vector<float> & pcmf32, const std::vector<float> & pcmf32,
float & logprob_min, float & logprob_min,

View File

@ -11,107 +11,78 @@
int main() { int main() {
{
Chessboard chess;
ASSERT(chess.process("pawn to d4") == "d2-d4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("c1 h6") == "c1-h6");
ASSERT(chess.process("queen h4") == "d8-h4");
ASSERT(chess.process("bishop to g5") == "h6-g5");
ASSERT(chess.process("bishop to b4") == "f8-b4");
ASSERT(chess.process("c4") == "");
ASSERT(chess.process("knight c3") == "b1-c3");
ASSERT(chess.process("knight c6") == "b8-c6");
ASSERT(chess.process("f3") == "");
}
{ {
// pawns
Chessboard chess; Chessboard chess;
ASSERT(chess.process("d4") == "d2-d4"); ASSERT(chess.process("pawn to d4, e5, e3, pawn to d5") == "d2-d4 e7-e5 e2-e3 d7-d5");
ASSERT(chess.process("e5") == "e7-e5"); ASSERT(chess.process("pawn to d4") == ""); // wrong
ASSERT(chess.process("e4") == "e2-e4"); ASSERT(chess.process("pawn to c5") == ""); // wrong
ASSERT(chess.process("queen h4") == "d8-h4"); ASSERT(chess.process("pawn to d5") == ""); // wrong
ASSERT(chess.process("queen h5") == "d1-h5"); ASSERT(chess.process("pawn to d3") == ""); // wrong
ASSERT(chess.process("f5") == ""); ASSERT(chess.process("pawn to f5") == ""); // wrong, white's turn
ASSERT(chess.process("g6") == "g7-g6");
ASSERT(chess.process("knight e2") == "g1-e2");
ASSERT(chess.process("f5") == "f7-f5");
ASSERT(chess.process("knight g3") == "e2-g3");
ASSERT(chess.process("g5") == "");
ASSERT(chess.process("king e7") == "e8-e7");
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("g5") == "g6-g5");
}
{
Chessboard chess;
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("c5") == "c7-c5");
ASSERT(chess.process("e5") == "e4-e5");
ASSERT(chess.process("c4") == "c5-c4");
ASSERT(chess.process("e6") == "e5-e6");
ASSERT(chess.process("c3") == "c4-c3");
ASSERT(chess.process("e7") == "");
ASSERT(chess.process("f7") == "e6-f7");
ASSERT(chess.process("d2") == "");
ASSERT(chess.process("king to f7") == "e8-f7");
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("d2") == "c3-d2");
ASSERT(chess.process("f5") == "");
ASSERT(chess.process("king to e2") == "e1-e2");
ASSERT(chess.process("king to g6") == "f7-g6");
ASSERT(chess.process("f5") == "f4-f5");
ASSERT(chess.process("e6") == "");
ASSERT(chess.process("king to h5") == "g6-h5");
ASSERT(chess.process("g4") == "g2-g4");
ASSERT(chess.process("king to g5") == "h5-g5");
ASSERT(chess.process("h4") == "h2-h4"); ASSERT(chess.process("h4") == "h2-h4");
ASSERT(chess.process("king to h5") == ""); ASSERT(chess.process("d4") == "e5-d4");
ASSERT(chess.process("king to g6") == ""); ASSERT(chess.process("e4") == "e3-e4");
ASSERT(chess.process("king to h6") == "g5-h6"); ASSERT(chess.process("d4") == ""); // wrong
ASSERT(chess.process("bishop to d2") == "c1-d2"); ASSERT(chess.process("e4") == "d5-e4");
ASSERT(chess.process("king to g5") == "");
ASSERT(chess.process("g5") == "g7-g5");
} }
{ {
// rook
Chessboard chess; Chessboard chess;
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("e5") == "e7-e5"); ASSERT(chess.process("rook to a3") == ""); // wrong
ASSERT(chess.process("g4") == "g2-g4"); ASSERT(chess.process("a4, h5, rook to a3, rook to h6") == "a2-a4 h7-h5 a1-a3 h8-h6");
ASSERT(chess.process("queen to h4") == "d8-h4#"); ASSERT(chess.process("rook to d3, rook to e6") == "a3-d3 h6-e6");
ASSERT(chess.process("knight f3") == ""); ASSERT(chess.process("rook to d4, rook to e5") == "d3-d4 e6-e5");
ASSERT(chess.grammar().empty()); ASSERT(chess.process("rook to a4") == ""); // wrong
ASSERT(chess.process("rook to d8") == ""); // wrong
ASSERT(chess.process("rook to d3") == "d4-d3");
ASSERT(chess.process("rook to e2") == "e5-e2");
} }
{ {
// knight
Chessboard chess; Chessboard chess;
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("e5") == "e7-e5"); ASSERT(chess.process("knight to c3, knight to c6") == "b1-c3 b8-c6");
ASSERT(chess.process("g4") == "g2-g4"); ASSERT(chess.process("knight to c3") == ""); // wrong
ASSERT(chess.process("d5") == "d7-d5"); ASSERT(chess.process("knight to a2") == ""); // wrong
ASSERT(chess.process("g1 f3") == "g1-f3"); ASSERT(chess.process("knight to b4") == ""); // wrong, white's turn
ASSERT(chess.process("queen to h4") == "d8-h4"); ASSERT(chess.process("knight to b5") == "c3-b5");
ASSERT(!chess.grammar().empty()); ASSERT(chess.process("knight to a5") == "c6-a5");
ASSERT(chess.process("knight to c7") == "b5-c7");
} }
{ {
// bishop
Chessboard chess; Chessboard chess;
ASSERT(chess.process("knight c3") == "b1-c3");
ASSERT(chess.process("knight c6") == "b8-c6"); ASSERT(chess.process("b3, b6, bishop to b2, bishop to b7") == "b2-b3 b7-b6 c1-b2 c8-b7");
ASSERT(chess.process("knight b5") == "c3-b5"); ASSERT(chess.process("bishop to a1") == ""); // wrong
ASSERT(chess.process("knight f6") == "g8-f6"); ASSERT(chess.process("bishop to h8") == ""); // wrong
ASSERT(chess.process("knight d6") == "b5-d6"); ASSERT(chess.process("bishop to a6") == ""); // wrong, white's turn
ASSERT(chess.process("knight d4") == ""); ASSERT(chess.process("bishop to g7") == "b2-g7");
ASSERT(chess.process("d6") == "c7-d6"); }
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("knight d4") == "c6-d4"); {
ASSERT(chess.process("d3") == "d2-d3"); // queen
ASSERT(chess.process("knight e4") == "f6-e4"); Chessboard chess;
ASSERT(chess.process("king to e2") == ""); ASSERT(chess.process("queen to d8") == ""); // wrong
ASSERT(chess.process("king to d2") == ""); ASSERT(chess.process("queen to f1") == ""); // wrong
ASSERT(chess.process("queen to h5") == ""); // wrong
ASSERT(chess.process("e3, d5, queen to h5, queen to d6") == "e2-e3 d7-d5 d1-h5 d8-d6");
ASSERT(chess.process("queen to c5") == ""); // wrong, white's turn
ASSERT(chess.process("queen to f7") == "h5-f7");
}
{
// king
Chessboard chess;
ASSERT(chess.process("d3, d6, king to d2, king to d7, king to c3, king to c6, king to c4") == "d2-d3 d7-d6 e1-d2 e8-d7 d2-c3 d7-c6 c3-c4");
ASSERT(chess.process("bishop to e6") == "c8-e6");
ASSERT(chess.process("king to b3") == "c4-b3"); // !! check check not implemented
} }
} }

View File

@ -4,5 +4,5 @@ if (WHISPER_SDL2)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE wchess-core common-sdl ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE libwchess common-sdl ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -7,7 +7,6 @@
#include "WChess.h" #include "WChess.h"
#include "common-sdl.h" #include "common-sdl.h"
#include <iostream>
#include <memory> #include <memory>
#include <thread> #include <thread>
@ -110,62 +109,18 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
} }
std::unique_ptr<WChess> g_wchess; std::unique_ptr<WChess> g_wchess;
int g_moveCount = 0; void set_moves(const std::string & moves) {
void set_move(const std::string & move, float) { if (!moves.empty()) fprintf(stdout, "%s", g_wchess->stringify_board().c_str());
if (!move.empty()) {
g_moveCount++;
fprintf(stdout, "Move: %s\n\n", move.c_str());
}
else fprintf(stdout, "Move rejected\n\n");
fprintf(stdout, "%s\n", g_wchess->stringify_board().c_str());
fprintf(stdout, "%s\n", g_moveCount ? "White's turn" : "Black's turn");
} }
audio_async g_audio(30*1000); audio_async g_audio(30*1000);
bool g_listening = false; void get_audio(int ms, std::vector<float> & pcmf32_cur) {
std::vector<float> g_pcmf32; g_audio.get(ms, pcmf32_cur);
bool read_input() {
std::string input;
while (true) {
fprintf(stdout, "[(l)isten/(p)ause/(q)uit]: ");
std::cin >> input;
fprintf(stdout, "\n");
if (input[0] == 'q') {
fprintf(stdout, "Quitting\n");
return false;
} }
if (input[0] == 'l') {
if (!g_listening) { bool clear_audio() {
fprintf(stdout, "Listening\n");
g_listening = true;
g_pcmf32.clear();
g_audio.resume();
g_audio.clear(); g_audio.clear();
} }
else fprintf(stdout, "Still listening\n");
return true;
}
else {
if (g_listening) {
g_listening = false;
g_audio.get(0, g_pcmf32);
g_audio.pause();
fprintf(stdout, "Processing\n");
}
else fprintf(stdout, "Not listening\n");
return true;
}
}
return true;
}
bool get_audio(std::vector<float> & pcmf32_cur) {
if (!read_input()) return false;
if (!g_pcmf32.empty()) pcmf32_cur = std::move(g_pcmf32);
else pcmf32_cur.clear();
return true;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
whisper_params params; whisper_params params;
@ -186,10 +141,6 @@ int main(int argc, char ** argv) {
cparams.use_gpu = params.use_gpu; cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (!ctx) {
fprintf(stderr, "%s: whisper_init_from_file_with_params() failed!\n", __func__);
return 1;
}
// init audio // init audio
@ -198,35 +149,42 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false; wparams.print_progress = false;
wparams.print_timestamps = true; wparams.print_special = params.print_special;
wparams.print_special = false; wparams.print_realtime = false;
wparams.no_timestamps = true; wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.no_timestamps = params.no_timestamps;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.max_tokens = 32; wparams.audio_ctx = params.audio_ctx;
wparams.audio_ctx = 768; // partial encoder context for better performance wparams.speed_up = params.speed_up;
wparams.temperature = 0.0f; wparams.temperature = 0.4f;
wparams.temperature_inc = 2.0f; wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 1; wparams.greedy.best_of = 5;
wparams.beam_search.beam_size = 1; wparams.beam_search.beam_size = 5;
wparams.language = "en";
wparams.grammar_penalty = 100.0;
wparams.initial_prompt = params.context.data(); wparams.initial_prompt = params.context.data();
g_audio.resume();
// wait for 1 second to avoid any buffered noise
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
g_audio.clear();
WChess::callbacks cb; WChess::callbacks cb;
cb.check_running = sdl_poll_events;
cb.get_audio = get_audio; cb.get_audio = get_audio;
cb.set_move = set_move; cb.set_moves = set_moves;
cb.clear_audio = clear_audio;
WChess::settings s; WChess::settings s;
s.vad_ms = 2000; s.vad_ms = 2000;
@ -237,9 +195,11 @@ int main(int argc, char ** argv) {
s.print_energy = params.print_energy; s.print_energy = params.print_energy;
g_wchess.reset(new WChess(ctx, wparams, cb, s)); g_wchess.reset(new WChess(ctx, wparams, cb, s));
set_move("start", 0); set_moves("start");
g_wchess->run(); g_wchess->run();
g_audio.pause();
whisper_print_timings(ctx); whisper_print_timings(ctx);
whisper_free(ctx); whisper_free(ctx);

View File

@ -8,7 +8,7 @@ include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
common common
wchess-core libwchess
) )
unset(EXTRA_FLAGS) unset(EXTRA_FLAGS)

View File

@ -1,11 +1,7 @@
<!doctype html> <!doctype html>
<html lang="en-us"> <html lang="en-us">
<head> <head>
<title>wchess : voice-controlled chess using Whisper + WebAssembly</title> <title>wchess : Voice assistant example using Whisper + WebAssembly</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
<meta name="viewport" content="width=device-width, initial-scale=0.7, maximum-scale=1, minimum-scale=0.7, user-scalable=no"/>
<meta name="apple-mobile-web-app-capable" content="yes" />
<style> <style>
#output { #output {
@ -27,95 +23,16 @@
overflow-wrap: normal; overflow-wrap: normal;
overflow-x: scroll; overflow-x: scroll;
} }
.button {
background-color: #000000;
color: #FFFFFF;
padding: 20px;
border-radius: 10px;
-moz-border-radius: 10px;
-webkit-border-radius: 10px;
margin:10px;
width: 100px;
height: 50px;
-webkit-touch-callout: none; /* Safari */
-webkit-user-select: none; /* Chrome */
-moz-user-select: none; /* Firefox */
-ms-user-select: none; /* Internet Explorer/Edge */
user-select: none;
}
button[disabled]{
background-color: #cccccc;
color: #666666;
padding: 20px;
border-radius: 10px;
-moz-border-radius: 10px;
-webkit-border-radius: 10px;
margin:10px;
width: 100px;
}
.center {
display: flex;
justify-content: center;
align-items: center;
width: 500px;
}
#description {
width: 500px;
}
</style> </style>
<link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous"> <link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous">
</head> </head>
<body> <body onload="loadWhisper()">
<div id="main-container"> <div id="main-container">
<div id="description"> <b>wchess : Voice assistant example using Whisper + WebAssembly</b>
<b>wchess : voice-controlled chess using Whisper + WebAssembly</b>
<br><br> <br><br>
This is a demonstration of using Whisper to recognize voice commands in the browser. You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">GitHub</a>.
<br><br>
Usage:<br>
<ul>
<li>Select a Whisper model</li>
<li>Accept the microphone permission request if prompted</li>
<li>Hold the button and say a chess move (e.g. "Knight to c3")</li>
<li>Release the button and wait for the move to be recognized</li>
<li>Repeat</li>
</ul>
Examples:<br>
<ul>
<li><b>"d4"</b></li>
<li><b>"e2 e4"</b></li>
<li><b>"Knight f3"</b></li>
<li><b>"Bishop to b5"</b></li>
</ul>
Features:<br>
<ul>
<li>Model quantization for reduced memory footprint (~42MB)</li>
<li><a href="https://github.com/ggerganov/whisper.cpp/pull/1229">Grammar-based sampling</a> for improved recognition accuracy</li>
</ul>
<b>
Note that not all chess moves are supported. For example, castling and pawn promotion
currently do not work, but can be easily implemented. There could also be some bugs in
the move handling logic in general. The main reason for that is to keep the implementation
simple. The assumption is that a real application would already have a proper move
validation logic in place.<br><br>
The main purpose of this example is to demonstrate the capabilities of whisper.cpp and
its application in the browser for voice recognition locally on your device.
</b>
<br><br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/wchess">GitHub</a>.
<br><br> <br><br>
@ -128,45 +45,40 @@
<br><br> <br><br>
</div>
<hr> <hr>
<div id="model-whisper"> <div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span> Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper()">tiny.en (Q8_0, 42 MB)</button>
<span id="fetch-whisper-progress"></span> <span id="fetch-whisper-progress"></span>
<br><br>
<button id="clear" onclick="clearCache()">Clear browser cache</button>
<!-- <!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" /> <input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
--> -->
</div> </div>
<div id="game">
<br> <br>
<div id="chessboard" style="width: 500px"></div> <div id="myBoard" style="width: 400px"></div>
<script src="js/jquery-3.7.1.min.js"></script> <script src="js/jquery-3.7.1.min.js"></script>
<script src="js/chessboard-1.0.0.min.js"></script> <script src="js/chessboard-1.0.0.min.js"></script>
<script> <script>
var board = Chessboard('chessboard', 'start') var board = Chessboard('myBoard', 'start')
var move_count = 0;
</script> </script>
<br> <br>
<div id="state"> <div id="input">
Status: <b><span id="state-status">select model</span></b> <button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<div id="input" class="center"> <button id="clear" onclick="clearCache()">Clear Cache</button>
<button id="toggler" class="button" onselectstart="return false" style="display: none">Hold</button>
</div> </div>
<pre id="state-grammar">[The grammar will be displayed here]</pre> <br>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-moves">[The moves will be displayed here]</pre> <pre id="state-moves">[The moves will be displayed here]</pre>
</div> </div>
</div>
<hr> <hr>
@ -183,6 +95,7 @@
<ul> <ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li> <li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li> <li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul> </ul>
@ -202,14 +115,15 @@
// web audio context // web audio context
var context = null; var context = null;
// audio data
var audio = null;
var audio0 = null;
// the command instance // the command instance
var instance = null; var instance = null;
// model name // model name
var model_whisper = null; var model_whisper = null;
var model_file = null;
var module_ready = null;
var Module = { var Module = {
print: printTextarea, print: printTextarea,
@ -223,30 +137,10 @@
printTextarea('js: Preparing ...'); printTextarea('js: Preparing ...');
}, },
postRun: function() { postRun: function() {
printTextarea('js: Module initialized successfully!'); printTextarea('js: Initialized successfully!');
module_ready = true;
initInstance();
} }
}; };
function initInstance() {
if (!module_ready || !model_file || instance) return
instance = Module.init(model_file);
if (instance) {
setStatus('Ready');
printTextarea("js: whisper initialized, instance: " + instance);
}
else {
printTextarea("js: failed to initialize whisper");
}
}
function setStatus(text) {
document.getElementById('state-status').innerHTML = text;
}
// //
// fetch models // fetch models
// //
@ -270,21 +164,36 @@
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!'; document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
model_file = fname; if (model_whisper != null) {
initInstance(); document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = true;
}
} }
function loadWhisper() { function loadWhisper() {
setStatus('Loading') // let urls = {
//let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin'; // 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
let url = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin'; // 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
let dst = 'whisper.bin';
let size_mb = 42;
model_whisper = 'tiny.en-q8_0'; // 'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
// 'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
// };
// let sizes = {
// 'tiny.en': 75,
// 'base.en': 142,
// 'tiny-en-q5_1': 31,
// 'base-en-q5_1': 57,
// };
let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin';
let dst = 'whisper.bin';
let size_mb = 75;
model_whisper = 'tiny.en';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... '; document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... ';
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
cbProgress = function(p) { cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress'); let el = document.getElementById('fetch-whisper-progress');
@ -297,30 +206,6 @@
}; };
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea); loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
// init audio capture so that the user receives a permission request
{
let context = new AudioContext({
sampleRate: 16000,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
stream.getTracks().forEach(function(track) {
track.stop();
});
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
context.close();
}
document.getElementById('toggler').style.display = 'block';
} }
// //
@ -339,9 +224,11 @@
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext; window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() { function stopRecording() {
if (mediaRecorder) { Module.set_status("paused");
mediaRecorder.stop(); doRecording = false;
} audio0 = null;
audio = null;
context = null;
} }
function startRecording() { function startRecording() {
@ -355,6 +242,12 @@
}); });
} }
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
doRecording = true;
startTime = Date.now(); startTime = Date.now();
var chunks = []; var chunks = [];
@ -372,6 +265,10 @@
reader.onload = function(event) { reader.onload = function(event) {
var buf = new Uint8Array(reader.result); var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) { context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate); var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource(); var source = offlineContext.createBufferSource();
@ -380,13 +277,22 @@
source.start(0); source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) { offlineContext.startRendering().then(function(renderedBuffer) {
let audio = renderedBuffer.getChannelData(0); audio = renderedBuffer.getChannelData(0);
printTextarea('js: number of samples: ' + audio.length);
Module.set_audio(instance, audio);
});
mediaRecorder = null; //printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
context = null;
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
if (instance) {
Module.set_audio(instance, audioAll);
}
});
}, function(e) {
audio = null;
}); });
} }
@ -394,16 +300,48 @@
}; };
mediaRecorder.onstop = function(e) { mediaRecorder.onstop = function(e) {
stream.getTracks().forEach(function(track) { if (doRecording) {
track.stop(); setTimeout(function() {
startRecording();
}); });
}
}; };
mediaRecorder.start(); mediaRecorder.start(kIntervalAudio_ms);
}) })
.catch(function(err) { .catch(function(err) {
printTextarea('js: error getting audio stream: ' + err); printTextarea('js: error getting audio stream: ' + err);
}); });
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
} }
// //
@ -411,65 +349,35 @@
// //
var nLines = 0; var nLines = 0;
var intervalUpdate = null;
var movesAll = ''; var movesAll = '';
// document.body.addEventListener('keydown', function(event) {
// if (event.keyCode === 32) {
// document.getElementById('toggler').innerText = "";
// onStart();
// }
// }, true);
// document.body.addEventListener('keyup', function(event) {
// if (event.keyCode === 32) {
// document.getElementById('toggler').innerText = "Hold";
// onStop();
// }
// }, true);
document.getElementById('toggler').addEventListener("touchstart", function(event){
this.innerText = "";
onStart();
}, true);
document.getElementById('toggler').addEventListener("touchend", function(event){
this.innerText = "Hold";
onStop();
}, true)
document.getElementById('toggler').addEventListener('mousedown', function(event) {
this.innerText = "";
onStart();
}, true);
document.getElementById('toggler').addEventListener('mouseup', function(event) {
this.innerText = "Hold";
onStop();
}, true);
function onStart() { function onStart() {
if (!instance) return; if (!instance) {
setStatus('Listening'); instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording(); startRecording();
}
function onStop() { intervalUpdate = setInterval(function() {
setStatus('Processing'); var moves = Module.get_moves();
printTextarea('js: stopping recording ...');
stopRecording();
}
function setMove(move, prob) { if (moves != null && moves.length > 1) {
if (move != null && move.length > 1) {
let gameOver = move[move.length - 1] === '#'; for (move of moves.split(' ')) {
if (gameOver) {
move = move.substring(0, move.length - 1);
document.getElementById('toggler').disabled = true;
}
board.move(move); board.move(move);
}
movesAll += move + ', prob = ' + prob.toFixed(2) + '% <br>'; movesAll += moves + '<br>';
nLines++; nLines++;
// if more than 10 lines, remove the first line // if more than 10 lines, remove the first line
@ -480,17 +388,15 @@
nLines--; nLines--;
} }
} }
++move_count;
setStatus(gameOver ? 'Done' : move_count % 2 ? 'Black\'s turn' : 'White\'s turn');
document.getElementById('state-moves').innerHTML = movesAll;
}
else {
setStatus('Failed. ' + (move_count % 2 ? 'Black\'s turn' : 'White\'s turn'));
}
} }
function setGrammar(grammar) { document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-grammar').innerHTML = grammar; document.getElementById('state-moves').innerHTML = movesAll;
}, 100);
}
function onStop() {
stopRecording();
} }
</script> </script>

View File

@ -1,7 +1,7 @@
#include <WChess.h> #include <WChess.h>
#include <emscripten.h>
#include <emscripten/bind.h> #include <emscripten/bind.h>
#include <atomic>
#include <thread> #include <thread>
constexpr int N_THREAD = 8; constexpr int N_THREAD = 8;
@ -11,28 +11,45 @@ std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex; std::mutex g_mutex;
std::thread g_worker; std::thread g_worker;
std::condition_variable g_cv; std::atomic<bool> g_running(false);
std::string g_status = "";
std::string g_status_forced = "";
std::string g_moves = "";
bool g_running(false);
std::vector<float> g_pcmf32; std::vector<float> g_pcmf32;
void set_move(const std::string & move, float prob) { void set_status(const std::string & status) {
MAIN_THREAD_EM_ASM({ std::lock_guard<std::mutex> lock(g_mutex);
setMove(UTF8ToString($0), $1) g_status = status;
}, move.c_str(), prob);
} }
void set_grammar(const std::string & grammar) { void set_moves(const std::string & moves) {
MAIN_THREAD_EM_ASM({ std::lock_guard<std::mutex> lock(g_mutex);
setGrammar(UTF8ToString($0)) g_moves = moves;
}, grammar.c_str());
} }
bool get_audio(std::vector<float> & audio) { void get_audio(int ms, std::vector<float> & audio) {
std::unique_lock<std::mutex> lock(g_mutex); const int64_t n_samples = (ms * WHISPER_SAMPLE_RATE) / 1000;
g_cv.wait(lock, [] { return !g_running || !g_pcmf32.empty(); });
if (!g_running) return false; int64_t n_take = 0;
audio = std::move(g_pcmf32); if (n_samples > (int) g_pcmf32.size()) {
n_take = g_pcmf32.size();
} else {
n_take = n_samples;
}
audio.resize(n_take);
std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin());
}
bool check_running() {
//g_pcmf32.clear();
return g_running;
}
bool clear_audio() {
g_pcmf32.clear();
return true; return true;
} }
@ -48,31 +65,30 @@ void wchess_main(size_t i) {
wparams.print_progress = false; wparams.print_progress = false;
wparams.print_timestamps = true; wparams.print_timestamps = true;
wparams.print_special = false; wparams.print_special = false;
wparams.no_timestamps = true;
wparams.max_tokens = 32; wparams.max_tokens = 32;
wparams.audio_ctx = 1280; // partial encoder context for better performance // wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.temperature = 0.0f; wparams.temperature = 0.4f;
wparams.temperature_inc = 2.0f; wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 1; wparams.greedy.best_of = 1;
wparams.beam_search.beam_size = 1; wparams.beam_search.beam_size = 5;
wparams.language = "en"; wparams.language = "en";
wparams.grammar_penalty = 100.0; wparams.grammar_penalty = 100.0;
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
printf("command: using %d threads\n", wparams.n_threads); printf("command: using %d threads\n", wparams.n_threads);
WChess::callbacks cb; WChess::callbacks cb;
cb.set_status = set_status;
cb.check_running = check_running;
cb.get_audio = get_audio; cb.get_audio = get_audio;
cb.set_move = set_move; cb.set_moves = set_moves;
cb.set_grammar = set_grammar; cb.clear_audio = clear_audio;
WChess(g_contexts[i], wparams, cb, {}).run(); WChess(g_contexts[i], wparams, cb, {}).run();
if (i < g_contexts.size()) { if (i < g_contexts.size()) {
whisper_free(g_contexts[i]); whisper_free(g_contexts[i]);
g_contexts[i] = nullptr; g_contexts[i] = nullptr;
@ -104,11 +120,9 @@ EMSCRIPTEN_BINDINGS(command) {
})); }));
emscripten::function("free", emscripten::optional_override([](size_t /* index */) { emscripten::function("free", emscripten::optional_override([](size_t /* index */) {
{ if (g_running) {
std::unique_lock<std::mutex> lock(g_mutex);
g_running = false; g_running = false;
} }
g_cv.notify_one();
})); }));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) { emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
@ -134,8 +148,37 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n); emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio); memoryView.call<void>("set", audio);
} }
g_cv.notify_one();
return 0; return 0;
})); }));
emscripten::function("get_moves", emscripten::optional_override([]() {
std::string moves;
{
std::lock_guard<std::mutex> lock(g_mutex);
moves = std::move(g_moves);
}
if (!moves.empty()) fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", moves.c_str(), "\033[0m");
return moves;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}));
} }

View File

@ -206,7 +206,6 @@ void AudioInputCallback(void * inUserData,
params.offset_ms = 0; params.offset_ms = 0;
params.no_context = true; params.no_context = true;
params.single_segment = self->stateInp.isRealtime; params.single_segment = self->stateInp.isRealtime;
params.no_timestamps = params.single_segment;
CFTimeInterval startTime = CACurrentMediaTime(); CFTimeInterval startTime = CACurrentMediaTime();

View File

@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
#ifdef GGML_ALLOCATOR_DEBUG #ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor); add_allocated_tensor(alloc, tensor);
size_t cur_max = (char*)addr - (char*)alloc->base + size; size_t cur_max = (char*)addr - (char*)alloc->data + size;
if (cur_max > alloc->max_size) { if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) { for (int i = 0; i < 1024; i++) {
@ -168,6 +168,10 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
size = aligned_offset(NULL, size, alloc->alignment); size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks); AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
if (!alloc->measure) {
ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
}
#ifdef GGML_ALLOCATOR_DEBUG #ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor); remove_allocated_tensor(alloc, tensor);
#endif #endif
@ -233,7 +237,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
} }
ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) { ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size); struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr)); ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
@ -445,6 +449,7 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) { static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
ggml_tallocr_t alloc = node_tallocr(galloc, view); ggml_tallocr_t alloc = node_tallocr(galloc, view);
//printf("init_view: %s from src %s\n", view->name, view->view_src->name);
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL); GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
if (update_backend) { if (update_backend) {
view->backend = view->view_src->backend; view->backend = view->view_src->backend;
@ -454,7 +459,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft); assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
if (!alloc->measure) { if (!alloc->measure) {
ggml_backend_buffer_init_tensor(alloc->buffer, view); ggml_backend_buffer_init_tensor(alloc->buffer, view);
@ -760,43 +765,3 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) { size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph); return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
} }
// utils
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
size_t alignment = ggml_backend_buft_get_alignment(buft);
size_t nbytes = 0;
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->data == NULL && t->view_src == NULL) {
nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
}
}
if (nbytes == 0) {
fprintf(stderr, "%s: no tensors to allocate\n", __func__);
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->data == NULL) {
if (t->view_src == NULL) {
ggml_tallocr_alloc(tallocr, t);
} else {
ggml_backend_view_init(buffer, t);
}
}
}
ggml_tallocr_free(tallocr);
return buffer;
}
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
}

View File

@ -8,7 +8,6 @@ extern "C" {
struct ggml_backend; struct ggml_backend;
struct ggml_backend_buffer; struct ggml_backend_buffer;
struct ggml_backend_buffer_type;
// //
// Legacy API // Legacy API
@ -43,7 +42,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
// ggml-backend v2 API // ggml-backend v2 API
// //
// Separate tensor and graph allocator objects // Seperate tensor and graph allocator objects
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
// The original API is kept as a wrapper around the new API // The original API is kept as a wrapper around the new API
@ -81,12 +80,6 @@ GGML_API void ggml_gallocr_alloc_graph_n(
struct ggml_hash_set hash_set, struct ggml_hash_set hash_set,
ggml_tallocr_t * hash_node_talloc); ggml_tallocr_t * hash_node_talloc);
// Utils
// Create a buffer and allocate all the tensors in a ggml_context
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -12,50 +12,31 @@ extern "C" {
// Backend buffer // Backend buffer
// //
// buffer type
typedef void * ggml_backend_buffer_type_context_t;
struct ggml_backend_buffer_type_i {
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
};
struct ggml_backend_buffer_type {
struct ggml_backend_buffer_type_i iface;
ggml_backend_buffer_type_context_t context;
};
// buffer
typedef void * ggml_backend_buffer_context_t; typedef void * ggml_backend_buffer_context_t;
struct ggml_backend_buffer_i { struct ggml_backend_buffer_i {
void (*free_buffer) (ggml_backend_buffer_t buffer); void (*free_buffer) (ggml_backend_buffer_t buffer);
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
void * (*get_base) (ggml_backend_buffer_t buffer); size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
}; };
struct ggml_backend_buffer { struct ggml_backend_buffer {
struct ggml_backend_buffer_i iface; struct ggml_backend_buffer_i iface;
ggml_backend_buffer_type_t buft;
ggml_backend_t backend;
ggml_backend_buffer_context_t context; ggml_backend_buffer_context_t context;
size_t size; size_t size;
}; };
ggml_backend_buffer_t ggml_backend_buffer_init( GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
ggml_backend_buffer_type_t buft, struct ggml_backend * backend,
struct ggml_backend_buffer_i iface, struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context, ggml_backend_buffer_context_t context,
size_t size); size_t size);
// //
// Backend // Backend
// //
@ -68,18 +49,21 @@ extern "C" {
void (*free)(ggml_backend_t backend); void (*free)(ggml_backend_t backend);
// buffer allocation // buffer allocation
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend); ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
// (optional) asynchroneous tensor data access // get buffer alignment
size_t (*get_alignment)(ggml_backend_t backend);
// tensor data access
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) asynchroneous tensor copy
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*synchronize) (ggml_backend_t backend); void (*synchronize) (ggml_backend_t backend);
// (optional) copy tensor between different backends, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
// compute graph with a plan // compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
@ -98,15 +82,6 @@ extern "C" {
ggml_backend_context_t context; ggml_backend_context_t context;
}; };
//
// Backend registry
//
typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

File diff suppressed because it is too large Load Diff

View File

@ -7,44 +7,41 @@
extern "C" { extern "C" {
#endif #endif
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
// //
// Backend buffer // Backend buffer
// //
// buffer type struct ggml_backend_buffer;
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
// buffer // backend buffer functions
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer); GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
// //
// Backend // Backend
// //
struct ggml_backend;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
GGML_API const char * ggml_backend_name(ggml_backend_t backend); GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend); GGML_API void ggml_backend_free(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
@ -60,7 +57,6 @@ extern "C" {
// tensor copy between different backends // tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
// //
// CPU backend // CPU backend
@ -72,23 +68,8 @@ extern "C" {
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
// Create a backend buffer from an existing pointer // Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
//
// Backend registry
//
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
GGML_API size_t ggml_backend_reg_get_count(void);
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
GGML_API const char * ggml_backend_reg_get_name(size_t i);
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
// //
// Backend scheduler // Backend scheduler
@ -150,32 +131,6 @@ extern "C" {
ggml_backend_sched_t sched, ggml_backend_sched_t sched,
struct ggml_cgraph * graph); struct ggml_cgraph * graph);
//
// Utils
//
struct ggml_backend_graph_copy {
ggml_backend_buffer_t buffer;
struct ggml_context * ctx_allocated;
struct ggml_context * ctx_unallocated;
struct ggml_cgraph * graph;
};
// Copy a graph to a different backend
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
// Tensor initialization
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

File diff suppressed because it is too large Load Diff

View File

@ -49,15 +49,7 @@ GGML_API int ggml_cuda_get_device_count(void);
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
// backend API // backend API
GGML_API ggml_backend_t ggml_backend_cuda_init(int device); GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -232,7 +232,7 @@ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key); size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full // returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key); size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
// return index, asserts if table is full // return index, asserts if table is full

View File

@ -99,12 +99,6 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// helper to check if the device supports a specific family
// ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
#ifdef __cplusplus #ifdef __cplusplus
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +1,20 @@
#include "ggml.h"
#include "ggml-opencl.h" #include "ggml-opencl.h"
#include <array> #include <array>
#include <atomic> #include <atomic>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <limits>
#define CL_TARGET_OPENCL_VERSION 110 #define CL_TARGET_OPENCL_VERSION 110
#include <clblast.h> #include <clblast.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "ggml.h"
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif

View File

@ -19,7 +19,7 @@
#ifdef __wasm_simd128__ #ifdef __wasm_simd128__
#include <wasm_simd128.h> #include <wasm_simd128.h>
#else #else
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__) #ifdef __POWER9_VECTOR__
#include <altivec.h> #include <altivec.h>
#undef bool #undef bool
#define bool _Bool #define bool _Bool
@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
size_t vl = __riscv_vsetvl_e8m1(qk/2); size_t vl = __riscv_vsetvl_e8m1(qk/2);
// These temporary registers are for masking and shift operations // These tempory registers are for masking and shift operations
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vl = 16; vl = 16;
// retrieve lane to multiply with scale // retreive lane to multiply with scale
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);

913
ggml.c

File diff suppressed because it is too large Load Diff

91
ggml.h
View File

@ -215,9 +215,9 @@
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
#define GGML_MAX_DIMS 4 #define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 2048 #define GGML_MAX_PARAMS 1024
#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 10 #define GGML_MAX_SRC 6
#define GGML_MAX_NAME 64 #define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 64 #define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
@ -244,10 +244,11 @@
#define GGML_ASSERT(x) \ #define GGML_ASSERT(x) \
do { \ do { \
if (!(x)) { \ if (!(x)) { \
fflush(stdout); \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
fflush(stderr); \
fflush(stdout); \
ggml_print_backtrace(); \ ggml_print_backtrace(); \
abort(); \ exit(1); \
} \ } \
} while (0) } while (0)
@ -283,20 +284,6 @@
const type prefix##3 = (pointer)->array[3]; \ const type prefix##3 = (pointer)->array[3]; \
GGML_UNUSED(prefix##3); GGML_UNUSED(prefix##3);
#define GGML_TENSOR_UNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
@ -395,7 +382,6 @@ extern "C" {
GGML_OP_GROUP_NORM, GGML_OP_GROUP_NORM,
GGML_OP_MUL_MAT, GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
GGML_OP_OUT_PROD, GGML_OP_OUT_PROD,
GGML_OP_SCALE, GGML_OP_SCALE,
@ -422,10 +408,8 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D, GGML_OP_POOL_1D,
GGML_OP_POOL_2D, GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
GGML_OP_FLASH_ATTN, GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF, GGML_OP_FLASH_FF,
@ -465,8 +449,7 @@ extern "C" {
GGML_UNARY_OP_GELU, GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK, GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU, GGML_UNARY_OP_SILU,
GGML_UNARY_OP_LEAKY
GGML_UNARY_OP_COUNT,
}; };
enum ggml_object_type { enum ggml_object_type {
@ -649,9 +632,6 @@ extern "C" {
GGML_API const char * ggml_op_name (enum ggml_op op); GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API const char * ggml_op_symbol(enum ggml_op op);
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_quantized(enum ggml_type type); GGML_API bool ggml_is_quantized(enum ggml_type type);
@ -794,9 +774,6 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// dst = a
// view(dst, nb1, nb2, nb3, offset) += b
// return dst
GGML_API struct ggml_tensor * ggml_acc( GGML_API struct ggml_tensor * ggml_acc(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -961,14 +938,15 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_leaky_relu( GGML_API struct ggml_tensor * ggml_leaky(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, float negative_slope, bool inplace); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu_inplace( GGML_API struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// TODO: double-check this computation is correct
GGML_API struct ggml_tensor * ggml_gelu( GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
@ -1050,16 +1028,6 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// indirect matrix multiplication
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * const as[],
int n_as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b);
// A: m columns, n rows, // A: m columns, n rows,
// B: p columns, n rows, // B: p columns, n rows,
// result is m columns, p rows // result is m columns, p rows
@ -1267,7 +1235,6 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows( GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -1316,14 +1283,6 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// fused soft_max(a*scale + mask)
// mask is optional
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale);
GGML_API struct ggml_tensor * ggml_soft_max_back( GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -1554,32 +1513,6 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
int scale_factor); int scale_factor);
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
GGML_API struct ggml_tensor * ggml_pad(
struct ggml_context * ctx,
struct ggml_tensor * a,
int p0,
int p1,
int p2,
int p3);
// sort rows
enum ggml_sort_order {
GGML_SORT_ASC,
GGML_SORT_DESC,
};
GGML_API struct ggml_tensor * ggml_argsort(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_sort_order order);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_flash_attn( GGML_API struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,
@ -1641,6 +1574,7 @@ extern "C" {
int kh); int kh);
// used in sam // used in sam
GGML_API struct ggml_tensor * ggml_add_rel_pos( GGML_API struct ggml_tensor * ggml_add_rel_pos(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -1815,7 +1749,7 @@ extern "C" {
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads); GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1); GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
@ -2111,7 +2045,6 @@ extern "C" {
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);

View File

@ -1063,7 +1063,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
if (params.use_gpu && ggml_cublas_loaded()) { if (params.use_gpu && ggml_cublas_loaded()) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init(0); backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) { if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
} }
@ -1077,10 +1077,6 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
backend_gpu = ggml_backend_metal_init(); backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) { if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
} else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
ggml_backend_free(backend_gpu);
backend_gpu = NULL;
} }
} }
#endif #endif
@ -1345,10 +1341,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@ -1578,6 +1574,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
auto tensor = model.tensors[name.data()]; auto tensor = model.tensors[name.data()];
const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
if (!is_conv_bias) {
if (ggml_nelements(tensor) != nelements) { if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
@ -1598,6 +1597,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe); __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false; return false;
} }
}
ggml_backend_t backend = wctx.backend; ggml_backend_t backend = wctx.backend;
@ -1607,7 +1607,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
|| ggml_backend_is_metal(backend) || ggml_backend_is_metal(backend)
#endif #endif
)) { ) && !is_conv_bias) {
// for the CPU and Metal backend, we can read directly into the tensor // for the CPU and Metal backend, we can read directly into the tensor
loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
BYTESWAP_TENSOR(tensor); BYTESWAP_TENSOR(tensor);
@ -1615,7 +1615,24 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// read into a temporary buffer first, then copy to device memory // read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor)); read_buf.resize(ggml_nbytes(tensor));
// we repeat the 2 bias tensors along dim 0:
// [1, 512] -> [3000, 512] (conv1.bias)
// [1, 512] -> [1500, 512] (conv2.bias)
if (is_conv_bias) {
loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
float * data_f32 = (float *) read_buf.data();
for (int64_t y = 0; y < tensor->ne[1]; ++y) {
const int64_t yy = tensor->ne[1] - y - 1;
const float val = data_f32[yy];
for (int64_t x = 0; x < tensor->ne[0]; ++x) {
data_f32[yy*tensor->ne[0] + x] = val;
}
}
} else {
loader->read(loader->context, read_buf.data(), read_buf.size()); loader->read(loader->context, read_buf.data(), read_buf.size());
}
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
} }
@ -1715,12 +1732,20 @@ static struct ggml_cgraph * whisper_build_graph_conv(
// convolution + gelu // convolution + gelu
{ {
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
if (n_ctx == hparams.n_audio_ctx) {
cur = ggml_add(ctx0, cur, model.e_conv_1_b); cur = ggml_add(ctx0, cur, model.e_conv_1_b);
} else {
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_1_b, cur->ne[0], cur->ne[1], model.e_conv_1_b->nb[1], 0)));
}
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
if (n_ctx == hparams.n_audio_ctx) {
cur = ggml_add(ctx0, cur, model.e_conv_2_b); cur = ggml_add(ctx0, cur, model.e_conv_2_b);
} else {
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_2_b, cur->ne[0], cur->ne[1], model.e_conv_2_b->nb[1], 0)));
}
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
} }
@ -3568,17 +3593,6 @@ const char * whisper_lang_str(int id) {
return nullptr; return nullptr;
} }
const char * whisper_lang_str_full(int id) {
for (const auto & kv : g_lang) {
if (kv.second.first == id) {
return kv.second.second.c_str();
}
}
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
return nullptr;
}
int whisper_lang_auto_detect_with_state( int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -5028,7 +5042,6 @@ int whisper_full_with_state(
// basically don't process anything that is less than 1.0s // basically don't process anything that is less than 1.0s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
return 0; return 0;
} }
@ -5167,7 +5180,7 @@ int whisper_full_with_state(
ctx, state, progress_cur, params.progress_callback_user_data); ctx, state, progress_cur, params.progress_callback_user_data);
} }
// if only 1 second left, then stop // of only 1 second left, then stop
if (seek + 100 >= seek_end) { if (seek + 100 >= seek_end) {
break; break;
} }
@ -5456,7 +5469,6 @@ int whisper_full_with_state(
// do not allow to go back in time // do not allow to go back in time
if (has_ts && seek_delta > seek_delta_new && result_len < i) { if (has_ts && seek_delta > seek_delta_new && result_len < i) {
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
failed = true; // TODO: maybe this is not a failure ? failed = true; // TODO: maybe this is not a failure ?
continue; continue;
} }
@ -5485,7 +5497,6 @@ int whisper_full_with_state(
if (seek + seek_delta + 100 >= seek_end) { if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1; result_len = i + 1;
} else { } else {
WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
failed = true; failed = true;
continue; continue;
} }
@ -5496,7 +5507,6 @@ int whisper_full_with_state(
seek_delta = 100*WHISPER_CHUNK_SIZE; seek_delta = 100*WHISPER_CHUNK_SIZE;
} }
WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
completed = true; completed = true;
continue; continue;
} }
@ -5512,7 +5522,6 @@ int whisper_full_with_state(
// sometimes, the decoding can get stuck in a repetition loop // sometimes, the decoding can get stuck in a repetition loop
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
failed = true; failed = true;
continue; continue;
} }
@ -5656,20 +5665,20 @@ int whisper_full_with_state(
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
} }
bool success = true;
// was the decoding successful for the current temperature? // was the decoding successful for the current temperature?
// do fallback only if: // do fallback only if:
// - we are not at the last temperature // - we are not at the last temperature
if (it != (int) temperatures.size() - 1) { // - we are not at the end of the audio (3 sec)
if (it != (int) temperatures.size() - 1 &&
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
bool success = true;
const auto & decoder = state->decoders[best_decoder_id]; const auto & decoder = state->decoders[best_decoder_id];
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
success = false; success = false;
state->n_fail_p++; state->n_fail_p++;
} }
}
if (success) { if (success) {
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
@ -5678,6 +5687,7 @@ int whisper_full_with_state(
break; break;
} }
}
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
} }
@ -6054,43 +6064,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
// 1GB array // 1GB array
const size_t size = arr*1e6; const size_t size = arr*1e6;
double sum = 0.0;
// heat-up
{
char * src = (char *) malloc(size);
char * dst = (char *) malloc(size);
for (size_t i = 0; i < size; i++) src[i] = i;
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
for (size_t i = 0; i < n; i++) {
const int64_t t0 = ggml_time_us();
memcpy(dst, src, size);
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
src[rand() % size] = rand() % 256;
}
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
// single-thread // single-thread
{ {
char * src = (char *) malloc(size); char * src = (char *) malloc(size);
@ -6101,6 +6074,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
memcpy(dst, src, size); // heat-up memcpy(dst, src, size); // heat-up
double tsum = 0.0; double tsum = 0.0;
double sum = 0.0;
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
const int64_t t0 = ggml_time_us(); const int64_t t0 = ggml_time_us();
@ -6114,72 +6088,20 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
src[rand() % size] = rand() % 256; src[rand() % size] = rand() % 256;
} }
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9)); snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1e9));
s += strbuf; s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away // needed to prevent the compiler from optimizing the memcpy away
{ {
for (size_t i = 0; i < size; i++) sum += dst[i]; for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
// multi-thread
for (uint32_t k = 1; k <= n_threads; k++) {
char * src = (char *) malloc(size);
char * dst = (char *) malloc(size);
for (size_t i = 0; i < size; i++) src[i] = i;
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
auto helper = [&](int th) {
const int64_t i0 = (th + 0)*size/k;
const int64_t i1 = (th + 1)*size/k;
for (size_t i = 0; i < n; i++) {
memcpy(dst + i0, src + i0, i1 - i0);
src[i0 + rand() % (i1 - i0)] = rand() % 256;
};
};
const int64_t t0 = ggml_time_us();
std::vector<std::thread> threads(k - 1);
for (uint32_t th = 0; th < k - 1; ++th) {
threads[th] = std::thread(helper, th);
}
helper(k - 1);
for (uint32_t th = 0; th < k - 1; ++th) {
threads[th].join();
}
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum); snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
s += strbuf; s += strbuf;
}
free(src);
free(dst);
}
return s.c_str(); return s.c_str();
} }

View File

@ -315,9 +315,6 @@ extern "C" {
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id); WHISPER_API const char * whisper_lang_str(int id);
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str_full(int id);
// Use mel data at offset_ms to try and auto-detect the spoken language // Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure // Returns the top language id or negative on failure