Compare commits

...

41 Commits

Author SHA1 Message Date
e2aa556a99 whisper : experiments with Flash Attention in the decoder 2023-01-07 21:00:51 +02:00
f30b5d322c ggml : fix bug in new soft max computation 2023-01-07 21:00:07 +02:00
44efbf7ff1 cmake : add -Wno-unused-function + update whisper.js 2023-01-07 20:18:34 +02:00
d347a59a5f ggml : when using BLAS start only 1 CPU thread 2023-01-07 19:48:56 +02:00
6394c906af ggml : fix running tasks with variable number of threads 2023-01-07 19:20:18 +02:00
74ffa14e1d ggml : unroll ggml_vec_dot_f16 in ggml_compute_forward_flash_attn_f16 2023-01-07 19:19:40 +02:00
65fdcbbbbb whisper : revert accidental MB change 2023-01-07 16:18:21 +02:00
d61d55cd4b ggml : speed-up soft max via Accelerate + unroll 2023-01-07 16:16:42 +02:00
d51fc3ee0a ggml : use vDSP_sve and vDSP_maxv from Accelerate 2023-01-07 16:10:16 +02:00
f82a7dd019 ggml : make gcc happy (minor) 2023-01-07 09:34:39 +02:00
87dd4a3081 talk.wasm : bump memory usage + update whisper.js 2023-01-06 21:13:44 +02:00
41e05c6b1b cmake : support AVX2 in Windows better (#381) 2023-01-06 19:36:33 +02:00
fa379cb22a Revert "tmp"
This reverts commit 1652965529.
2023-01-06 19:33:09 +02:00
322f4e6c4e go : bindings updated so they can be used in third party packages. (#379)
* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin
2023-01-06 19:32:28 +02:00
1652965529 tmp 2023-01-06 19:32:12 +02:00
6042c7a3be cmake : change min required version to 3.0 (#351)
We increase the min version only when want to use particular
functionality that is available in the newer version
2023-01-06 19:25:28 +02:00
6b351bb669 command : add "guided-mode" video demo in the README.md 2023-01-06 18:59:26 +02:00
a62170c656 ggml : add SSE3 and fp16 conversion lookup table (#368)
* Improves WASM performance:
  On MacBook M1 Pro, I observe 25% faster using Firefox and 35% faster using Chrome

* Add support for SSE3 SIMD

* Add SSE3 to system information

* Add Imath support for fp16-fp32 conversions

* Add Imath to system information

* Wrap Imath calls to avoid static function warnings

* Drop Imath; Add lookup table for f16 -> f32 conversions

* Remove TODO comments

* Update SSE3 to new macro arguments

* Correct updated macro definitions

* Prefer static inline where possible

* ggml : static inlines + add public f16 <-> f32 conversions

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-01-06 18:45:59 +02:00
1944e7c33e whisper : document POWER VSX support 2023-01-05 23:53:00 +02:00
49a8dd6732 ggml : reorganize POWER9 ppc64le SIMD code 2023-01-05 23:53:00 +02:00
8c7f642286 ggml : change f16 load and store macro arguments 2023-01-05 23:53:00 +02:00
ad2a4ffa03 whisper : do not use F16 tensors when in F32 mode (#369) 2023-01-05 22:56:25 +02:00
b3c865083e ci : add emscripten build 2023-01-05 22:10:20 +02:00
a0d4f8e65c main : make whisper_print_segment_callback() more readable (close #371) 2023-01-05 21:45:05 +02:00
4a214d2f07 cmake : add CMAKE_RUNTIME_OUTPUT_DIRECTORY
Currently needed by the wasm examples
2023-01-05 21:40:59 +02:00
0a0cfa7985 ggml : add void to argument-less functions 2023-01-05 21:40:38 +02:00
196d738974 minor : close #370 + Makefile build info print change 2023-01-05 21:35:45 +02:00
84c6b42e65 cmake : update to 3.19 (#351)
- update from 3.0 (from 2014) to 3.19 (from 2020)
- move some global setting onto the targets (through a cmake include)
2023-01-05 21:22:48 +02:00
dd6d582977 whisper : use ranged-based for loops for readability 2023-01-05 21:20:44 +02:00
d51c5eb906 ggml : define MIN / MAX only if not defined (minor) 2023-01-05 21:16:52 +02:00
0be6a1afd9 make : print build information 2023-01-02 13:35:26 +02:00
a466c3404d stream : fix data race on bool + avoid division-by-zero 2023-01-02 10:20:50 +02:00
d629c034a4 models : fix HF model URL (close #356) 2023-01-02 09:54:43 +02:00
f00509d57c command : refactor to split command list & general transcription modes (#331)
This makes it easier to understand if you're looking for only one of the capabilities.
2022-12-31 14:08:57 +02:00
424c410c42 ggml : improve f16 acceleration for POWER9 ppc64le 2022-12-31 10:02:19 +02:00
d97e6005e9 whisper : add whisper_n_audio_ctx and check for invalid audio_ctx
closes #344
2022-12-31 09:57:19 +02:00
3467230a77 models : fix typo in convert-h5-to-ggml.py
signficant -> significant
2022-12-31 09:49:01 +02:00
a091581eb3 cmake : add runtime destination install (#345)
needed for mingw32 build to successfully install the dlls in the correct location
2022-12-31 09:48:00 +02:00
68daf6e487 whisper : avoid some memory allocations 2022-12-30 13:43:48 +02:00
a593b932e4 main : add -ocsv, aka --output-csv to output a CSV file
Adds -ocsv, aka --output-csv feature to examples/main, which outputs a CSV file containing lines formatted as follows <startTime-in-integer-milliseconds>, <endTime-in-integer-milliseconds>, "<transcript-line-including-commas>".
2022-12-29 14:04:00 +02:00
9a8ad3db69 make : add i686 arch (close #329) 2022-12-29 13:58:55 +02:00
36 changed files with 1253 additions and 730 deletions

View File

@ -235,3 +235,33 @@ jobs:
with: with:
name: whisper-blas-bin-${{ matrix.arch }} name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Configure
run: echo "tmp"
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make

View File

@ -2,14 +2,15 @@ cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.0.4) project(whisper.cpp VERSION 1.0.4)
set(CMAKE_EXPORT_COMPILE_COMMANDS "on") # Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON) set(WHISPER_STANDALONE ON)
include(cmake/GitVars.cmake) include(GitVars)
include(cmake/BuildTypes.cmake) include(BuildTypes)
# configure project version # configure project version
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl") if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
@ -52,6 +53,7 @@ if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF) option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF) option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF) option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
else() else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF) option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif() endif()
@ -82,9 +84,6 @@ endif()
# dependencies # dependencies
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
# on APPLE - include Accelerate framework # on APPLE - include Accelerate framework
@ -131,6 +130,7 @@ if (WHISPER_ALL_WARNINGS)
-Wcast-qual \ -Wcast-qual \
-Wstrict-prototypes \ -Wstrict-prototypes \
-Wpointer-arith \ -Wpointer-arith \
-Wno-unused-function \
") ")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
-Wall \ -Wall \
@ -157,6 +157,7 @@ else()
if (MSVC) if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else() else()
if (EMSCRIPTEN) if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -168,7 +169,10 @@ else()
if(NOT WHISPER_NO_AVX2) if(NOT WHISPER_NO_AVX2)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif() endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c") if(NOT WHISPER_NO_FMA)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif() endif()
endif() endif()
endif() endif()
@ -190,6 +194,8 @@ add_library(${TARGET}
whisper.cpp whisper.cpp
) )
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC target_include_directories(${TARGET} PUBLIC
. .
) )
@ -223,6 +229,7 @@ target_compile_definitions(${TARGET} PUBLIC
install(TARGETS ${TARGET} install(TARGETS ${TARGET}
LIBRARY DESTINATION lib LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib/static ARCHIVE DESTINATION lib/static
RUNTIME DESTINATION bin
) )
# #

View File

@ -10,6 +10,9 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m) UNAME_M := $(shell uname -m)
endif endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
# Mac OS + Arm can report x86_64 # Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
@ -53,10 +56,13 @@ endif
# Architecture specific # Architecture specific
# TODO: probably these flags need to be tweaked on some architectures # TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue # feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),x86_64) ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
CFLAGS += -mfma -mf16c CFLAGS += -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features) AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M))) ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx CFLAGS += -mavx
endif endif
@ -81,6 +87,10 @@ ifeq ($(UNAME_M),x86_64)
ifneq (,$(findstring f16c,$(F16C_M))) ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c CFLAGS += -mf16c
endif endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
else ifeq ($(UNAME_S),Haiku) else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ") AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M))) ifneq (,$(findstring avx,$(AVX1_M)))
@ -141,6 +151,21 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
CFLAGS += -mfp16-format=ieee -mno-unaligned-access CFLAGS += -mfp16-format=ieee -mno-unaligned-access
endif endif
#
# Print build information
#
$(info I whisper.cpp build info: )
$(info I UNAME_S: $(UNAME_S))
$(info I UNAME_P: $(UNAME_P))
$(info I UNAME_M: $(UNAME_M))
$(info I CFLAGS: $(CFLAGS))
$(info I CXXFLAGS: $(CXXFLAGS))
$(info I LDFLAGS: $(LDFLAGS))
$(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main default: main
# #

View File

@ -11,6 +11,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- Plain C/C++ implementation without dependencies - Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework - Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- AVX intrinsics support for x86 architectures - AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision - Mixed F16 / F32 precision
- Low memory usage (Flash Attention + Flash Forward) - Low memory usage (Flash Attention + Flash Forward)
- Zero memory allocations at runtime - Zero memory allocations at runtime

View File

@ -1,3 +1,2 @@
build build
models models
go.sum

View File

@ -1,28 +1,27 @@
CMAKE := $(shell which cmake) BUILD_DIR := build
BUILD_DIR := "build" MODELS_DIR := models
MODELS_DIR := "models"
EXAMPLES_DIR := $(wildcard examples/*) EXAMPLES_DIR := $(wildcard examples/*)
C_INCLUDE_PATH := "../.." INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..)
all: clean whisper examples all: clean whisper examples
whisper: mkdir whisper: mkdir
@echo Build whisper @echo Build whisper
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on @${MAKE} -C ../.. libwhisper.a
@${CMAKE} --build ${BUILD_DIR} --target whisper
test: model-small whisper modtidy test: model-small whisper modtidy
@go test -v . @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@go test -v ./pkg/whisper/... @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
examples: $(EXAMPLES_DIR) examples: $(EXAMPLES_DIR)
model-small: mkdir examples/go-model-download model-small: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models small.en @${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
$(EXAMPLES_DIR): mkdir whisper modtidy $(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@) @echo Build example $(notdir $@)
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@ @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
mkdir: mkdir:
@echo Mkdir ${BUILD_DIR} @echo Mkdir ${BUILD_DIR}

View File

@ -74,4 +74,27 @@ And you can then test a model against samples with the following command:
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav ./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
``` ```
## Using the bindings
To use the bindings in your own software,
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
Look at the `Makefile` in the `bindings/go` directory for an example.
The API Documentation:
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -17,15 +17,14 @@ import (
// CONSTANTS // CONSTANTS
const ( const (
srcUrl = "https://huggingface.co/" // The location of the models srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
srcExt = ".bin" // Filename extension srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model bufSize = 1024 * 64 // Size of the buffer used for downloading the model
) )
var ( var (
// The models which will be downloaded, if no model is specified as an argument // The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"} modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
) )
var ( var (
@ -123,11 +122,14 @@ func GetModels() []string {
// URLForModel returns the URL for the given model on huggingface.co // URLForModel returns the URL for the given model on huggingface.co
func URLForModel(model string) (string, error) { func URLForModel(model string) (string, error) {
if filepath.Ext(model) != srcExt {
model += srcExt
}
url, err := url.Parse(srcUrl) url, err := url.Parse(srcUrl)
if err != nil { if err != nil {
return "", err return "", err
} else { } else {
url.Path = srcPathPrefix + "-" + model + srcExt url.Path = filepath.Join(url.Path, model)
} }
return url.String(), nil return url.String(), nil
} }

23
bindings/go/go.sum Normal file
View File

@ -0,0 +1,23 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,8 +1,5 @@
package whisper package whisper
// This file defines the whisper_token, whisper_token_data and whisper_full_params
// structures, which are used by the whisper_full() function.
import ( import (
"fmt" "fmt"
) )

View File

@ -9,8 +9,7 @@ import (
// CGO // CGO
/* /*
#cgo CFLAGS: -I${SRCDIR}/../.. #cgo LDFLAGS: -lwhisper -lm -lstdc++
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
#cgo darwin LDFLAGS: -framework Accelerate #cgo darwin LDFLAGS: -framework Accelerate
#include <whisper.h> #include <whisper.h>
#include <stdlib.h> #include <stdlib.h>
@ -171,6 +170,10 @@ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
} }
// Return the id of the specified language, returns -1 if not found // Return the id of the specified language, returns -1 if not found
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_lang_id(lang string) int { func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang))) return int(C.whisper_lang_id(C.CString(lang)))
} }
@ -211,6 +214,10 @@ func (ctx *Context) Whisper_n_text_ctx() int {
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx))) return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
} }
func (ctx *Context) Whisper_n_audio_ctx() int {
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_is_multilingual() int { func (ctx *Context) Whisper_is_multilingual() int {
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx))) return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
} }

View File

@ -50,7 +50,10 @@ func Test_Whisper_001(t *testing.T) {
ctx := whisper.Whisper_init(ModelPath) ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx) assert.NotNil(ctx)
defer ctx.Whisper_free() defer ctx.Whisper_free()
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil)) params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil)
assert.NoError(err)
// Print out tokens // Print out tokens
num_segments := ctx.Whisper_full_n_segments() num_segments := ctx.Whisper_full_n_segments()

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,17 @@
# Set the default compile features and properties for a target.
if (NOT TARGET)
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
endif()
target_compile_features(${TARGET}
PRIVATE
cxx_std_11
)
set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp emscripten.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
) )

View File

@ -1,3 +1,6 @@
set(TARGET bench) set(TARGET bench)
add_executable(${TARGET} bench.cpp) add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp emscripten.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
) )

View File

@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
# command # command
set(TARGET command) set(TARGET command)
add_executable(${TARGET} command.cpp) add_executable(${TARGET} command.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -9,7 +9,19 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance # On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0 ./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Guided mode
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
```bash
# Run in guided mode, the list of allowed commands is in commands.txt # Run in guided mode, the list of allowed commands is in commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt ./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
@ -17,9 +29,8 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0 ./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
``` ```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4 https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
Web version: [examples/command.wasm](/examples/command.wasm)
## Building ## Building

View File

@ -510,86 +510,23 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
return allowed_commands; return allowed_commands;
} }
int main(int argc, char ** argv) { // command-list mode
whisper_params params; // guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
// wait for 1 second to avoid any buffered noise
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
audio.clear();
int max_len = 0;
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
std::vector<std::string> allowed_commands;
std::vector<std::vector<whisper_token>> allowed_tokens;
std::string k_prompt;
std::vector<whisper_token> k_tokens;
if (!params.commands.empty()) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__); fprintf(stderr, "%s: guided mode\n", __func__);
allowed_commands = read_allowed_commands(params.commands); std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
if (allowed_commands.empty()) { if (allowed_commands.empty()) {
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str()); fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
return 2; return 2;
} }
int max_len = 0;
std::vector<std::vector<whisper_token>> allowed_tokens;
for (const auto & cmd : allowed_commands) { for (const auto & cmd : allowed_commands) {
whisper_token tokens[1024]; whisper_token tokens[1024];
allowed_tokens.emplace_back(); allowed_tokens.emplace_back();
@ -623,7 +560,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, " ]\n"); fprintf(stderr, " ]\n");
} }
k_prompt = "select one from the available words: "; std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) { for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) { if (i > 0) {
k_prompt += ", "; k_prompt += ", ";
@ -633,6 +570,7 @@ int main(int argc, char ** argv) {
k_prompt += ". selected word: "; k_prompt += ". selected word: ";
// tokenize prompt // tokenize prompt
std::vector<whisper_token> k_tokens;
{ {
k_tokens.resize(1024); k_tokens.resize(1024);
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024); const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
@ -655,12 +593,10 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: listening for a command ...\n", __func__); fprintf(stderr, "%s: listening for a command ...\n", __func__);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} else { bool is_running = true;
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
k_prompt = "Ok Whisper, start listening for commands."; std::vector<float> pcmf32_cur;
} std::vector<float> pcmf32_prompt;
// main loop // main loop
while (is_running) { while (is_running) {
@ -679,98 +615,13 @@ int main(int argc, char ** argv) {
} }
if (!is_running) { if (!is_running) {
break; return 0;
} }
} }
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (allowed_commands.empty()) {
// general-purpose mode
// freely transcribe the voice into text
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
int64_t t_ms = 0;
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
} else {
// command-list mode
// guide the transcription to match the most likely command from a provided list
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
@ -834,8 +685,8 @@ int main(int argc, char ** argv) {
fprintf(stdout, "\n"); fprintf(stdout, "\n");
for (const auto & cmd : probs_id) { for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) { for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]); fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
} }
fprintf(stdout, "\n"); fprintf(stdout, "\n");
} }
@ -845,9 +696,12 @@ int main(int argc, char ** argv) {
{ {
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
const float prob = probs_id[0].first;
const int index = probs_id[0].second;
fprintf(stdout, "\n"); fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first, "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count()); (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n"); fprintf(stdout, "\n");
} }
@ -855,6 +709,191 @@ int main(int argc, char ** argv) {
audio.clear(); audio.clear();
} }
} }
return 0;
}
// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
float prob0 = 0.0f;
float prob = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = "Ok Whisper, start listening for commands.";
fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
return 0;
}
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0;
if (!have_prompt) {
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
}
} else {
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0);
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
}
audio.clear();
}
}
}
return 0;
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init(params.model.c_str());
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
// wait for 1 second to avoid any buffered noise
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
audio.clear();
int ret_val = 0;
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
} }
audio.pause(); audio.pause();
@ -862,5 +901,5 @@ int main(int argc, char ** argv) {
whisper_print_timings(ctx); whisper_print_timings(ctx);
whisper_free(ctx); whisper_free(ctx);
return 0; return ret_val;
} }

View File

@ -1,3 +1,6 @@
set(TARGET main) set(TARGET main)
add_executable(${TARGET} main.cpp) add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -69,6 +69,7 @@ struct whisper_params {
bool output_vtt = false; bool output_vtt = false;
bool output_srt = false; bool output_srt = false;
bool output_wts = false; bool output_wts = false;
bool output_csv = false;
bool print_special = false; bool print_special = false;
bool print_colors = false; bool print_colors = false;
bool print_progress = false; bool print_progress = false;
@ -111,6 +112,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
@ -150,6 +152,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
@ -173,40 +176,27 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0;
int64_t t1;
// print the last n_new segments // print the last n_new segments
const int s0 = n_segments - n_new; const int s0 = n_segments - n_new;
if (s0 == 0) { if (s0 == 0) {
printf("\n"); printf("\n");
} }
for (int i = s0; i < n_segments; i++) { for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) { if (!params.no_timestamps || params.diarize) {
if (params.print_colors) { t0 = whisper_full_get_segment_t0(ctx, i);
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { t1 = whisper_full_get_segment_t1(ctx, i);
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
} }
const char * text = whisper_full_get_token_text(ctx, i, j); if (!params.no_timestamps) {
const float p = whisper_full_get_token_p (ctx, i, j); printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
} }
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker;
if (params.diarize && pcmf32s.size() == 2) { if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size(); const int64_t n_samples = pcmf32s[0].size();
@ -234,7 +224,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
} }
if (params.print_colors) { if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) { if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j); const whisper_token id = whisper_full_get_token_id(ctx, i, j);
@ -250,13 +239,18 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
} }
printf("\n");
} else { } else {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text); printf("%s%s", speaker.c_str(), text);
} }
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
} }
fflush(stdout);
} }
} }
@ -325,6 +319,32 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true; return true;
} }
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ')
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", "
<< 10 * t1 << ", \""
<< text << "\"\n";
}
return true;
}
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
@ -528,7 +548,7 @@ int main(int argc, char ** argv) {
} }
if (wav.sampleRate != WHISPER_SAMPLE_RATE) { if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8; return 8;
} }
@ -674,6 +694,13 @@ int main(int argc, char ** argv) {
const auto fname_wts = fname_inp + ".wts"; const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
} }
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_inp + ".csv";
output_csv(ctx, fname_csv.c_str());
}
} }
} }

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp emscripten.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
) )

View File

@ -2,6 +2,9 @@ if (WHISPER_SUPPORT_SDL2)
# stream # stream
set(TARGET stream) set(TARGET stream)
add_executable(${TARGET} stream.cpp) add_executable(${TARGET} stream.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -8,6 +8,7 @@
#include <SDL.h> #include <SDL.h>
#include <SDL_audio.h> #include <SDL_audio.h>
#include <atomic>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
@ -144,7 +145,7 @@ private:
int m_len_ms = 0; int m_len_ms = 0;
int m_sample_rate = 0; int m_sample_rate = 0;
bool m_running = false; std::atomic_bool m_running;
std::mutex m_mutex; std::mutex m_mutex;
std::vector<float> m_audio; std::vector<float> m_audio;
@ -155,6 +156,8 @@ private:
audio_async::audio_async(int len_ms) { audio_async::audio_async(int len_ms) {
m_len_ms = len_ms; m_len_ms = len_ms;
m_running = false;
} }
audio_async::~audio_async() { audio_async::~audio_async() {
@ -427,10 +430,10 @@ int main(int argc, char ** argv) {
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE; const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE; const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
params.no_timestamps = !use_vad; params.no_timestamps = !use_vad;
params.no_context = use_vad; params.no_context = use_vad;
params.max_tokens = 0; params.max_tokens = 0;

View File

@ -9,6 +9,8 @@ add_executable(${TARGET}
gpt-2.cpp gpt-2.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
) )
@ -31,8 +33,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \ --bind \
-s USE_PTHREADS=1 \ -s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \ -s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1600MB \ -s INITIAL_MEMORY=1800MB \
-s TOTAL_MEMORY=1600MB \ -s TOTAL_MEMORY=1800MB \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \

View File

@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported) - Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough) - Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI - Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.6GB of RAM - The web-page uses about 1.8GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good. Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results. Also, the prompting strategy can likely be improved to achieve better results.

View File

@ -8,6 +8,9 @@ if (WHISPER_SUPPORT_SDL2)
# TODO: this is temporary # TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy .. # need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp) add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../) target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -8,6 +8,8 @@ add_executable(${TARGET}
emscripten.cpp emscripten.cpp
) )
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
whisper whisper
) )

659
ggml.c

File diff suppressed because it is too large Load Diff

2
ggml.h
View File

@ -731,6 +731,8 @@ int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void); int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void); int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void); int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -56,7 +56,7 @@ def bytes_to_unicode():
The reversible bpe codes work on unicode strings. The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab. This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """

View File

@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
goto :eof goto :eof
) )
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin" PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 ( if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model% echo Failed to download ggml model %model%

View File

@ -204,6 +204,10 @@ struct whisper_vocab {
std::map<token, id> token_to_id; std::map<token, id> token_to_id;
std::map<id, token> id_to_token; std::map<id, token> id_to_token;
// used to avoid memory allocations during sampling
// TODO: move to whisper_context in the future
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
id token_eot = 50256; id token_eot = 50256;
id token_sot = 50257; id token_sot = 50257;
id token_prev = 50360; id token_prev = 50360;
@ -408,6 +412,8 @@ struct whisper_context {
std::vector<uint8_t> buf_compute; std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer; std::vector<uint8_t> buf_compute_layer;
ggml_type wtype; // weight type (FP32 or FP16)
whisper_model model; whisper_model model;
whisper_vocab vocab; whisper_vocab vocab;
@ -431,8 +437,7 @@ struct whisper_context {
}; };
template<typename T> template<typename T>
static void read_safe(std::ifstream& fin, T& dest) static void read_safe(std::ifstream& fin, T& dest) {
{
fin.read((char*)& dest, sizeof(T)); fin.read((char*)& dest, sizeof(T));
} }
@ -551,6 +556,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
std::string word; std::string word;
std::vector<char> tmp; std::vector<char> tmp;
tmp.reserve(128);
for (int i = 0; i < n_vocab; i++) { for (int i = 0; i < n_vocab; i++) {
uint32_t len; uint32_t len;
read_safe(fin, len); read_safe(fin, len);
@ -603,6 +611,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
vocab.id_to_token[i] = word; vocab.id_to_token[i] = word;
} }
} }
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
vocab.probs_id.reserve(n_vocab);
} }
{ {
@ -618,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// for the big tensors, we have the option to store the data in 16-bit floats // for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation // in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const ggml_type wtype = wctx.wtype;
size_t ctx_size = 0; size_t ctx_size = 0;
@ -639,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// encoder // encoder
{ {
// TODO: F16 .. maybe not?
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
@ -654,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// decoder // decoder
{ {
// TODO: F16 .. maybe not?
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
@ -971,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_mem = n_text_layer*n_text_ctx; const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
} }
// key/value memory for the cross-attention layer // key/value memory for the cross-attention layer
@ -982,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
} }
const size_t memory_size = const size_t memory_size =
@ -1229,14 +1242,14 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Qcur, Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Kcur, Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * V = struct ggml_tensor * V =
@ -1246,7 +1259,7 @@ static bool whisper_encode(
Vcur, Vcur,
n_state/n_head, n_head, n_ctx), n_state/n_head, n_head, n_ctx),
1, 2, 0, 3), 1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
); );
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@ -1262,7 +1275,7 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Kcur, Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
// K * Q // K * Q
@ -1280,7 +1293,7 @@ static bool whisper_encode(
// ggml_permute(ctxL, // ggml_permute(ctxL,
// ggml_cpy(ctxL, // ggml_cpy(ctxL,
// Vcur, // Vcur,
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3); // 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@ -1292,7 +1305,7 @@ static bool whisper_encode(
Vcur, Vcur,
n_state/n_head, n_head, n_ctx), n_state/n_head, n_head, n_ctx),
0, 2, 1, 3), 0, 2, 1, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
); );
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@ -1337,7 +1350,7 @@ static bool whisper_encode(
#ifdef USE_FLASH_FF #ifdef USE_FLASH_FF
cur = ggml_flash_ff(ctxL, cur = ggml_flash_ff(ctxL,
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)), ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else #else
// fully connected // fully connected
@ -1444,7 +1457,7 @@ static bool whisper_encode(
layer.cross_attn_k_w, layer.cross_attn_k_w,
cur); cur);
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); //Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w, layer.cross_attn_v_w,
@ -1566,14 +1579,14 @@ static bool whisper_decode(
Qcur), Qcur),
Qcur); Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key // note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w, layer.attn_k_w,
cur); cur);
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
layer.attn_v_w, layer.attn_v_w,
@ -1596,6 +1609,33 @@ static bool whisper_decode(
// ------ // ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
#else
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
@ -1613,13 +1653,13 @@ static bool whisper_decode(
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
//struct ggml_tensor * KQ_scaled = struct ggml_tensor * KQ_scaled =
// ggml_scale(ctxL, ggml_scale(ctxL,
// KQ, KQ,
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// ); );
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
@ -1631,6 +1671,7 @@ static bool whisper_decode(
1, 2, 0, 3); 1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
@ -1676,7 +1717,7 @@ static bool whisper_decode(
Qcur), Qcur),
Qcur); Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// Kcross is already scaled // Kcross is already scaled
struct ggml_tensor * Kcross = struct ggml_tensor * Kcross =
@ -1691,6 +1732,24 @@ static bool whisper_decode(
// ------ // ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
#else
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
@ -1717,6 +1776,7 @@ static bool whisper_decode(
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3); struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
@ -1849,7 +1909,7 @@ static bool whisper_decode(
// the most basic sampling scheme - select the top token // the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best( static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab, whisper_vocab & vocab,
const float * probs, const float * probs,
bool force_timestamp, bool force_timestamp,
bool is_initial) { bool is_initial) {
@ -1857,11 +1917,11 @@ static whisper_token_data whisper_sample_best(
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
}; };
int n_logits = vocab.id_to_token.size(); const int n_logits = vocab.n_vocab;
std::vector<std::pair<double, whisper_vocab::id>> probs_id; auto & probs_id = vocab.probs_id;
probs_id.reserve(n_logits);
probs_id.clear();
for (int i = 0; i < n_logits; i++) { for (int i = 0; i < n_logits; i++) {
probs_id.emplace_back(probs[i], i); probs_id.emplace_back(probs[i], i);
} }
@ -2001,6 +2061,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
std::vector<float> even; std::vector<float> even;
std::vector<float> odd; std::vector<float> odd;
even.reserve(N/2);
odd.reserve(N/2);
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (i % 2 == 0) { if (i % 2 == 0) {
even.push_back(in[i]); even.push_back(in[i]);
@ -2458,12 +2521,12 @@ int whisper_lang_auto_detect(
} }
{ {
for (int i = 0; i < (int) probs_id.size(); i++) { for (const auto & prob : probs_id) {
if (lang_probs) { if (lang_probs) {
lang_probs[probs_id[i].second] = probs_id[i].first; lang_probs[prob.second] = prob.first;
} }
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first); //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
} }
} }
@ -2482,6 +2545,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_ctx; return ctx->model.hparams.n_text_ctx;
} }
int whisper_n_audio_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_ctx;
}
int whisper_is_multilingual(struct whisper_context * ctx) { int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0; return ctx->vocab.is_multilingual() ? 1 : 0;
} }
@ -2562,6 +2629,8 @@ const char * whisper_print_system_info(void) {
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str(); return s.c_str();
} }
@ -2807,7 +2876,11 @@ int whisper_full(
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
} }
// overwrite audio_ctx // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -4;
}
ctx->exp_n_audio_ctx = params.audio_ctx; ctx->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed // these tokens determine the task that will be performed
@ -3134,7 +3207,7 @@ int whisper_full_parallel(
// separate key + value memory for each processor // separate key + value memory for each processor
{ {
auto & ctx = model.ctx_mem; auto & mctx = model.ctx_mem;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -3147,8 +3220,8 @@ int whisper_full_parallel(
const int n_mem = n_text_layer*n_text_ctx; const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
} }
// key/value memory for the cross-attention layer // key/value memory for the cross-attention layer
@ -3158,8 +3231,8 @@ int whisper_full_parallel(
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
} }
} }
} }
@ -3203,17 +3276,17 @@ int whisper_full_parallel(
for (int i = 0; i < n_processors - 1; ++i) { for (int i = 0; i < n_processors - 1; ++i) {
auto & results_i = ctxs[i].result_all; auto & results_i = ctxs[i].result_all;
for (int j = 0; j < (int) results_i.size(); ++j) { for (auto & result : results_i) {
// correct the segment timestamp taking into account the offset // correct the segment timestamp taking into account the offset
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
// make sure that segments are not overlapping // make sure that segments are not overlapping
if (!ctx->result_all.empty()) { if (!ctx->result_all.empty()) {
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1); result.t0 = std::max(result.t0, ctx->result_all.back().t1);
} }
ctx->result_all.push_back(std::move(results_i[j])); ctx->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment // call the new_segment_callback for each segment
if (params.new_segment_callback) { if (params.new_segment_callback) {
@ -3308,18 +3381,18 @@ static int64_t sample_to_timestamp(int i_sample) {
static float voice_length(const std::string & text) { static float voice_length(const std::string & text) {
float res = 0.0f; float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) { for (char c : text) {
if (text[i] == ' ') { if (c == ' ') {
res += 0.01f; res += 0.01f;
} else if (text[i] == ',') { } else if (c == ',') {
res += 2.00f; res += 2.00f;
} else if (text[i] == '.') { } else if (c == '.') {
res += 3.00f; res += 3.00f;
} else if (text[i] == '!') { } else if (c == '!') {
res += 3.00f; res += 3.00f;
} else if (text[i] == '?') { } else if (c == '?') {
res += 3.00f; res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') { } else if (c >= '0' && c <= '9') {
res += 3.00f; res += 3.00f;
} else { } else {
res += 1.00f; res += 1.00f;

View File

@ -177,6 +177,7 @@ extern "C" {
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token // The probabilities for the next token