Compare commits

..

1 Commits

Author SHA1 Message Date
aaa3b5e5f6 ggml : try to fix the abort mechanism 2023-11-05 20:02:24 +02:00
184 changed files with 2810 additions and 44814 deletions

View File

@ -88,7 +88,7 @@ jobs:
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
ctest -L gh --output-on-failure'
@ -115,7 +115,7 @@ jobs:
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
make
ctest -L gh --output-on-failure'
@ -182,7 +182,7 @@ jobs:
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_SDL2=${{ matrix.sdl2 }}
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
- name: Build
run: |
@ -217,10 +217,10 @@ jobs:
sdl2: [ON]
include:
- arch: Win32
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
s2arc: x86
- arch: x64
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
s2arc: x64
- sdl2: ON
s2ver: 2.26.0
@ -239,7 +239,7 @@ jobs:
7z x blas.zip -oblas -y
copy blas/include/cblas.h .
copy blas/include/openblas_config.h .
echo "OPENBLAS_PATH=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
@ -252,9 +252,9 @@ jobs:
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_OPENBLAS=${{ matrix.blas }}
-DCMAKE_LIBRARY_PATH="$env:OPENBLAS_PATH/lib"
-DWHISPER_SDL2=${{ matrix.sdl2 }}
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
- name: Build
run: |
@ -263,7 +263,7 @@ jobs:
- name: Copy libopenblas.dll
if: matrix.blas == 'ON'
run: copy "$env:OPENBLAS_PATH/bin/libopenblas.dll" build/bin/${{ matrix.build }}
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
@ -320,13 +320,6 @@ jobs:
cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy CUDA DLLs
run: >
Copy-Item -PassThru
-Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
-Include cudart64_*,cublas64_*,cublasLt64_*
-Destination build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
@ -403,32 +396,6 @@ jobs:
cd examples/whisper.android
./gradlew assembleRelease --no-daemon
android_java:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v3
- name: set up JDK 11
uses: actions/setup-java@v3
with:
java-version: '11'
distribution: 'temurin'
cache: gradle
- name: Setup Android SDK
uses: android-actions/setup-android@v2
with:
api-level: 30
build-tools-version: 30.0.3
- name: Build
run: |
cd examples/whisper.android.java
chmod +x ./gradlew
./gradlew assembleRelease
java:
needs: [ 'windows' ]
runs-on: windows-latest

11
.gitignore vendored
View File

@ -8,7 +8,6 @@
.DS_Store
build/
build-coreml/
build-em/
build-debug/
build-release/
@ -19,11 +18,6 @@ build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
# SPM
.build/
.swiftpm
*.metallib
/main
/stream
/command
@ -31,7 +25,6 @@ build-sanitize-thread/
/talk-llama
/bench
/quantize
/server
/lsp
arm_neon.h
@ -55,7 +48,3 @@ bindings/java/.idea/
.idea/
benchmark_results.csv
cmake-build-debug/
.cxx/
.gradle/
local.properties

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.5)
project(whisper.cpp VERSION 1.5.0)
project(whisper.cpp VERSION 1.4.2)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")

View File

@ -1,4 +1,4 @@
default: main bench quantize server
default: main bench quantize
ifndef UNAME_S
UNAME_S := $(shell uname -s)
@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
$(CC) $(CFLAGS) -c $< -o $@
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@
@ -331,14 +331,14 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
WHISPER_OBJ += ggml-metal.o
endif
libwhisper.a: $(WHISPER_OBJ)
$(AR) rcs libwhisper.a $(WHISPER_OBJ)
libwhisper.a: ggml.o $(WHISPER_OBJ)
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
libwhisper.so: $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
clean:
rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
#
# Examples
@ -349,33 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
SRC_COMMON_SDL = examples/common-sdl.cpp
main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
./main -h
bench: examples/bench/bench.cpp $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
server: examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o server $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
#
# Audio samples
@ -420,10 +417,9 @@ samples:
.PHONY: medium.en
.PHONY: medium
.PHONY: large-v1
.PHONY: large-v2
.PHONY: large-v3
.PHONY: large
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3: main
tiny.en tiny base.en base small.en small medium.en medium large-v1 large: main
bash ./models/download-ggml-model.sh $@
@echo ""
@echo "==============================================="

View File

@ -1,77 +0,0 @@
// swift-tools-version:5.5
import PackageDescription
#if arch(arm) || arch(arm64)
let platforms: [SupportedPlatform]? = [
.macOS(.v12),
.iOS(.v14),
.watchOS(.v4),
.tvOS(.v14)
]
let exclude: [String] = []
let resources: [Resource] = [
.process("ggml-metal.metal")
]
let additionalSources: [String] = ["ggml-metal.m"]
let additionalSettings: [CSetting] = [
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
]
#else
let platforms: [SupportedPlatform]? = nil
let exclude: [String] = ["ggml-metal.metal"]
let resources: [Resource] = []
let additionalSources: [String] = []
let additionalSettings: [CSetting] = []
#endif
let package = Package(
name: "whisper",
platforms: platforms,
products: [
.library(name: "whisper", targets: ["whisper"]),
],
targets: [
.target(
name: "whisper",
path: ".",
exclude: exclude + [
"bindings",
"cmake",
"coreml",
"examples",
"extra",
"models",
"samples",
"tests",
"CMakeLists.txt",
"ggml-cuda.cu",
"ggml-cuda.h",
"Makefile"
],
sources: [
"ggml.c",
"whisper.cpp",
"ggml-alloc.c",
"ggml-backend.c",
"ggml-quants.c"
] + additionalSources,
resources: resources,
publicHeadersPath: "spm-headers",
cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
.define("GGML_USE_ACCELERATE")
// NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64")
] + additionalSettings,
linkerSettings: [
.linkedFramework("Accelerate")
]
)
],
cxxLanguageStandard: .cxx11
)

View File

@ -6,7 +6,7 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.5.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -16,10 +16,12 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
- Low memory usage (Flash Attention)
- Zero memory allocations at runtime
- Support for CPU-only inference
- [Efficient GPU support for NVIDIA](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
- [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
@ -34,8 +36,10 @@ Supported platforms:
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
The rest of the code is part of the [ggml](https://github.com/ggerganov/ggml) machine learning library.
The entire implementation of the model is contained in 2 source files:
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
@ -230,19 +234,18 @@ make small
make medium.en
make medium
make large-v1
make large-v2
make large-v3
make large
```
## Memory usage
| Model | Disk | Mem |
| --- | --- | --- |
| tiny | 75 MiB | ~273 MB |
| base | 142 MiB | ~388 MB |
| small | 466 MiB | ~852 MB |
| medium | 1.5 GiB | ~2.1 GB |
| large | 2.9 GiB | ~3.9 GB |
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
## Quantization
@ -396,12 +399,12 @@ This can result in significant speedup in encoder performance. Here are the inst
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
cached for the next run.
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
## NVIDIA GPU support
## NVIDIA GPU support via cuBLAS
With NVIDIA cards the processing of the models is done efficiently on the GPU via cuBLAS and custom CUDA kernels.
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
Now build `whisper.cpp` with cuBLAS support:

View File

@ -24,7 +24,7 @@ const (
var (
// The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3"}
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
)
var (

View File

@ -83,6 +83,7 @@ const (
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
NumFFT = C.WHISPER_N_FFT
NumMEL = C.WHISPER_N_MEL
HopLength = C.WHISPER_HOP_LENGTH
ChunkSize = C.WHISPER_CHUNK_SIZE
)
@ -102,7 +103,7 @@ var (
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
return (*Context)(ctx)
} else {
return nil

View File

@ -9,7 +9,6 @@ archivesBaseName = 'whispercpp'
group = 'io.github.ggerganov'
version = '1.4.0'
sourceCompatibility = 1.8
targetCompatibility = 1.8

View File

@ -4,7 +4,6 @@ import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import java.util.List;
@ -24,9 +23,8 @@ public class WhisperContext extends Structure {
public PointerByReference vocab;
public PointerByReference state;
/** populated by whisper_init_from_file_with_params() */
/** populated by whisper_init_from_file() */
String path_model;
WhisperContextParams params;
// public static class ByReference extends WhisperContext implements Structure.ByReference {
// }

View File

@ -2,16 +2,12 @@ package io.github.ggerganov.whispercpp;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
@ -19,9 +15,8 @@ import java.util.List;
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private Pointer ctx = null;
private Pointer paramsPointer = null;
private Pointer greedyParamsPointer = null;
private Pointer beamParamsPointer = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
@ -36,18 +31,6 @@ public class WhisperCpp implements AutoCloseable {
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
*/
public void initContext(String modelPath) throws FileNotFoundException {
initContextImpl(modelPath, getContextDefaultParams());
}
/**
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
* @param params - params to use when initialising the context
*/
public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
initContextImpl(modelPath, params);
}
private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
@ -60,26 +43,13 @@ public class WhisperCpp implements AutoCloseable {
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}
ctx = lib.whisper_init_from_file_with_params(modelPath, params);
ctx = lib.whisper_init_from_file(modelPath);
if (ctx == null) {
throw new FileNotFoundException(modelPath);
}
}
/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
public WhisperContextParams getContextDefaultParams() {
paramsPointer = lib.whisper_context_default_params_by_ref();
WhisperContextParams params = new WhisperContextParams(paramsPointer);
params.read();
return params;
}
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
@ -93,15 +63,15 @@ public class WhisperCpp implements AutoCloseable {
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyParamsPointer == null) {
greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyParamsPointer;
pointer = greedyPointer;
} else {
if (beamParamsPointer == null) {
beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamParamsPointer;
pointer = beamPointer;
}
WhisperFullParams params = new WhisperFullParams(pointer);
@ -123,17 +93,13 @@ public class WhisperCpp implements AutoCloseable {
}
private void freeParams() {
if (paramsPointer != null) {
Native.free(Pointer.nativeValue(paramsPointer));
paramsPointer = null;
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
}
if (greedyParamsPointer != null) {
Native.free(Pointer.nativeValue(greedyParamsPointer));
greedyParamsPointer = null;
}
if (beamParamsPointer != null) {
Native.free(Pointer.nativeValue(beamParamsPointer));
beamParamsPointer = null;
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
}
}
@ -163,28 +129,6 @@ public class WhisperCpp implements AutoCloseable {
return str.toString().trim();
}
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}
int nSegments = lib.whisper_full_n_segments(ctx);
List<WhisperSegment> segments= new ArrayList<>(nSegments);
for (int i = 0; i < nSegments; i++) {
long t0 = lib.whisper_full_get_segment_t0(ctx, i);
String text = lib.whisper_full_get_segment_text(ctx, i);
long t1 = lib.whisper_full_get_segment_t1(ctx, i);
segments.add(new WhisperSegment(t0,t1,text));
}
return segments;
}
// public int getTextSegmentCount(Pointer ctx) {
// return lib.whisper_full_n_segments(ctx);

View File

@ -5,7 +5,6 @@ import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
public interface WhisperCppJnaLibrary extends Library {
@ -14,31 +13,12 @@ public interface WhisperCppJnaLibrary extends Library {
String whisper_print_system_info();
/**
* DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file.
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file(String path_model);
/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
Pointer whisper_context_default_params_by_ref();
void whisper_free_context_params(Pointer params);
/**
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @param params Pointer to whisper_context_params
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer.

View File

@ -1,47 +0,0 @@
package io.github.ggerganov.whispercpp.bean;
/**
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
*/
public class WhisperSegment {
private long start, end;
private String sentence;
public WhisperSegment() {
}
public WhisperSegment(long start, long end, String sentence) {
this.start = start;
this.end = end;
this.sentence = sentence;
}
public long getStart() {
return start;
}
public long getEnd() {
return end;
}
public String getSentence() {
return sentence;
}
public void setStart(long start) {
this.start = start;
}
public void setEnd(long end) {
this.end = end;
}
public void setSentence(String sentence) {
this.sentence = sentence;
}
@Override
public String toString() {
return "[" + start + " --> " + end + "]:" + sentence;
}
}

View File

@ -1,31 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import java.util.Arrays;
import java.util.List;
/**
* Parameters for the whisper_init_from_file_with_params() function.
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
* whisper_context_default_params()
*/
public class WhisperContextParams extends Structure {
public WhisperContextParams(Pointer p) {
super(p);
}
/** Use GPU for inference Number (default = true) */
public CBool use_gpu;
/** Use GPU for inference Number (default = true) */
public void useGpu(boolean enable) {
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("use_gpu");
}
}

View File

@ -58,9 +58,6 @@ public class WhisperFullParams extends Structure {
no_context = enable ? CBool.FALSE : CBool.TRUE;
}
/** Generate timestamps or not? */
public CBool no_timestamps;
/** Flag to force single segment output (useful for streaming). (default = false) */
public CBool single_segment;
@ -307,16 +304,10 @@ public class WhisperFullParams extends Structure {
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
}
/** Grammar stuff */
public Pointer grammar_rules;
public long n_grammar_rules;
public long i_start_rule;
public float grammar_penalty;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment", "no_timestamps",
"no_context", "single_segment",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
@ -325,7 +316,6 @@ public class WhisperFullParams extends Structure {
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data",
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
"logits_filter_callback", "logits_filter_callback_user_data");
}
}

View File

@ -2,7 +2,6 @@ package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
import io.github.ggerganov.whispercpp.params.CBool;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
@ -12,7 +11,6 @@ import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.List;
class WhisperCppTest {
private static WhisperCpp whisper = new WhisperCpp();
@ -22,12 +20,11 @@ class WhisperCppTest {
static void init() throws FileNotFoundException {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
//String modelName = "../../models/ggml-tiny.bin";
String modelName = "../../models/ggml-tiny.en.bin";
try {
whisper.initContext(modelName);
//whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
//whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
modelInitialised = true;
} catch (FileNotFoundException ex) {
System.out.println("Model " + modelName + " not found");
@ -45,7 +42,7 @@ class WhisperCppTest {
assertEquals(16384, params.n_max_text_ctx);
assertFalse(params.translate);
assertEquals(0.01f, params.thold_pt);
assertEquals(5, params.beam_search.beam_size);
assertEquals(2, params.beam_search.beam_size);
assertEquals(-1.0f, params.beam_search.patience);
}
@ -58,7 +55,7 @@ class WhisperCppTest {
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertEquals(5, params.greedy.best_of);
assertEquals(2, params.greedy.best_of);
}
@Test
@ -75,11 +72,11 @@ class WhisperCppTest {
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
//params.initial_prompt = "and so my fellow Americans um, like";
// params.initial_prompt = "and so my fellow Americans um, like";
try {
@ -102,43 +99,4 @@ class WhisperCppTest {
audioInputStream.close();
}
}
@Test
void testFullTranscribeWithTime() throws Exception {
if (!modelInitialised) {
System.out.println("Model not initialised, skipping test");
return;
}
// Given
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
//params.initial_prompt = "and so my fellow Americans um, like";
try {
audioInputStream.read(b);
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
floats[j] = intSample / 32767.0f;
}
List<WhisperSegment> segments = whisper.fullTranscribeWithTime(params, floats);
assertTrue(segments.size() > 0, "The size of segments should be greater than 0");
for (WhisperSegment segment : segments) {
System.out.println(segment);
}
} finally {
audioInputStream.close();
}
}
}

View File

@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
g_context = whisper_init_from_file(path_model.c_str());
if (g_context != nullptr) {
return true;
} else {

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.5.0",
"version": "1.4.2",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

File diff suppressed because one or more lines are too long

View File

@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
if (rw->context == nullptr) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}

View File

@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
/**
Make a prediction using the convenience interface
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/

View File

@ -3,8 +3,6 @@
// Code is derived from the work of Github user @wangchou
// ref: https://github.com/wangchou/callCoreMLFromCpp
#include <stdint.h>
#if __cplusplus
extern "C" {
#endif
@ -16,8 +14,6 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
int64_t n_ctx,
int64_t n_mel,
float * mel,
float * out);

View File

@ -48,15 +48,13 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
int64_t n_ctx,
int64_t n_mel,
float * mel,
float * out) {
MLMultiArray * inMultiArray = [
[MLMultiArray alloc] initWithDataPointer: mel
shape: @[@1, @(n_mel), @(n_ctx)]
shape: @[@1, @80, @3000]
dataType: MLMultiArrayDataTypeFloat32
strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
strides: @[@(240000), @(3000), @1]
deallocator: nil
error: nil
];

View File

@ -23,7 +23,6 @@ add_library(${TARGET} STATIC
common.cpp
common-ggml.h
common-ggml.cpp
grammar-parser.cpp
)
include(DefaultTargetOptions)
@ -65,7 +64,6 @@ elseif(CMAKE_JS_VERSION)
else()
add_subdirectory(main)
add_subdirectory(stream)
add_subdirectory(server)
add_subdirectory(command)
add_subdirectory(bench)
add_subdirectory(quantize)
@ -73,5 +71,3 @@ else()
add_subdirectory(talk-llama)
add_subdirectory(lsp)
endif()
add_subdirectory(wchess)

View File

@ -11,7 +11,6 @@ const whisperParamsMock = {
language: "en",
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
use_gpu: true,
};
describe("Run whisper.node", () => {

View File

@ -36,7 +36,6 @@ struct whisper_params {
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;
std::string language = "en";
std::string prompt;
@ -154,9 +153,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
@ -318,12 +315,10 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
std::string language = whisper_params.Get("language").As<Napi::String>();
std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
params.language = language;
params.model = model;
params.fname_inp.emplace_back(input);
params.use_gpu = use_gpu;
Napi::Function callback = info[1].As<Napi::Function>();
Worker* worker = new Worker(callback, params);

View File

@ -11,7 +11,6 @@ const whisperParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: "../../samples/jfk.wav",
use_gpu: true,
};
const arguments = process.argv.slice(2);

View File

@ -23,9 +23,7 @@ void bench_main(size_t index) {
fprintf(stderr, "%s: running benchmark with %d threads - please wait...\n", __func__, n_threads);
const int n_mels = whisper_model_n_mels(ctx);
if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
fprintf(stderr, "error: failed to set mel: %d\n", ret);
return;
}
@ -59,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) {
g_worker.join();

View File

@ -11,8 +11,6 @@ struct whisper_params {
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
std::string model = "models/ggml-base.en.bin";
bool use_gpu = true;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -25,10 +23,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -48,7 +45,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " %-7s 0 - whisper\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
@ -58,10 +54,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
int whisper_bench_full(const whisper_params & params) {
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
{
fprintf(stderr, "\n");
@ -73,15 +66,13 @@ int whisper_bench_full(const whisper_params & params) {
return 2;
}
const int n_mels = whisper_model_n_mels(ctx);
if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
if (int ret = whisper_set_mel(ctx, nullptr, 0, WHISPER_N_MEL)) {
fprintf(stderr, "error: failed to set mel: %d\n", ret);
return 3;
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode: %d\n", ret);
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
@ -90,13 +81,13 @@ int whisper_bench_full(const whisper_params & params) {
// prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
// text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
@ -104,30 +95,20 @@ int whisper_bench_full(const whisper_params & params) {
// actual run
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode: %d\n", ret);
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
// text-generation
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
// batched decoding
for (int i = 0; i < 64; i++) {
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
}
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
}

View File

@ -243,7 +243,7 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -9,7 +9,6 @@
#include "common-sdl.h"
#include "common.h"
#include "whisper.h"
#include "grammar-parser.h"
#include <sstream>
#include <cassert>
@ -22,11 +21,6 @@
#include <vector>
#include <map>
bool file_exists(const std::string & fname) {
std::ifstream f(fname.c_str());
return f.good();
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -36,27 +30,20 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float grammar_penalty = 100.0f;
grammar_parser::parse_state grammar_parsed;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
std::string commands;
std::string prompt;
std::string context;
std::string grammar;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -81,15 +68,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -118,36 +101,21 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
std::string transcribe(
whisper_context * ctx,
const whisper_params & params,
const std::vector<float> & pcmf32,
const std::string & grammar_rule,
float & logprob_min,
float & logprob_sum,
int & n_tokens,
int64_t & t_ms) {
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
logprob_min = 0.0f;
logprob_sum = 0.0f;
n_tokens = 0;
prob = 0.0f;
t_ms = 0;
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
@ -155,41 +123,19 @@ std::string transcribe(
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.no_timestamps = params.no_timestamps;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 5;
wparams.beam_search.beam_size = 5;
wparams.initial_prompt = params.context.data();
const auto & grammar_parsed = params.grammar_parsed;
auto grammar_rules = grammar_parsed.c_rules();
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
} else {
wparams.grammar_rules = grammar_rules.data();
wparams.n_grammar_rules = grammar_rules.size();
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
wparams.grammar_penalty = params.grammar_penalty;
}
}
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
@ -198,17 +144,19 @@ std::string transcribe(
result += text;
const int n = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n; ++j) {
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
if(token.plog > 0.0f) exit(0);
logprob_min = std::min(logprob_min, token.plog);
logprob_sum += token.plog;
++n_tokens;
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
@ -299,7 +247,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
fprintf(stderr, " ]\n");
}
std::string k_prompt = "select one from the available words: ";
std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
@ -467,9 +415,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
bool is_running = true;
bool ask_prompt = true;
float logprob_min = 0.0f;
float logprob_sum = 0.0f;
int n_tokens = 0;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
@ -507,7 +453,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// detect the commands
audio.get(params.command_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto words = get_words(txt);
@ -543,27 +489,18 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float logprob_min0 = 0.0f;
float logprob_min = 0.0f;
float logprob_sum0 = 0.0f;
float logprob_sum = 0.0f;
int n_tokens0 = 0;
int n_tokens = 0;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
std::string k_prompt = "Ok Whisper, start listening for commands.";
if (!params.prompt.empty()) {
k_prompt = params.prompt;
}
const std::string k_prompt = "Ok Whisper, start listening for commands.";
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
@ -596,11 +533,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
const float p = 100.0f * std::exp(logprob_min0);
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
@ -621,30 +556,19 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
// prepend 3 second of silence
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
const float p = 100.0f * std::exp(logprob_min);
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
if (n >= txt.size()) {
break;
}
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
@ -657,16 +581,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
}
}
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
if (best_len == 0) {
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
} else {
// cut the prompt from the decoded text
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
@ -693,10 +610,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
// print some info about the processing
{
@ -734,36 +648,12 @@ int main(int argc, char ** argv) {
int ret_val = 0;
if (!params.grammar.empty()) {
auto & grammar = params.grammar_parsed;
if (file_exists(params.grammar.c_str())) {
// read grammar from file
std::ifstream ifs(params.grammar.c_str());
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
grammar = grammar_parser::parse(txt.c_str());
} else {
// read grammar from string
grammar = grammar_parser::parse(params.grammar.c_str());
}
// will be empty (default) if there are parse errors
if (grammar.rules.empty()) {
ret_val = 1;
} else {
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, grammar);
fprintf(stderr, "\n");
}
}
if (ret_val == 0) {
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
}
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
}
audio.pause();

View File

@ -9,11 +9,6 @@ static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
{"q2_k", GGML_FTYPE_MOSTLY_Q2_K},
{"q3_k", GGML_FTYPE_MOSTLY_Q3_K},
{"q4_k", GGML_FTYPE_MOSTLY_Q4_K},
{"q5_k", GGML_FTYPE_MOSTLY_Q5_K},
{"q6_k", GGML_FTYPE_MOSTLY_Q6_K},
};
void ggml_print_ftypes(FILE * fp) {
@ -53,15 +48,15 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_UNKNOWN:
case GGML_FTYPE_ALL_F32:
case GGML_FTYPE_MOSTLY_F16:
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
case GGML_FTYPE_MOSTLY_Q2_K:
case GGML_FTYPE_MOSTLY_Q3_K:
case GGML_FTYPE_MOSTLY_Q4_K:
case GGML_FTYPE_MOSTLY_Q5_K:
case GGML_FTYPE_MOSTLY_Q6_K:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
@ -172,17 +167,24 @@ bool ggml_common_quantize_0(
switch ((ggml_type) ttype) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
{
cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements, hist_cur.data());
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_1:
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_0:
{
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_1:
{
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q8_0:
{
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_F32:
case GGML_TYPE_F16:
@ -190,6 +192,11 @@ bool ggml_common_quantize_0(
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_COUNT:
{

View File

@ -139,13 +139,10 @@ void audio_async::callback(uint8_t * stream, int len) {
return;
}
size_t n_samples = len / sizeof(float);
const size_t n_samples = len / sizeof(float);
if (n_samples > m_audio.size()) {
n_samples = m_audio.size();
stream += (len - (n_samples * sizeof(float)));
}
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
@ -156,7 +153,7 @@ void audio_async::callback(uint8_t * stream, int len) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], stream + n0 * sizeof(float), (n_samples - n0) * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();

View File

@ -41,6 +41,7 @@ private:
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};

View File

@ -181,7 +181,7 @@ private:
// It is assumed that PCM data is normalized to a range from -1 to 1
bool write_audio(const float * data, size_t length) {
for (size_t i = 0; i < length; ++i) {
const int16_t intSample = data[i] * 32767;
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
dataSize += sizeof(int16_t);
}

View File

@ -1,423 +0,0 @@
#include "grammar-parser.h"
#include <cstdint>
#include <cwchar>
#include <string>
#include <utility>
#include <stdexcept>
#include <exception>
namespace grammar_parser {
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from whisper.cpp
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t first_byte = static_cast<uint8_t>(*src);
uint8_t highbits = first_byte >> 4;
int len = lookup[highbits];
uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask;
const char * end = src + len; // may overrun!
const char * pos = src + 1;
for ( ; pos < end && *pos; pos++) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
}
return std::make_pair(value, pos);
}
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
return result.first->second;
}
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
return next_id;
}
void add_rule(
parse_state & state,
uint32_t rule_id,
const std::vector<whisper_grammar_element> & rule) {
if (state.rules.size() <= rule_id) {
state.rules.resize(rule_id + 1);
}
state.rules[rule_id] = rule;
}
bool is_word_char(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
const char * pos = src;
const char * end = src + size;
uint32_t value = 0;
for ( ; pos < end && *pos; pos++) {
value <<= 4;
char c = *pos;
if ('a' <= c && c <= 'f') {
value += c - 'a' + 10;
} else if ('A' <= c && c <= 'F') {
value += c - 'A' + 10;
} else if ('0' <= c && c <= '9') {
value += c - '0';
} else {
break;
}
}
if (pos != end) {
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
}
return std::make_pair(value, pos);
}
const char * parse_space(const char * src, bool newline_ok) {
const char * pos = src;
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
if (*pos == '#') {
while (*pos && *pos != '\r' && *pos != '\n') {
pos++;
}
} else {
pos++;
}
}
return pos;
}
const char * parse_name(const char * src) {
const char * pos = src;
while (is_word_char(*pos)) {
pos++;
}
if (pos == src) {
throw std::runtime_error(std::string("expecting name at ") + src);
}
return pos;
}
std::pair<uint32_t, const char *> parse_char(const char * src) {
if (*src == '\\') {
switch (src[1]) {
case 'x': return parse_hex(src + 2, 2);
case 'u': return parse_hex(src + 2, 4);
case 'U': return parse_hex(src + 2, 8);
case 't': return std::make_pair('\t', src + 2);
case 'r': return std::make_pair('\r', src + 2);
case 'n': return std::make_pair('\n', src + 2);
case '\\':
case '"':
case '[':
case ']':
return std::make_pair(src[1], src + 2);
default:
throw std::runtime_error(std::string("unknown escape at ") + src);
}
} else if (*src) {
return decode_utf8(src);
}
throw std::runtime_error("unexpected end of input");
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested);
const char * parse_sequence(
parse_state & state,
const char * src,
const std::string & rule_name,
std::vector<whisper_grammar_element> & out_elements,
bool is_nested) {
size_t last_sym_start = out_elements.size();
const char * pos = src;
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = out_elements.size();
while (*pos != '"') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = WHISPER_GRETYPE_CHAR_NOT;
}
last_sym_start = out_elements.size();
while (*pos != ']') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum whisper_gretype type = last_sym_start < out_elements.size()
? WHISPER_GRETYPE_CHAR_ALT
: start_type;
out_elements.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = out_elements.size();
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
last_sym_start = out_elements.size();
// output reference to synthesized rule
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
}
// apply transformation to previous symbol (last_sym_start to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
std::vector<whisper_grammar_element> sub_rule;
// add preceding symbol to generated rule
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
if (*pos == '*' || *pos == '+') {
// cause generated rule to recurse
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
}
// mark start of alternate def
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
if (*pos == '+') {
// add preceding symbol as alternate only for '+' (otherwise empty)
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
}
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
add_rule(state, sub_rule_id, sub_rule);
// in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start);
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
pos = parse_space(pos + 1, is_nested);
} else {
break;
}
}
return pos;
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested) {
std::vector<whisper_grammar_element> rule;
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
while (*pos == '|') {
rule.push_back({WHISPER_GRETYPE_ALT, 0});
pos = parse_space(pos + 1, true);
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
}
rule.push_back({WHISPER_GRETYPE_END, 0});
add_rule(state, rule_id, rule);
return pos;
}
const char * parse_rule(parse_state & state, const char * src) {
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(state, src, name_len);
const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(state, pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
parse_state parse(const char * src) {
try {
parse_state state;
const char * pos = parse_space(src, true);
while (*pos) {
pos = parse_rule(state, pos);
}
return state;
} catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
return parse_state();
}
}
void print_grammar_char(FILE * file, uint32_t c) {
if (0x20 <= c && c <= 0x7f) {
fprintf(file, "%c", static_cast<char>(c));
} else {
// cop out of encoding UTF-8
fprintf(file, "<U+%04X>", c);
}
}
bool is_char_element(whisper_grammar_element elem) {
switch (elem.type) {
case WHISPER_GRETYPE_CHAR: return true;
case WHISPER_GRETYPE_CHAR_NOT: return true;
case WHISPER_GRETYPE_CHAR_ALT: return true;
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
default: return false;
}
}
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
for (auto elem : rule) {
switch (elem.type) {
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
}
switch (elem.type) {
case WHISPER_GRETYPE_END:
case WHISPER_GRETYPE_ALT:
case WHISPER_GRETYPE_RULE_REF:
fprintf(file, "(%u) ", elem.value);
break;
case WHISPER_GRETYPE_CHAR:
case WHISPER_GRETYPE_CHAR_NOT:
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
case WHISPER_GRETYPE_CHAR_ALT:
fprintf(file, "(\"");
print_grammar_char(file, elem.value);
fprintf(file, "\") ");
break;
}
}
fprintf(file, "\n");
}
void print_rule(
FILE * file,
uint32_t rule_id,
const std::vector<whisper_grammar_element> & rule,
const std::map<uint32_t, std::string> & symbol_id_names) {
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
throw std::runtime_error(
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
}
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
whisper_grammar_element elem = rule[i];
switch (elem.type) {
case WHISPER_GRETYPE_END:
throw std::runtime_error(
"unexpected end of rule: " + std::to_string(rule_id) + "," +
std::to_string(i));
case WHISPER_GRETYPE_ALT:
fprintf(file, "| ");
break;
case WHISPER_GRETYPE_RULE_REF:
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
break;
case WHISPER_GRETYPE_CHAR:
fprintf(file, "[");
print_grammar_char(file, elem.value);
break;
case WHISPER_GRETYPE_CHAR_NOT:
fprintf(file, "[^");
print_grammar_char(file, elem.value);
break;
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
if (i == 0 || !is_char_element(rule[i - 1])) {
throw std::runtime_error(
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
std::to_string(rule_id) + "," + std::to_string(i));
}
fprintf(file, "-");
print_grammar_char(file, elem.value);
break;
case WHISPER_GRETYPE_CHAR_ALT:
if (i == 0 || !is_char_element(rule[i - 1])) {
throw std::runtime_error(
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
std::to_string(rule_id) + "," + std::to_string(i));
}
print_grammar_char(file, elem.value);
break;
}
if (is_char_element(elem)) {
switch (rule[i + 1].type) {
case WHISPER_GRETYPE_CHAR_ALT:
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
break;
default:
fprintf(file, "] ");
}
}
}
fprintf(file, "\n");
}
void print_grammar(FILE * file, const parse_state & state) {
try {
std::map<uint32_t, std::string> symbol_id_names;
for (auto kv : state.symbol_ids) {
symbol_id_names[kv.second] = kv.first;
}
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
// fprintf(file, "%zu: ", i);
// print_rule_binary(file, state.rules[i]);
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
// fprintf(file, "\n");
}
} catch (const std::exception & err) {
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
}
}
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
std::vector<const whisper_grammar_element *> ret;
for (const auto & rule : rules) {
ret.push_back(rule.data());
}
return ret;
}
}

View File

@ -1,29 +0,0 @@
// Implements a parser for an extended Backus-Naur form (BNF), producing the
// binary context-free grammar format specified by whisper.h. Supports character
// ranges, grouping, and repetition operators. As an example, a grammar for
// arithmetic might look like:
//
// root ::= expr
// expr ::= term ([-+*/] term)*
// term ::= num | "(" space expr ")" space
// num ::= [0-9]+ space
// space ::= [ \t\n]*
#pragma once
#include "whisper.h"
#include <vector>
#include <map>
#include <cstdint>
#include <string>
namespace grammar_parser {
struct parse_state {
std::map<std::string, uint32_t> symbol_ids;
std::vector<std::vector<whisper_grammar_element>> rules;
std::vector<const whisper_grammar_element *> c_rules() const;
};
parse_state parse(const char * src);
void print_grammar(FILE * file, const parse_state & state);
}

View File

@ -48,7 +48,7 @@ if [ -n "$3" ]; then
fi
# Whisper models
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
# list available models
function list_models {

View File

@ -30,7 +30,6 @@ struct whisper_params {
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool use_gpu = true;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
@ -73,7 +72,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else {
@ -104,7 +102,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, "\n");
@ -435,9 +432,7 @@ int main(int argc, char ** argv) {
}
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
// init audio
audio_async audio(30*1000);

View File

@ -62,8 +62,8 @@ struct whisper_params {
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
int32_t best_of = 2;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
@ -90,7 +90,6 @@ struct whisper_params {
bool print_progress = false;
bool no_timestamps = false;
bool log_score = false;
bool use_gpu = true;
std::string language = "en";
std::string prompt;
@ -165,8 +164,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -223,7 +221,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, "\n");
}
@ -880,10 +877,7 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
@ -925,9 +919,9 @@ int main(int argc, char ** argv) {
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors, params.beam_size, params.best_of,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",

View File

@ -1,6 +0,0 @@
set(TARGET server)
add_executable(${TARGET} server.cpp httplib.h json.hpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,59 +0,0 @@
# whisper.cpp http server
Simple http server. WAV Files are passed to the inference model via http requests.
```
./server -h
usage: ./bin/server [options]
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [2 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
--port PORT, [8080 ] Port number for the server
```
## request examples
**/inference**
```
curl 127.0.0.1:8080/inference \
-H "Content-Type: multipart/form-data" \
-F file="@<file-path>" \
-F temperature="0.2" \
-F response-format="json"
```
**/load**
```
curl 127.0.0.1:8080/load \
-H "Content-Type: multipart/form-data" \
-F model="<path-to-model-file>"
```

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,699 +0,0 @@
#include "common.h"
#include "whisper.h"
#include "httplib.h"
#include "json.hpp"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <cstring>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
using namespace httplib;
using json = nlohmann::json;
namespace {
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green.
const std::vector<std::string> k_colors = {
"\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
"\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
};
// output formats
const std::string json_format = "json";
const std::string text_format = "text";
const std::string srt_format = "srt";
const std::string vjson_format = "verbose_json";
const std::string vtt_format = "vtt";
struct server_params
{
std::string hostname = "127.0.0.1";
std::string public_path = "examples/server/public";
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
};
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 2;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float userdef_temp = 0.20f;
bool speed_up = false;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;
std::string language = "en";
std::string prompt = "";
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
std::string response_format = json_format;
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
std::string openvino_encode_device = "CPU";
};
// 500 -> 00:05.000
// 6000 -> 01:00.000
std::string to_timestamp(int64_t t, bool comma = false) {
int64_t msec = t * 10;
int64_t hr = msec / (1000 * 60 * 60);
msec = msec - hr * (1000 * 60 * 60);
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
return std::string(buf);
}
int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
bool is_file_exist(const char *fileName)
{
std::ifstream infile(fileName);
return infile.good();
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params,
const server_params& sparams) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] \n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
// server params
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
fprintf(stderr, "\n");
}
bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
else if ( arg == "--public") { sparams.public_path = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
}
return true;
}
struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
int progress_prev;
};
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "0";
} else if (energy1 > 1.1*energy0) {
speaker = "1";
} else {
speaker = "?";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
if (!id_only) {
speaker.insert(0, "(speaker ");
speaker.append(")");
}
return speaker;
}
void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
if (progress >= *progress_prev + progress_step) {
*progress_prev += progress_step;
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
}
}
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0 = 0;
int64_t t1 = 0;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
}
if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::stringstream result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
result << speaker << text << "\n";
}
return result.str();
}
void get_req_parameters(const Request & req, whisper_params & params)
{
// user model configu.has_fileion
if (req.has_file("offset-t"))
{
params.offset_t_ms = std::stoi(req.get_file_value("offset-t").content);
}
if (req.has_file("offset-n"))
{
params.offset_n = std::stoi(req.get_file_value("offset-n").content);
}
if (req.has_file("duration"))
{
params.duration_ms = std::stoi(req.get_file_value("duration").content);
}
if (req.has_file("max-context"))
{
params.max_context = std::stoi(req.get_file_value("max-context").content);
}
if (req.has_file("prompt"))
{
params.prompt = req.get_file_value("prompt").content;
}
if (req.has_file("response-format"))
{
params.response_format = req.get_file_value("response-format").content;
}
if (req.has_file("temerature"))
{
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
}
}
} // namespace
int main(int argc, char ** argv) {
whisper_params params;
server_params sparams;
std::mutex whisper_mutex;
if (whisper_params_parse(argc, argv, params, sparams) == false) {
whisper_print_usage(argc, argv, params, sparams);
return 1;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
if (params.diarize && params.tinydiarize) {
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
Server svr;
std::string const default_content = "<html>hello</html>";
// this is only called if no index.html is found in the public --path
svr.Get("/", [&default_content](const Request &, Response &res){
res.set_content(default_content, "text/html");
return false;
});
svr.Post("/inference", [&](const Request &req, Response &res){
// aquire whisper model mutex lock
whisper_mutex.lock();
// first check user requested fields of the request
if (!req.has_file("file"))
{
fprintf(stderr, "error: no 'file' field in the request\n");
const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
auto audio_file = req.get_file_value("file");
// check non-required fields
get_req_parameters(req, params);
std::string filename{audio_file.filename};
printf("Received request: %s\n", filename.c_str());
// audio arrays
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// write file to temporary file
std::ofstream temp_file{filename, std::ios::binary};
temp_file << audio_file.content;
// read wav content into pcmf32
if (!::read_wav(filename, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", filename.c_str());
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
// remove temp file
std::remove(filename.c_str());
printf("Successfully loaded %s\n", filename.c_str());
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, filename.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// run the inference
{
printf("Running whisper.cpp inference on %s\n", filename.c_str());
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.thold_pt = params.word_thold;
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = params.userdef_temp;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
if (wparams.print_progress) {
wparams.progress_callback = whisper_print_progress_callback;
wparams.progress_callback_user_data = &user_data;
}
// examples for abort mechanism
// in examples below, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
// the callback is called before every computation - if it returns true, the computation is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
wparams.abort_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
}
// return results to user
if (params.response_format == text_format)
{
std::string results = output_str(ctx, params, pcmf32s);
res.set_content(results.c_str(), "text/html");
}
// TODO add more output formats
else
{
std::string results = output_str(ctx, params, pcmf32s);
json jres = json{
{"text", results}
};
res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
}
// return whisper model mutex lock
whisper_mutex.unlock();
});
svr.Post("/load", [&](const Request &req, Response &res){
whisper_mutex.lock();
if (!req.has_file("model"))
{
fprintf(stderr, "error: no 'model' field in the request\n");
const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
std::string model = req.get_file_value("model").content;
if (!is_file_exist(model.c_str()))
{
fprintf(stderr, "error: 'model': %s not found!\n", model.c_str());
const std::string error_resp = "{\"error\":\"model not found!\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
// clean up
whisper_free(ctx);
// whisper init
ctx = whisper_init_from_file_with_params(model.c_str(), cparams);
// TODO perhaps load prior model here instead of exit
if (ctx == nullptr) {
fprintf(stderr, "error: model init failed, no model loaded must exit\n");
exit(1);
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
const std::string success = "Load was successful!";
res.set_content(success, "application/text");
// check if the model is in the file system
whisper_mutex.unlock();
});
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try {
std::rethrow_exception(std::move(ep));
} catch (std::exception &e) {
snprintf(buf, sizeof(buf), fmt, e.what());
} catch (...) {
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain");
res.status = 500;
});
svr.set_error_handler([](const Request &, Response &res) {
if (res.status == 400) {
res.set_content("Invalid request", "text/plain");
} else if (res.status != 500) {
res.set_content("File Not Found", "text/plain");
res.status = 404;
}
});
// set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
sparams.hostname.c_str(), sparams.port);
return 1;
}
// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
// to make it ctrl+clickable:
printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
if (!svr.listen_after_bind())
{
return 1;
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

View File

@ -132,7 +132,7 @@ EMSCRIPTEN_BINDINGS(stream) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -48,12 +48,11 @@ struct whisper_params {
bool no_context = true;
bool no_timestamps = false;
bool tinydiarize = false;
bool save_audio = false; // save audio to wav file
bool use_gpu = true;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
bool save_audio = false; // save audio to wav file
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -66,26 +65,25 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); }
else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); }
else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); }
else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); }
else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@ -120,9 +118,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
fprintf(stderr, "\n");
}
@ -166,10 +163,7 @@ int main(int argc, char ** argv) {
exit(0);
}
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
@ -430,4 +424,4 @@ int main(int argc, char ** argv) {
whisper_free(ctx);
return 0;
}
}

View File

@ -14,8 +14,6 @@ if (WHISPER_SDL2)
../common-sdl.cpp
../../ggml.c
../../ggml-alloc.c
../../ggml-backend.c
../../ggml-quants.c
../../whisper.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)

View File

@ -53,7 +53,6 @@ struct whisper_params {
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
int32_t n_gpu_layers = 999;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
@ -64,7 +63,6 @@ struct whisper_params {
bool print_energy = false;
bool no_timestamps = true;
bool verbose_prompt = false;
bool use_gpu = true;
std::string person = "Georgi";
std::string language = "en";
@ -86,27 +84,25 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i];}
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "--prompt-file") {
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i];}
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
@ -114,7 +110,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
}
}
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -130,29 +125,27 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
@ -251,7 +244,7 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
@ -259,21 +252,13 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// llama init
llama_backend_init(true);
auto lmparams = llama_model_default_params();
if (!params.use_gpu) {
lmparams.n_gpu_layers = 0;
} else {
lmparams.n_gpu_layers = params.n_gpu_layers;
}
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
@ -686,8 +671,8 @@ int main(int argc, char ** argv) {
}
}
text_to_speak = ::replace(text_to_speak, "'", "'\"'\"'");
int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + text_to_speak + "'").c_str());
text_to_speak = ::replace(text_to_speak, "\"", "");
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
if (ret != 0) {
fprintf(stderr, "%s: failed to speak\n", __func__);
}

View File

@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {

View File

@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
return false;
}
char word[129];
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word[len] = '\0';
fin.read((char *) word, len);
word.resize(len);
fin.read((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;

View File

@ -31,7 +31,6 @@ struct whisper_params {
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
std::string person = "Santa";
std::string language = "en";
@ -62,7 +61,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
@ -96,7 +94,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
@ -184,10 +181,8 @@ int main(int argc, char ** argv) {
}
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// gpt init

View File

@ -21,7 +21,7 @@ help()
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
echo "options:"
echo "-s Step in seconds (default is $step)."
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large-v3' (default is '$model')."
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large' (default is '$model')."
echo "-t Number of threads to use."
echo "-h Print this help page."
echo

View File

@ -1,9 +0,0 @@
set(CMAKE_CXX_STANDARD 11)
add_subdirectory(libwchess)
if (EMSCRIPTEN)
add_subdirectory(wchess.wasm)
else()
add_subdirectory(wchess.cmd)
endif()

View File

@ -1,19 +0,0 @@
add_library(libwchess
WChess.cpp
WChess.h
Chessboard.cpp
Chessboard.h
)
target_link_libraries(libwchess
PUBLIC
whisper
common
)
target_include_directories(libwchess
PUBLIC
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
)
add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)

View File

@ -1,714 +0,0 @@
#include "Chessboard.h"
#include <vector>
#include <algorithm>
#include <cstring>
#include <set>
namespace {
// remove std::string_view, c++17 -> c++11
constexpr std::array<const char*, 64> positions = {
"a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
"a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2",
"a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3",
"a4", "b4", "c4", "d4", "e4", "f4", "g4", "h4",
"a5", "b5", "c5", "d5", "e5", "f5", "g5", "h5",
"a6", "b6", "c6", "d6", "e6", "f6", "g6", "h6",
"a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7",
"a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8",
};
constexpr int INVALID_POS = positions.size();
constexpr int R = 0; // rank index
constexpr int F = 1; // file index
#define POS ((c[F] - '1') * 8 + (c[R] - 'a'))
constexpr int operator ""_P(const char * c, size_t size) {
return size < 2 || POS < 0 || POS > INVALID_POS ? INVALID_POS : POS;
}
#undef POS
struct sview {
const char * ptr = nullptr;
size_t size = 0;
sview() = default;
sview(const char * p, size_t s) : ptr(p), size(s) {}
sview(const std::string& s) : ptr(s.data()), size(s.size()) {}
size_t find(char del, size_t pos) {
while (pos < size && ptr[pos] != del) ++pos;
return pos < size ? pos : std::string::npos;
}
};
std::vector<sview> split(sview str, char del) {
std::vector<sview> res;
size_t cur = 0;
size_t last = 0;
while (cur != std::string::npos) {
if (str.ptr[last] == ' ') {
++last;
continue;
}
cur = str.find(del, last);
size_t len = cur == std::string::npos ? str.size - last : cur - last;
res.emplace_back(str.ptr + last, len);
last = cur + 1;
}
return res;
}
size_t strToPos(sview str) {
return operator ""_P(str.ptr, str.size);
}
constexpr std::array<const char*, 6> pieceNames = {
"pawn", "knight", "bishop", "rook", "queen", "king",
};
int strToType(sview str) {
auto it = std::find_if(pieceNames.begin(), pieceNames.end(), [str] (const char* name) { return strncmp(name, str.ptr, str.size) == 0; });
return it != pieceNames.end() ? int(it - pieceNames.begin()) : pieceNames.size();
}
}
Chessboard::Chessboard()
: blackPieces {{
{Piece::Pawn, Piece::Black, "a7"_P },
{Piece::Pawn, Piece::Black, "b7"_P },
{Piece::Pawn, Piece::Black, "c7"_P },
{Piece::Pawn, Piece::Black, "d7"_P },
{Piece::Pawn, Piece::Black, "e7"_P },
{Piece::Pawn, Piece::Black, "f7"_P },
{Piece::Pawn, Piece::Black, "g7"_P },
{Piece::Pawn, Piece::Black, "h7"_P },
{Piece::Rook, Piece::Black, "a8"_P },
{Piece::Knight, Piece::Black, "b8"_P },
{Piece::Bishop, Piece::Black, "c8"_P },
{Piece::Queen, Piece::Black, "d8"_P },
{Piece::King, Piece::Black, "e8"_P },
{Piece::Bishop, Piece::Black, "f8"_P },
{Piece::Knight, Piece::Black, "g8"_P },
{Piece::Rook, Piece::Black, "h8"_P },
}}
, whitePieces {{
{Piece::Pawn, Piece::White, "a2"_P },
{Piece::Pawn, Piece::White, "b2"_P },
{Piece::Pawn, Piece::White, "c2"_P },
{Piece::Pawn, Piece::White, "d2"_P },
{Piece::Pawn, Piece::White, "e2"_P },
{Piece::Pawn, Piece::White, "f2"_P },
{Piece::Pawn, Piece::White, "g2"_P },
{Piece::Pawn, Piece::White, "h2"_P },
{Piece::Rook, Piece::White, "a1"_P },
{Piece::Knight, Piece::White, "b1"_P },
{Piece::Bishop, Piece::White, "c1"_P },
{Piece::Queen, Piece::White, "d1"_P },
{Piece::King, Piece::White, "e1"_P },
{Piece::Bishop, Piece::White, "f1"_P },
{Piece::Knight, Piece::White, "g1"_P },
{Piece::Rook, Piece::White, "h1"_P },
}}
, board {{
&whitePieces[ 8], &whitePieces[ 9], &whitePieces[10], &whitePieces[11], &whitePieces[12], &whitePieces[13], &whitePieces[14], &whitePieces[15],
&whitePieces[ 0], &whitePieces[ 1], &whitePieces[ 2], &whitePieces[ 3], &whitePieces[ 4], &whitePieces[ 5], &whitePieces[ 6], &whitePieces[ 7],
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
&blackPieces[ 0], &blackPieces[ 1], &blackPieces[ 2], &blackPieces[ 3], &blackPieces[ 4], &blackPieces[ 5], &blackPieces[ 6], &blackPieces[ 7],
&blackPieces[ 8], &blackPieces[ 9], &blackPieces[10], &blackPieces[11], &blackPieces[12], &blackPieces[13], &blackPieces[14], &blackPieces[15],
}}
, whiteMoves {
{"b1"_P, "a3"_P}, {"b1"_P, "c3"_P},
{"g1"_P, "f3"_P}, {"g1"_P, "h3"_P},
{"a2"_P, "a3"_P}, {"a2"_P, "a4"_P},
{"b2"_P, "b3"_P}, {"b2"_P, "b4"_P},
{"c2"_P, "c3"_P}, {"c2"_P, "c4"_P},
{"d2"_P, "d3"_P}, {"d2"_P, "d4"_P},
{"e2"_P, "e3"_P}, {"e2"_P, "e4"_P},
{"f2"_P, "f3"_P}, {"f2"_P, "f4"_P},
{"g2"_P, "g3"_P}, {"g2"_P, "g4"_P},
{"h2"_P, "h3"_P}, {"h2"_P, "h4"_P},
}
, blackMoves {
{"a7"_P, "a5"_P}, {"a7"_P, "a6"_P},
{"b7"_P, "b5"_P}, {"b7"_P, "b6"_P},
{"c7"_P, "c5"_P}, {"c7"_P, "c6"_P},
{"d7"_P, "d5"_P}, {"d7"_P, "d6"_P},
{"e7"_P, "e5"_P}, {"e7"_P, "e6"_P},
{"f7"_P, "f5"_P}, {"f7"_P, "f6"_P},
{"g7"_P, "g5"_P}, {"g7"_P, "g6"_P},
{"h7"_P, "h5"_P}, {"h7"_P, "h6"_P},
{"b8"_P, "a6"_P}, {"b8"_P, "c6"_P},
{"g8"_P, "f6"_P}, {"g8"_P, "h6"_P},
}
{
static_assert(pieceNames.size() == Chessboard::Piece::Taken, "Mismatch between piece names and types");
std::sort(whiteMoves.begin(), whiteMoves.end());
std::sort(blackMoves.begin(), blackMoves.end());
}
std::string Chessboard::getRules(const std::string& prompt) const {
// leading space is very important!
std::string result =
"\n"
"# leading space is very important!\n"
"\n";
if (prompt.empty()) {
result += "move ::= \" \" ((piece | frompos) \" \" \"to \"?)? topos\n";
//result += "move ::= \" \" frompos \" \" \"to \"? topos\n";
}
else {
// result += "move ::= prompt \" \" ((piece | frompos) \" \" \"to \"?)? topos\n"
result += "move ::= prompt \" \" frompos \" \" \"to \"? topos\n"
"\n"
"prompt ::= \" " + prompt + "\"\n";
}
std::set<std::string> pieces;
std::set<std::string> from_pos;
std::set<std::string> to_pos;
auto& allowed_moves = m_moveCounter % 2 ? blackMoves : whiteMoves;
for (auto& m : allowed_moves) {
if (board[m.first]->type != Piece::Taken) pieces.insert(pieceNames[board[m.first]->type]);
from_pos.insert(positions[m.first]);
to_pos.insert(positions[m.second]);
}
if (!pieces.empty()) {
result += "piece ::= (";
for (auto& p : pieces) result += " \"" + p + "\" |";
result.pop_back();
result += ")\n\n";
}
if (!from_pos.empty()) {
result += "frompos ::= (";
for (auto& p : from_pos) result += " \"" + p + "\" |";
result.pop_back();
result += ")\n";
}
if (!to_pos.empty()) {
result += "topos ::= (";
for (auto& p : to_pos) result += " \"" + p + "\" |";
result.pop_back();
result += ")\n";
}
return result;
}
std::string Chessboard::stringifyBoard() {
static constexpr std::array<char, 6> blackShort = {
'p', 'n', 'b', 'r', 'q', 'k',
};
static constexpr std::array<char, 6> whiteShort = {
'P', 'N', 'B', 'R', 'Q', 'K',
};
std::string result;
result.reserve(16 + 2 * 64 + 16);
for (char rank = 'a'; rank <= 'h'; ++rank) {
result.push_back(rank);
result.push_back(' ');
}
result.back() = '\n';
for (int i = 7; i >= 0; --i) {
for (int j = 0; j < 8; ++j) {
auto p = board[i * 8 + j];
if (p) result.push_back(p->color == Piece::White ? whiteShort[p->type] : blackShort[p->type]);
else result.push_back((i + j) % 2 ? '.' : '*');
result.push_back(' ');
}
result.push_back('0' + i + 1);
result.push_back('\n');
}
return result;
}
std::string Chessboard::process(const std::string& command) {
auto color = Piece::Colors(m_moveCounter % 2);
fprintf(stdout, "%s: Command to %s: '%s%.*s%s'\n", __func__, (color ? "Black" : "White"), "\033[1m", int(command.size()), command.data(), "\033[0m");
if (command.empty()) return "";
auto tokens = split(command, ' ');
for (auto& t : tokens) fprintf(stdout, "%s: Token %.*s\n", __func__, int(t.size), t.ptr);
auto pos_from = INVALID_POS;
auto type = Piece::Types::Taken;
auto pos_to = INVALID_POS;
if (tokens.size() == 1) {
type = Piece::Types::Pawn;
pos_to = strToPos(tokens.front());
}
else {
pos_from = strToPos(tokens.front());
if (pos_from == INVALID_POS) type = Piece::Types(strToType(tokens.front()));
pos_to = strToPos(tokens.back());
}
if (pos_to == INVALID_POS) return "";
if (pos_from == INVALID_POS) {
if (type == Piece::Types::Taken) return "";
auto& pieces = color ? blackPieces : whitePieces;
auto pieceIndex = 0u;
for (; pieceIndex < pieces.size(); ++pieceIndex) {
if (pieces[pieceIndex].type == type && validateMove(pieces[pieceIndex], pos_to)) break;
}
if (pieceIndex == pieces.size()) return "";
pos_from = pieces[pieceIndex].pos;
}
if (board[pos_from] == nullptr) return "";
if (board[pos_from]->color != color) return "";
Move m = {pos_from, pos_to};
auto& allowed_moves = color ? blackMoves : whiteMoves;
fprintf(stdout, "%s:allowed size %d :\n", __func__, int(allowed_moves.size()));
for (auto& m : allowed_moves) fprintf(stdout, " %s %s; ", positions[m.first], positions[m.second]);
fprintf(stdout, "\n");
if (!std::binary_search(allowed_moves.begin(), allowed_moves.end(), m)) return "";
move(m);
{
auto it = std::remove_if(allowed_moves.begin(), allowed_moves.end(), [m] (const Move& move) { return move.first == m.first; });
allowed_moves.erase(it, allowed_moves.end());
}
std::vector<Piece*> affected = { board[m.second] };
for (auto& p : whitePieces) {
if (&p == board[m.second]
|| validateMove(p, m.first)
|| validateMove(p, m.second)
|| std::binary_search(whiteMoves.begin(), whiteMoves.end(), Move(p.pos, m.second))
) {
auto it = std::remove_if(whiteMoves.begin(), whiteMoves.end(), [&p] (const Move& m) { return m.first == p.pos; });
whiteMoves.erase(it, whiteMoves.end());
affected.push_back(&p);
}
}
for (auto& p : blackPieces) {
if (&p == board[m.second]
|| validateMove(p, m.first)
|| validateMove(p, m.second)
|| std::binary_search(blackMoves.begin(), blackMoves.end(), Move(p.pos, m.second))
) {
auto it = std::remove_if(blackMoves.begin(), blackMoves.end(), [&p] (const Move& m) { return m.first == p.pos; });
blackMoves.erase(it, blackMoves.end());
affected.push_back(&p);
}
}
for (auto& p : affected) getValidMoves(*p, p->color ? blackMoves : whiteMoves);
std::sort(blackMoves.begin(), blackMoves.end());
std::sort(whiteMoves.begin(), whiteMoves.end());
std::string result = positions[m.first];
result += "-";
result += positions[m.second];
++m_moveCounter;
fprintf(stdout, "%s: Move '%s%s%s'\n", __func__, "\033[1m", result.data(), "\033[0m");
return result;
}
void Chessboard::getValidMoves(const Piece& piece, std::vector<Move>& result) {
std::string cur = positions[piece.pos];
switch (piece.type) {
case Piece::Pawn: {
std::string next = cur;
piece.color ? --next[F] : ++next[F]; // one down / up
std::string left = { char(next[R] - 1), next[F]};
auto pos = strToPos(left);
if (pos != INVALID_POS && board[pos] && board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
std::string right = { char(next[R] + 1), next[F]};
pos = strToPos(right);
if (pos != INVALID_POS && board[pos] && board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
pos = strToPos(next);
if (pos != INVALID_POS && !board[pos]) result.emplace_back(piece.pos, pos);
else break;
if (piece.color ? cur[F] != '7' : cur[F] != '2') break;
piece.color ? --next[F] : ++next[F]; // one down / up
pos = strToPos(next);
if (pos != INVALID_POS && !board[pos]) result.emplace_back(piece.pos, pos);
break;
}
case Piece::Knight: {
std::string next = cur;
--next[F]; --next[F]; --next[R];
auto pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F]; --next[F]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; ++next[F]; --next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; ++next[F]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F]; --next[R]; --next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; --next[R]; --next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F]; ++next[R]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F]; ++next[R]; ++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
break;
}
case Piece::Bishop: {
std::string next = cur;
while (true) {
--next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
break;
}
case Piece::Rook: {
std::string next = cur;
while (true) {
--next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
break;
}
case Piece::Queen: {
std::string next = cur;
while (true) {
--next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; --next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R]; ++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[R];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
--next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
next = cur;
while (true) {
++next[F];
auto pos = strToPos(next);
if (pos == INVALID_POS) break;
else if (board[pos]) {
if (board[pos]->color != piece.color) result.emplace_back(piece.pos, pos);
break;
}
result.emplace_back(piece.pos, pos);
}
break;
}
case Piece::King: {
std::string next = cur;
--next[R]; --next[F];
auto pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[R]; ++next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[R]; --next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[R]; ++next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[R];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
--next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
next = cur;
++next[F];
pos = strToPos(next);
if (pos != INVALID_POS && !(board[pos] && board[pos]->color == piece.color)) result.emplace_back(piece.pos, pos);
break;
}
case Piece::Taken: break;
default: break;
}
}
bool Chessboard::validatePawnMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file) {
int direction = color == Piece::White ? 1 : -1;
bool two_ranks = color == Piece::White ? from_rank == 1 : from_rank == 6;
if (from_file == to_file) {
if (from_rank == to_rank - direction) return board[to_rank * 8 + to_file] == nullptr;
if (two_ranks && from_rank == to_rank - direction * 2) return board[(to_rank - direction) * 8 + to_file] == nullptr && board[to_rank * 8 + to_file] == nullptr;
}
else if (from_file + 1 == to_file || from_file - 1 == to_file) {
if (from_rank == to_rank - direction) return board[to_rank * 8 + to_file] != nullptr && board[to_rank * 8 + to_file]->color != color;
}
return false;
}
bool Chessboard::validateKnightMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file) {
int dr = std::abs(from_rank - to_rank);
int df = std::abs(from_file - to_file);
if ((dr == 2 && df == 1) || (dr == 1 && df == 2)) return board[to_rank * 8 + to_file] == nullptr || board[to_rank * 8 + to_file]->color != color;
return false;
}
bool Chessboard::validateBishopMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file) {
if (from_rank - from_file == to_rank - to_file) {
int direction = from_rank < to_rank ? 1 : -1;
from_rank += direction;
from_file += direction;
while (from_rank != to_rank) {
if (board[from_rank * 8 + from_file]) return false;
from_rank += direction;
from_file += direction;
}
return board[to_rank * 8 + to_file] == nullptr || board[to_rank * 8 + to_file]->color != color;
}
if (from_rank + from_file == to_rank + to_file) {
int direction = from_rank < to_rank ? 1 : -1;
from_rank += direction;
from_file -= direction;
while (from_rank != to_rank) {
if (board[from_rank * 8 + from_file]) return false;
from_rank += direction;
from_file -= direction;
}
return board[to_rank * 8 + to_file] == nullptr || board[to_rank * 8 + to_file]->color != color;
}
return false;
}
bool Chessboard::validateRookMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file) {
if (from_rank == to_rank) {
int direction = from_file < to_file ? 1 : -1;
from_file += direction;
while (from_file != to_file) {
if (board[from_rank * 8 + from_file]) return false;
from_file += direction;
}
return board[to_rank * 8 + to_file] == nullptr || board[to_rank * 8 + to_file]->color != color;
}
if (from_file == to_file) {
int direction = from_rank < to_rank ? 1 : -1;
from_rank += direction;
while (from_rank != to_rank) {
if (board[from_rank * 8 + from_file]) return false;
from_rank += direction;
}
return board[to_rank * 8 + to_file] == nullptr || board[to_rank * 8 + to_file]->color != color;
}
return false;
}
bool Chessboard::validateQueenMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file) {
if (validateBishopMove(color, from_rank, from_file, to_rank, to_file)) return true;
return validateRookMove(color, from_rank, from_file, to_rank, to_file);
}
bool Chessboard::validateKingMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file) {
if (std::abs(from_rank - to_rank) < 2 && std::abs(from_file - to_file) < 2) {
return board[to_rank * 8 + to_file] == nullptr || board[to_rank * 8 + to_file]->color != color;
}
return false;
}
bool Chessboard::validateMove(const Piece& piece, int pos) {
if (piece.type == Piece::Taken) return false;
if (piece.pos == pos) return false;
int i = piece.pos / 8;
int j = piece.pos - i * 8;
int ii = pos / 8;
int jj = pos - ii * 8;
switch (piece.type) {
case Piece::Pawn: return validatePawnMove(piece.color, i, j, ii, jj);
case Piece::Knight: return validateKnightMove(piece.color, i, j, ii, jj);
case Piece::Bishop: return validateBishopMove(piece.color, i, j, ii, jj);
case Piece::Rook: return validateRookMove(piece.color, i, j, ii, jj);
case Piece::Queen: return validateQueenMove(piece.color, i, j, ii, jj);
case Piece::King: return validateKingMove(piece.color, i, j, ii, jj);
default: break;
}
return false;
}
bool Chessboard::move(const Move& m) {
if (!board[m.first] || (board[m.second] && board[m.first]->color == board[m.second]->color)) return false;
if (board[m.second]) board[m.second]->type = Piece::Taken;
board[m.second] = board[m.first];
board[m.first] = nullptr;
board[m.second]->pos = m.second;
return true;
}

View File

@ -1,59 +0,0 @@
#pragma once
#include <string>
#include <array>
#include <vector>
class Chessboard {
public:
Chessboard();
std::string process(const std::string& t);
std::string stringifyBoard();
std::string getRules(const std::string & prompt) const;
using Move = std::pair<int, int>;
private:
bool move(const Move& move);
struct Piece {
enum Types {
Pawn,
Knight,
Bishop,
Rook,
Queen,
King,
Taken,
};
enum Colors {
White,
Black,
};
Types type;
Colors color;
int pos;
};
using PieceSet = std::array<Piece, 16>;
PieceSet blackPieces;
PieceSet whitePieces;
int m_moveCounter = 0;
using Board = std::array<Piece*, 64>;
Board board;
std::vector<Move> whiteMoves;
std::vector<Move> blackMoves;
bool validateMove(const Piece& piece, int pos);
void getValidMoves(const Piece& piece, std::vector<Move>& moves);
// just basic validation
// fixme: missing en passant, castling, promotion, etc.
bool validatePawnMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateKnightMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateBishopMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateRookMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateQueenMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
bool validateKingMove(Piece::Colors color, int from_rank, int from_file, int to_rank, int to_file);
};

View File

@ -1,220 +0,0 @@
#include "WChess.h"
#include "Chessboard.h"
#include "grammar-parser.h"
#include "common.h"
#include <thread>
WChess::WChess(whisper_context * ctx,
const whisper_full_params & wparams,
callbacks cb,
settings s)
: m_ctx(ctx)
, m_wparams(wparams)
, m_cb(cb)
, m_settings(s)
, m_board(new Chessboard())
{}
WChess::~WChess() = default;
void WChess::set_status(const std::string& msg) const {
if (m_cb.set_status) (*m_cb.set_status)(msg);
}
void WChess::set_moves(const std::string& moves) const {
if (m_cb.set_moves) (*m_cb.set_moves)(moves);
}
bool WChess::check_running() const {
if (m_cb.check_running) return (*m_cb.check_running)();
return false;
}
void WChess::clear_audio() const {
if (m_cb.clear_audio) (*m_cb.clear_audio)();
}
void WChess::get_audio(int ms, std::vector<float>& pcmf32) const {
if (m_cb.get_audio) (*m_cb.get_audio)(ms, pcmf32);
}
std::string WChess::stringify_board() const {
return m_board->stringifyBoard();
}
void WChess::run() {
set_status("loading data ...");
bool have_prompt = true;
bool ask_prompt = !have_prompt;
float logprob_min0 = 0.0f;
float logprob_min = 0.0f;
float logprob_sum0 = 0.0f;
float logprob_sum = 0.0f;
int n_tokens0 = 0;
int n_tokens = 0;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = have_prompt ? "" : "checkmate";
while (check_running()) {
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str());
set_status(txt);
}
ask_prompt = false;
}
int64_t t_ms = 0;
{
get_audio(m_settings.vad_ms, pcmf32_cur);
if (!pcmf32_cur.empty()) {
fprintf(stdout, "%s: Processing ...\n", __func__);
set_status("Processing ...");
if (!have_prompt) {
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ...");
set_status(txt);
}
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
static const size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
std::string rules = m_board->getRules(k_prompt);
fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, rules.c_str());
auto grammar_parsed = grammar_parser::parse(rules.c_str());
auto grammar_rules = grammar_parsed.c_rules();
m_wparams.grammar_rules = grammar_rules.data();
m_wparams.n_grammar_rules = grammar_rules.size();
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
const float p = 100.0f * std::exp(logprob_min);
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
{
char txt[1024];
snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms);
set_status(txt);
}
if (!command.empty()) {
auto move = m_board->process(command);
if (!move.empty()) {
set_moves(std::move(move));
}
}
}
clear_audio();
}
}
}
}
std::string WChess::transcribe(
const std::vector<float> & pcmf32,
float & logprob_min,
float & logprob_sum,
int & n_tokens,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
logprob_min = 0.0f;
logprob_sum = 0.0f;
n_tokens = 0;
t_ms = 0;
if (whisper_full(m_ctx, m_wparams, pcmf32.data(), pcmf32.size()) != 0) {
return {};
}
std::string result;
const int n_segments = whisper_full_n_segments(m_ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(m_ctx, i);
result += text;
const int n = whisper_full_n_tokens(m_ctx, i);
for (int j = 0; j < n; ++j) {
const auto token = whisper_full_get_token_data(m_ctx, i, j);
if(token.plog > 0.0f) return {};
logprob_min = std::min(logprob_min, token.plog);
logprob_sum += token.plog;
++n_tokens;
}
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}

View File

@ -1,62 +0,0 @@
#pragma once
#include "whisper.h"
#include <string>
#include <vector>
#include <memory>
class Chessboard;
class WChess {
public:
using SetStatusCb = void (*)(const std::string &);
using CheckRunningCb = bool (*)();
using GetAudioCb = void (*)(int, std::vector<float> &);
using SetMovesCb = void (*)(const std::string &);
using ClearAudioCb = void (*)();
struct callbacks {
SetStatusCb set_status = nullptr;
CheckRunningCb check_running = nullptr;
GetAudioCb get_audio = nullptr;
SetMovesCb set_moves = nullptr;
ClearAudioCb clear_audio = nullptr;
};
struct settings {
int32_t vad_ms = 2000;
int32_t prompt_ms = 5000;
int32_t command_ms = 4000;
float vad_thold = 0.2f;
float freq_thold = 100.0f;
bool print_energy = false;
};
WChess(
whisper_context * ctx,
const whisper_full_params & wparams,
callbacks cb,
settings s
);
~WChess();
void run();
std::string stringify_board() const;
private:
void get_audio(int ms, std::vector<float>& pcmf32) const;
void set_status(const std::string& msg) const;
void set_moves(const std::string& moves) const;
bool check_running() const;
void clear_audio() const;
std::string transcribe(
const std::vector<float> & pcmf32,
float & logprob_min,
float & logprob_sum,
int & n_tokens,
int64_t & t_ms);
whisper_context * m_ctx;
whisper_full_params m_wparams;
const callbacks m_cb;
const settings m_settings;
std::unique_ptr<Chessboard> m_board;
};

View File

@ -1,88 +0,0 @@
#include "Chessboard.h"
#define ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
fflush(stderr); \
exit(1); \
} \
} while (0)
int main() {
{
// pawns
Chessboard chess;
ASSERT(chess.process("pawn to d4, e5, e3, pawn to d5") == "d2-d4 e7-e5 e2-e3 d7-d5");
ASSERT(chess.process("pawn to d4") == ""); // wrong
ASSERT(chess.process("pawn to c5") == ""); // wrong
ASSERT(chess.process("pawn to d5") == ""); // wrong
ASSERT(chess.process("pawn to d3") == ""); // wrong
ASSERT(chess.process("pawn to f5") == ""); // wrong, white's turn
ASSERT(chess.process("h4") == "h2-h4");
ASSERT(chess.process("d4") == "e5-d4");
ASSERT(chess.process("e4") == "e3-e4");
ASSERT(chess.process("d4") == ""); // wrong
ASSERT(chess.process("e4") == "d5-e4");
}
{
// rook
Chessboard chess;
ASSERT(chess.process("rook to a3") == ""); // wrong
ASSERT(chess.process("a4, h5, rook to a3, rook to h6") == "a2-a4 h7-h5 a1-a3 h8-h6");
ASSERT(chess.process("rook to d3, rook to e6") == "a3-d3 h6-e6");
ASSERT(chess.process("rook to d4, rook to e5") == "d3-d4 e6-e5");
ASSERT(chess.process("rook to a4") == ""); // wrong
ASSERT(chess.process("rook to d8") == ""); // wrong
ASSERT(chess.process("rook to d3") == "d4-d3");
ASSERT(chess.process("rook to e2") == "e5-e2");
}
{
// knight
Chessboard chess;
ASSERT(chess.process("knight to c3, knight to c6") == "b1-c3 b8-c6");
ASSERT(chess.process("knight to c3") == ""); // wrong
ASSERT(chess.process("knight to a2") == ""); // wrong
ASSERT(chess.process("knight to b4") == ""); // wrong, white's turn
ASSERT(chess.process("knight to b5") == "c3-b5");
ASSERT(chess.process("knight to a5") == "c6-a5");
ASSERT(chess.process("knight to c7") == "b5-c7");
}
{
// bishop
Chessboard chess;
ASSERT(chess.process("b3, b6, bishop to b2, bishop to b7") == "b2-b3 b7-b6 c1-b2 c8-b7");
ASSERT(chess.process("bishop to a1") == ""); // wrong
ASSERT(chess.process("bishop to h8") == ""); // wrong
ASSERT(chess.process("bishop to a6") == ""); // wrong, white's turn
ASSERT(chess.process("bishop to g7") == "b2-g7");
}
{
// queen
Chessboard chess;
ASSERT(chess.process("queen to d8") == ""); // wrong
ASSERT(chess.process("queen to f1") == ""); // wrong
ASSERT(chess.process("queen to h5") == ""); // wrong
ASSERT(chess.process("e3, d5, queen to h5, queen to d6") == "e2-e3 d7-d5 d1-h5 d8-d6");
ASSERT(chess.process("queen to c5") == ""); // wrong, white's turn
ASSERT(chess.process("queen to f7") == "h5-f7");
}
{
// king
Chessboard chess;
ASSERT(chess.process("d3, d6, king to d2, king to d7, king to c3, king to c6, king to c4") == "d2-d3 d7-d6 e1-d2 e8-d7 d2-c3 d7-c6 c3-c4");
ASSERT(chess.process("bishop to e6") == "c8-e6");
ASSERT(chess.process("king to b3") == "c4-b3"); // !! check check not implemented
}
}

View File

@ -1,8 +0,0 @@
if (WHISPER_SDL2)
set(TARGET wchess)
add_executable(${TARGET} wchess.cmd.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE libwchess common-sdl ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -1,207 +0,0 @@
// Command line voice assisted chess
//
// Speak chess move commands to the microphone.
// The moves will translated to chessboard positions.
//
//
#include "WChess.h"
#include "common-sdl.h"
#include <memory>
#include <thread>
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t prompt_ms = 5000;
int32_t command_ms = 8000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float grammar_penalty = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
std::string commands;
std::string prompt;
std::string context;
std::string grammar;
};
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
std::unique_ptr<WChess> g_wchess;
void set_moves(const std::string & moves) {
if (!moves.empty()) fprintf(stdout, "%s", g_wchess->stringify_board().c_str());
}
audio_async g_audio(30*1000);
void get_audio(int ms, std::vector<float> & pcmf32_cur) {
g_audio.get(ms, pcmf32_cur);
}
void clear_audio() {
g_audio.clear();
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
// init audio
if (!g_audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.no_timestamps = params.no_timestamps;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 5;
wparams.beam_search.beam_size = 5;
wparams.initial_prompt = params.context.data();
g_audio.resume();
// wait for 1 second to avoid any buffered noise
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
g_audio.clear();
WChess::callbacks cb;
cb.check_running = sdl_poll_events;
cb.get_audio = get_audio;
cb.set_moves = set_moves;
cb.clear_audio = clear_audio;
WChess::settings s;
s.vad_ms = 2000;
s.prompt_ms = params.prompt_ms;
s.command_ms = params.command_ms;
s.vad_thold = params.vad_thold;
s.freq_thold = params.freq_thold;
s.print_energy = params.print_energy;
g_wchess.reset(new WChess(ctx, wparams, cb, s));
set_moves("start");
g_wchess->run();
g_audio.pause();
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

View File

@ -1,51 +0,0 @@
set(TARGET wchess.wasm)
add_executable(${TARGET}
wchess.wasm.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
common
libwchess
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside chess.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/${TARGET}.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/chess.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_CURRENT_SOURCE_DIR}/chessboardjs-1.0.0
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_SOURCE_DIR}/jquery-3.7.1.min.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/
)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_SOURCE_DIR}/examples/helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/helpers.js @ONLY)

View File

@ -1,32 +0,0 @@
# chessboard.js Change Log
All notable changes to this project will be documented in this file.
## [1.0.0] - 2019-06-11
- Orientation methods now return current orientation. [Issue #64]
- Drop support for IE8
- Do not check for `window.JSON` (Error #1004)
- Rename `ChessBoard` to `Chessboard` (`ChessBoard` is still supported, however)
- id query selectors are now supported as the first argument to `Chessboard()`
- Remove Error #1002
- Format code according to [StandardJS]
- Bump minimum jQuery version to 1.8.3
- Throttle piece drag functions
## [0.3.0] - 2013-08-10
- Added `appearSpeed` animation config property
- Added `onSnapbackEnd` event
- Added `onMoveEnd` event
## [0.2.0] - 2013-08-05
- Added `onMouseoverSquare` and `onMouseoutSquare` events
- Added `onSnapEnd` event
- Added square code as CSS class on the squares
- Added [chess.js] integration examples
## [0.1.0] - 2013-05-21
- Initial release
[chess.js]:https://github.com/jhlywa/chess.js
[Issue #64]:https://github.com/oakmac/chessboardjs/issues/64
[StandardJS]:https://standardjs.com/

View File

@ -1,20 +0,0 @@
Copyright 2019 Chris Oakman
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,82 +0,0 @@
# chessboard.js
chessboard.js is a JavaScript chessboard component. It depends on [jQuery].
Please see [chessboardjs.com] for documentation and examples.
## What is chessboard.js?
chessboard.js is a JavaScript chessboard component with a flexible "just a
board" API that
chessboard.js is a standalone JavaScript Chess Board. It is designed to be "just
a board" and expose a powerful API so that it can be used in different ways.
Here's a non-exhaustive list of things you can do with chessboard.js:
- Use chessboard.js to show game positions alongside your expert commentary.
- Use chessboard.js to have a tactics website where users have to guess the best
move.
- Integrate chessboard.js and [chess.js] with a PGN database and allow people to
search and playback games (see [Example 5000])
- Build a chess server and have users play their games out using the
chessboard.js board.
chessboard.js is flexible enough to handle any of these situations with relative
ease.
## What can chessboard.js **not** do?
The scope of chessboard.js is limited to "just a board." This is intentional and
makes chessboard.js flexible for handling a multitude of chess-related problems.
This is a common source of confusion for new users. [remove?]
Specifically, chessboard.js does not understand anything about how the game of
chess is played: how a knight moves, who's turn is it, is White in check?, etc.
Fortunately, the powerful [chess.js] library deals with exactly this sort of
problem domain and plays nicely with chessboard.js's flexible API. Some examples
of chessboard.js combined with chess.js: 5000, 5001, 5002
Please see the powerful [chess.js] library for an API to deal with these sorts
of questions.
This logic is distinct from the logic of the board. Please see the powerful
[chess.js] library for this aspect of your application.
Here is a list of things that chessboard.js is **not**:
- A chess engine
- A legal move validator
- A PGN parser
chessboard.js is designed to work well with any of those things, but the idea
behind chessboard.js is that the logic that controls the board should be
independent of those other problems.
## Docs and Examples
- Docs - <http://chessboardjs.com/docs>
- Examples - <http://chessboardjs.com/examples>
## Developer Tools
```sh
# create a build in the build/ directory
npm run build
# re-build the website
npm run website
```
## License
[MIT License](LICENSE.md)
[jQuery]:https://jquery.com/
[chessboardjs.com]:http://chessboardjs.com
[chess.js]:https://github.com/jhlywa/chess.js
[Example 5000]:http://chessboardjs.com/examples#5000

View File

@ -1,54 +0,0 @@
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
.clearfix-7da63 {
clear: both;
}
.board-b72b1 {
border: 2px solid #404040;
box-sizing: content-box;
}
.square-55d63 {
float: left;
position: relative;
/* disable any native browser highlighting */
-webkit-touch-callout: none;
-webkit-user-select: none;
-khtml-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
}
.white-1e1d7 {
background-color: #f0d9b5;
color: #b58863;
}
.black-3c85d {
background-color: #b58863;
color: #f0d9b5;
}
.highlight1-32417, .highlight2-9c5d2 {
box-shadow: inset 0 0 3px 3px yellow;
}
.notation-322f9 {
cursor: default;
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
font-size: 14px;
position: absolute;
}
.alpha-d2270 {
bottom: 1px;
right: 3px;
}
.numeric-fc462 {
top: 2px;
left: 2px;
}

View File

@ -1,2 +0,0 @@
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
.clearfix-7da63{clear:both}.board-b72b1{border:2px solid #404040;box-sizing:content-box}.square-55d63{float:left;position:relative;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.white-1e1d7{background-color:#f0d9b5;color:#b58863}.black-3c85d{background-color:#b58863;color:#f0d9b5}.highlight1-32417,.highlight2-9c5d2{box-shadow:inset 0 0 3px 3px #ff0}.notation-322f9{cursor:default;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:14px;position:absolute}.alpha-d2270{bottom:1px;right:3px}.numeric-fc462{top:2px;left:2px}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 777 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 748 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 KiB

File diff suppressed because one or more lines are too long

View File

@ -1,29 +0,0 @@
{
"author": "Chris Oakman <chris@oakmac.com> (http://chrisoakman.com/)",
"name": "@chrisoakman/chessboardjs",
"description": "JavaScript chessboard widget",
"homepage": "https://chessboardjs.com",
"license": "MIT",
"version": "1.0.0",
"repository": {
"type": "git",
"url": "git://github.com/oakmac/chessboardjs.git"
},
"files": ["dist/"],
"dependencies": {
"jquery": ">=3.4.1"
},
"devDependencies": {
"csso": "3.5.1",
"fs-plus": "3.1.1",
"kidif": "1.1.0",
"mustache": "2.3.0",
"standard": "10.0.2",
"uglify-js": "3.6.0"
},
"scripts": {
"build": "standard lib/chessboard.js && node scripts/build.js",
"standard": "standard --fix lib/*.js website/js/*.js",
"website": "node scripts/website.js"
}
}

View File

@ -1,376 +0,0 @@
<!doctype html>
<html lang="en-us">
<head>
<title>wchess : Voice assistant example using Whisper + WebAssembly</title>
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
</style>
<link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous">
</head>
<body onload="loadWhisper()">
<div id="main-container">
<b>wchess : Voice assistant example using Whisper + WebAssembly</b>
<br><br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">GitHub</a>.
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<span id="fetch-whisper-progress"></span>
<button id="clear" onclick="clearCache()">Clear Cache</button>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<br>
<div id="myBoard" style="width: 400px"></div>
<script src="js/jquery-3.7.1.min.js"></script>
<script src="js/chessboard-1.0.0.min.js"></script>
<script>
var board = Chessboard('myBoard', 'start')
</script>
<br>
<div id="input">
<button id="toggler" disabled>Hold</button>
</div>
<br>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-moves">[The moves will be displayed here]</pre>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="js/helpers.js"></script>
<script type='text/javascript'>
// web audio context
var context = null;
// the command instance
var instance = null;
// model name
var model_whisper = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Module initialized successfully!');
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
else {
printTextarea("js: failed to initialize whisper");
}
}
};
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
if (model_whisper != null) {
document.getElementById('toggler').disabled = false;
}
}
function loadWhisper() {
// let urls = {
// 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
// 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
// 'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
// 'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
// };
// let sizes = {
// 'tiny.en': 75,
// 'base.en': 142,
// 'tiny-en-q5_1': 31,
// 'base-en-q5_1': 57,
// };
let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin';
let dst = 'whisper.bin';
let size_mb = 75;
model_whisper = 'tiny.en';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
//
// microphone
//
const kSampleRate = 16000;
const kRestartRecording_s = 120;
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
var mediaRecorder = null;
var doRecording = false;
var startTime = 0;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
Module.set_status("paused");
mediaRecorder.stop();
}
function startRecording() {
if (!context) {
context = new AudioContext({
sampleRate: kSampleRate,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
}
startTime = Date.now();
var chunks = [];
var stream = null;
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = function(e) {
chunks.push(e.data);
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(offlineContext.destination);
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
let audio = renderedBuffer.getChannelData(0);
if (instance) {
printTextarea('js: number of samples: ' + audio.length);
Module.set_audio(instance, audio);
}
});
mediaRecorder = null;
context = null;
});
}
reader.readAsArrayBuffer(blob);
};
mediaRecorder.onstop = function(e) {
stream.getTracks().forEach(function(track) {
track.stop();
});
};
mediaRecorder.start();
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
}
//
// main
//
var nLines = 0;
var intervalUpdate = null;
var movesAll = '';
document.body.addEventListener('keydown', function(event) {
if (event.keyCode === 32) {
document.getElementById('toggler').innerText = "Release";
onStart();
}
}, true);
document.body.addEventListener('keyup', function(event) {
if (event.keyCode === 32) {
document.getElementById('toggler').innerText = "Hold";
onStop();
}
}, true);
document.getElementById('toggler').addEventListener('mousedown', function(event) {
this.innerText = "Release";
onStart();
}, true);
document.getElementById('toggler').addEventListener('mouseup', function(event) {
this.innerText = "Hold";
onStop();
}, true);
function onStart() {
if (!instance) {
return;
}
startRecording();
}
function onStop() {
printTextarea('js: stopping recording ...');
stopRecording();
var interval = setInterval(function() {
var moves = Module.get_moves();
if (moves != null && moves.length > 1) {
clearInterval(interval);
for (move of moves.split(' ')) {
board.move(move);
}
movesAll += moves + '<br>';
nLines++;
// if more than 10 lines, remove the first line
if (nLines > 10) {
var i = movesAll.indexOf('<br>');
if (i > 0) {
movesAll = movesAll.substring(i + 4);
nLines--;
}
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-moves').innerHTML = movesAll;
}
}, 100);
}
</script>
<script type="text/javascript" src="js/chess.js"></script>
</body>
</html>

File diff suppressed because one or more lines are too long

View File

@ -1,173 +0,0 @@
#include <WChess.h>
#include <emscripten/bind.h>
#include <atomic>
#include <thread>
constexpr int N_THREAD = 8;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::atomic<bool> g_running(false);
std::string g_status = "";
std::string g_status_forced = "";
std::string g_moves = "";
std::vector<float> g_pcmf32;
void set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;
}
void set_moves(const std::string & moves) {
std::lock_guard<std::mutex> lock(g_mutex);
g_moves = moves;
}
void get_audio(int /* ms */, std::vector<float> & audio) {
std::lock_guard<std::mutex> lock(g_mutex);
audio = g_pcmf32;
}
bool check_running() {
return g_running;
}
void clear_audio() {
std::lock_guard<std::mutex> lock(g_mutex);
g_pcmf32.clear();
}
void wchess_main(size_t i) {
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.no_timestamps = true;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.temperature = 0.0f;
wparams.temperature_inc = 2.0f;
wparams.greedy.best_of = 1;
wparams.beam_search.beam_size = 1;
wparams.language = "en";
wparams.grammar_penalty = 100.0;
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
printf("command: using %d threads\n", wparams.n_threads);
WChess::callbacks cb;
cb.set_status = set_status;
cb.check_running = check_running;
cb.get_audio = get_audio;
cb.set_moves = set_moves;
cb.clear_audio = clear_audio;
WChess(g_contexts[i], wparams, cb, {}).run();
if (i < g_contexts.size()) {
whisper_free(g_contexts[i]);
g_contexts[i] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
wchess_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t /* index */) {
if (g_running) {
g_running = false;
}
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
{
std::lock_guard<std::mutex> lock(g_mutex);
const int n = audio["length"].as<int>();
emscripten::val heap = emscripten::val::module_property("HEAPU8");
emscripten::val memory = heap["buffer"];
g_pcmf32.resize(n);
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
return 0;
}));
emscripten::function("get_moves", emscripten::optional_override([]() {
std::string moves;
{
std::lock_guard<std::mutex> lock(g_mutex);
moves = std::move(g_moves);
}
return moves;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}));
}

View File

@ -1,15 +0,0 @@
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties

View File

@ -1,20 +0,0 @@
A sample Android app using java code and [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
To use:
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
2. Copy the model to the "app/src/main/assets/models" folder.
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
4. Copy the sample to the "app/src/main/assets/samples" folder.
5. Modify the modelFilePath in the WhisperService.java
6. Modify the sampleFilePath in the WhisperService.java
7. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
[^1]: I recommend the tiny or base models for running on an Android device.
PS:
1. Do not move this android project folder individually to other folders, because this android project folder depends on the files of the whole project.
2. The cpp code is compiled during the build process
3. If you want to import a compiled cpp project in your Android project, please refer to the https://github.com/litongjava/whisper.cpp.android.java.demo
![](README_files/1.jpg)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 67 KiB

View File

@ -1 +0,0 @@
/build

View File

@ -1,58 +0,0 @@
plugins {
id 'com.android.application'
}
android {
compileSdkVersion 30
buildToolsVersion '30.0.3'
defaultConfig {
applicationId "com.litongjava.whisper.android.java"
minSdkVersion 21
targetSdkVersion 30
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
externalNativeBuild {
cmake {
cppFlags ""
}
}
ndk {
abiFilters 'arm64-v8a', 'armeabi-v7a', 'x86', 'x86_64'
}
}
buildTypes {
release {
signingConfig signingConfigs.debug
minifyEnabled true
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
externalNativeBuild {
cmake {
path "src/main/jni/whisper/CMakeLists.txt"
}
}
ndkVersion "25.2.9519653"
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'com.google.android.material:material:1.1.0'
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
//litongjava
implementation 'com.litongjava:android-view-inject:1.0'
implementation 'com.litongjava:jfinal-aop:1.0.1'
implementation 'com.litongjava:litongjava-android-utils:1.0.0'
}

View File

@ -1,21 +0,0 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View File

@ -1,26 +0,0 @@
package com.litongjava.whisper.android.java;
import android.content.Context;
import androidx.test.platform.app.InstrumentationRegistry;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import org.junit.Test;
import org.junit.runner.RunWith;
import static org.junit.Assert.*;
/**
* Instrumented test, which will execute on an Android device.
*
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
*/
@RunWith(AndroidJUnit4.class)
public class ExampleInstrumentedTest {
@Test
public void useAppContext() {
// Context of the app under test.
Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
assertEquals("com.litongjava.whisper.android.java", appContext.getPackageName());
}
}

View File

@ -1,22 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.litongjava.whisper.android.java">
<application
android:allowBackup="true"
android:name=".app.App"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.Whisperandroidjava">
<activity android:name=".MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -1,40 +0,0 @@
<?xml version="1.0" encoding="UTF-8" ?>
<configuration debug="false" xmlns="http://ch.qos.logback/xml/ns/logback"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://ch.qos.logback/xml/ns/logback https://raw.githubusercontent.com/enricopulatzo/logback-XSD/master/src/main/xsd/logback.xsd
http://ch.qos.logback/xml/ns/logback ">
<!--Define the storage address of the log file Do not use relative paths in the LogBack configuration. -->
<property name="LOG_HOME" value="logs" />
<!--Formatted output: %d means the date, %-6level: log level from the left display 6 characters wide, %m: log message, %n is a newline character -->
<property name="CONSOLE_LOG_PATTERN"
value="%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-6level%logger{0}.%M:%L - %m%n" />
<!-- console output -->
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
</encoder>
</appender>
<!-- Generate log files on a daily basis -->
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
</encoder>
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
<!--File name for log file output -->
<fileNamePattern>${LOG_HOME}/project-name-%d{yyyy-MM-dd}.log</fileNamePattern>
<!--Maximum size of log file -->
<maxHistory>180</maxHistory>
</rollingPolicy>
<!--日志文件最大的大小 -->
<triggeringPolicy class="ch.qos.logback.core.rolling.SizeBasedTriggeringPolicy">
<maxFileSize>10MB</maxFileSize>
</triggeringPolicy>
</appender>
<!-- Log output level and source-->
<root level="info">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -1,107 +0,0 @@
package com.litongjava.whisper.android.java;
import androidx.annotation.RequiresApi;
import androidx.appcompat.app.AppCompatActivity;
import android.content.Context;
import android.os.Build;
import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.view.View;
import android.widget.TextView;
import com.blankj.utilcode.util.ThreadUtils;
import com.litongjava.android.view.inject.annotation.FindViewById;
import com.litongjava.android.view.inject.annotation.FindViewByIdLayout;
import com.litongjava.android.view.inject.annotation.OnClick;
import com.litongjava.android.view.inject.utils.ViewInjectUtils;
import com.litongjava.jfinal.aop.Aop;
import com.litongjava.jfinal.aop.AopManager;
import com.litongjava.whisper.android.java.services.WhisperService;
import com.litongjava.whisper.android.java.task.LoadModelTask;
import com.litongjava.whisper.android.java.task.TranscriptionTask;
import com.litongjava.whisper.android.java.utils.AssetUtils;
import com.whispercpp.java.whisper.WhisperLib;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
@FindViewByIdLayout(R.layout.activity_main)
public class MainActivity extends AppCompatActivity {
@FindViewById(R.id.sample_text)
private TextView tv;
Logger log = LoggerFactory.getLogger(this.getClass());
private WhisperService whisperService = Aop.get(WhisperService.class);
@RequiresApi(api = Build.VERSION_CODES.O)
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
//setContentView(R.layout.activity_main);
ViewInjectUtils.injectActivity(this, this);
initAopBean();
showSystemInfo();
}
private void initAopBean() {
Handler mainHandler = new Handler(Looper.getMainLooper());
AopManager.me().addSingletonObject(mainHandler);
}
@RequiresApi(api = Build.VERSION_CODES.O)
@OnClick(R.id.loadModelBtn)
public void loadModelBtn_OnClick(View v) {
Context context = getBaseContext();
ThreadUtils.executeByIo(new LoadModelTask(tv));
}
@OnClick(R.id.transcriptSampleBtn)
public void transcriptSampleBtn_OnClick(View v) {
Context context = getBaseContext();
long start = System.currentTimeMillis();
String sampleFilePath = "samples/jfk.wav";
File filesDir = context.getFilesDir();
File sampleFile = AssetUtils.copyFileIfNotExists(context, filesDir, sampleFilePath);
long end = System.currentTimeMillis();
String msg = "copy file:" + (end - start) + "ms";
outputMsg(tv, msg);
ThreadUtils.executeByIo(new TranscriptionTask(tv, sampleFile));
}
private void outputMsg(TextView tv, String msg) {
tv.append(msg + "\n");
log.info(msg);
}
@RequiresApi(api = Build.VERSION_CODES.O)
@OnClick(R.id.systemInfoBtn)
public void systemInfoBtn_OnClick(View v) {
showSystemInfo();
}
@RequiresApi(api = Build.VERSION_CODES.O)
public void showSystemInfo() {
String systemInfo = WhisperLib.getSystemInfo();
tv.append(systemInfo + "\n");
}
@OnClick(R.id.clearBtn)
public void clearBtn_OnClick(View v) {
tv.setText("");
}
@RequiresApi(api = Build.VERSION_CODES.O)
@Override
protected void onDestroy() {
super.onDestroy();
whisperService.release();
}
}

View File

@ -1,13 +0,0 @@
package com.litongjava.whisper.android.java.app;
import android.app.Application;
import com.blankj.utilcode.util.Utils;
public class App extends Application {
@Override
public void onCreate() {
super.onCreate();
Utils.init(this);
}
}

View File

@ -1,47 +0,0 @@
package com.litongjava.whisper.android.java.bean;
/**
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
*/
public class WhisperSegment {
private long start, end;
private String sentence;
public WhisperSegment() {
}
public WhisperSegment(long start, long end, String sentence) {
this.start = start;
this.end = end;
this.sentence = sentence;
}
public long getStart() {
return start;
}
public long getEnd() {
return end;
}
public String getSentence() {
return sentence;
}
public void setStart(long start) {
this.start = start;
}
public void setEnd(long end) {
this.end = end;
}
public void setSentence(String sentence) {
this.sentence = sentence;
}
@Override
public String toString() {
return "["+start+" --> "+end+"]:"+sentence;
}
}

View File

@ -1,101 +0,0 @@
package com.litongjava.whisper.android.java.services;
import android.content.Context;
import android.os.Build;
import android.os.Handler;
import android.widget.TextView;
import android.widget.Toast;
import androidx.annotation.RequiresApi;
import com.blankj.utilcode.util.ToastUtils;
import com.blankj.utilcode.util.Utils;
import com.litongjava.android.utils.dialog.AlertDialogUtils;
import com.litongjava.jfinal.aop.Aop;
import com.litongjava.whisper.android.java.bean.WhisperSegment;
import com.litongjava.whisper.android.java.single.LocalWhisper;
import com.litongjava.whisper.android.java.utils.WaveEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ExecutionException;
public class WhisperService {
private Logger log = LoggerFactory.getLogger(this.getClass());
private final Object lock = new Object();
@RequiresApi(api = Build.VERSION_CODES.O)
public void loadModel(TextView tv) {
String modelFilePath = LocalWhisper.modelFilePath;
String msg = "load model from :" + modelFilePath + "\n";
outputMsg(tv, msg);
long start = System.currentTimeMillis();
LocalWhisper.INSTANCE.init();
long end = System.currentTimeMillis();
msg = "model load successful:" + (end - start) + "ms";
outputMsg(tv, msg);
ToastUtils.showLong(msg);
}
@RequiresApi(api = Build.VERSION_CODES.O)
public void transcribeSample(TextView tv, File sampleFile) {
String msg = "";
msg = "transcribe file from :" + sampleFile.getAbsolutePath();
outputMsg(tv, msg);
Long start = System.currentTimeMillis();
float[] audioData = new float[0]; // 读取音频样本
try {
audioData = WaveEncoder.decodeWaveFile(sampleFile);
} catch (IOException e) {
e.printStackTrace();
return;
}
long end = System.currentTimeMillis();
msg = "decode wave file:" + (end - start) + "ms";
outputMsg(tv, msg);
start = System.currentTimeMillis();
List<WhisperSegment> transcription = null;
try {
//transcription = LocalWhisper.INSTANCE.transcribeData(audioData);
transcription = LocalWhisper.INSTANCE.transcribeDataWithTime(audioData);
} catch (ExecutionException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
end = System.currentTimeMillis();
if(transcription!=null){
ToastUtils.showLong(transcription.toString());
msg = "Transcript successful:" + (end - start) + "ms";
outputMsg(tv, msg);
outputMsg(tv, transcription.toString());
}else{
msg = "Transcript failed:" + (end - start) + "ms";
outputMsg(tv, msg);
}
}
private void outputMsg(TextView tv, String msg) {
log.info(msg);
if(tv!=null){
Aop.get(Handler.class).post(()->{ tv.append(msg + "\n");});
}
}
@RequiresApi(api = Build.VERSION_CODES.O)
public void release() {
//noting to do
}
}

Some files were not shown because too many files have changed in this diff Show More