Compare commits

..

1 Commits

Author SHA1 Message Date
4e6d2e98ab ggml : try to improve threading 2022-12-29 13:05:20 +02:00
146 changed files with 6763 additions and 26544 deletions

View File

@ -1,22 +0,0 @@
name: Bindings Tests (Go)
on:
push:
paths:
- bindings/go/**
- whisper.h
pull_request:
paths:
- bindings/go/**
- whisper.h
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: '^1.19'
- uses: actions/checkout@v1
- run: |
cd bindings/go
make test

View File

@ -1,22 +0,0 @@
name: Bindings Tests (Ruby)
on:
push:
paths:
- bindings/ruby/**
- whisper.h
pull_request:
paths:
- bindings/ruby/**
- whisper.h
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: ruby/setup-ruby@v1
with:
ruby-version: '3.0'
- uses: actions/checkout@v1
- run: |
cd bindings/ruby/ext
ruby extconf.rb && make

View File

@ -1,365 +1,237 @@
name: CI name: CI
on: [push, pull_request] on: [push]
jobs: jobs:
ubuntu-latest: ubuntu-latest:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Dependencies - name: Dependencies
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install build-essential sudo apt-get install build-essential
sudo apt-get install libsdl2-dev sudo apt-get install libsdl2-dev
- name: Build - name: Build
run: | run: |
make make
make stream make stream
macOS-latest: macOS-latest:
runs-on: macOS-latest runs-on: macOS-latest
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Dependencies - name: Dependencies
run: | run: |
brew update brew update
brew install sdl2 brew install sdl2
- name: Build - name: Build
run: | run: |
make make
make stream make stream
ubuntu-latest-gcc: ubuntu-latest-gcc:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
build: [Debug, Release] build: [Debug, Release]
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Dependencies - name: Dependencies
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install build-essential sudo apt-get install build-essential
sudo apt-get install cmake sudo apt-get install cmake
sudo apt-get install libsdl2-dev sudo apt-get install libsdl2-dev
- name: Configure - name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
- name: Build - name: Build
run: | run: |
make make
ctest -L gh --output-on-failure ctest -L gh --output-on-failure
ubuntu-latest-clang: ubuntu-latest-clang:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
build: [Debug, Release] build: [Debug, Release]
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Dependencies - name: Dependencies
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install build-essential sudo apt-get install build-essential
sudo apt-get install cmake sudo apt-get install cmake
sudo apt-get install libsdl2-dev sudo apt-get install libsdl2-dev
- name: Configure - name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
- name: Build - name: Build
run: | run: |
make make
ctest -L gh --output-on-failure ctest -L gh --output-on-failure
ubuntu-latest-gcc-sanitized: ubuntu-latest-gcc-sanitized:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED] sanitizer: [ADDRESS, THREAD, UNDEFINED]
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Dependencies - name: Dependencies
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install build-essential sudo apt-get install build-essential
sudo apt-get install cmake sudo apt-get install cmake
- name: Configure - name: Configure
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
- name: Build - name: Build
run: | run: |
make make
ctest -L gh --output-on-failure ctest -L gh --output-on-failure
windows: windows:
runs-on: windows-latest runs-on: windows-latest
strategy: strategy:
matrix: matrix:
build: [Release] build: [Release]
arch: [Win32, x64] arch: [Win32, x64]
sdl2: [ON] sdl2: [ON]
include: include:
- arch: Win32 - arch: Win32
s2arc: x86 s2arc: x86
- arch: x64 - arch: x64
s2arc: x64 s2arc: x64
- sdl2: ON - sdl2: ON
s2ver: 2.26.0 s2ver: 2.26.0
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Add msbuild to PATH - name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1 uses: microsoft/setup-msbuild@v1
- name: Fetch SDL2 and set SDL2_DIR - name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: | run: |
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip 7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure - name: Configure
run: > run: >
cmake -S . -B ./build -A ${{ matrix.arch }} cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }} -DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
- name: Build - name: Build
run: | run: |
cd ./build cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy SDL2.dll - name: Copy SDL2.dll
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries - name: Upload binaries
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1 uses: actions/upload-artifact@v1
with: with:
name: whisper-bin-${{ matrix.arch }} name: whisper-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
windows-blas: windows-blas:
runs-on: windows-latest runs-on: windows-latest
strategy: strategy:
matrix: matrix:
build: [Release] build: [Release]
arch: [Win32, x64] arch: [Win32, x64]
blas: [ON] blas: [ON]
sdl2: [ON] sdl2: [ON]
include: include:
- arch: Win32 - arch: Win32
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip
s2arc: x86 s2arc: x86
- arch: x64 - arch: x64
obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip
s2arc: x64 s2arc: x64
- sdl2: ON - sdl2: ON
s2ver: 2.26.0 s2ver: 2.26.0
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Add msbuild to PATH - name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1 uses: microsoft/setup-msbuild@v1
- name: Fetch OpenBLAS - name: Fetch OpenBLAS
if: matrix.blas == 'ON' if: matrix.blas == 'ON'
run: | run: |
C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }} C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }}
7z x blas.zip -oblas -y 7z x blas.zip -oblas -y
copy blas/include/cblas.h . copy blas/include/cblas.h .
copy blas/include/openblas_config.h . copy blas/include/openblas_config.h .
echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV
- name: Fetch SDL2 and set SDL2_DIR - name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: | run: |
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip 7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure - name: Configure
run: > run: >
cmake -S . -B ./build -A ${{ matrix.arch }} cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }} -DWHISPER_SUPPORT_OPENBLAS=${{ matrix.blas }}
-DCMAKE_LIBRARY_PATH="$env:blasdir/lib" -DCMAKE_LIBRARY_PATH="$env:blasdir/lib"
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }} -DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
- name: Build - name: Build
run: | run: |
cd ./build cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy libopenblas.dll - name: Copy libopenblas.dll
if: matrix.blas == 'ON' if: matrix.blas == 'ON'
run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }} run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }}
- name: Copy SDL2.dll - name: Copy SDL2.dll
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries - name: Upload binaries
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1 uses: actions/upload-artifact@v1
with: with:
name: whisper-blas-bin-${{ matrix.arch }} name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
windows-cublas:
runs-on: windows-latest
strategy:
matrix:
build: [Release]
arch: [x64]
cublas: [ON]
sdl2: [ON]
include:
- arch: x64
s2arc: x64
- sdl2: ON
s2ver: 2.26.0
steps:
- name: Clone
uses: actions/checkout@v1
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
- name: Install CUDA Toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.10
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
run: |
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_CUBLAS=1
- name: Build
run: |
cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
with:
name: whisper-cublas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Configure
run: echo "tmp"
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
ios:
runs-on: macos-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Configure
run: |
cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
mkdir models/ggml-base.en-encoder.mlmodelc
- name: Build objc example
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
- name: Build swiftui example
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphonesimulator build
android:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v1
- name: Install Java
uses: actions/setup-java@v3
with:
distribution: zulu
java-version: 17
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: Build
run: |
cd examples/whisper.android
./gradlew assembleRelease --no-daemon

View File

@ -1,48 +0,0 @@
name: Examples Tests
on:
push:
paths:
- examples/addon.node/**
- whisper.h
pull_request:
paths:
- examples/addon.node/**
- whisper.h
jobs:
addon_node-ubuntu-latest:
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [ 16.x, 18.x ]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v1
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
- name: Install package.json dependencies
working-directory: ./examples/addon.node
run: npm install
- name: Compile addon.node
run: npx cmake-js compile -T whisper-addon -B Release
- name: Download test model
run: |
bash ./models/download-ggml-model.sh base.en
- name: Test
run: |
cd examples/addon.node
npm run test

14
.gitignore vendored
View File

@ -1,8 +1,5 @@
*.o *.o
*.a
.cache/ .cache/
.coreml/
.test/
.vs/ .vs/
.vscode/ .vscode/
.DS_Store .DS_Store
@ -11,9 +8,6 @@ build/
build-em/ build-em/
build-debug/ build-debug/
build-release/ build-release/
build-static/
build-cublas/
build-no-accel/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/
@ -21,13 +15,9 @@ build-sanitize-thread/
/stream /stream
/command /command
/talk /talk
/talk-llama
/bench /bench
/quantize
arm_neon.h
sync.sh sync.sh
libwhisper.a
libwhisper.so libwhisper.so
compile_commands.json compile_commands.json
@ -37,7 +27,3 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
extra/bench-gg.txt extra/bench-gg.txt
models/*.mlmodel
models/*.mlmodelc
models/*.mlpackage

View File

@ -1,16 +1,15 @@
cmake_minimum_required (VERSION 3.0) cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.4.2) project(whisper.cpp VERSION 1.0.4)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON) set(WHISPER_STANDALONE ON)
include(GitVars) include(cmake/GitVars.cmake)
include(BuildTypes) include(cmake/BuildTypes.cmake)
# configure project version # configure project version
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl") if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
@ -35,36 +34,29 @@ endif()
# options # options
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT}) option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON) option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF) option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF) option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF) option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF)
option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF) option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE}) option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE}) option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
option(WHISPER_NO_F16C "whisper: disable F16c" OFF)
if (APPLE) if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF) option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF) option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF) option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
else() else()
option(WHISPER_OPENBLAS "whisper: support for OpenBLAS" OFF) option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
endif() endif()
option(WHISPER_PERF "whisper: enable perf timings" OFF) option(WHISPER_PERF "whisper: enable perf timings" OFF)
# sanitizers # sanitizers
@ -90,43 +82,25 @@ endif()
# dependencies # dependencies
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
# on APPLE # on APPLE - include Accelerate framework
if (APPLE) if (APPLE AND NOT WHISPER_NO_ACCELERATE)
# include Accelerate framework find_library(ACCELERATE_FRAMEWORK Accelerate)
if (NOT WHISPER_NO_ACCELERATE) if (ACCELERATE_FRAMEWORK)
find_library(ACCELERATE_FRAMEWORK Accelerate) message(STATUS "Accelerate framework found")
if (ACCELERATE_FRAMEWORK) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
message(STATUS "Accelerate framework found") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) message(WARNING "Accelerate framework not found")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()
endif()
if (WHISPER_COREML)
find_library(FOUNDATION_FRAMEWORK Foundation)
find_library(COREML_FRAMEWORK CoreML)
if (COREML_FRAMEWORK)
message(STATUS "CoreML framework found")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
else()
message(WARNING "CoreML framework not found")
endif()
if (WHISPER_COREML_ALLOW_FALLBACK)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_COREML_ALLOW_FALLBACK)
endif()
endif() endif()
endif() endif()
if (WHISPER_OPENBLAS) if (WHISPER_SUPPORT_OPENBLAS)
find_library(OPENBLAS_LIB find_library(OPENBLAS_LIB
NAMES openblas libopenblas NAMES openblas libopenblas
) )
@ -140,46 +114,6 @@ if (WHISPER_OPENBLAS)
endif() endif()
endif() endif()
if (WHISPER_CUBLAS)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
enable_language(CUDA)
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
if (WHISPER_STATIC)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
else()
message(WARNING "cuBLAS not found")
endif()
endif()
if (WHISPER_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
message(STATUS "CLBlast found")
set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h)
add_compile_definitions(GGML_USE_CLBLAST)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast)
else()
message(WARNING "CLBlast not found")
endif()
endif()
# compiler flags # compiler flags
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
@ -197,7 +131,6 @@ if (WHISPER_ALL_WARNINGS)
-Wcast-qual \ -Wcast-qual \
-Wstrict-prototypes \ -Wstrict-prototypes \
-Wpointer-arith \ -Wpointer-arith \
-Wno-unused-function \
") ")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
-Wall \ -Wall \
@ -222,17 +155,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES
else() else()
message(STATUS "x86 detected") message(STATUS "x86 detected")
if (MSVC) if (MSVC)
if(NOT WHISPER_NO_AVX2) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if(NOT WHISPER_NO_AVX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX")
endif()
endif()
else() else()
if (EMSCRIPTEN) if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -244,12 +168,7 @@ else()
if(NOT WHISPER_NO_AVX2) if(NOT WHISPER_NO_AVX2)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif() endif()
if(NOT WHISPER_NO_FMA) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
if(NOT WHISPER_NO_F16C)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
endif() endif()
endif() endif()
endif() endif()
@ -258,33 +177,6 @@ if (WHISPER_PERF)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF) set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
endif() endif()
#
# whisper.coreml - Core ML support
#
if (WHISPER_COREML)
set(TARGET whisper.coreml)
add_library(${TARGET}
coreml/whisper-encoder.h
coreml/whisper-encoder.mm
coreml/whisper-encoder-impl.h
coreml/whisper-encoder-impl.m
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC
.
)
target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
set_target_properties(${TARGET} PROPERTIES
COMPILE_FLAGS "-fobjc-arc"
)
endif()
# #
# whisper - this is the main library of the project # whisper - this is the main library of the project
# #
@ -294,22 +186,14 @@ set(TARGET whisper)
add_library(${TARGET} add_library(${TARGET}
ggml.h ggml.h
ggml.c ggml.c
${GGML_CUDA_SOURCES}
${GGML_OPENCL_SOURCES}
whisper.h whisper.h
whisper.cpp whisper.cpp
) )
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC target_include_directories(${TARGET} PUBLIC
. .
) )
if (WHISPER_COREML)
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
endif()
if (MSVC) if (MSVC)
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
@ -325,19 +209,7 @@ if (BUILD_SHARED_LIBS)
target_compile_definitions(${TARGET} PUBLIC target_compile_definitions(${TARGET} PUBLIC
WHISPER_SHARED WHISPER_SHARED
GGML_SHARED
) )
target_compile_definitions(${TARGET} PRIVATE
WHISPER_BUILD
GGML_BUILD
)
endif()
if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
endif() endif()
if (EMSCRIPTEN) if (EMSCRIPTEN)
@ -348,13 +220,9 @@ target_compile_definitions(${TARGET} PUBLIC
${WHISPER_EXTRA_FLAGS} ${WHISPER_EXTRA_FLAGS}
) )
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
install(TARGETS ${TARGET} install(TARGETS ${TARGET}
LIBRARY DESTINATION lib LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib/static ARCHIVE DESTINATION lib/static
RUNTIME DESTINATION bin
PUBLIC_HEADER DESTINATION include
) )
# #
@ -367,7 +235,7 @@ add_subdirectory(bindings)
# programs, examples and tests # programs, examples and tests
# #
if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION) if (WHISPER_BUILD_TESTS)
enable_testing() enable_testing()
add_subdirectory(tests) add_subdirectory(tests)
endif () endif ()

View File

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2023 Georgi Gerganov Copyright (c) 2022 Georgi Gerganov
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

184
Makefile
View File

@ -1,5 +1,3 @@
default: main bench quantize
ifndef UNAME_S ifndef UNAME_S
UNAME_S := $(shell uname -s) UNAME_S := $(shell uname -s)
endif endif
@ -12,9 +10,6 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m) UNAME_M := $(shell uname -m)
endif endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
# Mac OS + Arm can report x86_64 # Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
@ -32,16 +27,10 @@ endif
# Compile flags # Compile flags
# #
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC CFLAGS = -I. -O3 -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
LDFLAGS = LDFLAGS =
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
ifneq ($(wildcard /usr/include/musl/*),)
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
endif
# OS specific # OS specific
# TODO: support Windows # TODO: support Windows
ifeq ($(UNAME_S),Linux) ifeq ($(UNAME_S),Linux)
@ -64,13 +53,10 @@ endif
# Architecture specific # Architecture specific
# TODO: probably these flags need to be tweaked on some architectures # TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue # feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) ifeq ($(UNAME_M),x86_64)
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
CFLAGS += -mf16c CFLAGS += -mfma -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features) AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M))) ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx CFLAGS += -mavx
endif endif
@ -79,6 +65,10 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
CFLAGS += -mavx2 CFLAGS += -mavx2
endif endif
else ifeq ($(UNAME_S),Linux) else ifeq ($(UNAME_S),Linux)
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo) AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
ifneq (,$(findstring avx2,$(AVX2_M))) ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2 CFLAGS += -mavx2
@ -90,17 +80,12 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
F16C_M := $(shell grep "f16c " /proc/cpuinfo) F16C_M := $(shell grep "f16c " /proc/cpuinfo)
ifneq (,$(findstring f16c,$(F16C_M))) ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c CFLAGS += -mf16c
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif endif
else ifeq ($(UNAME_S),Haiku) else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ") AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
ifneq (,$(findstring avx2,$(AVX2_M))) ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2 CFLAGS += -mavx2
@ -112,11 +97,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
F16C_M := $(shell sysinfo -cpu | grep "F16C ") F16C_M := $(shell sysinfo -cpu | grep "F16C ")
ifneq (,$(findstring f16c,$(F16C_M))) ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c CFLAGS += -mf16c
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif endif
else else
CFLAGS += -mfma -mf16c -mavx -mavx2 CFLAGS += -mfma -mf16c -mavx -mavx2
@ -125,18 +105,12 @@ endif
ifeq ($(UNAME_M),amd64) ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c CFLAGS += -mavx -mavx2 -mfma -mf16c
endif endif
ifeq ($(UNAME_M),ppc64le)
ifneq ($(filter ppc64%,$(UNAME_M)),)
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
ifneq (,$(findstring POWER9,$(POWER9_M))) ifneq (,$(findstring POWER9,$(POWER9_M)))
CFLAGS += -mpower9-vector CFLAGS += -mpower9-vector
endif endif
# Require c++23's std::byteswap for big-endian support.
ifeq ($(UNAME_M),ppc64)
CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
endif
endif endif
ifndef WHISPER_NO_ACCELERATE ifndef WHISPER_NO_ACCELERATE
# Mac M1 - include Accelerate framework # Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
@ -144,118 +118,49 @@ ifndef WHISPER_NO_ACCELERATE
LDFLAGS += -framework Accelerate LDFLAGS += -framework Accelerate
endif endif
endif endif
ifdef WHISPER_COREML
CXXFLAGS += -DWHISPER_USE_COREML
LDFLAGS += -framework Foundation -framework CoreML
ifdef WHISPER_COREML_ALLOW_FALLBACK
CXXFLAGS += -DWHISPER_COREML_ALLOW_FALLBACK
endif
endif
ifdef WHISPER_OPENBLAS ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas LDFLAGS += -lopenblas
endif endif
ifdef WHISPER_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
WHISPER_OBJ += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif
ifdef WHISPER_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
LDFLAGS += -lclblast -lOpenCL
WHISPER_OBJ += ggml-opencl.o
ggml-opencl.o: ggml-opencl.c ggml-opencl.h
$(CC) $(CFLAGS) -c $< -o $@
endif
ifdef WHISPER_GPROF ifdef WHISPER_GPROF
CFLAGS += -pg CFLAGS += -pg
CXXFLAGS += -pg CXXFLAGS += -pg
endif endif
ifneq ($(filter aarch64%,$(UNAME_M)),) ifneq ($(filter aarch64%,$(UNAME_M)),)
CFLAGS += -mcpu=native
CXXFLAGS += -mcpu=native
endif endif
ifneq ($(filter armv6%,$(UNAME_M)),) ifneq ($(filter armv6%,$(UNAME_M)),)
# 32-bit Raspberry Pi 1, 2, 3 # Raspberry Pi 1, 2, 3
CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
endif endif
ifneq ($(filter armv7%,$(UNAME_M)),) ifneq ($(filter armv7%,$(UNAME_M)),)
# 32-bit ARM, for example on Armbian or possibly raspbian # Raspberry Pi 4
#CFLAGS += -mfpu=neon -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
#CXXFLAGS += -mfpu=neon -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
# 64-bit ARM on 32-bit OS, use these (TODO: auto-detect 64-bit)
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
endif endif
ifneq ($(filter armv8%,$(UNAME_M)),) ifneq ($(filter armv8%,$(UNAME_M)),)
# Raspberry Pi 4 # Raspberry Pi 4
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access CFLAGS += -mfp16-format=ieee -mno-unaligned-access
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -funsafe-math-optimizations -mno-unaligned-access
endif endif
# default: main
# Print build information
#
$(info I whisper.cpp build info: )
$(info I UNAME_S: $(UNAME_S))
$(info I UNAME_P: $(UNAME_P))
$(info I UNAME_M: $(UNAME_M))
$(info I CFLAGS: $(CFLAGS))
$(info I CXXFLAGS: $(CXXFLAGS))
$(info I LDFLAGS: $(LDFLAGS))
$(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
# #
# Build library # Build library
# #
ggml.o: ggml.c ggml.h ggml-cuda.h ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c $< -o $@ $(CC) $(CFLAGS) -c ggml.c -o ggml.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
ifndef WHISPER_COREML libwhisper.a: ggml.o whisper.o
WHISPER_OBJ += whisper.o $(AR) rcs libwhisper.a ggml.o whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder.mm -o whisper-encoder.o
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h libwhisper.so: ggml.o whisper.o
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o $(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
endif
libwhisper.a: ggml.o $(WHISPER_OBJ)
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
clean: clean:
rm -f *.o main stream command talk talk-llama bench quantize libwhisper.a libwhisper.so rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
# #
# Examples # Examples
@ -263,30 +168,21 @@ clean:
CC_SDL=`sdl2-config --cflags --libs` CC_SDL=`sdl2-config --cflags --libs`
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp main: examples/main/main.cpp ggml.o whisper.o
SRC_COMMON_SDL = examples/common-sdl.cpp $(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
./main -h ./main -h
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) stream: examples/stream/stream.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON) command: examples/command/command.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
# #
# Audio samples # Audio samples

371
README.md
View File

@ -1,26 +1,20 @@
# whisper.cpp # whisper.cpp
![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)
[![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions) [![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
- Plain C/C++ implementation without dependencies - Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support) - Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- AVX intrinsics support for x86 architectures - AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision - Mixed F16 / F32 precision
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization) - Low memory usage (Flash Attention + Flash Forward)
- Low memory usage (Flash Attention)
- Zero memory allocations at runtime - Zero memory allocations at runtime
- Runs on the CPU - Runs on the CPU
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h) - [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
Supported platforms: Supported platforms:
@ -63,16 +57,12 @@ the Accelerate framework utilizes the special-purpose AMX coprocessor available
## Quick start ## Quick start
First clone the repository. First, download one of the Whisper models converted in [ggml format](models). For example:
Then, download one of the Whisper models converted in [ggml format](models). For example:
```bash ```bash
bash ./models/download-ggml-model.sh base.en bash ./models/download-ggml-model.sh base.en
``` ```
If you wish to convert the Whisper models to ggml format yourself, instructions are in [models/README.md](models/README.md).
Now build the [main](examples/main) example and transcribe an audio file like this: Now build the [main](examples/main) example and transcribe an audio file like this:
```bash ```bash
@ -80,7 +70,7 @@ Now build the [main](examples/main) example and transcribe an audio file like th
make make
# transcribe an audio file # transcribe an audio file
./main -f samples/jfk.wav ./main -f input.wav
``` ```
--- ---
@ -98,38 +88,27 @@ c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o gg
usage: ./main [options] file0.wav file1.wav ... usage: ./main [options] file0.wav file1.wav ...
options: options:
-h, --help [default] show this help message and exit -h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation -t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation -p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds -ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset -on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds -d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store -mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters -ml N, --max-len N [0 ] maximum segment length in characters
-bo N, --best-of N [5 ] number of best candidates to keep -wt N, --word-thold N [0.01 ] word timestamp probability threshold
-bs N, --beam-size N [-1 ] beam size for beam search -su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-wt N, --word-thold N [0.01 ] word timestamp probability threshold -tr, --translate [false ] translate from source language to english
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail -otxt, --output-txt [false ] output result in a text file
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -ovtt, --output-vtt [false ] output result in a vtt file
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy) -osrt, --output-srt [false ] output result in a srt file
-tr, --translate [false ] translate from source language to english -owts, --output-words [false ] output script for generating karaoke video
-di, --diarize [false ] stereo audio diarization -ps, --print-special [false ] print special tokens
-nf, --no-fallback [false ] do not use temperature fallback while decoding -pc, --print-colors [false ] print colors
-otxt, --output-txt [false ] output result in a text file -nt, --no-timestamps [true ] do not print timestamps
-ovtt, --output-vtt [false ] output result in a vtt file -l LANG, --language LANG [en ] spoken language
-osrt, --output-srt [false ] output result in a srt file -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-owts, --output-words [false ] output script for generating karaoke video -f FNAME, --file FNAME [ ] input WAV file path
-ocsv, --output-csv [false ] output result in a CSV file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
bash ./models/download-ggml-model.sh base.en bash ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ... Downloading ggml model base.en ...
@ -148,8 +127,7 @@ Running base.en on all samples in ./samples ...
[+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen) [+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
---------------------------------------------- ----------------------------------------------
whisper_init_from_file: loading model from 'models/ggml-base.en.bin' whisper_model_load: loading model from 'models/ggml-base.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51864 whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500 whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 512 whisper_model_load: n_audio_state = 512
@ -162,14 +140,13 @@ whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80 whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1 whisper_model_load: f16 = 1
whisper_model_load: type = 2 whisper_model_load: type = 2
whisper_model_load: mem required = 215.00 MB (+ 6.00 MB per decoder)
whisper_model_load: kv self size = 5.25 MB
whisper_model_load: kv cross size = 17.58 MB
whisper_model_load: adding 1607 extra tokens whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 140.60 MB whisper_model_load: mem_required = 506.00 MB
whisper_model_load: ggml ctx size = 140.60 MB
whisper_model_load: memory size = 22.83 MB
whisper_model_load: model size = 140.54 MB whisper_model_load: model size = 140.54 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ... main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
@ -177,13 +154,12 @@ main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 proc
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. [00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: fallbacks = 0 p / 0 h whisper_print_timings: load time = 105.91 ms
whisper_print_timings: load time = 113.81 ms whisper_print_timings: mel time = 24.62 ms
whisper_print_timings: mel time = 15.40 ms whisper_print_timings: sample time = 3.63 ms
whisper_print_timings: sample time = 11.58 ms / 27 runs ( 0.43 ms per run) whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
whisper_print_timings: encode time = 266.60 ms / 1 runs ( 266.60 ms per run) whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
whisper_print_timings: decode time = 66.11 ms / 27 runs ( 2.45 ms per run) whisper_print_timings: total time = 542.81 ms
whisper_print_timings: total time = 476.31 ms
``` ```
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`. The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@ -226,128 +202,26 @@ make large
| Model | Disk | Mem | SHA | | Model | Disk | Mem | SHA |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` | | tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` | | base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` | | small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` | | medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` | | large | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
## Quantization
`whisper.cpp` supports integer quantization of the Whisper `ggml` models.
Quantized models require less memory and disk space and depending on the hardware can be processed more efficiently.
Here are the steps for creating and using a quantized model:
```bash
# quantize a model with Q5_0 method
make quantize
./quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
# run the examples as usual, specifying the quantized model file
./main -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
```
## Core ML support
On Apple Silicon devices, the Encoder inference can be executed on the Apple Neural Engine (ANE) via Core ML. This can result in significant
speed-up - more than x3 faster compared with CPU-only execution. Here are the instructions for generating a Core ML model and using it with `whisper.cpp`:
- Install Python dependencies needed for the creation of the Core ML model:
```bash
pip install ane_transformers
pip install openai-whisper
pip install coremltools
```
- To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools.
- Python 3.10 is recommended.
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
- To create an environment, use: `conda create -n py310-whisper python=3.10 -y`
- To activate the environment, use: `conda activate py310-whisper`
- Generate a Core ML model. For example, to generate a `base.en` model, use:
```bash
./models/generate-coreml-model.sh base.en
```
This will generate the folder `models/ggml-base.en-encoder.mlmodelc`
- Build `whisper.cpp` with Core ML support:
```bash
# using Makefile
make clean
WHISPER_COREML=1 make -j
# using CMake
cd build
cmake -DWHISPER_COREML=1 ..
```
- Run the examples as usual. For example:
```bash
./main -m models/ggml-base.en.bin -f samples/jfk.wav
...
whisper_init_state: loading Core ML model from 'models/ggml-base.en-encoder.mlmodelc'
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |
...
```
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster.
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## NVIDIA GPU support via cuBLAS
With NVIDIA cards, the Encoder processing can be offloaded to the GPU to a large extend through cuBLAS.
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
Now build `whisper.cpp` with cuBLAS support:
```
make clean
WHISPER_CUBLAS=1 make -j
```
## OpenCL GPU support via CLBlast
For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APU's or low end devices for up to ~2x speedup.
First, make sure you have installed `CLBlast` for your OS or Distribution: https://github.com/CNugteren/CLBlast
Now build `whisper.cpp` with CLBlast support:
```
Makefile:
cd whisper.cpp
make clean
WHISPER_CLBLAST=1 make -j
CMake:
cd whisper.cpp ; mkdir build ; cd build
cmake -DWHISPER_CLBLAST=ON ..
make clean
make -j
cp bin/* ../
```
Run all the examples as usual.
## Limitations ## Limitations
- Inference only - Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Another example ## Another example
@ -360,8 +234,7 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
```java ```java
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8 $ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
whisper_init_from_file: loading model from 'models/ggml-medium.en.bin' whisper_model_load: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51864 whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500 whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 1024 whisper_model_load: n_audio_state = 1024
@ -374,71 +247,65 @@ whisper_model_load: n_text_layer = 24
whisper_model_load: n_mels = 80 whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1 whisper_model_load: f16 = 1
whisper_model_load: type = 4 whisper_model_load: type = 4
whisper_model_load: mem required = 1720.00 MB (+ 43.00 MB per decoder) whisper_model_load: mem_required = 2610.00 MB
whisper_model_load: kv self size = 42.00 MB
whisper_model_load: kv cross size = 140.62 MB
whisper_model_load: adding 1607 extra tokens whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 1462.35 MB whisper_model_load: ggml ctx size = 1644.97 MB
whisper_model_load: model size = 1462.12 MB whisper_model_load: memory size = 182.62 MB
whisper_model_load: model size = 1462.12 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, lang = en, task = transcribe, timestamps = 1 ...
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ... [00:00.000 --> 00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
[00:08.000 --> 00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
[00:17.000 --> 00:23.000] A short time later, debris was seen falling from the skies above Texas.
[00:23.000 --> 00:29.000] The Columbia's lost. There are no survivors.
[00:29.000 --> 00:32.000] On board was a crew of seven.
[00:32.000 --> 00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
[00:39.000 --> 00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
[00:48.000 --> 00:52.000] a colonel in the Israeli Air Force.
[00:52.000 --> 00:58.000] These men and women assumed great risk in the service to all humanity.
[00:58.000 --> 01:03.000] In an age when space flight has come to seem almost routine,
[01:03.000 --> 01:07.000] it is easy to overlook the dangers of travel by rocket
[01:07.000 --> 01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
[01:12.000 --> 01:18.000] These astronauts knew the dangers, and they faced them willingly,
[01:18.000 --> 01:23.000] knowing they had a high and noble purpose in life.
[01:23.000 --> 01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
[01:31.000 --> 01:36.000] All Americans today are thinking as well of the families of these men and women
[01:36.000 --> 01:40.000] who have been given this sudden shock and grief.
[01:40.000 --> 01:45.000] You're not alone. Our entire nation grieves with you,
[01:45.000 --> 01:52.000] and those you love will always have the respect and gratitude of this country.
[01:52.000 --> 01:56.000] The cause in which they died will continue.
[01:56.000 --> 02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
[02:04.000 --> 02:11.000] and the longing to understand. Our journey into space will go on.
[02:11.000 --> 02:16.000] In the skies today, we saw destruction and tragedy.
[02:16.000 --> 02:22.000] Yet farther than we can see, there is comfort and hope.
[02:22.000 --> 02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
[02:29.000 --> 02:35.000] who created all these. He who brings out the starry hosts one by one
[02:35.000 --> 02:39.000] and calls them each by name."
[02:39.000 --> 02:46.000] Because of His great power and mighty strength, not one of them is missing.
[02:46.000 --> 02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
[02:55.000 --> 03:01.000] The crew of the shuttle Columbia did not return safely to earth,
[03:01.000 --> 03:05.000] yet we can pray that all are safely home.
[03:05.000 --> 03:13.000] May God bless the grieving families, and may God continue to bless America.
[03:13.000 --> 03:41.000] Audio
[00:00:00.000 --> 00:00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country. whisper_print_timings: load time = 575.92 ms
[00:00:08.000 --> 00:00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia. whisper_print_timings: mel time = 230.60 ms
[00:00:17.000 --> 00:00:23.000] A short time later, debris was seen falling from the skies above Texas. whisper_print_timings: sample time = 73.19 ms
[00:00:23.000 --> 00:00:29.000] The Columbia's lost. There are no survivors. whisper_print_timings: encode time = 19552.61 ms / 814.69 ms per layer
[00:00:29.000 --> 00:00:32.000] On board was a crew of seven. whisper_print_timings: decode time = 13249.96 ms / 552.08 ms per layer
[00:00:32.000 --> 00:00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, whisper_print_timings: total time = 33686.27 ms
[00:00:39.000 --> 00:00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
[00:00:48.000 --> 00:00:52.000] a colonel in the Israeli Air Force.
[00:00:52.000 --> 00:00:58.000] These men and women assumed great risk in the service to all humanity.
[00:00:58.000 --> 00:01:03.000] In an age when space flight has come to seem almost routine,
[00:01:03.000 --> 00:01:07.000] it is easy to overlook the dangers of travel by rocket
[00:01:07.000 --> 00:01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
[00:01:12.000 --> 00:01:18.000] These astronauts knew the dangers, and they faced them willingly,
[00:01:18.000 --> 00:01:23.000] knowing they had a high and noble purpose in life.
[00:01:23.000 --> 00:01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
[00:01:31.000 --> 00:01:36.000] All Americans today are thinking as well of the families of these men and women
[00:01:36.000 --> 00:01:40.000] who have been given this sudden shock and grief.
[00:01:40.000 --> 00:01:45.000] You're not alone. Our entire nation grieves with you,
[00:01:45.000 --> 00:01:52.000] and those you love will always have the respect and gratitude of this country.
[00:01:52.000 --> 00:01:56.000] The cause in which they died will continue.
[00:01:56.000 --> 00:02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
[00:02:04.000 --> 00:02:11.000] and the longing to understand. Our journey into space will go on.
[00:02:11.000 --> 00:02:16.000] In the skies today, we saw destruction and tragedy.
[00:02:16.000 --> 00:02:22.000] Yet farther than we can see, there is comfort and hope.
[00:02:22.000 --> 00:02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
[00:02:29.000 --> 00:02:35.000] who created all these. He who brings out the starry hosts one by one
[00:02:35.000 --> 00:02:39.000] and calls them each by name."
[00:02:39.000 --> 00:02:46.000] Because of His great power and mighty strength, not one of them is missing.
[00:02:46.000 --> 00:02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
[00:02:55.000 --> 00:03:01.000] The crew of the shuttle Columbia did not return safely to earth,
[00:03:01.000 --> 00:03:05.000] yet we can pray that all are safely home.
[00:03:05.000 --> 00:03:13.000] May God bless the grieving families, and may God continue to bless America.
[00:03:13.000 --> 00:03:19.000] [Silence]
whisper_print_timings: fallbacks = 1 p / 0 h
whisper_print_timings: load time = 569.03 ms
whisper_print_timings: mel time = 146.85 ms
whisper_print_timings: sample time = 238.66 ms / 553 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per run)
whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
whisper_print_timings: total time = 32733.52 ms
``` ```
</details> </details>
## Real-time audio input example ## Real-time audio input example
This is a naive example of performing real-time inference on audio from your microphone. This is a naive example of performing real-time inference on audio from your microphone.
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously. The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continously.
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10). More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```java ```java
make stream
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000 ./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
``` ```
@ -449,22 +316,18 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy
to highlight words with high or low confidence: to highlight words with high or low confidence:
```java
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
```
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png"> <img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
## Controlling the length of the generated text segments (experimental) ## Controlling the length of the generated text segments (experimental)
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`: For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
```java ```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16 ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
whisper_model_load: loading model from './models/ggml-base.en.bin' whisper_model_load: loading model from './models/ggml-base.en.bin'
... ...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ... main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
@ -488,11 +351,11 @@ The `--max-len` argument can be used to obtain word-level timestamps. Simply use
whisper_model_load: loading model from './models/ggml-base.en.bin' whisper_model_load: loading model from './models/ggml-base.en.bin'
... ...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ... main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:00.320] [00:00:00.000 --> 00:00:00.320]
[00:00:00.320 --> 00:00:00.370] And [00:00:00.320 --> 00:00:00.370] And
[00:00:00.370 --> 00:00:00.690] so [00:00:00.370 --> 00:00:00.690] so
[00:00:00.690 --> 00:00:00.850] my [00:00:00.690 --> 00:00:00.850] my
@ -558,19 +421,6 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
--- ---
## Video comparison of different models
Use the [extra/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/extra/bench-wts.sh) script to generate a video in the following format:
```java
./extra/bench-wts.sh samples/jfk.wav
ffplay ./samples/jfk.wav.all.mp4
```
https://user-images.githubusercontent.com/1991296/223206245-2d36d903-cf8e-4f09-8c3b-eb9f9c39d6fc.mp4
---
## Benchmarks ## Benchmarks
In order to have an objective comparison of the performance of the inference across different system configurations, In order to have an objective comparison of the performance of the inference across different system configurations,
@ -591,7 +441,7 @@ The original models are converted to a custom binary format. This allows to pack
You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script
or manually from here: or manually from here:
- https://huggingface.co/ggerganov/whisper.cpp - https://huggingface.co/datasets/ggerganov/whisper.cpp
- https://ggml.ggerganov.com - https://ggml.ggerganov.com
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
@ -601,19 +451,9 @@ in [models](models).
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310) - [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309) - [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312) - [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313) - [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper) - [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
- [X] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
- [X] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
## Examples ## Examples
@ -627,7 +467,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture | | [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic | | [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot | | [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp | | [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp | | [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp | | [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |

View File

@ -1,2 +1,3 @@
build build
models models
go.sum

View File

@ -1,27 +1,28 @@
BUILD_DIR := build CMAKE := $(shell which cmake)
MODELS_DIR := models BUILD_DIR := "build"
MODELS_DIR := "models"
EXAMPLES_DIR := $(wildcard examples/*) EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../..) C_INCLUDE_PATH := "../.."
LIBRARY_PATH := $(abspath ../..)
all: clean whisper examples all: clean whisper examples
whisper: mkdir whisper: mkdir
@echo Build whisper @echo Build whisper
@${MAKE} -C ../.. libwhisper.a @${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
@${CMAKE} --build ${BUILD_DIR} --target whisper
test: model-small whisper modtidy test: model-small whisper modtidy
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v . @go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/... @go test -v ./pkg/whisper/...
examples: $(EXAMPLES_DIR) examples: $(EXAMPLES_DIR)
model-small: mkdir examples/go-model-download model-small: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin @${BUILD_DIR}/go-model-download -out models small.en
$(EXAMPLES_DIR): mkdir whisper modtidy $(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@) @echo Build example $(notdir $@)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@ @go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
mkdir: mkdir:
@echo Mkdir ${BUILD_DIR} @echo Mkdir ${BUILD_DIR}

View File

@ -74,27 +74,4 @@ And you can then test a model against samples with the following command:
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav ./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
``` ```
## Using the bindings
To use the bindings in your own software,
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
Look at the `Makefile` in the `bindings/go` directory for an example.
The API Documentation:
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -17,14 +17,15 @@ import (
// CONSTANTS // CONSTANTS
const ( const (
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models srcUrl = "https://huggingface.co/" // The location of the models
srcExt = ".bin" // Filename extension srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
bufSize = 1024 * 64 // Size of the buffer used for downloading the model srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
) )
var ( var (
// The models which will be downloaded, if no model is specified as an argument // The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"} modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
) )
var ( var (
@ -122,14 +123,11 @@ func GetModels() []string {
// URLForModel returns the URL for the given model on huggingface.co // URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) { func URLForModel(model string) (string, error) {
if filepath.Ext(model) != srcExt {
model += srcExt
}
url, err := url.Parse(srcUrl) url, err := url.Parse(srcUrl)
if err != nil { if err != nil {
return "", err return "", err
} else { } else {
url.Path = filepath.Join(url.Path, model) url.Path = srcPathPrefix + "-" + model + srcExt
} }
return url.String(), nil return url.String(), nil
} }

View File

@ -1,22 +0,0 @@
package main
import "fmt"
///////////////////////////////////////////////////////////////////////////////
// CONSTANTS
const (
Reset = "\033[0m"
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
RGBSuffix = "m"
)
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Colorize text with RGB values, from 0 to 23
func Colorize(text string, v int) string {
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
// Grayscale colors are in the range 232-255
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
}

View File

@ -2,12 +2,6 @@ package main
import ( import (
"flag" "flag"
"fmt"
"strings"
"time"
// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
) )
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -48,26 +42,6 @@ func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String() return flags.Lookup("language").Value.String()
} }
func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}
func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}
func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}
func (flags *Flags) IsSpeedup() bool { func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true" return flags.Lookup("speedup").Value.String() == "true"
} }
@ -76,81 +50,12 @@ func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true" return flags.Lookup("tokens").Value.String() == "true"
} }
func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}
func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}
func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}
func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}
// Return success
return nil
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS
func registerFlags(flag *Flags) { func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file") flag.String("model", "", "Path to the model file")
flag.String("language", "", "Spoken language") flag.String("language", "", "Language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup") flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens") flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
} }

View File

@ -35,7 +35,8 @@ func main() {
// Process files // Process files
for _, filename := range flags.Args() { for _, filename := range flags.Args() {
if err := Process(model, filename, flags); err != nil { fmt.Println("Processing", filename)
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
continue continue
} }

View File

@ -11,7 +11,7 @@ import (
wav "github.com/go-audio/wav" wav "github.com/go-audio/wav"
) )
func Process(model whisper.Model, path string, flags *Flags) error { func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
var data []float32 var data []float32
// Create processing context // Create processing context
@ -20,22 +20,14 @@ func Process(model whisper.Model, path string, flags *Flags) error {
return err return err
} }
// Set the parameters
if err := flags.SetParams(context); err != nil {
return err
}
fmt.Printf("\n%s\n", context.SystemInfo())
// Open the file // Open the file
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path) fh, err := os.Open(path)
if err != nil { if err != nil {
return err return err
} }
defer fh.Close() defer fh.Close()
// Decode the WAV file - load the full buffer // Decode the WAV file
dec := wav.NewDecoder(fh) dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil { if buf, err := dec.FullPCMBuffer(); err != nil {
return err return err
@ -47,86 +39,42 @@ func Process(model whisper.Model, path string, flags *Flags) error {
data = buf.AsFloat32Buffer().Data data = buf.AsFloat32Buffer().Data
} }
// Segment callback when -tokens is specified // Set the parameters
var cb whisper.SegmentCallback var cb whisper.SegmentCallback
if flags.IsTokens() { if lang != "" {
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if speedup {
context.SetSpeedup(true)
}
if tokens {
cb = func(segment whisper.Segment) { cb = func(segment whisper.Segment) {
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond)) fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens { for _, token := range segment.Tokens {
if flags.IsColorize() && context.IsText(token) { fmt.Printf("%q ", token.Text)
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
} else {
fmt.Fprint(flags.Output(), token.Text, " ")
}
} }
fmt.Fprintln(flags.Output(), "") fmt.Println("")
fmt.Fprintln(flags.Output(), "")
} }
} }
// Process the data // Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb); err != nil { if err := context.Process(data, cb); err != nil {
return err return err
} }
context.PrintTimings()
// Print out the results // Print out the results
switch {
case flags.GetOut() == "srt":
return OutputSRT(os.Stdout, context)
case flags.GetOut() == "none":
return nil
default:
return Output(os.Stdout, context, flags.IsColorize())
}
}
// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
n := 1
for { for {
segment, err := context.NextSegment() segment, err := context.NextSegment()
if err == io.EOF { if err == io.EOF {
return nil break
} else if err != nil { } else if err != nil {
return err return err
} }
fmt.Fprintln(w, n) fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
fmt.Fprintln(w, segment.Text)
fmt.Fprintln(w, "")
n++
} }
}
// Output text to terminal // Return success
func Output(w io.Writer, context whisper.Context, colorize bool) error { return nil
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
if colorize {
for _, token := range segment.Tokens {
if !context.IsText(token) {
continue
}
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
}
fmt.Fprint(w, "\n")
} else {
fmt.Fprintln(w, " ", segment.Text)
}
}
}
// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
} }

View File

@ -1,23 +0,0 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,5 +1,8 @@
package whisper package whisper
// This file defines the whisper_token, whisper_token_data and whisper_full_params
// structures, which are used by the whisper_full() function.
import ( import (
"fmt" "fmt"
) )
@ -47,12 +50,7 @@ func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v) p.speed_up = toBool(v)
} }
// Set language id
func (p *Params) SetLanguage(lang int) error { func (p *Params) SetLanguage(lang int) error {
if lang == -1 {
p.language = nil
return nil
}
str := C.whisper_lang_str(C.int(lang)) str := C.whisper_lang_str(C.int(lang))
if str == nil { if str == nil {
return ErrInvalidLanguage return ErrInvalidLanguage
@ -62,7 +60,6 @@ func (p *Params) SetLanguage(lang int) error {
return nil return nil
} }
// Get language id
func (p *Params) Language() int { func (p *Params) Language() int {
if p.language == nil { if p.language == nil {
return -1 return -1
@ -70,50 +67,18 @@ func (p *Params) Language() int {
return int(C.whisper_lang_id(p.language)) return int(C.whisper_lang_id(p.language))
} }
// Threads available
func (p *Params) Threads() int {
return int(p.n_threads)
}
// Set number of threads to use
func (p *Params) SetThreads(threads int) { func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads) p.n_threads = C.int(threads)
} }
// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) { func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms) p.offset_ms = C.int(offset_ms)
} }
// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) { func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms) p.duration_ms = C.int(duration_ms)
} }
// Set timestamp token probability threshold (~0.01)
func (p *Params) SetTokenThreshold(t float32) {
p.thold_pt = C.float(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (p *Params) SetTokenSumThreshold(t float32) {
p.thold_ptsum = C.float(t)
}
// Set max segment length in characters
func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}
func (p *Params) SetTokenTimestamps(b bool) {
p.token_timestamps = toBool(b)
}
// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS

View File

@ -11,11 +11,10 @@ import (
// ERRORS // ERRORS
var ( var (
ErrUnableToLoadModel = errors.New("unable to load model") ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error") ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed") ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language") ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
) )
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@ -1,9 +1,7 @@
package whisper package whisper
import ( import (
"fmt"
"io" "io"
"runtime"
"strings" "strings"
"time" "time"
@ -26,7 +24,7 @@ var _ Context = (*context)(nil)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE // LIFECYCLE
func newContext(model *model, params whisper.Params) (Context, error) { func NewContext(model *model, params whisper.Params) (Context, error) {
context := new(context) context := new(context)
context.model = model context.model = model
context.params = params context.params = params
@ -43,13 +41,7 @@ func (context *context) SetLanguage(lang string) error {
if context.model.ctx == nil { if context.model.ctx == nil {
return ErrInternalAppError return ErrInternalAppError
} }
if !context.model.IsMultilingual() { if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrModelNotMultilingual
}
if lang == "auto" {
context.params.SetLanguage(-1)
} else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil { } else if err := context.params.SetLanguage(id); err != nil {
return err return err
@ -58,99 +50,16 @@ func (context *context) SetLanguage(lang string) error {
return nil return nil
} }
func (context *context) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Get language // Get language
func (context *context) Language() string { func (context *context) Language() string {
id := context.params.Language()
if id == -1 {
return "auto"
}
return whisper.Whisper_lang_str(context.params.Language()) return whisper.Whisper_lang_str(context.params.Language())
} }
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// Set speedup flag // Set speedup flag
func (context *context) SetSpeedup(v bool) { func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v) context.params.SetSpeedup(v)
} }
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
}
// Set time offset
func (context *context) SetOffset(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
func (context *context) SetTokenThreshold(t float32) {
context.params.SetTokenThreshold(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (context *context) SetTokenSumThreshold(t float32) {
context.params.SetTokenSumThreshold(t)
}
// Set max segment length in characters
func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set token timestamps flag
func (context *context) SetTokenTimestamps(b bool) {
context.params.SetTokenTimestamps(b)
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// ResetTimings resets the mode timings. Should be called before processing
func (context *context) ResetTimings() {
context.model.ctx.Whisper_reset_timings()
}
// PrintTimings prints the model timings to stdout.
func (context *context) PrintTimings() {
context.model.ctx.Whisper_print_timings()
}
// SystemInfo returns the system information
func (context *context) SystemInfo() string {
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
context.params.Threads(),
runtime.NumCPU(),
whisper.Whisper_print_system_info(),
)
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
if err != nil {
return nil, err
}
return langProbs, nil
}
// Process new sample data and return any errors // Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error { func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil { if context.model.ctx == nil {
@ -210,65 +119,6 @@ func (context *context) NextSegment() (Segment, error) {
return result, nil return result, nil
} }
// Test for text tokens
func (context *context) IsText(t Token) bool {
switch {
case context.IsBEG(t):
return false
case context.IsSOT(t):
return false
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
return false
case context.IsPREV(t):
return false
case context.IsSOLM(t):
return false
case context.IsNOT(t):
return false
default:
return true
}
}
// Test for "begin" token
func (context *context) IsBEG(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
}
// Test for "start of transcription" token
func (context *context) IsSOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
}
// Test for "end of transcription" token
func (context *context) IsEOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
}
// Test for "start of prev" token
func (context *context) IsPREV(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
}
// Test for "start of lm" token
func (context *context) IsSOLM(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
}
// Test for "No timestamps" token
func (context *context) IsNOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
}
// Test for token associated with a specific language
func (context *context) IsLANG(t Token, lang string) bool {
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
} else {
return false
}
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS // PRIVATE METHODS
@ -285,14 +135,10 @@ func toSegment(ctx *whisper.Context, n int) Segment {
func toTokens(ctx *whisper.Context, n int) []Token { func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n)) result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ { for i := 0; i < len(result); i++ {
data := ctx.Whisper_full_get_token_data(n, i)
result[i] = Token{ result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)), Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: ctx.Whisper_full_get_token_text(n, i), Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i), P: ctx.Whisper_full_get_token_p(n, i),
Start: time.Duration(data.T0()) * time.Millisecond * 10,
End: time.Duration(data.T1()) * time.Millisecond * 10,
} }
} }
return result return result

View File

@ -20,29 +20,15 @@ type Model interface {
// Return a new speech-to-text context. // Return a new speech-to-text context.
NewContext() (Context, error) NewContext() (Context, error)
// Return true if the model is multilingual.
IsMultilingual() bool
// Return all languages supported. // Return all languages supported.
Languages() []string Languages() []string
} }
// Context is the speach recognition context. // Context is the speach recognition context.
type Context interface { type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. SetLanguage(string) error // Set the language to use for speech recognition.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language Language() string // Get language
SetSpeedup(bool) // Set speedup flag
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors. // Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the // If defined, newly generated segments are passed to the
@ -52,21 +38,6 @@ type Context interface {
// After process is called, return segments until the end of the stream // After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned. // is reached, when io.EOF is returned.
NextSegment() (Segment, error) NextSegment() (Segment, error)
IsBEG(Token) bool // Test for "begin" token
IsSOT(Token) bool // Test for "start of transcription" token
IsEOT(Token) bool // Test for "end of transcription" token
IsPREV(Token) bool // Test for "start of prev" token
IsSOLM(Token) bool // Test for "start of lm" token
IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token
// Timings
PrintTimings()
ResetTimings()
SystemInfo() string
} }
// Segment is the text result of a speech recognition. // Segment is the text result of a speech recognition.
@ -86,8 +57,7 @@ type Segment struct {
// Token is a text or special token // Token is a text or special token
type Token struct { type Token struct {
Id int Id int
Text string Text string
P float32 P float32
Start, End time.Duration
} }

View File

@ -23,7 +23,7 @@ var _ Model = (*model)(nil)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE // LIFECYCLE
func New(path string) (Model, error) { func New(path string) (*model, error) {
model := new(model) model := new(model)
if _, err := os.Stat(path); err != nil { if _, err := os.Stat(path); err != nil {
return nil, err return nil, err
@ -64,11 +64,6 @@ func (model *model) String() string {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS // PUBLIC METHODS
// Return true if model is multilingual (language and translation options are supported)
func (model *model) IsMultilingual() bool {
return model.ctx.Whisper_is_multilingual() != 0
}
// Return all recognized languages. Initially it is set to auto-detect // Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string { func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id()) result := make([]string, 0, whisper.Whisper_lang_max_id())
@ -94,8 +89,7 @@ func (model *model) NewContext() (Context, error) {
params.SetPrintRealtime(false) params.SetPrintRealtime(false)
params.SetPrintTimestamps(false) params.SetPrintTimestamps(false)
params.SetThreads(runtime.NumCPU()) params.SetThreads(runtime.NumCPU())
params.SetNoContext(true)
// Return new context // Return new context
return newContext(model, params) return NewContext(model, params)
} }

View File

@ -9,7 +9,8 @@ import (
// CGO // CGO
/* /*
#cgo LDFLAGS: -lwhisper -lm -lstdc++ #cgo CFLAGS: -I${SRCDIR}/../..
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
#cgo darwin LDFLAGS: -framework Accelerate #cgo darwin LDFLAGS: -framework Accelerate
#include <whisper.h> #include <whisper.h>
#include <stdlib.h> #include <stdlib.h>
@ -20,7 +21,7 @@ extern bool callEncoderBegin(void* user_data);
// Text segment callback // Text segment callback
// Called on every newly generated text segment // Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) { static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL) {
callNewSegment(user_data, n_new); callNewSegment(user_data, n_new);
} }
@ -29,7 +30,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
// Encoder begin callback // Encoder begin callback
// If not NULL, called before the encoder starts // If not NULL, called before the encoder starts
// If it returns false, the computation is aborted // If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) { static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL) {
return callEncoderBegin(user_data); return callEncoderBegin(user_data);
} }
@ -91,7 +92,7 @@ var (
func Whisper_init(path string) *Context { func Whisper_init(path string) *Context {
cPath := C.CString(path) cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath)) defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file(cPath); ctx != nil { if ctx := C.whisper_init(cPath); ctx != nil {
return (*Context)(ctx) return (*Context)(ctx)
} else { } else {
return nil return nil
@ -147,6 +148,16 @@ func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
} }
} }
// whisper_sample_best() returns the token with the highest probability
func (ctx *Context) Whisper_sample_best() TokenData {
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
}
// whisper_sample_timestamp() returns the most probable timestamp token
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
}
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens. // Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success // Returns the number of tokens on success
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) { func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
@ -160,10 +171,6 @@ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
} }
// Return the id of the specified language, returns -1 if not found // Return the id of the specified language, returns -1 if not found
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_lang_id(lang string) int { func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang))) return int(C.whisper_lang_id(C.CString(lang)))
} }
@ -204,10 +211,6 @@ func (ctx *Context) Whisper_n_text_ctx() int {
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx))) return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
} }
func (ctx *Context) Whisper_n_audio_ctx() int {
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_is_multilingual() int { func (ctx *Context) Whisper_is_multilingual() int {
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx))) return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
} }
@ -356,7 +359,7 @@ func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
// Get token data for the specified token in the specified segment. // Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc. // This contains probabilities, timestamps, etc.
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData { func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
} }
@ -407,11 +410,3 @@ func callEncoderBegin(user_data unsafe.Pointer) C.bool {
} }
return true return true
} }
func (t TokenData) T0() int64 {
return int64(t.t0)
}
func (t TokenData) T1() int64 {
return int64(t.t1)
}

View File

@ -50,10 +50,7 @@ func Test_Whisper_001(t *testing.T) {
ctx := whisper.Whisper_init(ModelPath) ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx) assert.NotNil(ctx)
defer ctx.Whisper_free() defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
assert.NoError(err)
// Print out tokens // Print out tokens
num_segments := ctx.Whisper_full_n_segments() num_segments := ctx.Whisper_full_n_segments()

View File

@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) { EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) { if (g_context == nullptr) {
g_context = whisper_init_from_file(path_model.c_str()); g_context = whisper_init(path_model.c_str());
if (g_context != nullptr) { if (g_context != nullptr) {
return true; return true;
} else { } else {

View File

@ -1,6 +1,6 @@
{ {
"name": "whisper.cpp", "name": "whisper.cpp",
"version": "1.4.2", "version": "1.0.4",
"description": "Whisper speech recognition", "description": "Whisper speech recognition",
"main": "whisper.js", "main": "whisper.js",
"scripts": { "scripts": {

File diff suppressed because one or more lines are too long

View File

@ -1,7 +0,0 @@
Makefile
ggml.c
ggml.h
whisper.bundle
whisper.cpp
whisper.h
dr_wav.h

View File

@ -1,21 +0,0 @@
require 'mkmf'
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
# need to use c++ compiler flags
$CXXFLAGS << ' -std=c++11'
# Set to true when building binary gems
if enable_config('static-stdlib', false)
$LDFLAGS << ' -static-libgcc -static-libstdc++'
end
if enable_config('march-tune-native', false)
$CFLAGS << ' -march=native -mtune=native'
$CXXFLAGS << ' -march=native -mtune=native'
end
create_makefile('whisper')

View File

@ -1,426 +0,0 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#ifdef __cplusplus
extern "C" {
#endif
#define BOOL_PARAMS_SETTER(self, prop, value) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (value == Qfalse || value == Qnil) { \
rwp->params.prop = false; \
} else { \
rwp->params.prop = true; \
} \
return value; \
#define BOOL_PARAMS_GETTER(self, prop) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (rwp->params.prop) { \
return Qtrue; \
} else { \
return Qfalse; \
}
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
static void ruby_whisper_free(ruby_whisper *rw) {
if (rw->context) {
whisper_free(rw->context);
rw->context = NULL;
}
}
static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
}
void rb_whisper_mark(ruby_whisper *rw) {
// call rb_gc_mark on any ruby references in rw
}
void rb_whisper_free(ruby_whisper *rw) {
ruby_whisper_free(rw);
free(rw);
}
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
}
void rb_whisper_params_free(ruby_whisper_params *rwp) {
ruby_whisper_params_free(rwp);
free(rwp);
}
static VALUE ruby_whisper_allocate(VALUE klass) {
ruby_whisper *rw;
rw = ALLOC(ruby_whisper);
rw->context = NULL;
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
}
static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
VALUE whisper_model_file_path;
// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
if (rw->context == nullptr) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
return self;
}
/*
* transcribe a single file
* can emit to a block results
*
**/
static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params;
rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
Data_Get_Struct(self, ruby_whisper, rw);
Data_Get_Struct(params, ruby_whisper_params, rwp);
if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
}
std::string fname_inp = StringValueCStr(wave_file_path);
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input - this is directly from main.cpp example
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true) {
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return self;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return self;
}
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return self;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return self;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return self;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (rwp->diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}
const int n_segments = whisper_full_n_segments(rw->context);
VALUE output = rb_str_new2("");
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(rw->context, i);
output = rb_str_concat(output, rb_str_new2(text));
}
VALUE idCall = rb_intern("call");
if (blk != Qnil) {
rb_funcall(blk, idCall, 1, output);
}
return self;
}
/*
* params.language = "auto" | "en", etc...
*/
static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->params.language = "auto";
} else {
rwp->params.language = StringValueCStr(value);
}
return value;
}
static VALUE ruby_whisper_params_get_language(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->params.language) {
return rb_str_new2(rwp->params.language);
} else {
return rb_str_new2("auto");
}
}
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, translate, value)
}
static VALUE ruby_whisper_params_get_translate(VALUE self) {
BOOL_PARAMS_GETTER(self, translate)
}
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, no_context, value)
}
static VALUE ruby_whisper_params_get_no_context(VALUE self) {
BOOL_PARAMS_GETTER(self, no_context)
}
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, single_segment, value)
}
static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
BOOL_PARAMS_GETTER(self, single_segment)
}
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_special, value)
}
static VALUE ruby_whisper_params_get_print_special(VALUE self) {
BOOL_PARAMS_GETTER(self, print_special)
}
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_progress, value)
}
static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
BOOL_PARAMS_GETTER(self, print_progress)
}
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_realtime, value)
}
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
BOOL_PARAMS_GETTER(self, print_realtime)
}
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_timestamps, value)
}
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, print_timestamps)
}
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_blank, value)
}
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_blank)
}
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
}
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
}
static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, token_timestamps)
}
static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, token_timestamps, value)
}
static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
BOOL_PARAMS_GETTER(self, split_on_word)
}
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, split_on_word, value)
}
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
BOOL_PARAMS_GETTER(self, speed_up)
}
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, speed_up, value)
}
static VALUE ruby_whisper_params_get_diarize(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->diarize) {
return Qtrue;
} else {
return Qfalse;
}
}
static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->diarize = false;
} else {
rwp->diarize = true;
} \
return value;
}
static VALUE ruby_whisper_params_get_offset(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.offset_ms);
}
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.offset_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_duration(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.duration_ms);
}
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.duration_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.n_max_text_ctx);
}
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.n_max_text_ctx = NUM2INT(value);
return value;
}
void Init_whisper() {
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
}
#ifdef __cplusplus
}
#endif

View File

@ -1,15 +0,0 @@
#ifndef __RUBY_WHISPER_H
#define __RUBY_WHISPER_H
#include "whisper.h"
typedef struct {
struct whisper_context *context;
} ruby_whisper;
typedef struct {
struct whisper_full_params params;
bool diarize;
} ruby_whisper_params;
#endif

View File

@ -1,138 +0,0 @@
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
EXTDIR = File.join(TOPDIR, 'ext')
#$LIBDIR = File.join(TOPDIR, 'lib')
#$:.unshift(LIBDIR)
$:.unshift(EXTDIR)
require 'whisper'
require 'test/unit'
class TestWhisper < Test::Unit::TestCase
def setup
@params = Whisper::Params.new
end
def test_language
@params.language = "en"
assert_equal @params.language, "en"
@params.language = "auto"
assert_equal @params.language, "auto"
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@params.token_timestamps = true
assert @params.token_timestamps
@params.token_timestamps = false
assert !@params.token_timestamps
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@params.split_on_word = false
assert !@params.split_on_word
end
def test_speed_up
@params.speed_up = true
assert @params.speed_up
@params.speed_up = false
assert !@params.speed_up
end
def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params) {|text|
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
}
end
end

View File

@ -1,17 +0,0 @@
# Set the default compile features and properties for a target.
if (NOT TARGET)
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
endif()
target_compile_features(${TARGET}
PRIVATE
cxx_std_11
)
set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -1,146 +0,0 @@
//
// whisper-decoder-impl.h
//
// This file was automatically generated and should not be edited.
//
#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
#include <stdint.h>
#include <os/log.h>
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
/// token_data as 1 by 1 matrix of 32-bit integers
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
/// var_1346 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle;
/**
Initialize whisper_decoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init;
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Construct whisper_decoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Make a prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the convenience interface
@param token_data as 1 by 1 matrix of 32-bit integers:
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Batch prediction
@param inputArray array of whisper_decoder_implInput instances to obtain predictions from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the predictions as NSArray<whisper_decoder_implOutput *>
*/
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,201 +0,0 @@
//
// whisper-decoder-impl.m
//
// This file was automatically generated and should not be edited.
//
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-decoder-impl.h"
@implementation whisper_decoder_implInput
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data {
self = [super init];
if (self) {
_token_data = token_data;
_audio_data = audio_data;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"token_data", @"audio_data"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"token_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.token_data];
}
if ([featureName isEqualToString:@"audio_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.audio_data];
}
return nil;
}
@end
@implementation whisper_decoder_implOutput
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
self = [super init];
if (self) {
_var_1346 = var_1346;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"var_1346"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"var_1346"]) {
return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
}
return nil;
}
@end
@implementation whisper_decoder_impl
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle {
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_decoder_impl" ofType:@"mlmodelc"];
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-decoder-impl.mlmodelc in the bundle resource"); return nil; }
return [NSURL fileURLWithPath:assetPath];
}
/**
Initialize whisper_decoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
return self;
}
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
}
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
}
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Construct whisper_decoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
configuration:configuration
completionHandler:handler];
}
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
[MLModel loadContentsOfURL:modelURL
configuration:configuration
completionHandler:^(MLModel *model, NSError *error) {
if (model != nil) {
whisper_decoder_impl *typedModel = [[whisper_decoder_impl alloc] initWithMLModel:model];
handler(typedModel, nil);
} else {
handler(nil, error);
}
}];
}
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
}
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
}
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_decoder_implInput *input_ = [[whisper_decoder_implInput alloc] initWithToken_data:token_data audio_data:audio_data];
return [self predictionFromFeatures:input_ error:error];
}
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
if (!outBatch) { return nil; }
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
[results addObject:result];
}
return results;
}
@end

View File

@ -1,142 +0,0 @@
//
// whisper-encoder-impl.h
//
// This file was automatically generated and should not be edited.
//
#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
#include <stdint.h>
#include <os/log.h>
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * logmel_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data NS_DESIGNATED_INITIALIZER;
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
/// output as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * output;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle;
/**
Initialize whisper_encoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init;
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Construct whisper_encoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Make a prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the convenience interface
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Batch prediction
@param inputArray array of whisper_encoder_implInput instances to obtain predictions from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the predictions as NSArray<whisper_encoder_implOutput *>
*/
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,197 +0,0 @@
//
// whisper-encoder-impl.m
//
// This file was automatically generated and should not be edited.
//
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-encoder-impl.h"
@implementation whisper_encoder_implInput
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data {
self = [super init];
if (self) {
_logmel_data = logmel_data;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"logmel_data"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"logmel_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.logmel_data];
}
return nil;
}
@end
@implementation whisper_encoder_implOutput
- (instancetype)initWithOutput:(MLMultiArray *)output {
self = [super init];
if (self) {
_output = output;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"output"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"output"]) {
return [MLFeatureValue featureValueWithMultiArray:self.output];
}
return nil;
}
@end
@implementation whisper_encoder_impl
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle {
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_encoder_impl" ofType:@"mlmodelc"];
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-encoder-impl.mlmodelc in the bundle resource"); return nil; }
return [NSURL fileURLWithPath:assetPath];
}
/**
Initialize whisper_encoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
return self;
}
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
}
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
}
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Construct whisper_encoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
configuration:configuration
completionHandler:handler];
}
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
[MLModel loadContentsOfURL:modelURL
configuration:configuration
completionHandler:^(MLModel *model, NSError *error) {
if (model != nil) {
whisper_encoder_impl *typedModel = [[whisper_encoder_impl alloc] initWithMLModel:model];
handler(typedModel, nil);
} else {
handler(nil, error);
}
}];
}
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
}
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
}
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
return [self predictionFromFeatures:input_ error:error];
}
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
if (!outBatch) { return nil; }
NSMutableArray<whisper_encoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_encoder_implOutput * result = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[resultProvider featureValueForName:@"output"].multiArrayValue];
[results addObject:result];
}
return results;
}
@end

View File

@ -1,22 +0,0 @@
// Wrapper of the Core ML Whisper Encoder model
//
// Code is derived from the work of Github user @wangchou
// ref: https://github.com/wangchou/callCoreMLFromCpp
#if __cplusplus
extern "C" {
#endif
struct whisper_coreml_context;
struct whisper_coreml_context * whisper_coreml_init(const char * path_model);
void whisper_coreml_free(struct whisper_coreml_context * ctx);
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
float * mel,
float * out);
#if __cplusplus
}
#endif

View File

@ -1,63 +0,0 @@
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-encoder.h"
#import "whisper-encoder-impl.h"
#import <CoreML/CoreML.h>
#include <stdlib.h>
#if __cplusplus
extern "C" {
#endif
struct whisper_coreml_context {
const void * data;
};
struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
NSString * path_model_str = [[NSString alloc] initWithUTF8String:path_model];
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
if (data == NULL) {
return NULL;
}
whisper_coreml_context * ctx = new whisper_coreml_context;
ctx->data = data;
return ctx;
}
void whisper_coreml_free(struct whisper_coreml_context * ctx) {
CFRelease(ctx->data);
delete ctx;
}
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
float * mel,
float * out) {
MLMultiArray * inMultiArray = [
[MLMultiArray alloc] initWithDataPointer: mel
shape: @[@1, @80, @3000]
dataType: MLMultiArrayDataTypeFloat32
strides: @[@(240000), @(3000), @1]
deallocator: nil
error: nil
];
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
memcpy(out, outCoreML.output.dataPointer, outCoreML.output.count * sizeof(float));
}
#if __cplusplus
}
#endif

View File

@ -4,7 +4,7 @@ find_package(Threads REQUIRED)
# third-party # third-party
if (WHISPER_SDL2) if (WHISPER_SUPPORT_SDL2)
# SDL2 # SDL2
find_package(SDL2 REQUIRED) find_package(SDL2 REQUIRED)
@ -14,41 +14,6 @@ if (WHISPER_SDL2)
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}") message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
endif() endif()
# common
set(TARGET common)
add_library(${TARGET} STATIC
common.h
common.cpp
common-ggml.h
common-ggml.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
if (WHISPER_SDL2)
# common-sdl
set(TARGET common-sdl)
add_library(${TARGET} STATIC
common-sdl.h
common-sdl.cpp
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES})
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# examples # examples
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR})
@ -59,14 +24,10 @@ if (EMSCRIPTEN)
add_subdirectory(command.wasm) add_subdirectory(command.wasm)
add_subdirectory(talk.wasm) add_subdirectory(talk.wasm)
add_subdirectory(bench.wasm) add_subdirectory(bench.wasm)
elseif(CMAKE_JS_VERSION)
add_subdirectory(addon.node)
else() else()
add_subdirectory(main) add_subdirectory(main)
add_subdirectory(stream) add_subdirectory(stream)
add_subdirectory(command) add_subdirectory(command)
add_subdirectory(bench) add_subdirectory(bench)
add_subdirectory(quantize)
add_subdirectory(talk) add_subdirectory(talk)
add_subdirectory(talk-llama)
endif() endif()

View File

@ -1,3 +0,0 @@
.idea
node_modules
build

View File

@ -1,31 +0,0 @@
set(TARGET whisper-addon)
# Base settings
#==================================================================
# env var supported by cmake-js
add_definitions(-DNAPI_VERSION=4)
include_directories(${CMAKE_JS_INC})
#==================================================================
add_library(${TARGET} SHARED ${CMAKE_JS_SRC} addon.cpp)
set_target_properties(${TARGET} PROPERTIES PREFIX "" SUFFIX ".node")
include(DefaultTargetOptions)
# Include N-API wrappers
#==================================================================
execute_process(COMMAND node -p "require('node-addon-api').include"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE NODE_ADDON_API_DIR
)
string(REPLACE "\n" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
#==================================================================
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} common whisper ${CMAKE_THREAD_LIBS_INIT})
if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
# Generate node.lib
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
endif()

View File

@ -1,37 +0,0 @@
# addon
This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js).
It can be used as a reference for using the whisper.cpp project in other node projects.
## Install
```shell
npm install
```
## Compile
Make sure it is in the project root directory and compiled with make-js.
```shell
npx cmake-js compile -T whisper-addon -B Release
```
For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes.
> Such as appointing special cmake path:
> ```shell
> npx cmake-js compile -c 'xxx/cmake' -T whisper-addon -B Release
> ```
## Run
```shell
cd examples/addon.node
node index.js --language='language' --model='model-path' --fname_inp='file-path'
```
Because this is a simple Demo, only the above parameters are set in the node environment.
Other parameters can also be specified in the node environment.

View File

@ -1,23 +0,0 @@
const path = require("path");
const { whisper } = require(path.join(
__dirname,
"../../../build/Release/whisper-addon"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
const whisperParamsMock = {
language: "en",
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
};
describe("Run whisper.node", () => {
test("it should receive a non-empty value", async () => {
let result = await whisperAsync(whisperParamsMock);
expect(result.length).toBeGreaterThan(0);
}, 10000);
});

View File

@ -1,338 +0,0 @@
#include "napi.h"
#include "common.h"
#include "whisper.h"
#include <string>
#include <thread>
#include <vector>
#include <cmath>
#include <cstdint>
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;
bool speed_up = false;
bool translate = false;
bool diarize = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
std::string language = "en";
std::string prompt;
std::string model = "../../ggml-large.bin";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
};
struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
};
// 500 -> 00:05.000
// 6000 -> 01:00.000
std::string to_timestamp(int64_t t, bool comma = false) {
int64_t msec = t * 10;
int64_t hr = msec / (1000 * 60 * 60);
msec = msec - hr * (1000 * 60 * 60);
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
return std::string(buf);
}
int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
// colorful print bug
//
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
return 2;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
continue;
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.initial_prompt = params.prompt.c_str();
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return 10;
}
}
}
const int n_segments = whisper_full_n_segments(ctx);
result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result[i].emplace_back(to_timestamp(t0, true));
result[i].emplace_back(to_timestamp(t1, true));
result[i].emplace_back(text);
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}
class Worker : public Napi::AsyncWorker {
public:
Worker(Napi::Function& callback, whisper_params params)
: Napi::AsyncWorker(callback), params(params) {}
void Execute() override {
run(params, result);
}
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
res[i] = tmp;
}
Callback().Call({Env().Null(), res});
}
private:
whisper_params params;
std::vector<std::vector<std::string>> result;
};
Napi::Value whisper(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (info.Length() <= 0 || !info[0].IsObject()) {
Napi::TypeError::New(env, "object expected").ThrowAsJavaScriptException();
}
whisper_params params;
Napi::Object whisper_params = info[0].As<Napi::Object>();
std::string language = whisper_params.Get("language").As<Napi::String>();
std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
params.language = language;
params.model = model;
params.fname_inp.emplace_back(input);
Napi::Function callback = info[1].As<Napi::Function>();
Worker* worker = new Worker(callback, params);
worker->Queue();
return env.Undefined();
}
Napi::Object Init(Napi::Env env, Napi::Object exports) {
exports.Set(
Napi::String::New(env, "whisper"),
Napi::Function::New(env, whisper)
);
return exports;
}
NODE_API_MODULE(whisper, Init);

View File

@ -1,36 +0,0 @@
const path = require("path");
const { whisper } = require(path.join(
__dirname,
"../../build/Release/whisper-addon"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
const whisperParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: "../../samples/jfk.wav",
};
const arguments = process.argv.slice(2);
const params = Object.fromEntries(
arguments.reduce((pre, item) => {
if (item.startsWith("--")) {
return [...pre, item.slice(2).split("=")];
}
return pre;
}, [])
);
for (const key in params) {
if (whisperParams.hasOwnProperty(key)) {
whisperParams[key] = params[key];
}
}
console.log("whisperParams =", whisperParams);
whisperAsync(whisperParams).then((result) => {
console.log(`Result from whisper: ${result}`);
});

View File

@ -1,16 +0,0 @@
{
"name": "whisper-addon",
"version": "0.0.0",
"description": "",
"main": "index.js",
"author": "Qanhe Chen",
"license": "MIT",
"scripts": {
"test": "jest"
},
"devDependencies": {
"cmake-js": "^7.1.1",
"jest": "^29.4.0",
"node-addon-api": "^5.0.0"
}
}

View File

@ -8,8 +8,6 @@ add_executable(${TARGET}
emscripten.cpp emscripten.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
) )
@ -31,9 +29,9 @@ endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \ --bind \
-s USE_PTHREADS=1 \ -s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE_STRICT=0 \ -s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=2000MB \ -s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=2000MB \ -s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \

View File

@ -28,11 +28,6 @@ void bench_main(size_t index) {
return; return;
} }
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
}
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) { if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret); fprintf(stderr, "error: failed to encode model: %d\n", ret);
return; return;
@ -57,7 +52,7 @@ EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) { for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) { if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str()); g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) { if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) { if (g_worker.joinable()) {
g_worker.join(); g_worker.join();

View File

@ -35,15 +35,6 @@
<br><br> <br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr> <hr>
Select the model you would like to use and click the "Bench" button.<br> Select the model you would like to use and click the "Bench" button.<br>
@ -53,18 +44,11 @@
<div id="model-whisper"> <div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span> Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button> <button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button> <button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<button id="fetch-whisper-small-en-q5_1" onclick="loadWhisper('small-en-q5_1')">small.en (Q5_1, 182 MB)</button>
<button id="fetch-whisper-medium-en-q5_0" onclick="loadWhisper('medium-en-q5_0')">medium.en (Q5_0, 515 MB)</button>
<button id="fetch-whisper-large-q5_0" onclick="loadWhisper('large-q5_0')">large (Q5_0, 1030 MB)</button>
<span id="fetch-whisper-progress"></span> <span id="fetch-whisper-progress"></span>
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
</div> </div>
<br> <br>
@ -176,14 +160,6 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none'; document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none'; document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-small-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none'; document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name; document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
} }
@ -192,42 +168,19 @@
let urls = { let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin', 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin', 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin',
'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin',
'large-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin',
}; };
let sizes = { let sizes = {
'tiny.en': 75, 'tiny.en': 75,
'base.en': 142, 'base.en': 142,
'small.en': 466,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
'small-en-q5_1': 182,
'medium-en-q5_0': 515,
'large-q5_0': 1030,
}; };
let url = urls[model]; let url = urls[model];
let dst = 'whisper.bin'; let dst = 'whisper.bin';
let size_mb = sizes[model]; let size_mb = sizes[model];
document.getElementById('fetch-whisper-tiny-en').style.display = 'none'; document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none'; document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-small-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q5_1' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q5_0').style.display = 'none';
document.getElementById('fetch-whisper-large-q5_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... '; document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) { cbProgress = function(p) {
@ -237,18 +190,9 @@
cbCancel = function() { cbCancel = function() {
var el; var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
el = document.getElementById('fetch-whisper-tiny-en-q5_1' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en-q5_1' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-medium-en-q5_0'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-large-q5_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
}; };
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea); loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);

View File

@ -1,6 +1,3 @@
set(TARGET bench) set(TARGET bench)
add_executable(${TARGET} bench.cpp) add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -7,7 +7,6 @@
// command-line parameters // command-line parameters
struct whisper_params { struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
std::string model = "models/ggml-base.en.bin"; std::string model = "models/ggml-base.en.bin";
}; };
@ -24,7 +23,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
} }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else { else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
@ -43,17 +41,19 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -h, --help [default] show this help message and exit\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
int whisper_bench_encoder(const whisper_params & params) { int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
// whisper init // whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init(params.model.c_str());
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -92,22 +92,3 @@ int whisper_bench_encoder(const whisper_params & params) {
return 0; return 0;
} }
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
int ret = -1;
switch (params.what) {
case 0: ret = whisper_bench_encoder(params); break;
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
}
return ret;
}

View File

@ -8,10 +8,7 @@ add_executable(${TARGET}
emscripten.cpp emscripten.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
common
whisper whisper
) )

View File

@ -1,5 +1,4 @@
#include "ggml.h" #include "ggml.h"
#include "common.h"
#include "whisper.h" #include "whisper.h"
#include <emscripten.h> #include <emscripten.h>
@ -28,11 +27,92 @@ std::string g_transcribed = "";
std::vector<float> g_pcmf32; std::vector<float> g_pcmf32;
static std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
// compute similarity between two strings using Levenshtein distance
static float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1;
const size_t len1 = s1.size() + 1;
std::vector<int> col(len1, 0);
std::vector<int> prevCol(len1, 0);
for (size_t i = 0; i < len1; i++) {
prevCol[i] = i;
}
for (size_t i = 0; i < len0; i++) {
col[0] = i;
for (size_t j = 1; j < len1; j++) {
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
}
col.swap(prevCol);
}
const float dist = prevCol[len1 - 1];
return 1.0f - (dist / std::max(s0.size(), s1.size()));
}
void command_set_status(const std::string & status) { void command_set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex); std::lock_guard<std::mutex> lock(g_mutex);
g_status = status; g_status = status;
} }
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (size_t i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -75,7 +155,7 @@ void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
const int64_t n_samples = (ms * sample_rate) / 1000; const int64_t n_samples = (ms * sample_rate) / 1000;
int64_t n_take = 0; int64_t n_take = 0;
if (n_samples > (int) g_pcmf32.size()) { if (g_pcmf32.size() < n_samples) {
n_take = g_pcmf32.size(); n_take = g_pcmf32.size();
} else { } else {
n_take = n_samples; n_take = n_samples;
@ -107,6 +187,7 @@ void command_main(size_t index) {
printf("command: using %d threads\n", wparams.n_threads); printf("command: using %d threads\n", wparams.n_threads);
bool is_running = true;
bool have_prompt = false; bool have_prompt = false;
bool ask_prompt = true; bool ask_prompt = true;
bool print_energy = false; bool print_energy = false;
@ -152,7 +233,7 @@ void command_main(size_t index) {
{ {
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) { if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
command_set_status("Speech detected! Processing ..."); command_set_status("Speech detected! Processing ...");
@ -243,7 +324,7 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) { for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) { if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str()); g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) { if (g_contexts[i] != nullptr) {
g_running = true; g_running = true;
if (g_worker.joinable()) { if (g_worker.joinable()) {

View File

@ -35,15 +35,6 @@
<br><br> <br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr> <hr>
Select the model you would like to use, click the "Start" button and follow the instructions. Select the model you would like to use, click the "Start" button and follow the instructions.
@ -54,10 +45,6 @@
Whisper model: <span id="model-whisper-status"></span> Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button> <button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button> <button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span> <span id="fetch-whisper-progress"></span>
<!-- <!--
@ -175,17 +162,11 @@
let urls = { let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin', 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin', 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
}; };
let sizes = { let sizes = {
'tiny.en': 75, 'tiny.en': 75,
'base.en': 142, 'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
}; };
let url = urls[model]; let url = urls[model];
@ -196,10 +177,6 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none'; document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none'; document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... '; document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) { cbProgress = function(p) {
@ -211,10 +188,6 @@
var el; var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = ''; el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
}; };

View File

@ -1,9 +1,7 @@
if (WHISPER_SDL2) if (WHISPER_SUPPORT_SDL2)
# command # command
set(TARGET command) set(TARGET command)
add_executable(${TARGET} command.cpp) add_executable(${TARGET} command.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
include(DefaultTargetOptions) target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -9,19 +9,7 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance # On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0 ./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Guided mode
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
```bash
# Run in guided mode, the list of allowed commands is in commands.txt # Run in guided mode, the list of allowed commands is in commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt ./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
@ -29,8 +17,9 @@ Initial tests show that this approach might be extremely efficient in terms of p
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0 ./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
``` ```
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4 https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Building ## Building

File diff suppressed because it is too large Load Diff

View File

@ -1,235 +0,0 @@
#include "common-ggml.h"
#include <regex>
#include <map>
static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q4_0", GGML_FTYPE_MOSTLY_Q4_0},
{"q4_1", GGML_FTYPE_MOSTLY_Q4_1},
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
};
void ggml_print_ftypes(FILE * fp) {
for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) {
fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
}
}
enum ggml_ftype ggml_parse_ftype(const char * str) {
enum ggml_ftype ftype;
if (str[0] == 'q') {
const auto it = GGML_FTYPE_MAP.find(str);
if (it == GGML_FTYPE_MAP.end()) {
fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, str);
return GGML_FTYPE_UNKNOWN;
}
ftype = it->second;
} else {
ftype = (enum ggml_ftype) atoi(str);
}
return ftype;
}
bool ggml_common_quantize_0(
std::ifstream & finp,
std::ofstream & fout,
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip) {
ggml_type qtype = GGML_TYPE_F32;
switch (ftype) {
case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break;
case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_UNKNOWN:
case GGML_FTYPE_ALL_F32:
case GGML_FTYPE_MOSTLY_F16:
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
}
};
if (!ggml_is_quantized(qtype)) {
fprintf(stderr, "%s: invalid quantization type %d (%s)\n", __func__, qtype, ggml_type_name(qtype));
return false;
}
size_t total_size_org = 0;
size_t total_size_new = 0;
std::vector<float> work;
std::vector<uint8_t> data_u8;
std::vector<ggml_fp16_t> data_f16;
std::vector<float> data_f32;
std::vector<int64_t> hist_all(1 << 4, 0);
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
finp.read(reinterpret_cast<char *>(&length), sizeof(length));
finp.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (finp.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[4] = { 1, 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
finp.read (&name[0], length);
printf("%64s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype));
bool quantize = false;
// check if we should quantize this tensor
for (const auto & s : to_quant) {
if (std::regex_match(name, std::regex(s))) {
quantize = true;
break;
}
}
// check if we should skip this tensor
for (const auto & s : to_skip) {
if (std::regex_match(name, std::regex(s))) {
quantize = false;
break;
}
}
// quantize only 2D tensors
quantize &= (n_dims == 2);
if (quantize) {
if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) {
fprintf(stderr, "%s: unsupported ttype %d (%s) for integer quantization\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
return false;
}
if (ttype == GGML_TYPE_F16) {
data_f16.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
data_f32.resize(nelements);
for (int i = 0; i < nelements; ++i) {
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
}
} else {
data_f32.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
}
ttype = qtype;
} else {
const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t);
data_u8.resize(nelements*bpe);
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
}
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
fout.write(reinterpret_cast<char *>(&ttype), sizeof(ttype));
for (int i = 0; i < n_dims; ++i) {
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
}
fout.write(&name[0], length);
if (quantize) {
work.resize(nelements); // for quantization
size_t cur_size = 0;
std::vector<int64_t> hist_cur(1 << 4, 0);
switch ((ggml_type) ttype) {
case GGML_TYPE_Q4_0:
{
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_1:
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_0:
{
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_1:
{
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q8_0:
{
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_Q8_1:
case GGML_TYPE_COUNT:
{
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
return false;
}
}
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
total_size_new += cur_size;
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
for (int i = 0; i < (int) hist_cur.size(); ++i) {
hist_all[i] += hist_cur[i];
}
for (int i = 0; i < (int) hist_cur.size(); ++i) {
printf("%5.3f ", hist_cur[i] / (float)nelements);
}
printf("\n");
} else {
printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
total_size_new += data_u8.size();
}
total_size_org += nelements * sizeof(float);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype));
{
int64_t sum_all = 0;
for (int i = 0; i < (int) hist_all.size(); ++i) {
sum_all += hist_all[i];
}
printf("%s: hist: ", __func__);
for (int i = 0; i < (int) hist_all.size(); ++i) {
printf("%5.3f ", hist_all[i] / (float)sum_all);
}
printf("\n");
}
return true;
}

View File

@ -1,18 +0,0 @@
#pragma once
#include "ggml.h"
#include <fstream>
#include <vector>
#include <string>
enum ggml_ftype ggml_parse_ftype(const char * str);
void ggml_print_ftypes(FILE * fp = stderr);
bool ggml_common_quantize_0(
std::ifstream & finp,
std::ofstream & fout,
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip);

View File

@ -1,226 +0,0 @@
#include "common-sdl.h"
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
bool sdl_poll_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
return false;
} break;
default:
break;
}
}
return true;
}

View File

@ -1,50 +0,0 @@
#pragma once
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cstdint>
#include <vector>
#include <mutex>
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
// Return false if need to quit
bool sdl_poll_events();

View File

@ -1,539 +0,0 @@
#include "common.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath>
#include <fstream>
#include <regex>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i];
} else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") {
params.top_k = std::stoi(argv[++i]);
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
params.temp = std::stof(argv[++i]);
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
} else if (arg == "-f" || arg == "--file") {
if (++i > argc) {
fprintf(stderr, "Invalid file param");
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
break;
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " load prompt from a file\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
}
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
case 0: return "So";
case 1: return "Once upon a time";
case 2: return "When";
case 3: return "The";
case 4: return "After";
case 5: return "If";
case 6: return "import";
case 7: return "He";
case 8: return "She";
case 9: return "They";
default: return "To";
}
return "The";
}
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
std::map<std::string, int32_t> json_parse(const std::string & fname) {
std::map<std::string, int32_t> result;
// read file into string
std::string json;
{
std::ifstream ifs(fname);
if (!ifs) {
fprintf(stderr, "Failed to open %s\n", fname.c_str());
exit(1);
}
json = std::string((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
}
if (json[0] != '{') {
return result;
}
// parse json
{
bool has_key = false;
bool in_token = false;
std::string str_key = "";
std::string str_val = "";
int n = json.size();
for (int i = 1; i < n; ++i) {
if (!in_token) {
if (json[i] == ' ') continue;
if (json[i] == '"') {
in_token = true;
continue;
}
} else {
if (json[i] == '\\' && i+1 < n) {
if (has_key == false) {
str_key += json[i];
} else {
str_val += json[i];
}
++i;
} else if (json[i] == '"') {
if (has_key == false) {
has_key = true;
++i;
while (json[i] == ' ') ++i;
++i; // :
while (json[i] == ' ') ++i;
if (json[i] != '\"') {
while (json[i] != ',' && json[i] != '}') {
str_val += json[i++];
}
has_key = false;
} else {
in_token = true;
continue;
}
} else {
has_key = false;
}
str_key = ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
str_key = ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
str_key = ::replace(str_key, "\\\"", "\""); // \\\" -> "
try {
result[str_key] = std::stoi(str_val);
} catch (...) {
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
}
str_key = "";
str_val = "";
in_token = false;
continue;
}
if (has_key == false) {
str_key += json[i];
} else {
str_val += json[i];
}
}
}
}
return result;
}
void gpt_vocab::add_special_token(const std::string & token) {
special_tokens.push_back(token);
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
std::string special_tokens_subpattern;
for (const auto & token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += token;
}
// Modify the regex pattern with the generated special tokens subpattern
pat = special_tokens_subpattern + "|" + pat;
}
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
vocab.token_to_id = ::json_parse(fname);
for (const auto & kv : vocab.token_to_id) {
vocab.id_to_token[kv.second] = kv.first;
}
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
// print the vocabulary
//for (auto kv : vocab.token_to_id) {
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
//}
return true;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return false;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
return false;
}
if (stereo && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
return false;
}
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
return false;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
return false;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (stereo) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
return true;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1;
const size_t len1 = s1.size() + 1;
std::vector<int> col(len1, 0);
std::vector<int> prevCol(len1, 0);
for (size_t i = 0; i < len1; i++) {
prevCol[i] = i;
}
for (size_t i = 0; i < len0; i++) {
col[0] = i;
for (size_t j = 1; j < len1; j++) {
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (i > 0 && s0[i - 1] == s1[j - 1] ? 0 : 1));
}
col.swap(prevCol);
}
const float dist = prevCol[len1 - 1];
return 1.0f - (dist / std::max(s0.size(), s1.size()));
}

View File

@ -1,125 +0,0 @@
// Various helper functions and utilities
#pragma once
#include <string>
#include <map>
#include <vector>
#include <random>
#include <thread>
#define COMMON_SAMPLE_RATE 16000
//
// CLI argument parsing
//
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict
// sampling parameters
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t n_batch = 8; // batch size for prompt processing
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt;
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
std::string gpt_random_prompt(std::mt19937 & rng);
//
// Vocab utils
//
std::string trim(const std::string & s);
std::string replace(
const std::string & s,
const std::string & from,
const std::string & to);
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string & token);
};
// poor-man's JSON parsing
std::map<std::string, int32_t> json_parse(const std::string & fname);
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
//
// Regex (Python):
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
//
// Regex (C++):
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
// load the tokens from encoder.json
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// sample next token given probabilities for each embedding
//
// - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P
//
// TODO: not sure if this implementation is correct
// TODO: temperature is not implemented
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng);
//
// Audio utils
//
// Read WAV audio file and store the PCM data into pcmf32
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
bool read_wav(
const std::string & fname,
std::vector<float> & pcmf32,
std::vector<std::vector<float>> & pcmf32s,
bool stereo);
// Apply a high-pass frequency filter to PCM audio
// Suppresses frequencies below cutoff Hz
void high_pass_filter(
std::vector<float> & data,
float cutoff,
float sample_rate);
// Basic voice activity detection (VAD) using audio energy adaptive threshold
bool vad_simple(
std::vector<float> & pcmf32,
int sample_rate,
int last_ms,
float vad_thold,
float freq_thold,
bool verbose);
// compute similarity between two strings using Levenshtein distance
float similarity(const std::string & s0, const std::string & s1);

View File

@ -8,7 +8,7 @@ function convertTypedArray(src, type) {
var printTextarea = (function() { var printTextarea = (function() {
var element = document.getElementById('output'); var element = document.getElementById('output');
if (element) element.value = ''; // clear browser cache if (element) element.alue = ''; // clear browser cache
return function(text) { return function(text) {
if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' '); if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(' ');
console.log(text); console.log(text);
@ -88,15 +88,11 @@ async function fetchRemote(url, cbProgress, cbPrint) {
// - check if the data is already in the IndexedDB // - check if the data is already in the IndexedDB
// - if not, fetch it from the remote URL and store it in the IndexedDB // - if not, fetch it from the remote URL and store it in the IndexedDB
function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) { function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
if (!navigator.storage || !navigator.storage.estimate) { // query the storage quota and print it
cbPrint('loadRemote: navigator.storage.estimate() is not supported'); navigator.storage.estimate().then(function (estimate) {
} else { cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
// query the storage quota and print it cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
navigator.storage.estimate().then(function (estimate) { });
cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes');
cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes');
});
}
// check if the data is already in the IndexedDB // check if the data is already in the IndexedDB
var rq = indexedDB.open(dbName, dbVersion); var rq = indexedDB.open(dbName, dbVersion);
@ -145,15 +141,7 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
var db = event.target.result; var db = event.target.result;
var tx = db.transaction(['models'], 'readwrite'); var tx = db.transaction(['models'], 'readwrite');
var os = tx.objectStore('models'); var os = tx.objectStore('models');
var rq = os.put(data, url);
var rq = null;
try {
var rq = os.put(data, url);
} catch (e) {
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e);
cbCancel();
return;
}
rq.onsuccess = function (event) { rq.onsuccess = function (event) {
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB'); cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
@ -188,6 +176,7 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
rq.onabort = function (event) { rq.onabort = function (event) {
cbPrint('loadRemote: failed to open IndexedDB: abort'); cbPrint('loadRemote: failed to open IndexedDB: abort');
cbCancel();
}; };
} }

View File

@ -100,7 +100,7 @@ while [ $running -eq 1 ]; do
err=$(cat /tmp/whisper-live.err | wc -l) err=$(cat /tmp/whisper-live.err | wc -l)
done done
./main -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1 ./main -t 8 -m ./models/ggml-base.en.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do
sleep 1 sleep 1

View File

@ -1,6 +1,3 @@
set(TARGET main) set(TARGET main)
add_executable(${TARGET} main.cpp) add_executable(${TARGET} main.cpp)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -9,36 +9,25 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
usage: ./main [options] file0.wav file1.wav ... usage: ./main [options] file0.wav file1.wav ...
options: options:
-h, --help [default] show this help message and exit -h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation -t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation -p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds -ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset -on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds -d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store -mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters -ml N, --max-len N [0 ] maximum segment length in characters
-bo N, --best-of N [5 ] number of best candidates to keep -wt N, --word-thold N [0.01 ] word timestamp probability threshold
-bs N, --beam-size N [-1 ] beam size for beam search -su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-wt N, --word-thold N [0.01 ] word timestamp probability threshold -tr, --translate [false ] translate from source language to english
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail -otxt, --output-txt [false ] output result in a text file
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -ovtt, --output-vtt [false ] output result in a vtt file
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy) -osrt, --output-srt [false ] output result in a srt file
-tr, --translate [false ] translate from source language to english -owts, --output-words [false ] output script for generating karaoke video
-di, --diarize [false ] stereo audio diarization -ps, --print-special [false ] print special tokens
-nf, --no-fallback [false ] do not use temperature fallback while decoding -pc, --print-colors [false ] print colors
-otxt, --output-txt [false ] output result in a text file -nt, --no-timestamps [true ] do not print timestamps
-ovtt, --output-vtt [false ] output result in a vtt file -l LANG, --language LANG [en ] spoken language
-osrt, --output-srt [false ] output result in a srt file -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-owts, --output-words [false ] output script for generating karaoke video -f FNAME, --file FNAME [ ] input WAV file path
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
``` ```

View File

@ -1,14 +1,16 @@
#include "common.h"
#include "whisper.h" #include "whisper.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <cstring>
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green. // Lowest is red, middle is yellow, highest is green.
@ -51,32 +53,22 @@ void replace_all(std::string & s, const std::string & search, const std::string
// command-line parameters // command-line parameters
struct whisper_params { struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1; int32_t n_processors = 1;
int32_t offset_t_ms = 0; int32_t offset_t_ms = 0;
int32_t offset_n = 0; int32_t offset_n = 0;
int32_t duration_ms = 0; int32_t duration_ms = 0;
int32_t max_context = -1; int32_t max_context = -1;
int32_t max_len = 0; int32_t max_len = 0;
int32_t best_of = 2;
int32_t beam_size = -1;
float word_thold = 0.01f; float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool detect_language= false;
bool diarize = false; bool diarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false; bool output_txt = false;
bool output_vtt = false; bool output_vtt = false;
bool output_srt = false; bool output_srt = false;
bool output_wts = false; bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool output_lrc = false;
bool print_special = false; bool print_special = false;
bool print_colors = false; bool print_colors = false;
bool print_progress = false; bool print_progress = false;
@ -84,11 +76,9 @@ struct whisper_params {
std::string language = "en"; std::string language = "en";
std::string prompt; std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin"; std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {}; std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
}; };
void whisper_print_usage(int argc, char ** argv, const whisper_params & params); void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -97,11 +87,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
std::string arg = argv[i]; std::string arg = argv[i];
if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-') { if (arg[0] != '-') {
params.fname_inp.push_back(arg); params.fname_inp.push_back(arg);
continue; continue;
@ -118,31 +103,19 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; } else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
@ -161,42 +134,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false"); fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -206,87 +167,96 @@ struct whisper_print_user_data {
const std::vector<std::vector<float>> * pcmf32s; const std::vector<std::vector<float>> * pcmf32s;
}; };
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0 = 0;
int64_t t1 = 0;
// print the last n_new segments // print the last n_new segments
const int s0 = n_segments - n_new; const int s0 = n_segments - n_new;
if (s0 == 0) { if (s0 == 0) {
printf("\n"); printf("\n");
} }
for (int i = s0; i < n_segments; i++) { for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) { if (params.no_timestamps) {
t0 = whisper_full_get_segment_t0(ctx, i); if (params.print_colors) {
t1 = whisper_full_get_segment_t1(ctx, i); for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
} if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (!params.no_timestamps) { if (id >= whisper_token_eot(ctx)) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); continue;
} }
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
} }
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker;
if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
} }
const char * text = whisper_full_get_token_text(ctx, i, j); if (energy0 > 1.1*energy1) {
const float p = whisper_full_get_token_p (ctx, i, j); speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
} }
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text); if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
} }
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
} }
} }
@ -355,201 +325,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true; return true;
} }
char *escape_double_quotes_and_backslashes(const char *str) {
if (str == NULL) {
return NULL;
}
size_t escaped_length = strlen(str) + 1;
for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"' || str[i] == '\\') {
escaped_length++;
}
}
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
if (escaped == NULL) {
return NULL;
}
size_t pos = 0;
for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"' || str[i] == '\\') {
escaped[pos++] = '\\';
}
escaped[pos++] = str[i];
}
// no need to set zero due to calloc() being used prior
return escaped;
}
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,text\n";
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
char * text_escaped = escape_double_quotes_and_backslashes(text);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
}
return true;
}
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
std::ofstream fout(fname);
int indent = 0;
auto doindent = [&]() {
for (int i = 0; i < indent; i++) fout << "\t";
};
auto start_arr = [&](const char *name) {
doindent();
fout << "\"" << name << "\": [\n";
indent++;
};
auto end_arr = [&](bool end = false) {
indent--;
doindent();
fout << (end ? "]\n" : "},\n");
};
auto start_obj = [&](const char *name = nullptr) {
doindent();
if (name) {
fout << "\"" << name << "\": {\n";
} else {
fout << "{\n";
}
indent++;
};
auto end_obj = [&](bool end = false) {
indent--;
doindent();
fout << (end ? "}\n" : "},\n");
};
auto start_value = [&](const char *name) {
doindent();
fout << "\"" << name << "\": ";
};
auto value_s = [&](const char *name, const char *val, bool end = false) {
start_value(name);
char * val_escaped = escape_double_quotes_and_backslashes(val);
fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
free(val_escaped);
};
auto end_value = [&](bool end = false) {
fout << (end ? "\n" : ",\n");
};
auto value_i = [&](const char *name, const int64_t val, bool end = false) {
start_value(name);
fout << val;
end_value(end);
};
auto value_b = [&](const char *name, const bool val, bool end = false) {
start_value(name);
fout << (val ? "true" : "false");
end_value(end);
};
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj();
value_s("systeminfo", whisper_print_system_info());
start_obj("model");
value_s("type", whisper_model_type_readable(ctx));
value_b("multilingual", whisper_is_multilingual(ctx));
value_i("vocab", whisper_model_n_vocab(ctx));
start_obj("audio");
value_i("ctx", whisper_model_n_audio_ctx(ctx));
value_i("state", whisper_model_n_audio_state(ctx));
value_i("head", whisper_model_n_audio_head(ctx));
value_i("layer", whisper_model_n_audio_layer(ctx), true);
end_obj();
start_obj("text");
value_i("ctx", whisper_model_n_text_ctx(ctx));
value_i("state", whisper_model_n_text_state(ctx));
value_i("head", whisper_model_n_text_head(ctx));
value_i("layer", whisper_model_n_text_layer(ctx), true);
end_obj();
value_i("mels", whisper_model_n_mels(ctx));
value_i("ftype", whisper_model_ftype(ctx), true);
end_obj();
start_obj("params");
value_s("model", params.model.c_str());
value_s("language", params.language.c_str());
value_b("translate", params.translate, true);
end_obj();
start_obj("result");
value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
end_obj();
start_arr("transcription");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
start_obj();
start_obj("timestamps");
value_s("from", to_timestamp(t0, true).c_str());
value_s("to", to_timestamp(t1, true).c_str(), true);
end_obj();
start_obj("offsets");
value_i("from", t0 * 10);
value_i("to", t1 * 10, true);
end_obj();
value_s("text", text, true);
end_obj(i == (n_segments - 1));
}
end_arr(true);
end_obj(true);
return true;
}
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
std::ofstream fout(fname); std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
static const char * font = params.font_path.c_str(); // TODO: become parameter
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::ifstream fin(font);
if (!fin.is_open()) {
fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font);
return false;
}
fout << "#!/bin/bash" << "\n"; fout << "#!/bin/bash" << "\n";
fout << "\n"; fout << "\n";
@ -652,39 +437,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
return true; return true;
} }
bool output_lrc(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "[by:whisper.cpp]\n";
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t = whisper_full_get_segment_t0(ctx, i);
int64_t msec = t * 10;
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[16];
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
std::string timestamp_lrc = std::string(buf);
fout << '[' << timestamp_lrc << ']' << text << "\n";
}
return true;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
whisper_params params; whisper_params params;
@ -706,23 +458,115 @@ int main(int argc, char ** argv) {
// whisper init // whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init(params.model.c_str());
if (ctx == nullptr) { if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n"); fprintf(stderr, "error: failed to initialize whisper context\n");
return 3; return 3;
} }
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { // WAV input
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); {
continue; drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 9;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
} }
// print system information // print system information
@ -742,9 +586,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
} }
} }
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors, params.n_threads, params.n_processors,
@ -759,15 +600,12 @@ int main(int argc, char ** argv) {
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_progress = params.print_progress; wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special; wparams.print_special = params.print_special;
wparams.translate = params.translate; wparams.translate = params.translate;
wparams.language = params.language.c_str(); wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms; wparams.offset_ms = params.offset_t_ms;
@ -776,18 +614,11 @@ int main(int argc, char ** argv) {
wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold; wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
wparams.initial_prompt = params.prompt.c_str(); wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { &params, &pcmf32s }; whisper_print_user_data user_data = { &params, &pcmf32s };
@ -803,7 +634,7 @@ int main(int argc, char ** argv) {
{ {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
return !is_aborted; return !is_aborted;
}; };
@ -822,45 +653,27 @@ int main(int argc, char ** argv) {
// output to text file // output to text file
if (params.output_txt) { if (params.output_txt) {
const auto fname_txt = fname_out + ".txt"; const auto fname_txt = fname_inp + ".txt";
output_txt(ctx, fname_txt.c_str()); output_txt(ctx, fname_txt.c_str());
} }
// output to VTT file // output to VTT file
if (params.output_vtt) { if (params.output_vtt) {
const auto fname_vtt = fname_out + ".vtt"; const auto fname_vtt = fname_inp + ".vtt";
output_vtt(ctx, fname_vtt.c_str()); output_vtt(ctx, fname_vtt.c_str());
} }
// output to SRT file // output to SRT file
if (params.output_srt) { if (params.output_srt) {
const auto fname_srt = fname_out + ".srt"; const auto fname_srt = fname_inp + ".srt";
output_srt(ctx, fname_srt.c_str(), params); output_srt(ctx, fname_srt.c_str(), params);
} }
// output to WTS file // output to WTS file
if (params.output_wts) { if (params.output_wts) {
const auto fname_wts = fname_out + ".wts"; const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
} }
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str());
}
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str());
}
} }
} }

View File

@ -1,6 +0,0 @@
set(TARGET quantize)
add_executable(${TARGET} quantize.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,3 +0,0 @@
# quantize
Tool for integer quantization of Whisper `ggml` model files

View File

@ -1,221 +0,0 @@
#include "ggml.h"
#include "common.h"
#include "common-ggml.h"
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <regex>
// default hparams (Whisper tiny)
struct whisper_hparams {
int32_t n_vocab = 51864;
int32_t n_audio_ctx = 1500;
int32_t n_audio_state = 384;
int32_t n_audio_head = 6;
int32_t n_audio_layer = 4;
int32_t n_text_ctx = 448;
int32_t n_text_state = 384;
int32_t n_text_head = 6;
int32_t n_text_layer = 4;
int32_t n_mels = 80;
int32_t ftype = 1;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
// quantize a model
bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
gpt_vocab vocab;
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
auto finp = std::ifstream(fname_inp, std::ios::binary);
if (!finp) {
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
return false;
}
auto fout = std::ofstream(fname_out, std::ios::binary);
if (!fout) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
return false;
}
// verify magic
{
uint32_t magic;
finp.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
return false;
}
fout.write((char *) &magic, sizeof(magic));
}
whisper_hparams hparams;
// load hparams
{
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
finp.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
finp.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
finp.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
finp.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype);
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fout.write((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fout.write((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fout.write((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fout.write((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fout.write((char *) &ftype_dst, sizeof(hparams.ftype));
}
// load mel filters
{
whisper_filters filters;
finp.read ((char *) &filters.n_mel, sizeof(filters.n_mel));
fout.write((char *) &filters.n_mel, sizeof(filters.n_mel));
finp.read ((char *) &filters.n_fft, sizeof(filters.n_fft));
fout.write((char *) &filters.n_fft, sizeof(filters.n_fft));
filters.data.resize(filters.n_mel * filters.n_fft);
finp.read ((char *) filters.data.data(), filters.data.size() * sizeof(float));
fout.write((char *) filters.data.data(), filters.data.size() * sizeof(float));
}
// load vocab
{
int32_t n_vocab = 0;
finp.read ((char *) &n_vocab, sizeof(n_vocab));
fout.write((char *) &n_vocab, sizeof(n_vocab));
//if (n_vocab != hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
// __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
// return false;
//}
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
finp.read ((char *) &len, sizeof(len));
fout.write((char *) &len, sizeof(len));
word.resize(len);
finp.read ((char *) word.data(), len);
fout.write((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// regexes of tensor names to not be quantized
const std::vector<std::string> to_skip = {
//"encoder.*",
"encoder.conv1.bias",
"encoder.conv2.bias",
"encoder.positional_embedding",
"decoder.positional_embedding",
};
if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) {
fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
return false;
}
finp.close();
fout.close();
return true;
}
int main(int argc, char ** argv) {
if (argc != 4) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
ggml_print_ftypes(stderr);
return 1;
}
// needed to initialize f16 tables
{
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
const std::string fname_inp = argv[1];
const std::string fname_out = argv[2];
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
const int64_t t_main_start_us = ggml_time_us();
int64_t t_quantize_us = 0;
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}
t_quantize_us = ggml_time_us() - t_start_us;
}
// report timing
{
const int64_t t_main_end_us = ggml_time_us();
printf("\n");
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
}
return 0;
}

View File

@ -8,8 +8,6 @@ add_executable(${TARGET}
emscripten.cpp emscripten.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
) )

View File

@ -49,9 +49,6 @@ void stream_main(size_t index) {
wparams.max_tokens = 32; wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance wparams.audio_ctx = 768; // partial encoder context for better performance
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.language = "en"; wparams.language = "en";
printf("stream: using %d threads\n", wparams.n_threads); printf("stream: using %d threads\n", wparams.n_threads);
@ -132,7 +129,7 @@ EMSCRIPTEN_BINDINGS(stream) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) { for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) { if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str()); g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) { if (g_contexts[i] != nullptr) {
g_running = true; g_running = true;
if (g_worker.joinable()) { if (g_worker.joinable()) {

View File

@ -35,15 +35,6 @@
<br><br> <br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr> <hr>
Select the model you would like to use, click the "Start" button and start speaking Select the model you would like to use, click the "Start" button and start speaking
@ -54,10 +45,6 @@
Whisper model: <span id="model-whisper-status"></span> Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button> <button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button> <button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span> <span id="fetch-whisper-progress"></span>
<!-- <!--
@ -175,17 +162,11 @@
let urls = { let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin', 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin', 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
}; };
let sizes = { let sizes = {
'tiny.en': 75, 'tiny.en': 75,
'base.en': 142, 'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
}; };
let url = urls[model]; let url = urls[model];
@ -196,10 +177,6 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none'; document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none'; document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... '; document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) { cbProgress = function(p) {
@ -211,10 +188,6 @@
var el; var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = ''; el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
}; };

View File

@ -1,9 +1,7 @@
if (WHISPER_SDL2) if (WHISPER_SUPPORT_SDL2)
# stream # stream
set(TARGET stream) set(TARGET stream)
add_executable(${TARGET} stream.cpp) add_executable(${TARGET} stream.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
include(DefaultTargetOptions) target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -3,16 +3,18 @@
// A very quick-n-dirty implementation serving mainly as a proof of concept. // A very quick-n-dirty implementation serving mainly as a proof of concept.
// //
#include "common.h"
#include "common-sdl.h"
#include "whisper.h" #include "whisper.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <fstream> #include <fstream>
#include <mutex>
// 500 -> 00:05.000 // 500 -> 00:05.000
// 6000 -> 01:00.000 // 6000 -> 01:00.000
@ -43,7 +45,6 @@ struct whisper_params {
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool no_fallback = false;
bool print_special = false; bool print_special = false;
bool no_context = true; bool no_context = true;
bool no_timestamps = false; bool no_timestamps = false;
@ -74,7 +75,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@ -96,26 +96,323 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms); fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms); fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms); fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
whisper_params params; whisper_params params;
@ -123,21 +420,20 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
params.keep_ms = std::min(params.keep_ms, params.step_ms); params.keep_ms = std::min(params.keep_ms, params.step_ms); // cannot be more than step_ms
params.length_ms = std::max(params.length_ms, params.step_ms);
const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE; const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE; const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE; const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE; const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line params.no_timestamps = !use_vad;
params.no_context = use_vad;
params.no_timestamps = !use_vad; params.max_tokens = 0;
params.no_context |= use_vad;
params.max_tokens = 0;
// init audio // init audio
@ -151,16 +447,16 @@ int main(int argc, char ** argv) {
// whisper init // whisper init
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){ if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
exit(0); exit(0);
} }
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init(params.model.c_str());
std::vector<float> pcmf32 (n_samples_30s, 0.0f); std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old; std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
std::vector<float> pcmf32_new(n_samples_30s, 0.0f); std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
std::vector<whisper_token> prompt_tokens; std::vector<whisper_token> prompt_tokens;
@ -187,7 +483,7 @@ int main(int argc, char ** argv) {
params.no_timestamps ? 0 : 1); params.no_timestamps ? 0 : 1);
if (!use_vad) { if (!use_vad) {
fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context); fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
} else { } else {
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__); fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
} }
@ -217,7 +513,23 @@ int main(int argc, char ** argv) {
// main audio loop // main audio loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = sdl_poll_events(); {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
if (!is_running) { if (!is_running) {
break; break;
@ -240,7 +552,7 @@ int main(int argc, char ** argv) {
break; break;
} }
std::this_thread::sleep_for(std::chrono::milliseconds(1)); SDL_Delay(1);
} }
const int n_samples_new = pcmf32_new.size(); const int n_samples_new = pcmf32_new.size();
@ -271,7 +583,7 @@ int main(int argc, char ** argv) {
audio.get(2000, pcmf32_new); audio.get(2000, pcmf32_new);
if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) { if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
audio.get(params.length_ms, pcmf32); audio.get(params.length_ms, pcmf32);
} else { } else {
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
@ -291,6 +603,7 @@ int main(int argc, char ** argv) {
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate; wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = !use_vad; wparams.single_segment = !use_vad;
wparams.max_tokens = params.max_tokens; wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str(); wparams.language = params.language.c_str();
@ -299,10 +612,6 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
// disable temperature fallback
//wparams.temperature_inc = -1.0f;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
@ -383,7 +692,6 @@ int main(int argc, char ** argv) {
} }
} }
} }
fflush(stdout);
} }
} }

View File

@ -1 +0,0 @@
audio.mp3

View File

@ -1,16 +0,0 @@
if (WHISPER_SDL2)
# talk-llama
set(TARGET talk-llama)
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions)
endif ()

View File

@ -1,50 +0,0 @@
# talk-llama
Talk with an LLaMA AI in your terminal
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
## Building
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2 on Linux
sudo apt-get install libsdl2-dev
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk-llama" executable
make talk-llama
# Run it
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
## Session
The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.
Example usage:
```bash
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
## Discussion
If you have any feedback, please let "us" know in the following discussion: https://github.com/ggerganov/whisper.cpp/discussions/672?converting=1

View File

@ -1,23 +0,0 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk-llama/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
# Get a Voice object, by name or UUID
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = voice.generate(str(sys.argv[2:]))
# Save the TTS to a file
audio.save("audio")

View File

@ -1,433 +0,0 @@
// Internal header to be included only by llama.cpp.
// Contains wrappers around OS interfaces.
#ifndef LLAMA_UTIL_H
#define LLAMA_UTIL_H
#include <cstdio>
#include <cstdint>
#include <cerrno>
#include <cstring>
#include <cstdarg>
#include <cstdlib>
#include <climits>
#include <string>
#include <vector>
#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#if defined(_POSIX_MAPPED_FILES)
#include <sys/mman.h>
#endif
#if defined(_POSIX_MEMLOCK_RANGE)
#include <sys/resource.h>
#endif
#endif
#endif
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
#endif
#define LLAMA_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#ifdef __GNUC__
#ifdef __MINGW32__
__attribute__((format(gnu_printf, 1, 2)))
#else
__attribute__((format(printf, 1, 2)))
#endif
#endif
static std::string format(const char * fmt, ...) {
va_list ap, ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
LLAMA_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}
struct llama_file {
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
size_t size;
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw format("failed to open %s: %s", fname, std::strerror(errno));
}
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
size_t tell() const {
#ifdef _WIN32
__int64 ret = _ftelli64(fp);
#else
long ret = std::ftell(fp);
#endif
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
return (size_t) ret;
}
void seek(size_t offset, int whence) {
#ifdef _WIN32
int ret = _fseeki64(fp, (__int64) offset, whence);
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
LLAMA_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, size, 1, fp);
if (ferror(fp)) {
throw format("read error: %s", strerror(errno));
}
if (ret != 1) {
throw std::string("unexpectedly reached end of file");
}
}
std::uint32_t read_u32() {
std::uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
std::string read_string(std::uint32_t len) {
std::vector<char> chars(len);
read_raw(chars.data(), len);
return std::string(chars.data(), len);
}
void write_raw(const void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, size, 1, fp);
if (ret != 1) {
throw format("write error: %s", strerror(errno));
}
}
void write_u32(std::uint32_t val) {
write_raw(&val, sizeof(val));
}
~llama_file() {
if (fp) {
std::fclose(fp);
}
}
};
#if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) {
LPSTR buf;
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
if (!size) {
return "FormatMessageA failed";
}
std::string ret(buf, size);
LocalFree(buf);
return ret;
}
#endif
struct llama_mmap {
void * addr;
size_t size;
llama_mmap(const llama_mmap &) = delete;
#ifdef _POSIX_MAPPED_FILES
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, bool prefetch = true) {
size = file->size;
int fd = fileno(file->fp);
int flags = MAP_SHARED;
#ifdef __linux__
flags |= MAP_POPULATE;
#endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) {
throw format("mmap failed: %s", strerror(errno));
}
if (prefetch) {
// Advise the kernel to preload the mapped memory
if (madvise(addr, file->size, MADV_WILLNEED)) {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
}
~llama_mmap() {
munmap(addr, size);
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, bool prefetch = true) {
size = file->size;
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
DWORD error = GetLastError();
if (hMapping == NULL) {
throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str());
}
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
error = GetLastError();
CloseHandle(hMapping);
if (addr == NULL) {
throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str());
}
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
if (prefetch) {
// Advise the kernel to preload the mapped memory
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = (SIZE_T)size;
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
}
~llama_mmap() {
if (!UnmapViewOfFile(addr)) {
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
llama_mmap(struct llama_file *) {
throw std::string("mmap not supported");
}
#endif
};
// Represents some region of memory being locked using mlock or VirtualLock;
// will automatically unlock on destruction.
struct llama_mlock {
void * addr = NULL;
size_t size = 0;
bool failed_already = false;
llama_mlock() {}
llama_mlock(const llama_mlock &) = delete;
~llama_mlock() {
if (size) {
raw_unlock(addr, size);
}
}
void init(void * addr) {
LLAMA_ASSERT(this->addr == NULL && this->size == 0);
this->addr = addr;
}
void grow_to(size_t target_size) {
LLAMA_ASSERT(addr);
if (failed_already) {
return;
}
size_t granularity = lock_granularity();
target_size = (target_size + granularity - 1) & ~(granularity - 1);
if (target_size > size) {
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
size = target_size;
} else {
failed_already = true;
}
}
}
#ifdef _POSIX_MEMLOCK_RANGE
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
return (size_t) sysconf(_SC_PAGESIZE);
}
#ifdef __APPLE__
#define MLOCK_SUGGESTION \
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
#else
#define MLOCK_SUGGESTION \
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
#endif
bool raw_lock(const void * addr, size_t size) {
if (!mlock(addr, size)) {
return true;
} else {
char* errmsg = std::strerror(errno);
bool suggest = (errno == ENOMEM);
// Check if the resource limit is fine after all
struct rlimit lock_limit;
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit))
suggest = false;
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size))
suggest = false;
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
return false;
}
}
#undef MLOCK_SUGGESTION
void raw_unlock(void * addr, size_t size) {
if (munlock(addr, size)) {
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
}
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
SYSTEM_INFO si;
GetSystemInfo(&si);
return (size_t) si.dwPageSize;
}
bool raw_lock(void * addr, size_t size) {
for (int tries = 1; ; tries++) {
if (VirtualLock(addr, size)) {
return true;
}
if (tries == 2) {
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
size, this->size, llama_format_win_err(GetLastError()).c_str());
return false;
}
// It failed but this was only the first try; increase the working
// set size and try again.
SIZE_T min_ws_size, max_ws_size;
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
// Per MSDN: "The maximum number of pages that a process can lock
// is equal to the number of pages in its minimum working set minus
// a small overhead."
// Hopefully a megabyte is enough overhead:
size_t increment = size + 1048576;
// The minimum must be <= the maximum, so we need to increase both:
min_ws_size += increment;
max_ws_size += increment;
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
}
}
void raw_unlock(void * addr, size_t size) {
if (!VirtualUnlock(addr, size)) {
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
void raw_lock(const void * addr, size_t size) {
fprintf(stderr, "warning: mlock not supported on this system\n");
}
void raw_unlock(const void * addr, size_t size) {}
#endif
};
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
struct llama_buffer {
uint8_t * addr = NULL;
size_t size = 0;
void resize(size_t size) {
delete[] addr;
addr = new uint8_t[size];
this->size = size;
}
~llama_buffer() {
delete[] addr;
}
};
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
struct llama_ctx_buffer {
uint8_t * addr = NULL;
size_t size = 0;
void resize(size_t size) {
if (addr) {
ggml_cuda_host_free(addr);
}
addr = (uint8_t *) ggml_cuda_host_malloc(size);
this->size = size;
}
~llama_ctx_buffer() {
if (addr) {
ggml_cuda_host_free(addr);
}
}
};
#else
typedef llama_buffer llama_ctx_buffer;
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@ -1,261 +0,0 @@
#ifndef LLAMA_H
#define LLAMA_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef LLAMA_BUILD
# define LLAMA_API __declspec(dllexport)
# else
# define LLAMA_API __declspec(dllimport)
# endif
# else
# define LLAMA_API __attribute__ ((visibility ("default")))
# endif
#else
# define LLAMA_API
#endif
#define LLAMA_FILE_VERSION 2
#define LLAMA_FILE_MAGIC 'ggjt'
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
#define LLAMA_SESSION_MAGIC 'ggsn'
#define LLAMA_SESSION_VERSION 1
#ifdef __cplusplus
extern "C" {
#endif
//
// C interface
//
// TODO: show sample usage
//
struct llama_context;
typedef int llama_token;
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
} llama_token_data;
typedef struct llama_token_data_array {
llama_token_data * data;
size_t size;
bool sorted;
} llama_token_data_array;
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
int n_parts; // -1 for default
int n_gpu_layers; // number of layers to store in VRAM
int seed; // RNG seed, -1 for random
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
};
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
};
LLAMA_API struct llama_context_params llama_context_default_params();
LLAMA_API bool llama_mmap_supported();
LLAMA_API bool llama_mlock_supported();
// Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
LLAMA_API struct llama_context * llama_init_from_file(
const char * path_model,
struct llama_context_params params);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
// TODO: not great API - very likely to change
// Returns 0 on success
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
LLAMA_API int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
enum llama_ftype ftype,
int nthread);
// Apply a LoRA adapter to a loaded model
// path_base_model is the path to a higher quality model to use as a base for
// the layers modified by the adapter. Can be NULL to use the current loaded model.
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
// will be applied on top of the previous one
// Returns 0 on success
LLAMA_API int llama_apply_lora_from_file(
struct llama_context * ctx,
const char * path_lora,
const char * path_base_model,
int n_threads);
// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
// Set the state reading from the specified address
// Returns the number of bytes read
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
// Save/load session file
LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
// Returns 0 on success
LLAMA_API int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token
// Rows: n_tokens
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();
LLAMA_API llama_token llama_token_nl();
// Sampling functions
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
/// @details Selects the token with the highest probability.
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Randomly selects a token from the candidates based on their probabilities.
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
// Print system information
LLAMA_API const char * llama_print_system_info(void);
#ifdef __cplusplus
}
#endif
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
#ifdef LLAMA_API_INTERNAL
#include <vector>
#include <string>
struct ggml_tensor;
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
#endif
#endif // LLAMA_H

View File

@ -1,23 +0,0 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
### Response:
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4}

View File

@ -1,21 +0,0 @@
#!/bin/bash
# Usage:
# speak.sh <voice_id> <text-to-speak>
# espeak
# Mac OS: brew install espeak
# Linux: apt-get install espeak
#
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
# for Mac
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk-llama/eleven-labs.py
#
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2" >/dev/null 2>&1
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1

View File

@ -1,673 +0,0 @@
// Talk with AI
//
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "llama.h"
#include <cassert>
#include <cstdio>
#include <fstream>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
std::vector<llama_token> res(text.size() + (int)add_bos);
int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
assert(n >= 0);
res.resize(n);
return res;
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t voice_ms = 10000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
int32_t n_parts_llama = -1;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool verbose_prompt = false;
std::string person = "Georgi";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk-llama/speak.sh";
std::string prompt = "";
std::string fname_out;
std::string path_session = ""; // path to file for saving/loading model eval state
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i];}
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
}
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
std::string transcribe(
whisper_context * ctx,
const whisper_params & params,
const std::vector<float> & pcmf32,
const std::string prompt_text,
float & prob,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
std::vector<whisper_token> prompt_tokens;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4})";
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// llama init
auto lparams = llama_context_default_params();
// tune these to your liking
lparams.n_ctx = 2048;
lparams.seed = 1;
lparams.f16_kv = true;
lparams.n_parts = params.n_parts_llama;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx_wsp)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
float prob0 = 0.0f;
const std::string chat_symb = ":";
const std::string bot_name = "LLaMA";
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
// construct the initial prompt for LLaMA inference
std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
// need to have leading ' '
prompt_llama.insert(0, 1, ' ');
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
{
// get time string
std::string time_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%H:%M", now);
time_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{2}", time_str);
}
{
// get year string
std::string year_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%Y", now);
year_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{3}", year_str);
}
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
// init session
std::string path_session = params.path_session;
std::vector<llama_token> session_tokens;
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
// fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);
session_tokens.resize(lparams.n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);
for (size_t i = 0; i < session_tokens.size(); i++) {
embd_inp[i] = session_tokens[i];
}
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
}
// evaluate the initial prompt
printf("\n");
printf("%s : initializing - please wait ...\n", __func__);
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
if (params.verbose_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s", prompt_llama.c_str());
fflush(stdout);
}
// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
}
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
printf("%s : done! start speaking in the microphone\n", __func__);
printf("\n");
printf("%s%s", params.person.c_str(), chat_symb.c_str());
fflush(stdout);
// clear audio buffer
audio.clear();
// text inference variables
const int voice_id = 2;
const int n_keep = embd_inp.size();
const int n_ctx = llama_n_ctx(ctx_llama);
int n_past = n_keep;
int n_prev = 64; // TODO arg
int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
std::vector<llama_token> embd;
// reverse prompts for detecting when it's time to stop speaking
std::vector<std::string> antiprompts = {
params.person + chat_symb,
};
// main loop
while (is_running) {
// handle Ctrl + C
is_running = sdl_poll_events();
if (!is_running) {
break;
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int64_t t_ms = 0;
{
audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
//fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
}
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
if (text_heard.empty() || tokens.empty() || force_speak) {
//fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
audio.clear();
continue;
}
force_speak = false;
text_heard.insert(0, 1, ' ');
text_heard += "\n" + bot_name + chat_symb;
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
fflush(stdout);
embd = ::llama_tokenize(ctx_llama, text_heard, false);
// Append the new input tokens to the session_tokens vector
if (!path_session.empty()) {
session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
}
// text inference
bool done = false;
std::string text_to_speak;
while (true) {
// predict
if (embd.size() > 0) {
if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep;
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
// stop saving session if we run out of context
path_session = "";
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
}
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;
}
n_past++;
n_session_consumed++;
if (n_session_consumed >= (int) session_tokens.size()) {
i++;
break;
}
}
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}
}
if (embd.size() > 0 && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size();
embd.clear();
if (done) break;
{
// out of user input, sample next token
const float top_k = 5;
const float top_p = 0.80f;
const float temp = 0.30f;
const float repeat_penalty = 1.1764f;
const int repeat_last_n = 256;
if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
llama_token id = 0;
{
auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(ctx_llama);
logits[llama_token_eos()] = 0;
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// apply repeat penalty
const float nl_logit = logits[llama_token_nl()];
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
embd_inp.data() + std::max(0, n_past - repeat_last_n),
repeat_last_n, repeat_penalty);
logits[llama_token_nl()] = nl_logit;
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
} else {
// Temperature sampling
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
llama_sample_temperature(ctx_llama, &candidates_p, temp);
id = llama_sample_token(ctx_llama, &candidates_p);
}
}
if (id != llama_token_eos()) {
// add it to the context
embd.push_back(id);
text_to_speak += llama_token_to_str(ctx_llama, id);
printf("%s", llama_token_to_str(ctx_llama, id));
}
}
{
std::string last_output;
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
}
last_output += llama_token_to_str(ctx_llama, embd[0]);
for (std::string & antiprompt : antiprompts) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
done = true;
text_to_speak = ::replace(text_to_speak, antiprompt, "");
fflush(stdout);
need_to_save_session = true;
break;
}
}
}
is_running = sdl_poll_events();
if (!is_running) {
break;
}
}
text_to_speak = ::replace(text_to_speak, "\"", "");
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
audio.clear();
++n_iter;
}
}
}
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_free(ctx_wsp);
llama_print_timings(ctx_llama);
llama_free(ctx_llama);
return 0;
}

View File

@ -9,11 +9,8 @@ add_executable(${TARGET}
gpt-2.cpp gpt-2.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
common
) )
unset(EXTRA_FLAGS) unset(EXTRA_FLAGS)
@ -34,8 +31,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \ --bind \
-s USE_PTHREADS=1 \ -s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \ -s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1800MB \ -s INITIAL_MEMORY=1600MB \
-s TOTAL_MEMORY=1800MB \ -s TOTAL_MEMORY=1600MB \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \

View File

@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported) - Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough) - Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI - Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.8GB of RAM - The web-page uses about 1.6GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good. Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results. Also, the prompting strategy can likely be improved to achieve better results.

View File

@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) { for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) { if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file(path_model.c_str()); g_contexts[i] = whisper_init(path_model.c_str());
if (g_contexts[i] != nullptr) { if (g_contexts[i] != nullptr) {
g_running = true; g_running = true;
if (g_worker.joinable()) { if (g_worker.joinable()) {

View File

@ -1,6 +1,4 @@
#include "ggml.h" #include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h" #include "gpt-2.h"
#include <cmath> #include <cmath>
@ -16,6 +14,150 @@
/////////////////////// GPT-2 BEGIN ///////////////////////// /////////////////////// GPT-2 BEGIN /////////////////////////
//
// Vocab utils
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.push_back(std::make_pair(logits[i], i));
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
// normalize
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
if (top_p < 1.0f) {
{
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first;
if (cumsum >= top_p) {
logits_id.resize(i+1);
break;
}
}
}
// normalize again
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
}
//printf("\n");
//for (int i = 0; i < (int)logits_id.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
//}
//exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
// default hparams (GPT-2 117M) // default hparams (GPT-2 117M)
struct gpt2_hparams { struct gpt2_hparams {
int32_t n_vocab = 50257; int32_t n_vocab = 50257;
@ -23,7 +165,7 @@ struct gpt2_hparams {
int32_t n_embd = 768; int32_t n_embd = 768;
int32_t n_head = 12; int32_t n_head = 12;
int32_t n_layer = 12; int32_t n_layer = 12;
int32_t ftype = 1; int32_t f16 = 1;
}; };
struct gpt2_layer { struct gpt2_layer {
@ -45,7 +187,7 @@ struct gpt2_layer {
struct ggml_tensor * c_mlp_fc_w; struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b; struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w; struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
struct ggml_tensor * c_mlp_proj_b; struct ggml_tensor * c_mlp_proj_b;
}; };
@ -56,9 +198,8 @@ struct gpt2_model {
struct ggml_tensor * ln_f_g; struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b; struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers; std::vector<gpt2_layer> layers;
@ -100,14 +241,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); fin.read((char *) &hparams.f16, sizeof(hparams.f16));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd); printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer); printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype); printf("%s: f16 = %d\n", __func__, hparams.f16);
} }
// load vocab // load vocab
@ -134,14 +275,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
} }
} }
// for the big tensors, we have the option to store the data in 16-bit floats or quantized // for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation // in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx; auto & ctx = model.ctx;
@ -155,33 +291,32 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx; const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead ctx_size += (6 + 12*n_layer)*256; // object overhead
@ -190,11 +325,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// create the ggml context // create the ggml context
{ {
struct ggml_init_params params = { struct ggml_init_params params;
.mem_size = ctx_size, params.mem_size = ctx_size;
.mem_buffer = NULL, params.mem_buffer = NULL;
.no_alloc = false,
};
model.ctx = ggml_init(params); model.ctx = ggml_init(params);
if (!model.ctx) { if (!model.ctx) {
@ -217,38 +350,36 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx); model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name // map by name
model.tensors["model/ln_f/g"] = model.ln_f_g; model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b; model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte; model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe; model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i]; auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd); layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd); layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name // map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g; model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
@ -266,7 +397,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b; model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b; model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
} }
} }
@ -294,16 +425,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
{ {
size_t total_size = 0; size_t total_size = 0;
bool has_lm_head = false;
while (true) { while (true) {
int32_t n_dims; int32_t n_dims;
int32_t length; int32_t length;
int32_t ttype; int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length)); fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype)); fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) { if (fin.eof()) {
break; break;
@ -332,18 +461,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]); __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false; return false;
} }
// for debugging const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype)); if (nelements*bpe != ggml_nbytes(tensor)) {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe); __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false; return false;
@ -351,15 +475,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor)); fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// GPT-2 models share the WTE tensor as the LM head //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor); total_size += ggml_nbytes(tensor);
} }
@ -377,7 +493,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// - n_threads: number of threads to use // - n_threads: number of threads to use
// - n_past: the context size so far // - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context // - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token // - embd_w: the predicted probabilities of the next token
// //
bool gpt2_eval( bool gpt2_eval(
const gpt2_model & model, const gpt2_model & model,
@ -396,12 +512,12 @@ bool gpt2_eval(
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
static size_t buf_size = 512u*1024*1024; static size_t buf_size = 640u*1024*1024;
static void * buf = malloc(buf_size); static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) { if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate // reallocate
buf_size = buf_size_new; buf_size = buf_size_new;
@ -412,14 +528,13 @@ bool gpt2_eval(
} }
} }
struct ggml_init_params params = { struct ggml_init_params params;
/*.mem_size =*/ buf_size, params.mem_size = buf_size;
/*.mem_buffer =*/ buf, params.mem_buffer = buf;
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
struct ggml_cgraph gf = { };
gf.n_threads = n_threads; gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@ -463,7 +578,7 @@ bool gpt2_eval(
// [2304, N] // [2304, N]
{ {
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_attn_w, ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -539,13 +654,11 @@ bool gpt2_eval(
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12] // [n_past + N, 64, 12]
struct ggml_tensor * V_trans = struct ggml_tensor * V_trans =
ggml_cpy(ctx0, ggml_permute(ctx0,
ggml_permute(ctx0, ggml_reshape_3d(ctx0,
ggml_reshape_3d(ctx0, ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), n_embd/n_head, n_head, n_past + N),
n_embd/n_head, n_head, n_past + N), 1, 2, 0, 3);
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max // KQV = transpose(V) * KQ_soft_max
// [64, N, 12] // [64, N, 12]
@ -572,7 +685,7 @@ bool gpt2_eval(
// [768, N] // [768, N]
{ {
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_proj_w, ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -609,7 +722,7 @@ bool gpt2_eval(
// cur = fc_w*cur + fc_b // cur = fc_w*cur + fc_b
// [3072, N] // [3072, N]
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_fc_w, ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -629,7 +742,7 @@ bool gpt2_eval(
// cur = proj_w*cur + proj_b // cur = proj_w*cur + proj_b
// [768, N] // [768, N]
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w, model.layers[il].c_mlp_proj_w_trans,
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -656,12 +769,12 @@ bool gpt2_eval(
} }
// inpL = WTE * inpL // inpL = WTE * inpL
// [ 768, 50257] - model.lm_head // [ 768, 50257] - model.wte
// [ 768, N] - inpL // [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL); inpL = ggml_mul_mat(ctx0, model.wte, inpL);
// logits -> probs // logits -> probs
//inpL = ggml_soft_max(ctx0, inpL); inpL = ggml_soft_max(ctx0, inpL);
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
@ -675,7 +788,7 @@ bool gpt2_eval(
//embd_w.resize(n_vocab*N); //embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result just for the last token // return result for just the last token
embd_w.resize(n_vocab); embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
@ -712,7 +825,7 @@ Me too.
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters // sampling parameters
int32_t top_k = 5; int32_t top_k = 40;
float top_p = 0.9f; float top_p = 0.9f;
float temp = 1.0f; float temp = 1.0f;
}; };
@ -720,15 +833,14 @@ Me too.
struct gpt2_context * gpt2_init(const char * path_model) { struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context; gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(nullptr)); ctx->rng = std::mt19937(time(NULL));
// load the model // load the model
{ {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) { if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model); fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
delete ctx;
return nullptr; return nullptr;
} }
@ -772,9 +884,9 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
std::string result; std::string result;
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) { for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
// predict // predict
if (!embd.empty()) { if (embd.size() > 0) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) { if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n"); printf("gpt-2: failed to generate text\n");
return ""; return "";
@ -801,7 +913,10 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
result += ctx->vocab.id_to_token[embd[0]]; result += ctx->vocab.id_to_token[embd[0]];
// end of text token // end of text token
if (embd.back() == 50256) { if (embd.back() == 50256 ||
ctx->vocab.id_to_token[embd.back()] == "." ||
ctx->vocab.id_to_token[embd.back()] == "!" ||
ctx->vocab.id_to_token[embd.back()] == "?") {
break; break;
} }
} }

View File

@ -2,12 +2,18 @@
// TODO: Change to C-style API and move to ./examples for easy reuse. // TODO: Change to C-style API and move to ./examples for easy reuse.
#include "common.h"
#include <vector> #include <vector>
#include <map> #include <map>
#include <string> #include <string>
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
};
struct gpt2_context; struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model); struct gpt2_context * gpt2_init(const char * path_model);

View File

@ -44,15 +44,6 @@
<br><br> <br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr> <hr>
Select the models you would like to use and click the "Start" button to begin the conversation Select the models you would like to use and click the "Start" button to begin the conversation
@ -63,10 +54,6 @@
Whisper model: <span id="model-whisper-status"></span> Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button> <button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button> <button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span> <span id="fetch-whisper-progress"></span>
<!-- <!--
@ -279,17 +266,11 @@
let urls = { let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin', 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin', 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
}; };
let sizes = { let sizes = {
'tiny.en': 75, 'tiny.en': 75,
'base.en': 142, 'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
}; };
let url = urls[model]; let url = urls[model];
@ -300,10 +281,6 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none'; document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none'; document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... '; document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) { cbProgress = function(p) {
@ -315,10 +292,6 @@
var el; var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = ''; el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
}; };

View File

@ -1 +1 @@
audio.mp3 eleven-labs.py

View File

@ -1,8 +1,13 @@
if (WHISPER_SDL2) if (WHISPER_SUPPORT_SDL2)
# talk # talk
set(TARGET talk) set(TARGET talk)
add_executable(${TARGET} talk.cpp gpt-2.cpp) #add_executable(${TARGET} talk.cpp gpt-2.cpp)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT}) #target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
#target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions) # TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this: Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
``` ```
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
``` ```
## TTS ## TTS

View File

@ -1,23 +0,0 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
# Get a Voice object, by name or UUID
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = voice.generate(str(sys.argv[2:]))
# Save the TTS to a file
audio.save("audio")

View File

@ -1,6 +1,4 @@
#include "ggml.h" #include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h" #include "gpt-2.h"
#include <cmath> #include <cmath>
@ -16,6 +14,150 @@
/////////////////////// GPT-2 BEGIN ///////////////////////// /////////////////////// GPT-2 BEGIN /////////////////////////
//
// Vocab utils
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.empty()) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double /*temp*/,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
logits_id.emplace_back(logits[i], i);
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
// normalize
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
if (top_p < 1.0f) {
{
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += logits_id[i].first;
if (cumsum >= top_p) {
logits_id.resize(i+1);
break;
}
}
}
// normalize again
{
double sum = 0.0f;
for (int i = 0; i < (int)logits_id.size(); i++) {
sum += logits_id[i].first;
}
sum = 1.0/sum;
for (int i = 0; i < (int)logits_id.size(); i++) {
logits_id[i].first *= sum;
}
}
}
//printf("\n");
//for (int i = 0; i < (int) logits_id.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
//}
//exit(0);
// sample from the obtained distribution
std::vector<double> probs;
probs.reserve(logits_id.size());
for (int i = 0; i < (int) logits_id.size(); i++) {
probs.push_back(logits_id[i].first);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
// default hparams (GPT-2 117M) // default hparams (GPT-2 117M)
struct gpt2_hparams { struct gpt2_hparams {
int32_t n_vocab = 50257; int32_t n_vocab = 50257;
@ -23,7 +165,7 @@ struct gpt2_hparams {
int32_t n_embd = 768; int32_t n_embd = 768;
int32_t n_head = 12; int32_t n_head = 12;
int32_t n_layer = 12; int32_t n_layer = 12;
int32_t ftype = 1; int32_t f16 = 1;
}; };
struct gpt2_layer { struct gpt2_layer {
@ -45,7 +187,7 @@ struct gpt2_layer {
struct ggml_tensor * c_mlp_fc_w; struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b; struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w; struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
struct ggml_tensor * c_mlp_proj_b; struct ggml_tensor * c_mlp_proj_b;
}; };
@ -56,9 +198,8 @@ struct gpt2_model {
struct ggml_tensor * ln_f_g; struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b; struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers; std::vector<gpt2_layer> layers;
@ -100,14 +241,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); fin.read((char *) &hparams.f16, sizeof(hparams.f16));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd); printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer); printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype); printf("%s: f16 = %d\n", __func__, hparams.f16);
} }
// load vocab // load vocab
@ -127,21 +268,16 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read((char *) &len, sizeof(len)); fin.read((char *) &len, sizeof(len));
word.resize(len); word.resize(len);
fin.read((char *) word.data(), len); fin.read((char *) &word[0], len);
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word; vocab.id_to_token[i] = word;
} }
} }
// for the big tensors, we have the option to store the data in 16-bit floats or quantized // for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation // in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx; auto & ctx = model.ctx;
@ -155,33 +291,32 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx; const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead ctx_size += (6 + 12*n_layer)*256; // object overhead
@ -190,11 +325,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// create the ggml context // create the ggml context
{ {
struct ggml_init_params params = { struct ggml_init_params params;
.mem_size = ctx_size, params.mem_size = ctx_size;
.mem_buffer = NULL, params.mem_buffer = nullptr;
.no_alloc = false,
};
model.ctx = ggml_init(params); model.ctx = ggml_init(params);
if (!model.ctx) { if (!model.ctx) {
@ -217,38 +350,36 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx); model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name // map by name
model.tensors["model/ln_f/g"] = model.ln_f_g; model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b; model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte; model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe; model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i]; auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd); layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd); layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name // map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g; model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
@ -266,7 +397,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b; model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b; model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
} }
} }
@ -294,16 +425,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
{ {
size_t total_size = 0; size_t total_size = 0;
bool has_lm_head = false;
while (true) { while (true) {
int32_t n_dims; int32_t n_dims;
int32_t length; int32_t length;
int32_t ttype; int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length)); fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype)); fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) { if (fin.eof()) {
break; break;
@ -319,7 +448,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
std::string name(length, 0); std::string name(length, 0);
fin.read(&name[0], length); fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) { if (model.tensors.find(name) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false; return false;
} }
@ -332,18 +461,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]); __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false; return false;
} }
// for debugging const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype)); if (nelements*bpe != ggml_nbytes(tensor)) {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe); __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false; return false;
@ -351,15 +475,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor)); fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// GPT-2 models share the WTE tensor as the LM head //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor); total_size += ggml_nbytes(tensor);
} }
@ -377,7 +493,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
// - n_threads: number of threads to use // - n_threads: number of threads to use
// - n_past: the context size so far // - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context // - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token // - embd_w: the predicted probabilities of the next token
// //
bool gpt2_eval( bool gpt2_eval(
const gpt2_model & model, const gpt2_model & model,
@ -396,12 +512,12 @@ bool gpt2_eval(
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
static size_t buf_size = 512u*1024*1024; static size_t buf_size = 5640ull*1024*1024;
static void * buf = malloc(buf_size); static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) { if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate // reallocate
buf_size = buf_size_new; buf_size = buf_size_new;
@ -412,14 +528,13 @@ bool gpt2_eval(
} }
} }
struct ggml_init_params params = { struct ggml_init_params params;
/*.mem_size =*/ buf_size, params.mem_size = buf_size;
/*.mem_buffer =*/ buf, params.mem_buffer = buf;
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
struct ggml_cgraph gf = { };
gf.n_threads = n_threads; gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@ -463,7 +578,7 @@ bool gpt2_eval(
// [2304, N] // [2304, N]
{ {
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_attn_w, ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -539,13 +654,11 @@ bool gpt2_eval(
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12] // [n_past + N, 64, 12]
struct ggml_tensor * V_trans = struct ggml_tensor * V_trans =
ggml_cpy(ctx0, ggml_permute(ctx0,
ggml_permute(ctx0, ggml_reshape_3d(ctx0,
ggml_reshape_3d(ctx0, ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), n_embd/n_head, n_head, n_past + N),
n_embd/n_head, n_head, n_past + N), 1, 2, 0, 3);
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max // KQV = transpose(V) * KQ_soft_max
// [64, N, 12] // [64, N, 12]
@ -572,7 +685,7 @@ bool gpt2_eval(
// [768, N] // [768, N]
{ {
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_proj_w, ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -609,7 +722,7 @@ bool gpt2_eval(
// cur = fc_w*cur + fc_b // cur = fc_w*cur + fc_b
// [3072, N] // [3072, N]
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_fc_w, ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -629,7 +742,7 @@ bool gpt2_eval(
// cur = proj_w*cur + proj_b // cur = proj_w*cur + proj_b
// [768, N] // [768, N]
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w, model.layers[il].c_mlp_proj_w_trans,
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
@ -656,12 +769,12 @@ bool gpt2_eval(
} }
// inpL = WTE * inpL // inpL = WTE * inpL
// [ 768, 50257] - model.lm_head // [ 768, 50257] - model.wte
// [ 768, N] - inpL // [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL); inpL = ggml_mul_mat(ctx0, model.wte, inpL);
// logits -> probs // logits -> probs
//inpL = ggml_soft_max(ctx0, inpL); inpL = ggml_soft_max(ctx0, inpL);
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
@ -675,7 +788,7 @@ bool gpt2_eval(
//embd_w.resize(n_vocab*N); //embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result just for the last token // return result for just the last token
embd_w.resize(n_vocab); embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);

Some files were not shown because too many files have changed in this diff Show More