mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-07-04 08:20:57 +02:00
Compare commits
1 Commits
fa-decoder
...
threads
Author | SHA1 | Date | |
---|---|---|---|
4e6d2e98ab |
30
.github/workflows/build.yml
vendored
30
.github/workflows/build.yml
vendored
@ -235,33 +235,3 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: whisper-blas-bin-${{ matrix.arch }}
|
name: whisper-blas-bin-${{ matrix.arch }}
|
||||||
path: build/bin/${{ matrix.build }}
|
path: build/bin/${{ matrix.build }}
|
||||||
|
|
||||||
emscripten:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
build: [Release]
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v1
|
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
|
||||||
tar -xvf master.tar.gz
|
|
||||||
emsdk-master/emsdk update
|
|
||||||
emsdk-master/emsdk install latest
|
|
||||||
emsdk-master/emsdk activate latest
|
|
||||||
|
|
||||||
- name: Configure
|
|
||||||
run: echo "tmp"
|
|
||||||
|
|
||||||
- name: Build
|
|
||||||
run: |
|
|
||||||
pushd emsdk-master
|
|
||||||
source ./emsdk_env.sh
|
|
||||||
popd
|
|
||||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
|
||||||
make
|
|
||||||
|
@ -2,15 +2,14 @@ cmake_minimum_required (VERSION 3.0)
|
|||||||
|
|
||||||
project(whisper.cpp VERSION 1.0.4)
|
project(whisper.cpp VERSION 1.0.4)
|
||||||
|
|
||||||
# Add path to modules
|
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
|
||||||
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||||
|
|
||||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
set(WHISPER_STANDALONE ON)
|
set(WHISPER_STANDALONE ON)
|
||||||
include(GitVars)
|
include(cmake/GitVars.cmake)
|
||||||
include(BuildTypes)
|
include(cmake/BuildTypes.cmake)
|
||||||
|
|
||||||
# configure project version
|
# configure project version
|
||||||
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
|
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
|
||||||
@ -53,7 +52,6 @@ if (APPLE)
|
|||||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
|
||||||
else()
|
else()
|
||||||
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||||
endif()
|
endif()
|
||||||
@ -84,6 +82,9 @@ endif()
|
|||||||
|
|
||||||
# dependencies
|
# dependencies
|
||||||
|
|
||||||
|
set(CMAKE_C_STANDARD 11)
|
||||||
|
set(CMAKE_CXX_STANDARD 11)
|
||||||
|
|
||||||
find_package(Threads REQUIRED)
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
# on APPLE - include Accelerate framework
|
# on APPLE - include Accelerate framework
|
||||||
@ -130,7 +131,6 @@ if (WHISPER_ALL_WARNINGS)
|
|||||||
-Wcast-qual \
|
-Wcast-qual \
|
||||||
-Wstrict-prototypes \
|
-Wstrict-prototypes \
|
||||||
-Wpointer-arith \
|
-Wpointer-arith \
|
||||||
-Wno-unused-function \
|
|
||||||
")
|
")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||||
-Wall \
|
-Wall \
|
||||||
@ -157,7 +157,6 @@ else()
|
|||||||
if (MSVC)
|
if (MSVC)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
|
|
||||||
else()
|
else()
|
||||||
if (EMSCRIPTEN)
|
if (EMSCRIPTEN)
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||||
@ -169,10 +168,7 @@ else()
|
|||||||
if(NOT WHISPER_NO_AVX2)
|
if(NOT WHISPER_NO_AVX2)
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||||
endif()
|
endif()
|
||||||
if(NOT WHISPER_NO_FMA)
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
|
||||||
endif()
|
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
@ -194,8 +190,6 @@ add_library(${TARGET}
|
|||||||
whisper.cpp
|
whisper.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_include_directories(${TARGET} PUBLIC
|
target_include_directories(${TARGET} PUBLIC
|
||||||
.
|
.
|
||||||
)
|
)
|
||||||
@ -229,7 +223,6 @@ target_compile_definitions(${TARGET} PUBLIC
|
|||||||
install(TARGETS ${TARGET}
|
install(TARGETS ${TARGET}
|
||||||
LIBRARY DESTINATION lib
|
LIBRARY DESTINATION lib
|
||||||
ARCHIVE DESTINATION lib/static
|
ARCHIVE DESTINATION lib/static
|
||||||
RUNTIME DESTINATION bin
|
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
29
Makefile
29
Makefile
@ -10,9 +10,6 @@ ifndef UNAME_M
|
|||||||
UNAME_M := $(shell uname -m)
|
UNAME_M := $(shell uname -m)
|
||||||
endif
|
endif
|
||||||
|
|
||||||
CCV := $(shell $(CC) --version | head -n 1)
|
|
||||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
|
||||||
|
|
||||||
# Mac OS + Arm can report x86_64
|
# Mac OS + Arm can report x86_64
|
||||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||||
ifeq ($(UNAME_S),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
@ -56,13 +53,10 @@ endif
|
|||||||
# Architecture specific
|
# Architecture specific
|
||||||
# TODO: probably these flags need to be tweaked on some architectures
|
# TODO: probably these flags need to be tweaked on some architectures
|
||||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
ifeq ($(UNAME_M),x86_64)
|
||||||
ifeq ($(UNAME_S),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
CFLAGS += -mf16c
|
CFLAGS += -mfma -mf16c
|
||||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||||
ifneq (,$(findstring FMA,$(AVX1_M)))
|
|
||||||
CFLAGS += -mfma
|
|
||||||
endif
|
|
||||||
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
|
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
|
||||||
CFLAGS += -mavx
|
CFLAGS += -mavx
|
||||||
endif
|
endif
|
||||||
@ -87,10 +81,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
|||||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||||
CFLAGS += -mf16c
|
CFLAGS += -mf16c
|
||||||
endif
|
endif
|
||||||
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
|
|
||||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
|
||||||
CFLAGS += -msse3
|
|
||||||
endif
|
|
||||||
else ifeq ($(UNAME_S),Haiku)
|
else ifeq ($(UNAME_S),Haiku)
|
||||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||||
@ -151,21 +141,6 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
|
|||||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||||
endif
|
endif
|
||||||
|
|
||||||
#
|
|
||||||
# Print build information
|
|
||||||
#
|
|
||||||
|
|
||||||
$(info I whisper.cpp build info: )
|
|
||||||
$(info I UNAME_S: $(UNAME_S))
|
|
||||||
$(info I UNAME_P: $(UNAME_P))
|
|
||||||
$(info I UNAME_M: $(UNAME_M))
|
|
||||||
$(info I CFLAGS: $(CFLAGS))
|
|
||||||
$(info I CXXFLAGS: $(CXXFLAGS))
|
|
||||||
$(info I LDFLAGS: $(LDFLAGS))
|
|
||||||
$(info I CC: $(CCV))
|
|
||||||
$(info I CXX: $(CXXV))
|
|
||||||
$(info )
|
|
||||||
|
|
||||||
default: main
|
default: main
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -11,7 +11,6 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
|||||||
- Plain C/C++ implementation without dependencies
|
- Plain C/C++ implementation without dependencies
|
||||||
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
||||||
- AVX intrinsics support for x86 architectures
|
- AVX intrinsics support for x86 architectures
|
||||||
- VSX intrinsics support for POWER architectures
|
|
||||||
- Mixed F16 / F32 precision
|
- Mixed F16 / F32 precision
|
||||||
- Low memory usage (Flash Attention + Flash Forward)
|
- Low memory usage (Flash Attention + Flash Forward)
|
||||||
- Zero memory allocations at runtime
|
- Zero memory allocations at runtime
|
||||||
|
1
bindings/go/.gitignore
vendored
1
bindings/go/.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
build
|
build
|
||||||
models
|
models
|
||||||
|
go.sum
|
||||||
|
@ -1,27 +1,28 @@
|
|||||||
BUILD_DIR := build
|
CMAKE := $(shell which cmake)
|
||||||
MODELS_DIR := models
|
BUILD_DIR := "build"
|
||||||
|
MODELS_DIR := "models"
|
||||||
EXAMPLES_DIR := $(wildcard examples/*)
|
EXAMPLES_DIR := $(wildcard examples/*)
|
||||||
INCLUDE_PATH := $(abspath ../..)
|
C_INCLUDE_PATH := "../.."
|
||||||
LIBRARY_PATH := $(abspath ../..)
|
|
||||||
|
|
||||||
all: clean whisper examples
|
all: clean whisper examples
|
||||||
|
|
||||||
whisper: mkdir
|
whisper: mkdir
|
||||||
@echo Build whisper
|
@echo Build whisper
|
||||||
@${MAKE} -C ../.. libwhisper.a
|
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
|
||||||
|
@${CMAKE} --build ${BUILD_DIR} --target whisper
|
||||||
|
|
||||||
test: model-small whisper modtidy
|
test: model-small whisper modtidy
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
@go test -v .
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
@go test -v ./pkg/whisper/...
|
||||||
|
|
||||||
examples: $(EXAMPLES_DIR)
|
examples: $(EXAMPLES_DIR)
|
||||||
|
|
||||||
model-small: mkdir examples/go-model-download
|
model-small: mkdir examples/go-model-download
|
||||||
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
|
@${BUILD_DIR}/go-model-download -out models small.en
|
||||||
|
|
||||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||||
@echo Build example $(notdir $@)
|
@echo Build example $(notdir $@)
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||||
|
|
||||||
mkdir:
|
mkdir:
|
||||||
@echo Mkdir ${BUILD_DIR}
|
@echo Mkdir ${BUILD_DIR}
|
||||||
|
@ -74,27 +74,4 @@ And you can then test a model against samples with the following command:
|
|||||||
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using the bindings
|
|
||||||
|
|
||||||
To use the bindings in your own software,
|
|
||||||
|
|
||||||
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
|
|
||||||
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
|
|
||||||
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
|
|
||||||
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
|
|
||||||
|
|
||||||
Look at the `Makefile` in the `bindings/go` directory for an example.
|
|
||||||
|
|
||||||
The API Documentation:
|
|
||||||
|
|
||||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
|
|
||||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
|
|
||||||
|
|
||||||
Getting help:
|
|
||||||
|
|
||||||
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
|
|
||||||
|
|
||||||
|
@ -17,14 +17,15 @@ import (
|
|||||||
// CONSTANTS
|
// CONSTANTS
|
||||||
|
|
||||||
const (
|
const (
|
||||||
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
srcUrl = "https://huggingface.co/" // The location of the models
|
||||||
srcExt = ".bin" // Filename extension
|
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
|
||||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
srcExt = ".bin" // Filename extension
|
||||||
|
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// The models which will be downloaded, if no model is specified as an argument
|
// The models which will be downloaded, if no model is specified as an argument
|
||||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
|
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -122,14 +123,11 @@ func GetModels() []string {
|
|||||||
|
|
||||||
// URLForModel returns the URL for the given model on huggingface.co
|
// URLForModel returns the URL for the given model on huggingface.co
|
||||||
func URLForModel(model string) (string, error) {
|
func URLForModel(model string) (string, error) {
|
||||||
if filepath.Ext(model) != srcExt {
|
|
||||||
model += srcExt
|
|
||||||
}
|
|
||||||
url, err := url.Parse(srcUrl)
|
url, err := url.Parse(srcUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
} else {
|
} else {
|
||||||
url.Path = filepath.Join(url.Path, model)
|
url.Path = srcPathPrefix + "-" + model + srcExt
|
||||||
}
|
}
|
||||||
return url.String(), nil
|
return url.String(), nil
|
||||||
}
|
}
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
|
||||||
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
|
||||||
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
|
||||||
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
|
||||||
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
|
||||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
|
||||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@ -1,5 +1,8 @@
|
|||||||
package whisper
|
package whisper
|
||||||
|
|
||||||
|
// This file defines the whisper_token, whisper_token_data and whisper_full_params
|
||||||
|
// structures, which are used by the whisper_full() function.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
@ -9,7 +9,8 @@ import (
|
|||||||
// CGO
|
// CGO
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#cgo LDFLAGS: -lwhisper -lm -lstdc++
|
#cgo CFLAGS: -I${SRCDIR}/../..
|
||||||
|
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
|
||||||
#cgo darwin LDFLAGS: -framework Accelerate
|
#cgo darwin LDFLAGS: -framework Accelerate
|
||||||
#include <whisper.h>
|
#include <whisper.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
@ -170,10 +171,6 @@ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the id of the specified language, returns -1 if not found
|
// Return the id of the specified language, returns -1 if not found
|
||||||
// Examples:
|
|
||||||
//
|
|
||||||
// "de" -> 2
|
|
||||||
// "german" -> 2
|
|
||||||
func (ctx *Context) Whisper_lang_id(lang string) int {
|
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||||
return int(C.whisper_lang_id(C.CString(lang)))
|
return int(C.whisper_lang_id(C.CString(lang)))
|
||||||
}
|
}
|
||||||
@ -214,10 +211,6 @@ func (ctx *Context) Whisper_n_text_ctx() int {
|
|||||||
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx *Context) Whisper_n_audio_ctx() int {
|
|
||||||
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) Whisper_is_multilingual() int {
|
func (ctx *Context) Whisper_is_multilingual() int {
|
||||||
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||||
}
|
}
|
||||||
|
@ -50,10 +50,7 @@ func Test_Whisper_001(t *testing.T) {
|
|||||||
ctx := whisper.Whisper_init(ModelPath)
|
ctx := whisper.Whisper_init(ModelPath)
|
||||||
assert.NotNil(ctx)
|
assert.NotNil(ctx)
|
||||||
defer ctx.Whisper_free()
|
defer ctx.Whisper_free()
|
||||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
|
||||||
data := buf.AsFloat32Buffer().Data
|
|
||||||
err = ctx.Whisper_full(params, data, nil, nil)
|
|
||||||
assert.NoError(err)
|
|
||||||
|
|
||||||
// Print out tokens
|
// Print out tokens
|
||||||
num_segments := ctx.Whisper_full_n_segments()
|
num_segments := ctx.Whisper_full_n_segments()
|
||||||
|
Submodule bindings/ios updated: 6707f1ea1c...1502317fe0
File diff suppressed because one or more lines are too long
@ -1,17 +0,0 @@
|
|||||||
# Set the default compile features and properties for a target.
|
|
||||||
|
|
||||||
if (NOT TARGET)
|
|
||||||
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_compile_features(${TARGET}
|
|
||||||
PRIVATE
|
|
||||||
cxx_std_11
|
|
||||||
)
|
|
||||||
|
|
||||||
set_target_properties(${TARGET}
|
|
||||||
PROPERTIES
|
|
||||||
EXPORT_COMPILE_COMMANDS ON
|
|
||||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
|
||||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
|
||||||
)
|
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
|||||||
emscripten.cpp
|
emscripten.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE
|
target_link_libraries(${TARGET} PRIVATE
|
||||||
whisper
|
whisper
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
set(TARGET bench)
|
set(TARGET bench)
|
||||||
add_executable(${TARGET} bench.cpp)
|
add_executable(${TARGET} bench.cpp)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
|||||||
emscripten.cpp
|
emscripten.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE
|
target_link_libraries(${TARGET} PRIVATE
|
||||||
whisper
|
whisper
|
||||||
)
|
)
|
||||||
|
@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
|
|||||||
# command
|
# command
|
||||||
set(TARGET command)
|
set(TARGET command)
|
||||||
add_executable(${TARGET} command.cpp)
|
add_executable(${TARGET} command.cpp)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -9,19 +9,7 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
|||||||
|
|
||||||
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
||||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
||||||
```
|
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
|
||||||
|
|
||||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
|
||||||
|
|
||||||
## Guided mode
|
|
||||||
|
|
||||||
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
|
|
||||||
|
|
||||||
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run in guided mode, the list of allowed commands is in commands.txt
|
# Run in guided mode, the list of allowed commands is in commands.txt
|
||||||
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
||||||
|
|
||||||
@ -29,8 +17,9 @@ Initial tests show that this approach might be extremely efficient in terms of p
|
|||||||
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
||||||
```
|
```
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
|
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||||
|
|
||||||
|
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
|
@ -510,333 +510,6 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
|||||||
return allowed_commands;
|
return allowed_commands;
|
||||||
}
|
}
|
||||||
|
|
||||||
// command-list mode
|
|
||||||
// guide the transcription to match the most likely command from a provided list
|
|
||||||
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
|
||||||
|
|
||||||
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
|
|
||||||
|
|
||||||
if (allowed_commands.empty()) {
|
|
||||||
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
int max_len = 0;
|
|
||||||
|
|
||||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
|
||||||
|
|
||||||
for (const auto & cmd : allowed_commands) {
|
|
||||||
whisper_token tokens[1024];
|
|
||||||
allowed_tokens.emplace_back();
|
|
||||||
|
|
||||||
for (int l = 0; l < (int) cmd.size(); ++l) {
|
|
||||||
// NOTE: very important to add the whitespace !
|
|
||||||
// the reason is that the first decoded token starts with a whitespace too!
|
|
||||||
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
|
||||||
|
|
||||||
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
|
||||||
if (n < 0) {
|
|
||||||
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n == 1) {
|
|
||||||
allowed_tokens.back().push_back(tokens[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
max_len = std::max(max_len, (int) cmd.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
|
||||||
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
|
||||||
for (const auto & token : allowed_tokens[i]) {
|
|
||||||
fprintf(stderr, " %5d", token);
|
|
||||||
}
|
|
||||||
fprintf(stderr, " ]\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string k_prompt = "select one from the available words: ";
|
|
||||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
|
||||||
if (i > 0) {
|
|
||||||
k_prompt += ", ";
|
|
||||||
}
|
|
||||||
k_prompt += allowed_commands[i];
|
|
||||||
}
|
|
||||||
k_prompt += ". selected word: ";
|
|
||||||
|
|
||||||
// tokenize prompt
|
|
||||||
std::vector<whisper_token> k_tokens;
|
|
||||||
{
|
|
||||||
k_tokens.resize(1024);
|
|
||||||
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
|
||||||
if (n < 0) {
|
|
||||||
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
|
||||||
return 4;
|
|
||||||
}
|
|
||||||
k_tokens.resize(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
|
||||||
fprintf(stderr, "%s: tokens: [", __func__);
|
|
||||||
for (const auto & token : k_tokens) {
|
|
||||||
fprintf(stderr, " %d", token);
|
|
||||||
}
|
|
||||||
fprintf(stderr, " ]\n");
|
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
|
|
||||||
bool is_running = true;
|
|
||||||
|
|
||||||
std::vector<float> pcmf32_cur;
|
|
||||||
std::vector<float> pcmf32_prompt;
|
|
||||||
|
|
||||||
// main loop
|
|
||||||
while (is_running) {
|
|
||||||
// handle Ctrl + C
|
|
||||||
{
|
|
||||||
SDL_Event event;
|
|
||||||
while (SDL_PollEvent(&event)) {
|
|
||||||
switch (event.type) {
|
|
||||||
case SDL_QUIT:
|
|
||||||
{
|
|
||||||
is_running = false;
|
|
||||||
} break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!is_running) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// delay
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
||||||
|
|
||||||
audio.get(2000, pcmf32_cur);
|
|
||||||
|
|
||||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
|
||||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
|
||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
|
||||||
|
|
||||||
wparams.print_progress = false;
|
|
||||||
wparams.print_special = params.print_special;
|
|
||||||
wparams.print_realtime = false;
|
|
||||||
wparams.print_timestamps = !params.no_timestamps;
|
|
||||||
wparams.translate = params.translate;
|
|
||||||
wparams.no_context = true;
|
|
||||||
wparams.single_segment = true;
|
|
||||||
wparams.max_tokens = 1;
|
|
||||||
wparams.language = params.language.c_str();
|
|
||||||
wparams.n_threads = params.n_threads;
|
|
||||||
|
|
||||||
wparams.audio_ctx = params.audio_ctx;
|
|
||||||
wparams.speed_up = params.speed_up;
|
|
||||||
|
|
||||||
wparams.prompt_tokens = k_tokens.data();
|
|
||||||
wparams.prompt_n_tokens = k_tokens.size();
|
|
||||||
|
|
||||||
// run the transformer and a single decoding pass
|
|
||||||
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
|
||||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto * probs = whisper_get_probs(ctx);
|
|
||||||
std::vector<std::pair<float, int>> probs_id;
|
|
||||||
|
|
||||||
double psum = 0.0;
|
|
||||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
|
||||||
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
|
||||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
|
||||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
|
||||||
}
|
|
||||||
probs_id.back().first /= allowed_tokens[i].size();
|
|
||||||
psum += probs_id.back().first;
|
|
||||||
}
|
|
||||||
|
|
||||||
// normalize
|
|
||||||
for (auto & p : probs_id) {
|
|
||||||
p.first /= psum;
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort descending
|
|
||||||
{
|
|
||||||
using pair_type = decltype(probs_id)::value_type;
|
|
||||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
|
||||||
return a.first > b.first;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// print the commands and the respective probabilities
|
|
||||||
{
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
for (const auto & cmd : probs_id) {
|
|
||||||
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
|
||||||
for (int token : allowed_tokens[cmd.second]) {
|
|
||||||
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
|
||||||
}
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// best command
|
|
||||||
{
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
const float prob = probs_id[0].first;
|
|
||||||
const int index = probs_id[0].second;
|
|
||||||
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
|
||||||
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
|
||||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
audio.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// general-purpose mode
|
|
||||||
// freely transcribe the voice into text
|
|
||||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
|
||||||
bool is_running = true;
|
|
||||||
bool have_prompt = false;
|
|
||||||
bool ask_prompt = true;
|
|
||||||
|
|
||||||
float prob0 = 0.0f;
|
|
||||||
float prob = 0.0f;
|
|
||||||
|
|
||||||
std::vector<float> pcmf32_cur;
|
|
||||||
std::vector<float> pcmf32_prompt;
|
|
||||||
|
|
||||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
|
||||||
|
|
||||||
// main loop
|
|
||||||
while (is_running) {
|
|
||||||
// handle Ctrl + C
|
|
||||||
{
|
|
||||||
SDL_Event event;
|
|
||||||
while (SDL_PollEvent(&event)) {
|
|
||||||
switch (event.type) {
|
|
||||||
case SDL_QUIT:
|
|
||||||
{
|
|
||||||
is_running = false;
|
|
||||||
} break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!is_running) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// delay
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
||||||
|
|
||||||
if (ask_prompt) {
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
|
|
||||||
ask_prompt = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
audio.get(2000, pcmf32_cur);
|
|
||||||
|
|
||||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
|
||||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
|
||||||
|
|
||||||
int64_t t_ms = 0;
|
|
||||||
|
|
||||||
if (!have_prompt) {
|
|
||||||
// wait for activation phrase
|
|
||||||
audio.get(params.prompt_ms, pcmf32_cur);
|
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
|
||||||
|
|
||||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
|
||||||
|
|
||||||
const float sim = similarity(txt, k_prompt);
|
|
||||||
|
|
||||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
|
||||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
|
||||||
ask_prompt = true;
|
|
||||||
} else {
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
|
||||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
|
|
||||||
// save the audio for the prompt
|
|
||||||
pcmf32_prompt = pcmf32_cur;
|
|
||||||
have_prompt = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// we have heard the activation phrase, now detect the commands
|
|
||||||
audio.get(params.command_ms, pcmf32_cur);
|
|
||||||
|
|
||||||
// prepend the prompt audio
|
|
||||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
|
||||||
|
|
||||||
prob = 100.0f*(prob - prob0);
|
|
||||||
|
|
||||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
|
||||||
|
|
||||||
// find the prompt in the text
|
|
||||||
float best_sim = 0.0f;
|
|
||||||
size_t best_len = 0;
|
|
||||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
|
||||||
const auto prompt = txt.substr(0, n);
|
|
||||||
|
|
||||||
const float sim = similarity(prompt, k_prompt);
|
|
||||||
|
|
||||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
|
||||||
|
|
||||||
if (sim > best_sim) {
|
|
||||||
best_sim = sim;
|
|
||||||
best_len = n;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string command = ::trim(txt.substr(best_len));
|
|
||||||
|
|
||||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
audio.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
whisper_params params;
|
whisper_params params;
|
||||||
|
|
||||||
@ -888,12 +561,300 @@ int main(int argc, char ** argv) {
|
|||||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||||
audio.clear();
|
audio.clear();
|
||||||
|
|
||||||
int ret_val = 0;
|
int max_len = 0;
|
||||||
|
|
||||||
|
bool is_running = true;
|
||||||
|
bool have_prompt = false;
|
||||||
|
bool ask_prompt = true;
|
||||||
|
|
||||||
|
float prob0 = 0.0f;
|
||||||
|
float prob = 0.0f;
|
||||||
|
|
||||||
|
std::vector<float> pcmf32_cur;
|
||||||
|
std::vector<float> pcmf32_prompt;
|
||||||
|
|
||||||
|
std::vector<std::string> allowed_commands;
|
||||||
|
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||||
|
|
||||||
|
std::string k_prompt;
|
||||||
|
std::vector<whisper_token> k_tokens;
|
||||||
|
|
||||||
if (!params.commands.empty()) {
|
if (!params.commands.empty()) {
|
||||||
ret_val = process_command_list(ctx, audio, params);
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||||
|
|
||||||
|
allowed_commands = read_allowed_commands(params.commands);
|
||||||
|
|
||||||
|
if (allowed_commands.empty()) {
|
||||||
|
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & cmd : allowed_commands) {
|
||||||
|
whisper_token tokens[1024];
|
||||||
|
allowed_tokens.emplace_back();
|
||||||
|
|
||||||
|
for (int l = 0; l < (int) cmd.size(); ++l) {
|
||||||
|
// NOTE: very important to add the whitespace !
|
||||||
|
// the reason is that the first decoded token starts with a whitespace too!
|
||||||
|
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
||||||
|
|
||||||
|
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
||||||
|
if (n < 0) {
|
||||||
|
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n == 1) {
|
||||||
|
allowed_tokens.back().push_back(tokens[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
max_len = std::max(max_len, (int) cmd.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||||
|
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
||||||
|
for (const auto & token : allowed_tokens[i]) {
|
||||||
|
fprintf(stderr, " %5d", token);
|
||||||
|
}
|
||||||
|
fprintf(stderr, " ]\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
k_prompt = "select one from the available words: ";
|
||||||
|
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||||
|
if (i > 0) {
|
||||||
|
k_prompt += ", ";
|
||||||
|
}
|
||||||
|
k_prompt += allowed_commands[i];
|
||||||
|
}
|
||||||
|
k_prompt += ". selected word: ";
|
||||||
|
|
||||||
|
// tokenize prompt
|
||||||
|
{
|
||||||
|
k_tokens.resize(1024);
|
||||||
|
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
||||||
|
if (n < 0) {
|
||||||
|
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
k_tokens.resize(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
||||||
|
fprintf(stderr, "%s: tokens: [", __func__);
|
||||||
|
for (const auto & token : k_tokens) {
|
||||||
|
fprintf(stderr, " %d", token);
|
||||||
|
}
|
||||||
|
fprintf(stderr, " ]\n");
|
||||||
|
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
ret_val = process_general_transcription(ctx, audio, params);
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||||
|
|
||||||
|
k_prompt = "Ok Whisper, start listening for commands.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// main loop
|
||||||
|
while (is_running) {
|
||||||
|
// handle Ctrl + C
|
||||||
|
{
|
||||||
|
SDL_Event event;
|
||||||
|
while (SDL_PollEvent(&event)) {
|
||||||
|
switch (event.type) {
|
||||||
|
case SDL_QUIT:
|
||||||
|
{
|
||||||
|
is_running = false;
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_running) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// delay
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
|
||||||
|
if (allowed_commands.empty()) {
|
||||||
|
// general-purpose mode
|
||||||
|
// freely transcribe the voice into text
|
||||||
|
|
||||||
|
if (ask_prompt) {
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
|
||||||
|
ask_prompt = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
int64_t t_ms = 0;
|
||||||
|
|
||||||
|
audio.get(2000, pcmf32_cur);
|
||||||
|
|
||||||
|
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||||
|
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||||
|
|
||||||
|
if (!have_prompt) {
|
||||||
|
// wait for activation phrase
|
||||||
|
audio.get(params.prompt_ms, pcmf32_cur);
|
||||||
|
|
||||||
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||||
|
|
||||||
|
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||||
|
|
||||||
|
const float sim = similarity(txt, k_prompt);
|
||||||
|
|
||||||
|
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||||
|
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||||
|
ask_prompt = true;
|
||||||
|
} else {
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||||
|
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
|
||||||
|
// save the audio for the prompt
|
||||||
|
pcmf32_prompt = pcmf32_cur;
|
||||||
|
have_prompt = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// we have heard the activation phrase, now detect the commands
|
||||||
|
audio.get(params.command_ms, pcmf32_cur);
|
||||||
|
|
||||||
|
// prepend the prompt audio
|
||||||
|
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||||
|
|
||||||
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||||
|
|
||||||
|
prob = 100.0f*(prob - prob0);
|
||||||
|
|
||||||
|
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||||
|
|
||||||
|
// find the prompt in the text
|
||||||
|
float best_sim = 0.0f;
|
||||||
|
size_t best_len = 0;
|
||||||
|
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||||
|
const auto prompt = txt.substr(0, n);
|
||||||
|
|
||||||
|
const float sim = similarity(prompt, k_prompt);
|
||||||
|
|
||||||
|
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||||
|
|
||||||
|
if (sim > best_sim) {
|
||||||
|
best_sim = sim;
|
||||||
|
best_len = n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string command = ::trim(txt.substr(best_len));
|
||||||
|
|
||||||
|
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
audio.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// command-list mode
|
||||||
|
// guide the transcription to match the most likely command from a provided list
|
||||||
|
|
||||||
|
audio.get(2000, pcmf32_cur);
|
||||||
|
|
||||||
|
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||||
|
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||||
|
|
||||||
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
|
|
||||||
|
wparams.print_progress = false;
|
||||||
|
wparams.print_special = params.print_special;
|
||||||
|
wparams.print_realtime = false;
|
||||||
|
wparams.print_timestamps = !params.no_timestamps;
|
||||||
|
wparams.translate = params.translate;
|
||||||
|
wparams.no_context = true;
|
||||||
|
wparams.single_segment = true;
|
||||||
|
wparams.max_tokens = 1;
|
||||||
|
wparams.language = params.language.c_str();
|
||||||
|
wparams.n_threads = params.n_threads;
|
||||||
|
|
||||||
|
wparams.audio_ctx = params.audio_ctx;
|
||||||
|
wparams.speed_up = params.speed_up;
|
||||||
|
|
||||||
|
wparams.prompt_tokens = k_tokens.data();
|
||||||
|
wparams.prompt_n_tokens = k_tokens.size();
|
||||||
|
|
||||||
|
// run the transformer and a single decoding pass
|
||||||
|
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
||||||
|
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto * probs = whisper_get_probs(ctx);
|
||||||
|
std::vector<std::pair<float, int>> probs_id;
|
||||||
|
|
||||||
|
double psum = 0.0;
|
||||||
|
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||||
|
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
||||||
|
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||||
|
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||||
|
}
|
||||||
|
probs_id.back().first /= allowed_tokens[i].size();
|
||||||
|
psum += probs_id.back().first;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize
|
||||||
|
for (auto & p : probs_id) {
|
||||||
|
p.first /= psum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort descending
|
||||||
|
{
|
||||||
|
using pair_type = decltype(probs_id)::value_type;
|
||||||
|
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// print the commands and the respective probabilities
|
||||||
|
{
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
for (const auto & cmd : probs_id) {
|
||||||
|
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
||||||
|
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
|
||||||
|
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
|
||||||
|
}
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// best command
|
||||||
|
{
|
||||||
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||||
|
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
|
||||||
|
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
audio.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
audio.pause();
|
audio.pause();
|
||||||
@ -901,5 +862,5 @@ int main(int argc, char ** argv) {
|
|||||||
whisper_print_timings(ctx);
|
whisper_print_timings(ctx);
|
||||||
whisper_free(ctx);
|
whisper_free(ctx);
|
||||||
|
|
||||||
return ret_val;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
set(TARGET main)
|
set(TARGET main)
|
||||||
add_executable(${TARGET} main.cpp)
|
add_executable(${TARGET} main.cpp)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
@ -69,7 +69,6 @@ struct whisper_params {
|
|||||||
bool output_vtt = false;
|
bool output_vtt = false;
|
||||||
bool output_srt = false;
|
bool output_srt = false;
|
||||||
bool output_wts = false;
|
bool output_wts = false;
|
||||||
bool output_csv = false;
|
|
||||||
bool print_special = false;
|
bool print_special = false;
|
||||||
bool print_colors = false;
|
bool print_colors = false;
|
||||||
bool print_progress = false;
|
bool print_progress = false;
|
||||||
@ -112,7 +111,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
|
||||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||||
@ -152,7 +150,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
|
||||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||||
@ -176,81 +173,90 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
|||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
|
|
||||||
std::string speaker = "";
|
|
||||||
|
|
||||||
int64_t t0;
|
|
||||||
int64_t t1;
|
|
||||||
|
|
||||||
// print the last n_new segments
|
// print the last n_new segments
|
||||||
const int s0 = n_segments - n_new;
|
const int s0 = n_segments - n_new;
|
||||||
|
|
||||||
if (s0 == 0) {
|
if (s0 == 0) {
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = s0; i < n_segments; i++) {
|
for (int i = s0; i < n_segments; i++) {
|
||||||
if (!params.no_timestamps || params.diarize) {
|
if (params.no_timestamps) {
|
||||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
if (params.print_colors) {
|
||||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||||
}
|
if (params.print_special == false) {
|
||||||
|
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||||
if (!params.no_timestamps) {
|
if (id >= whisper_token_eot(ctx)) {
|
||||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.diarize && pcmf32s.size() == 2) {
|
|
||||||
const int64_t n_samples = pcmf32s[0].size();
|
|
||||||
|
|
||||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
|
||||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
|
||||||
|
|
||||||
double energy0 = 0.0f;
|
|
||||||
double energy1 = 0.0f;
|
|
||||||
|
|
||||||
for (int64_t j = is0; j < is1; j++) {
|
|
||||||
energy0 += fabs(pcmf32s[0][j]);
|
|
||||||
energy1 += fabs(pcmf32s[1][j]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (energy0 > 1.1*energy1) {
|
|
||||||
speaker = "(speaker 0)";
|
|
||||||
} else if (energy1 > 1.1*energy0) {
|
|
||||||
speaker = "(speaker 1)";
|
|
||||||
} else {
|
|
||||||
speaker = "(speaker ?)";
|
|
||||||
}
|
|
||||||
|
|
||||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.print_colors) {
|
|
||||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
|
||||||
if (params.print_special == false) {
|
|
||||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
|
||||||
if (id >= whisper_token_eot(ctx)) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||||
|
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||||
|
|
||||||
|
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||||
|
|
||||||
|
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
|
printf("%s", text);
|
||||||
|
}
|
||||||
|
fflush(stdout);
|
||||||
|
} else {
|
||||||
|
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||||
|
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||||
|
|
||||||
|
std::string speaker;
|
||||||
|
|
||||||
|
if (params.diarize && pcmf32s.size() == 2) {
|
||||||
|
const int64_t n_samples = pcmf32s[0].size();
|
||||||
|
|
||||||
|
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||||
|
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||||
|
|
||||||
|
double energy0 = 0.0f;
|
||||||
|
double energy1 = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t j = is0; j < is1; j++) {
|
||||||
|
energy0 += fabs(pcmf32s[0][j]);
|
||||||
|
energy1 += fabs(pcmf32s[1][j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
if (energy0 > 1.1*energy1) {
|
||||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
speaker = "(speaker 0)";
|
||||||
|
} else if (energy1 > 1.1*energy0) {
|
||||||
|
speaker = "(speaker 1)";
|
||||||
|
} else {
|
||||||
|
speaker = "(speaker ?)";
|
||||||
|
}
|
||||||
|
|
||||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||||
|
|
||||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
|
|
||||||
printf("%s%s", speaker.c_str(), text);
|
if (params.print_colors) {
|
||||||
|
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||||
|
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||||
|
if (params.print_special == false) {
|
||||||
|
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||||
|
if (id >= whisper_token_eot(ctx)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||||
|
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||||
|
|
||||||
|
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||||
|
|
||||||
|
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
} else {
|
||||||
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
|
|
||||||
|
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// with timestamps or speakers: each segment on new line
|
|
||||||
if (!params.no_timestamps || params.diarize) {
|
|
||||||
printf("\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -319,32 +325,6 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
|
||||||
std::ofstream fout(fname);
|
|
||||||
if (!fout.is_open()) {
|
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
if (text[0] == ' ')
|
|
||||||
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
|
|
||||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
|
||||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
|
||||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
|
||||||
fout << 10 * t0 << ", "
|
|
||||||
<< 10 * t1 << ", \""
|
|
||||||
<< text << "\"\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// karaoke video generation
|
// karaoke video generation
|
||||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||||
// TODO: font parameter adjustments
|
// TODO: font parameter adjustments
|
||||||
@ -548,7 +528,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
|
||||||
return 8;
|
return 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -694,13 +674,6 @@ int main(int argc, char ** argv) {
|
|||||||
const auto fname_wts = fname_inp + ".wts";
|
const auto fname_wts = fname_inp + ".wts";
|
||||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// output to CSV file
|
|
||||||
if (params.output_csv) {
|
|
||||||
const auto fname_csv = fname_inp + ".csv";
|
|
||||||
output_csv(ctx, fname_csv.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
|||||||
emscripten.cpp
|
emscripten.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE
|
target_link_libraries(${TARGET} PRIVATE
|
||||||
whisper
|
whisper
|
||||||
)
|
)
|
||||||
|
@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
|
|||||||
# stream
|
# stream
|
||||||
set(TARGET stream)
|
set(TARGET stream)
|
||||||
add_executable(${TARGET} stream.cpp)
|
add_executable(${TARGET} stream.cpp)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
#include <SDL.h>
|
#include <SDL.h>
|
||||||
#include <SDL_audio.h>
|
#include <SDL_audio.h>
|
||||||
|
|
||||||
#include <atomic>
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -145,8 +144,8 @@ private:
|
|||||||
int m_len_ms = 0;
|
int m_len_ms = 0;
|
||||||
int m_sample_rate = 0;
|
int m_sample_rate = 0;
|
||||||
|
|
||||||
std::atomic_bool m_running;
|
bool m_running = false;
|
||||||
std::mutex m_mutex;
|
std::mutex m_mutex;
|
||||||
|
|
||||||
std::vector<float> m_audio;
|
std::vector<float> m_audio;
|
||||||
std::vector<float> m_audio_new;
|
std::vector<float> m_audio_new;
|
||||||
@ -156,8 +155,6 @@ private:
|
|||||||
|
|
||||||
audio_async::audio_async(int len_ms) {
|
audio_async::audio_async(int len_ms) {
|
||||||
m_len_ms = len_ms;
|
m_len_ms = len_ms;
|
||||||
|
|
||||||
m_running = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
audio_async::~audio_async() {
|
audio_async::~audio_async() {
|
||||||
@ -430,9 +427,9 @@ int main(int argc, char ** argv) {
|
|||||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||||
|
|
||||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
|
||||||
|
|
||||||
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
|
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||||
|
|
||||||
params.no_timestamps = !use_vad;
|
params.no_timestamps = !use_vad;
|
||||||
params.no_context = use_vad;
|
params.no_context = use_vad;
|
||||||
|
@ -9,8 +9,6 @@ add_executable(${TARGET}
|
|||||||
gpt-2.cpp
|
gpt-2.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE
|
target_link_libraries(${TARGET} PRIVATE
|
||||||
whisper
|
whisper
|
||||||
)
|
)
|
||||||
@ -33,8 +31,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
|||||||
--bind \
|
--bind \
|
||||||
-s USE_PTHREADS=1 \
|
-s USE_PTHREADS=1 \
|
||||||
-s PTHREAD_POOL_SIZE=8 \
|
-s PTHREAD_POOL_SIZE=8 \
|
||||||
-s INITIAL_MEMORY=1800MB \
|
-s INITIAL_MEMORY=1600MB \
|
||||||
-s TOTAL_MEMORY=1800MB \
|
-s TOTAL_MEMORY=1600MB \
|
||||||
-s FORCE_FILESYSTEM=1 \
|
-s FORCE_FILESYSTEM=1 \
|
||||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||||
${EXTRA_FLAGS} \
|
${EXTRA_FLAGS} \
|
||||||
|
@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
|
|||||||
- Latest Chrome or Firefox browser (Safari is not supported)
|
- Latest Chrome or Firefox browser (Safari is not supported)
|
||||||
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
|
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
|
||||||
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
|
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
|
||||||
- The web-page uses about 1.8GB of RAM
|
- The web-page uses about 1.6GB of RAM
|
||||||
|
|
||||||
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
|
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
|
||||||
Also, the prompting strategy can likely be improved to achieve better results.
|
Also, the prompting strategy can likely be improved to achieve better results.
|
||||||
|
@ -8,9 +8,6 @@ if (WHISPER_SUPPORT_SDL2)
|
|||||||
# TODO: this is temporary
|
# TODO: this is temporary
|
||||||
# need to export ggml symbols for MSVC, but too lazy ..
|
# need to export ggml symbols for MSVC, but too lazy ..
|
||||||
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
|||||||
emscripten.cpp
|
emscripten.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE
|
target_link_libraries(${TARGET} PRIVATE
|
||||||
whisper
|
whisper
|
||||||
)
|
)
|
||||||
|
14
ggml.h
14
ggml.h
@ -177,11 +177,13 @@ extern "C" {
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
#define GGML_MAX_DIMS 4
|
#define GGML_MAX_DIMS 4
|
||||||
#define GGML_MAX_NODES 4096
|
#define GGML_MAX_NODES 4096
|
||||||
#define GGML_MAX_PARAMS 16
|
#define GGML_MAX_PARAMS 16
|
||||||
#define GGML_MAX_CONTEXTS 64
|
#define GGML_MAX_CONTEXTS 64
|
||||||
#define GGML_MAX_OPT 4
|
#define GGML_MAX_OPT 4
|
||||||
|
#define GGML_MAX_THREADS 64
|
||||||
|
#define GGML_MAX_THREAD_POOLS 16
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
// we use the built-in 16-bit float type
|
// we use the built-in 16-bit float type
|
||||||
@ -731,8 +733,6 @@ int ggml_cpu_has_f16c(void);
|
|||||||
int ggml_cpu_has_fp16_va(void);
|
int ggml_cpu_has_fp16_va(void);
|
||||||
int ggml_cpu_has_wasm_simd(void);
|
int ggml_cpu_has_wasm_simd(void);
|
||||||
int ggml_cpu_has_blas(void);
|
int ggml_cpu_has_blas(void);
|
||||||
int ggml_cpu_has_sse3(void);
|
|
||||||
int ggml_cpu_has_vsx(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ def bytes_to_unicode():
|
|||||||
The reversible bpe codes work on unicode strings.
|
The reversible bpe codes work on unicode strings.
|
||||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
"""
|
"""
|
||||||
|
@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
|
|||||||
goto :eof
|
goto :eof
|
||||||
)
|
)
|
||||||
|
|
||||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
||||||
|
|
||||||
if %ERRORLEVEL% neq 0 (
|
if %ERRORLEVEL% neq 0 (
|
||||||
echo Failed to download ggml model %model%
|
echo Failed to download ggml model %model%
|
||||||
|
181
whisper.cpp
181
whisper.cpp
@ -204,10 +204,6 @@ struct whisper_vocab {
|
|||||||
std::map<token, id> token_to_id;
|
std::map<token, id> token_to_id;
|
||||||
std::map<id, token> id_to_token;
|
std::map<id, token> id_to_token;
|
||||||
|
|
||||||
// used to avoid memory allocations during sampling
|
|
||||||
// TODO: move to whisper_context in the future
|
|
||||||
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
||||||
|
|
||||||
id token_eot = 50256;
|
id token_eot = 50256;
|
||||||
id token_sot = 50257;
|
id token_sot = 50257;
|
||||||
id token_prev = 50360;
|
id token_prev = 50360;
|
||||||
@ -412,8 +408,6 @@ struct whisper_context {
|
|||||||
std::vector<uint8_t> buf_compute;
|
std::vector<uint8_t> buf_compute;
|
||||||
std::vector<uint8_t> buf_compute_layer;
|
std::vector<uint8_t> buf_compute_layer;
|
||||||
|
|
||||||
ggml_type wtype; // weight type (FP32 or FP16)
|
|
||||||
|
|
||||||
whisper_model model;
|
whisper_model model;
|
||||||
whisper_vocab vocab;
|
whisper_vocab vocab;
|
||||||
|
|
||||||
@ -437,8 +431,9 @@ struct whisper_context {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void read_safe(std::ifstream& fin, T& dest) {
|
static void read_safe(std::ifstream& fin, T& dest)
|
||||||
fin.read((char*)& dest, sizeof(T));
|
{
|
||||||
|
fin.read((char*)& dest, sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
// load the model from a ggml file
|
// load the model from a ggml file
|
||||||
@ -556,9 +551,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
|
|
||||||
std::string word;
|
std::string word;
|
||||||
std::vector<char> tmp;
|
std::vector<char> tmp;
|
||||||
|
|
||||||
tmp.reserve(128);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_vocab; i++) {
|
for (int i = 0; i < n_vocab; i++) {
|
||||||
uint32_t len;
|
uint32_t len;
|
||||||
read_safe(fin, len);
|
read_safe(fin, len);
|
||||||
@ -611,11 +603,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
vocab.id_to_token[i] = word;
|
vocab.id_to_token[i] = word;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
||||||
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
||||||
|
|
||||||
vocab.probs_id.reserve(n_vocab);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -631,9 +618,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
|
|
||||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||||
// in order to save memory and also to speed up the computation
|
// in order to save memory and also to speed up the computation
|
||||||
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
const ggml_type wtype = wctx.wtype;
|
|
||||||
|
|
||||||
size_t ctx_size = 0;
|
size_t ctx_size = 0;
|
||||||
|
|
||||||
@ -654,6 +639,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
|
|
||||||
// encoder
|
// encoder
|
||||||
{
|
{
|
||||||
|
// TODO: F16 .. maybe not?
|
||||||
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
|
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
|
||||||
|
|
||||||
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
|
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
|
||||||
@ -668,6 +654,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
|
|
||||||
// decoder
|
// decoder
|
||||||
{
|
{
|
||||||
|
// TODO: F16 .. maybe not?
|
||||||
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
|
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
|
||||||
|
|
||||||
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
|
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
|
||||||
@ -984,8 +971,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
const int n_mem = n_text_layer*n_text_ctx;
|
const int n_mem = n_text_layer*n_text_ctx;
|
||||||
const int n_elements = n_text_state*n_mem;
|
const int n_elements = n_text_state*n_mem;
|
||||||
|
|
||||||
model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
// key/value memory for the cross-attention layer
|
// key/value memory for the cross-attention layer
|
||||||
@ -995,8 +982,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
const int n_mem = n_text_layer*n_audio_ctx;
|
const int n_mem = n_text_layer*n_audio_ctx;
|
||||||
const int n_elements = n_text_state*n_mem;
|
const int n_elements = n_text_state*n_mem;
|
||||||
|
|
||||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t memory_size =
|
const size_t memory_size =
|
||||||
@ -1034,7 +1021,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
|||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
std::vector<char> tmp(length); // create a buffer
|
std::vector<char> tmp(length); // create a buffer
|
||||||
fin.read(&tmp[0], tmp.size()); // read to buffer
|
fin.read( &tmp[0], tmp.size() ); // read to buffer
|
||||||
name.assign(&tmp[0], tmp.size());
|
name.assign(&tmp[0], tmp.size());
|
||||||
|
|
||||||
if (model.tensors.find(name) == model.tensors.end()) {
|
if (model.tensors.find(name) == model.tensors.end()) {
|
||||||
@ -1242,14 +1229,14 @@ static bool whisper_encode(
|
|||||||
ggml_permute(ctxL,
|
ggml_permute(ctxL,
|
||||||
ggml_cpy(ctxL,
|
ggml_cpy(ctxL,
|
||||||
Qcur,
|
Qcur,
|
||||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_permute(ctxL,
|
ggml_permute(ctxL,
|
||||||
ggml_cpy(ctxL,
|
ggml_cpy(ctxL,
|
||||||
Kcur,
|
Kcur,
|
||||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
@ -1259,7 +1246,7 @@ static bool whisper_encode(
|
|||||||
Vcur,
|
Vcur,
|
||||||
n_state/n_head, n_head, n_ctx),
|
n_state/n_head, n_head, n_ctx),
|
||||||
1, 2, 0, 3),
|
1, 2, 0, 3),
|
||||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
|
||||||
);
|
);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
||||||
@ -1275,7 +1262,7 @@ static bool whisper_encode(
|
|||||||
ggml_permute(ctxL,
|
ggml_permute(ctxL,
|
||||||
ggml_cpy(ctxL,
|
ggml_cpy(ctxL,
|
||||||
Kcur,
|
Kcur,
|
||||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
@ -1293,7 +1280,7 @@ static bool whisper_encode(
|
|||||||
// ggml_permute(ctxL,
|
// ggml_permute(ctxL,
|
||||||
// ggml_cpy(ctxL,
|
// ggml_cpy(ctxL,
|
||||||
// Vcur,
|
// Vcur,
|
||||||
// ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||||
// 1, 2, 0, 3);
|
// 1, 2, 0, 3);
|
||||||
|
|
||||||
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||||
@ -1305,7 +1292,7 @@ static bool whisper_encode(
|
|||||||
Vcur,
|
Vcur,
|
||||||
n_state/n_head, n_head, n_ctx),
|
n_state/n_head, n_head, n_ctx),
|
||||||
0, 2, 1, 3),
|
0, 2, 1, 3),
|
||||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
|
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
|
||||||
);
|
);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
|
||||||
@ -1350,7 +1337,7 @@ static bool whisper_encode(
|
|||||||
|
|
||||||
#ifdef USE_FLASH_FF
|
#ifdef USE_FLASH_FF
|
||||||
cur = ggml_flash_ff(ctxL,
|
cur = ggml_flash_ff(ctxL,
|
||||||
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
|
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
|
||||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||||
#else
|
#else
|
||||||
// fully connected
|
// fully connected
|
||||||
@ -1457,7 +1444,7 @@ static bool whisper_encode(
|
|||||||
layer.cross_attn_k_w,
|
layer.cross_attn_k_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
||||||
layer.cross_attn_v_w,
|
layer.cross_attn_v_w,
|
||||||
@ -1579,14 +1566,14 @@ static bool whisper_decode(
|
|||||||
Qcur),
|
Qcur),
|
||||||
Qcur);
|
Qcur);
|
||||||
|
|
||||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
// note: no bias for Key
|
// note: no bias for Key
|
||||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
||||||
layer.attn_k_w,
|
layer.attn_k_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
||||||
layer.attn_v_w,
|
layer.attn_v_w,
|
||||||
@ -1609,33 +1596,6 @@ static bool whisper_decode(
|
|||||||
|
|
||||||
// ------
|
// ------
|
||||||
|
|
||||||
#ifdef USE_FLASH_ATTN
|
|
||||||
struct ggml_tensor * Q =
|
|
||||||
ggml_permute(ctxL,
|
|
||||||
ggml_cpy(ctxL,
|
|
||||||
Qcur,
|
|
||||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
|
||||||
0, 2, 1, 3);
|
|
||||||
|
|
||||||
struct ggml_tensor * K =
|
|
||||||
ggml_permute(ctxL,
|
|
||||||
ggml_reshape_3d(ctxL,
|
|
||||||
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
|
|
||||||
n_state/n_head, n_head, n_past + N),
|
|
||||||
0, 2, 1, 3);
|
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
|
||||||
ggml_cpy(ctxL,
|
|
||||||
ggml_permute(ctxL,
|
|
||||||
ggml_reshape_3d(ctxL,
|
|
||||||
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
|
|
||||||
n_state/n_head, n_head, n_past + N),
|
|
||||||
1, 2, 0, 3),
|
|
||||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
|
|
||||||
);
|
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
|
|
||||||
#else
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctxL,
|
ggml_permute(ctxL,
|
||||||
ggml_cpy(ctxL,
|
ggml_cpy(ctxL,
|
||||||
@ -1653,13 +1613,13 @@ static bool whisper_decode(
|
|||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
|
||||||
|
|
||||||
struct ggml_tensor * KQ_scaled =
|
//struct ggml_tensor * KQ_scaled =
|
||||||
ggml_scale(ctxL,
|
// ggml_scale(ctxL,
|
||||||
KQ,
|
// KQ,
|
||||||
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
||||||
);
|
// );
|
||||||
|
|
||||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
|
||||||
|
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
|
||||||
|
|
||||||
@ -1671,7 +1631,6 @@ static bool whisper_decode(
|
|||||||
1, 2, 0, 3);
|
1, 2, 0, 3);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||||
#endif
|
|
||||||
|
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
@ -1717,7 +1676,7 @@ static bool whisper_decode(
|
|||||||
Qcur),
|
Qcur),
|
||||||
Qcur);
|
Qcur);
|
||||||
|
|
||||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
// Kcross is already scaled
|
// Kcross is already scaled
|
||||||
struct ggml_tensor * Kcross =
|
struct ggml_tensor * Kcross =
|
||||||
@ -1732,24 +1691,6 @@ static bool whisper_decode(
|
|||||||
|
|
||||||
// ------
|
// ------
|
||||||
|
|
||||||
#ifdef USE_FLASH_ATTN
|
|
||||||
struct ggml_tensor * Q =
|
|
||||||
ggml_permute(ctxL,
|
|
||||||
ggml_cpy(ctxL,
|
|
||||||
Qcur,
|
|
||||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
|
||||||
0, 2, 1, 3);
|
|
||||||
|
|
||||||
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
|
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
|
||||||
ggml_cpy(ctxL,
|
|
||||||
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
|
|
||||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
|
|
||||||
);
|
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
|
||||||
#else
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctxL,
|
ggml_permute(ctxL,
|
||||||
ggml_cpy(ctxL,
|
ggml_cpy(ctxL,
|
||||||
@ -1776,7 +1717,6 @@ static bool whisper_decode(
|
|||||||
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
|
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||||
#endif
|
|
||||||
|
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
@ -1909,7 +1849,7 @@ static bool whisper_decode(
|
|||||||
|
|
||||||
// the most basic sampling scheme - select the top token
|
// the most basic sampling scheme - select the top token
|
||||||
static whisper_token_data whisper_sample_best(
|
static whisper_token_data whisper_sample_best(
|
||||||
whisper_vocab & vocab,
|
const whisper_vocab & vocab,
|
||||||
const float * probs,
|
const float * probs,
|
||||||
bool force_timestamp,
|
bool force_timestamp,
|
||||||
bool is_initial) {
|
bool is_initial) {
|
||||||
@ -1917,11 +1857,11 @@ static whisper_token_data whisper_sample_best(
|
|||||||
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
||||||
};
|
};
|
||||||
|
|
||||||
const int n_logits = vocab.n_vocab;
|
int n_logits = vocab.id_to_token.size();
|
||||||
|
|
||||||
auto & probs_id = vocab.probs_id;
|
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
||||||
|
probs_id.reserve(n_logits);
|
||||||
|
|
||||||
probs_id.clear();
|
|
||||||
for (int i = 0; i < n_logits; i++) {
|
for (int i = 0; i < n_logits; i++) {
|
||||||
probs_id.emplace_back(probs[i], i);
|
probs_id.emplace_back(probs[i], i);
|
||||||
}
|
}
|
||||||
@ -2061,9 +2001,6 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|||||||
std::vector<float> even;
|
std::vector<float> even;
|
||||||
std::vector<float> odd;
|
std::vector<float> odd;
|
||||||
|
|
||||||
even.reserve(N/2);
|
|
||||||
odd.reserve(N/2);
|
|
||||||
|
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if (i % 2 == 0) {
|
if (i % 2 == 0) {
|
||||||
even.push_back(in[i]);
|
even.push_back(in[i]);
|
||||||
@ -2497,7 +2434,7 @@ int whisper_lang_auto_detect(
|
|||||||
std::vector<std::pair<float, int>> probs_id;
|
std::vector<std::pair<float, int>> probs_id;
|
||||||
for (const auto & kv : g_lang) {
|
for (const auto & kv : g_lang) {
|
||||||
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
||||||
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
|
probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
|
||||||
}
|
}
|
||||||
|
|
||||||
// sort descending
|
// sort descending
|
||||||
@ -2521,12 +2458,12 @@ int whisper_lang_auto_detect(
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
for (const auto & prob : probs_id) {
|
for (int i = 0; i < (int) probs_id.size(); i++) {
|
||||||
if (lang_probs) {
|
if (lang_probs) {
|
||||||
lang_probs[prob.second] = prob.first;
|
lang_probs[probs_id[i].second] = probs_id[i].first;
|
||||||
}
|
}
|
||||||
|
|
||||||
//printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
|
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2545,10 +2482,6 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
|
|||||||
return ctx->model.hparams.n_text_ctx;
|
return ctx->model.hparams.n_text_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
|
||||||
return ctx->model.hparams.n_audio_ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
int whisper_is_multilingual(struct whisper_context * ctx) {
|
int whisper_is_multilingual(struct whisper_context * ctx) {
|
||||||
return ctx->vocab.is_multilingual() ? 1 : 0;
|
return ctx->vocab.is_multilingual() ? 1 : 0;
|
||||||
}
|
}
|
||||||
@ -2629,8 +2562,6 @@ const char * whisper_print_system_info(void) {
|
|||||||
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
||||||
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
||||||
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
||||||
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
|
||||||
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
|
||||||
|
|
||||||
return s.c_str();
|
return s.c_str();
|
||||||
}
|
}
|
||||||
@ -2876,11 +2807,7 @@ int whisper_full(
|
|||||||
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
// overwrite audio_ctx
|
||||||
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
|
||||||
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
|
||||||
return -4;
|
|
||||||
}
|
|
||||||
ctx->exp_n_audio_ctx = params.audio_ctx;
|
ctx->exp_n_audio_ctx = params.audio_ctx;
|
||||||
|
|
||||||
// these tokens determine the task that will be performed
|
// these tokens determine the task that will be performed
|
||||||
@ -3207,7 +3134,7 @@ int whisper_full_parallel(
|
|||||||
|
|
||||||
// separate key + value memory for each processor
|
// separate key + value memory for each processor
|
||||||
{
|
{
|
||||||
auto & mctx = model.ctx_mem;
|
auto & ctx = model.ctx_mem;
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@ -3220,8 +3147,8 @@ int whisper_full_parallel(
|
|||||||
const int n_mem = n_text_layer*n_text_ctx;
|
const int n_mem = n_text_layer*n_text_ctx;
|
||||||
const int n_elements = n_text_state*n_mem;
|
const int n_elements = n_text_state*n_mem;
|
||||||
|
|
||||||
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
// key/value memory for the cross-attention layer
|
// key/value memory for the cross-attention layer
|
||||||
@ -3231,8 +3158,8 @@ int whisper_full_parallel(
|
|||||||
const int n_mem = n_text_layer*n_audio_ctx;
|
const int n_mem = n_text_layer*n_audio_ctx;
|
||||||
const int n_elements = n_text_state*n_mem;
|
const int n_elements = n_text_state*n_mem;
|
||||||
|
|
||||||
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3276,17 +3203,17 @@ int whisper_full_parallel(
|
|||||||
for (int i = 0; i < n_processors - 1; ++i) {
|
for (int i = 0; i < n_processors - 1; ++i) {
|
||||||
auto & results_i = ctxs[i].result_all;
|
auto & results_i = ctxs[i].result_all;
|
||||||
|
|
||||||
for (auto & result : results_i) {
|
for (int j = 0; j < (int) results_i.size(); ++j) {
|
||||||
// correct the segment timestamp taking into account the offset
|
// correct the segment timestamp taking into account the offset
|
||||||
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||||
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||||
|
|
||||||
// make sure that segments are not overlapping
|
// make sure that segments are not overlapping
|
||||||
if (!ctx->result_all.empty()) {
|
if (!ctx->result_all.empty()) {
|
||||||
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
|
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->result_all.push_back(std::move(result));
|
ctx->result_all.push_back(std::move(results_i[j]));
|
||||||
|
|
||||||
// call the new_segment_callback for each segment
|
// call the new_segment_callback for each segment
|
||||||
if (params.new_segment_callback) {
|
if (params.new_segment_callback) {
|
||||||
@ -3381,18 +3308,18 @@ static int64_t sample_to_timestamp(int i_sample) {
|
|||||||
static float voice_length(const std::string & text) {
|
static float voice_length(const std::string & text) {
|
||||||
float res = 0.0f;
|
float res = 0.0f;
|
||||||
|
|
||||||
for (char c : text) {
|
for (size_t i = 0; i < text.size(); ++i) {
|
||||||
if (c == ' ') {
|
if (text[i] == ' ') {
|
||||||
res += 0.01f;
|
res += 0.01f;
|
||||||
} else if (c == ',') {
|
} else if (text[i] == ',') {
|
||||||
res += 2.00f;
|
res += 2.00f;
|
||||||
} else if (c == '.') {
|
} else if (text[i] == '.') {
|
||||||
res += 3.00f;
|
res += 3.00f;
|
||||||
} else if (c == '!') {
|
} else if (text[i] == '!') {
|
||||||
res += 3.00f;
|
res += 3.00f;
|
||||||
} else if (c == '?') {
|
} else if (text[i] == '?') {
|
||||||
res += 3.00f;
|
res += 3.00f;
|
||||||
} else if (c >= '0' && c <= '9') {
|
} else if (text[i] >= '0' && text[i] <= '9') {
|
||||||
res += 3.00f;
|
res += 3.00f;
|
||||||
} else {
|
} else {
|
||||||
res += 1.00f;
|
res += 1.00f;
|
||||||
|
@ -148,7 +148,7 @@ extern "C" {
|
|||||||
struct whisper_context * ctx,
|
struct whisper_context * ctx,
|
||||||
const char * text,
|
const char * text,
|
||||||
whisper_token * tokens,
|
whisper_token * tokens,
|
||||||
int n_max_tokens);
|
int n_max_tokens);
|
||||||
|
|
||||||
// Largest language id (i.e. number of available languages - 1)
|
// Largest language id (i.e. number of available languages - 1)
|
||||||
WHISPER_API int whisper_lang_max_id();
|
WHISPER_API int whisper_lang_max_id();
|
||||||
@ -177,7 +177,6 @@ extern "C" {
|
|||||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||||
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
||||||
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
||||||
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
|
||||||
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
||||||
|
|
||||||
// The probabilities for the next token
|
// The probabilities for the next token
|
||||||
|
Reference in New Issue
Block a user