mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-07-01 06:50:24 +02:00
Compare commits
1 Commits
fa-decoder
...
threads
Author | SHA1 | Date | |
---|---|---|---|
4e6d2e98ab |
30
.github/workflows/build.yml
vendored
30
.github/workflows/build.yml
vendored
@ -235,33 +235,3 @@ jobs:
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
|
||||
tar -xvf master.tar.gz
|
||||
emsdk-master/emsdk update
|
||||
emsdk-master/emsdk install latest
|
||||
emsdk-master/emsdk activate latest
|
||||
|
||||
- name: Configure
|
||||
run: echo "tmp"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
pushd emsdk-master
|
||||
source ./emsdk_env.sh
|
||||
popd
|
||||
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
|
@ -2,15 +2,14 @@ cmake_minimum_required (VERSION 3.0)
|
||||
|
||||
project(whisper.cpp VERSION 1.0.4)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(WHISPER_STANDALONE ON)
|
||||
include(GitVars)
|
||||
include(BuildTypes)
|
||||
include(cmake/GitVars.cmake)
|
||||
include(cmake/BuildTypes.cmake)
|
||||
|
||||
# configure project version
|
||||
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
|
||||
@ -53,7 +52,6 @@ if (APPLE)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
|
||||
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
|
||||
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
|
||||
else()
|
||||
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
|
||||
endif()
|
||||
@ -84,6 +82,9 @@ endif()
|
||||
|
||||
# dependencies
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# on APPLE - include Accelerate framework
|
||||
@ -130,7 +131,6 @@ if (WHISPER_ALL_WARNINGS)
|
||||
-Wcast-qual \
|
||||
-Wstrict-prototypes \
|
||||
-Wpointer-arith \
|
||||
-Wno-unused-function \
|
||||
")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||
-Wall \
|
||||
@ -157,7 +157,6 @@ else()
|
||||
if (MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
|
||||
else()
|
||||
if (EMSCRIPTEN)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||
@ -169,10 +168,7 @@ else()
|
||||
if(NOT WHISPER_NO_AVX2)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||
endif()
|
||||
if(NOT WHISPER_NO_FMA)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
@ -194,8 +190,6 @@ add_library(${TARGET}
|
||||
whisper.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC
|
||||
.
|
||||
)
|
||||
@ -229,7 +223,6 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
install(TARGETS ${TARGET}
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
)
|
||||
|
||||
#
|
||||
|
29
Makefile
29
Makefile
@ -10,9 +10,6 @@ ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
||||
|
||||
# Mac OS + Arm can report x86_64
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@ -56,13 +53,10 @@ endif
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -mf16c
|
||||
CFLAGS += -mfma -mf16c
|
||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||
ifneq (,$(findstring FMA,$(AVX1_M)))
|
||||
CFLAGS += -mfma
|
||||
endif
|
||||
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
|
||||
CFLAGS += -mavx
|
||||
endif
|
||||
@ -87,10 +81,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifneq (,$(findstring f16c,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
endif
|
||||
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
|
||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Haiku)
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
@ -151,21 +141,6 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
#
|
||||
# Print build information
|
||||
#
|
||||
|
||||
$(info I whisper.cpp build info: )
|
||||
$(info I UNAME_S: $(UNAME_S))
|
||||
$(info I UNAME_P: $(UNAME_P))
|
||||
$(info I UNAME_M: $(UNAME_M))
|
||||
$(info I CFLAGS: $(CFLAGS))
|
||||
$(info I CXXFLAGS: $(CXXFLAGS))
|
||||
$(info I LDFLAGS: $(LDFLAGS))
|
||||
$(info I CC: $(CCV))
|
||||
$(info I CXX: $(CXXV))
|
||||
$(info )
|
||||
|
||||
default: main
|
||||
|
||||
#
|
||||
|
@ -11,7 +11,6 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- Low memory usage (Flash Attention + Flash Forward)
|
||||
- Zero memory allocations at runtime
|
||||
|
1
bindings/go/.gitignore
vendored
1
bindings/go/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
build
|
||||
models
|
||||
go.sum
|
||||
|
@ -1,27 +1,28 @@
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
CMAKE := $(shell which cmake)
|
||||
BUILD_DIR := "build"
|
||||
MODELS_DIR := "models"
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
C_INCLUDE_PATH := "../.."
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@echo Build whisper
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
|
||||
@${CMAKE} --build ${BUILD_DIR} --target whisper
|
||||
|
||||
test: model-small whisper modtidy
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
@go test -v .
|
||||
@go test -v ./pkg/whisper/...
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
model-small: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
|
||||
@${BUILD_DIR}/go-model-download -out models small.en
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
|
@ -74,27 +74,4 @@ And you can then test a model against samples with the following command:
|
||||
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
||||
```
|
||||
|
||||
## Using the bindings
|
||||
|
||||
To use the bindings in your own software,
|
||||
|
||||
1. Import `github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper` (or `github.com/ggerganov/whisper.cpp/bindings/go` into your package;
|
||||
2. Compile `libwhisper.a` (you can use `make whisper` in the `bindings/go` directory);
|
||||
3. Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH`
|
||||
to point to the `whisper.h` file directory and `libwhisper.a` file directory respectively.
|
||||
|
||||
Look at the `Makefile` in the `bindings/go` directory for an example.
|
||||
|
||||
The API Documentation:
|
||||
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go
|
||||
* https://pkg.go.dev/github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper
|
||||
|
||||
Getting help:
|
||||
|
||||
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
|
||||
## License
|
||||
|
||||
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
|
||||
|
||||
|
@ -17,14 +17,15 @@ import (
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
srcUrl = "https://huggingface.co/" // The location of the models
|
||||
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large"}
|
||||
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
|
||||
)
|
||||
|
||||
var (
|
||||
@ -122,14 +123,11 @@ func GetModels() []string {
|
||||
|
||||
// URLForModel returns the URL for the given model on huggingface.co
|
||||
func URLForModel(model string) (string, error) {
|
||||
if filepath.Ext(model) != srcExt {
|
||||
model += srcExt
|
||||
}
|
||||
url, err := url.Parse(srcUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
url.Path = filepath.Join(url.Path, model)
|
||||
url.Path = srcPathPrefix + "-" + model + srcExt
|
||||
}
|
||||
return url.String(), nil
|
||||
}
|
||||
|
@ -1,23 +0,0 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
@ -1,5 +1,8 @@
|
||||
package whisper
|
||||
|
||||
// This file defines the whisper_token, whisper_token_data and whisper_full_params
|
||||
// structures, which are used by the whisper_full() function.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
@ -9,7 +9,8 @@ import (
|
||||
// CGO
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lwhisper -lm -lstdc++
|
||||
#cgo CFLAGS: -I${SRCDIR}/../..
|
||||
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
|
||||
#cgo darwin LDFLAGS: -framework Accelerate
|
||||
#include <whisper.h>
|
||||
#include <stdlib.h>
|
||||
@ -170,10 +171,6 @@ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||
}
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
// Examples:
|
||||
//
|
||||
// "de" -> 2
|
||||
// "german" -> 2
|
||||
func (ctx *Context) Whisper_lang_id(lang string) int {
|
||||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
@ -214,10 +211,6 @@ func (ctx *Context) Whisper_n_text_ctx() int {
|
||||
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_audio_ctx() int {
|
||||
return int(C.whisper_n_audio_ctx((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_is_multilingual() int {
|
||||
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
@ -50,10 +50,7 @@ func Test_Whisper_001(t *testing.T) {
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
err = ctx.Whisper_full(params, data, nil, nil)
|
||||
assert.NoError(err)
|
||||
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
|
||||
|
||||
// Print out tokens
|
||||
num_segments := ctx.Whisper_full_n_segments()
|
||||
|
Submodule bindings/ios updated: 6707f1ea1c...1502317fe0
File diff suppressed because one or more lines are too long
@ -1,17 +0,0 @@
|
||||
# Set the default compile features and properties for a target.
|
||||
|
||||
if (NOT TARGET)
|
||||
message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions")
|
||||
endif()
|
||||
|
||||
target_compile_features(${TARGET}
|
||||
PRIVATE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${TARGET}
|
||||
PROPERTIES
|
||||
EXPORT_COMPILE_COMMANDS ON
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
||||
)
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -1,6 +1,3 @@
|
||||
set(TARGET bench)
|
||||
add_executable(${TARGET} bench.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# command
|
||||
set(TARGET command)
|
||||
add_executable(${TARGET} command.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -9,19 +9,7 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
||||
|
||||
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Guided mode
|
||||
|
||||
"Guided mode" allows you to specify a list of commands (i.e. strings) and the transcription will be guided to classify your command into one from the list. This can be useful in situations where a device is listening only for a small subset of commands.
|
||||
|
||||
Initial tests show that this approach might be extremely efficient in terms of performance, since it integrates very well with the "partial Encoder" idea from #137.
|
||||
|
||||
```bash
|
||||
# Run in guided mode, the list of allowed commands is in commands.txt
|
||||
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
||||
|
||||
@ -29,8 +17,9 @@ Initial tests show that this approach might be extremely efficient in terms of p
|
||||
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
Web version: [examples/command.wasm](/examples/command.wasm)
|
||||
|
||||
## Building
|
||||
|
||||
|
@ -510,333 +510,6 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
||||
return allowed_commands;
|
||||
}
|
||||
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
int max_len = 0;
|
||||
|
||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||
|
||||
for (const auto & cmd : allowed_commands) {
|
||||
whisper_token tokens[1024];
|
||||
allowed_tokens.emplace_back();
|
||||
|
||||
for (int l = 0; l < (int) cmd.size(); ++l) {
|
||||
// NOTE: very important to add the whitespace !
|
||||
// the reason is that the first decoded token starts with a whitespace too!
|
||||
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
||||
|
||||
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
allowed_tokens.back().push_back(tokens[0]);
|
||||
}
|
||||
}
|
||||
|
||||
max_len = std::max(max_len, (int) cmd.size());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
||||
for (const auto & token : allowed_tokens[i]) {
|
||||
fprintf(stderr, " %5d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
}
|
||||
k_prompt += allowed_commands[i];
|
||||
}
|
||||
k_prompt += ". selected word: ";
|
||||
|
||||
// tokenize prompt
|
||||
std::vector<whisper_token> k_tokens;
|
||||
{
|
||||
k_tokens.resize(1024);
|
||||
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
||||
return 4;
|
||||
}
|
||||
k_tokens.resize(n);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
||||
fprintf(stderr, "%s: tokens: [", __func__);
|
||||
for (const auto & token : k_tokens) {
|
||||
fprintf(stderr, " %d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
bool is_running = true;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = 1;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = k_tokens.data();
|
||||
wparams.prompt_n_tokens = k_tokens.size();
|
||||
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
const auto * probs = whisper_get_probs(ctx);
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
|
||||
double psum = 0.0;
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||
}
|
||||
probs_id.back().first /= allowed_tokens[i].size();
|
||||
psum += probs_id.back().first;
|
||||
}
|
||||
|
||||
// normalize
|
||||
for (auto & p : probs_id) {
|
||||
p.first /= psum;
|
||||
}
|
||||
|
||||
// sort descending
|
||||
{
|
||||
using pair_type = decltype(probs_id)::value_type;
|
||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
// print the commands and the respective probabilities
|
||||
{
|
||||
fprintf(stdout, "\n");
|
||||
for (const auto & cmd : probs_id) {
|
||||
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
||||
for (int token : allowed_tokens[cmd.second]) {
|
||||
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// best command
|
||||
{
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
const float prob = probs_id[0].first;
|
||||
const int index = probs_id[0].second;
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
if (!have_prompt) {
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -888,12 +561,300 @@ int main(int argc, char ** argv) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
audio.clear();
|
||||
|
||||
int ret_val = 0;
|
||||
int max_len = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
std::vector<std::string> allowed_commands;
|
||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||
|
||||
std::string k_prompt;
|
||||
std::vector<whisper_token> k_tokens;
|
||||
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
allowed_commands = read_allowed_commands(params.commands);
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
for (const auto & cmd : allowed_commands) {
|
||||
whisper_token tokens[1024];
|
||||
allowed_tokens.emplace_back();
|
||||
|
||||
for (int l = 0; l < (int) cmd.size(); ++l) {
|
||||
// NOTE: very important to add the whitespace !
|
||||
// the reason is that the first decoded token starts with a whitespace too!
|
||||
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
||||
|
||||
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
allowed_tokens.back().push_back(tokens[0]);
|
||||
}
|
||||
}
|
||||
|
||||
max_len = std::max(max_len, (int) cmd.size());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
||||
for (const auto & token : allowed_tokens[i]) {
|
||||
fprintf(stderr, " %5d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
}
|
||||
k_prompt += allowed_commands[i];
|
||||
}
|
||||
k_prompt += ". selected word: ";
|
||||
|
||||
// tokenize prompt
|
||||
{
|
||||
k_tokens.resize(1024);
|
||||
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
||||
return 4;
|
||||
}
|
||||
k_tokens.resize(n);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
||||
fprintf(stderr, "%s: tokens: [", __func__);
|
||||
for (const auto & token : k_tokens) {
|
||||
fprintf(stderr, " %d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
|
||||
k_prompt = "Ok Whisper, start listening for commands.";
|
||||
}
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
{
|
||||
SDL_Event event;
|
||||
while (SDL_PollEvent(&event)) {
|
||||
switch (event.type) {
|
||||
case SDL_QUIT:
|
||||
{
|
||||
is_running = false;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_running) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
{
|
||||
int64_t t_ms = 0;
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
if (!have_prompt) {
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = 1;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = k_tokens.data();
|
||||
wparams.prompt_n_tokens = k_tokens.size();
|
||||
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
const auto * probs = whisper_get_probs(ctx);
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
|
||||
double psum = 0.0;
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||
}
|
||||
probs_id.back().first /= allowed_tokens[i].size();
|
||||
psum += probs_id.back().first;
|
||||
}
|
||||
|
||||
// normalize
|
||||
for (auto & p : probs_id) {
|
||||
p.first /= psum;
|
||||
}
|
||||
|
||||
// sort descending
|
||||
{
|
||||
using pair_type = decltype(probs_id)::value_type;
|
||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
// print the commands and the respective probabilities
|
||||
{
|
||||
fprintf(stdout, "\n");
|
||||
for (const auto & cmd : probs_id) {
|
||||
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
||||
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
|
||||
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// best command
|
||||
{
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
|
||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
@ -901,5 +862,5 @@ int main(int argc, char ** argv) {
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return ret_val;
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,6 +1,3 @@
|
||||
set(TARGET main)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -69,7 +69,6 @@ struct whisper_params {
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
@ -112,7 +111,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
@ -152,7 +150,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
@ -176,81 +173,90 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
if (params.no_timestamps) {
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s", text);
|
||||
}
|
||||
fflush(stdout);
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
std::string speaker;
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
if (params.print_colors) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
|
||||
}
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
@ -319,32 +325,6 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
if (text[0] == ' ')
|
||||
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
||||
fout << 10 * t0 << ", "
|
||||
<< 10 * t1 << ", \""
|
||||
<< text << "\"\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// karaoke video generation
|
||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||
// TODO: font parameter adjustments
|
||||
@ -548,7 +528,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
|
||||
return 8;
|
||||
}
|
||||
|
||||
@ -694,13 +674,6 @@ int main(int argc, char ** argv) {
|
||||
const auto fname_wts = fname_inp + ".wts";
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||
}
|
||||
|
||||
// output to CSV file
|
||||
if (params.output_csv) {
|
||||
const auto fname_csv = fname_inp + ".csv";
|
||||
output_csv(ctx, fname_csv.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
@ -2,9 +2,6 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# stream
|
||||
set(TARGET stream)
|
||||
add_executable(${TARGET} stream.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include <SDL.h>
|
||||
#include <SDL_audio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
@ -145,8 +144,8 @@ private:
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
std::atomic_bool m_running;
|
||||
std::mutex m_mutex;
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
@ -156,8 +155,6 @@ private:
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
@ -430,9 +427,9 @@ int main(int argc, char ** argv) {
|
||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
|
||||
|
||||
const int n_new_line = !use_vad ? params.length_ms / params.step_ms - 1 : 1; // number of steps to print new line
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context = use_vad;
|
||||
|
@ -9,8 +9,6 @@ add_executable(${TARGET}
|
||||
gpt-2.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
@ -33,8 +31,8 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
--bind \
|
||||
-s USE_PTHREADS=1 \
|
||||
-s PTHREAD_POOL_SIZE=8 \
|
||||
-s INITIAL_MEMORY=1800MB \
|
||||
-s TOTAL_MEMORY=1800MB \
|
||||
-s INITIAL_MEMORY=1600MB \
|
||||
-s TOTAL_MEMORY=1600MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
|
@ -36,7 +36,7 @@ In order to run this demo efficiently, you need to have the following:
|
||||
- Latest Chrome or Firefox browser (Safari is not supported)
|
||||
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
|
||||
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
|
||||
- The web-page uses about 1.8GB of RAM
|
||||
- The web-page uses about 1.6GB of RAM
|
||||
|
||||
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
|
||||
Also, the prompting strategy can likely be improved to achieve better results.
|
||||
|
@ -8,9 +8,6 @@ if (WHISPER_SUPPORT_SDL2)
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif ()
|
||||
|
@ -8,8 +8,6 @@ add_executable(${TARGET}
|
||||
emscripten.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
whisper
|
||||
)
|
||||
|
14
ggml.h
14
ggml.h
@ -177,11 +177,13 @@ extern "C" {
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#define GGML_MAX_DIMS 4
|
||||
#define GGML_MAX_NODES 4096
|
||||
#define GGML_MAX_PARAMS 16
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_OPT 4
|
||||
#define GGML_MAX_DIMS 4
|
||||
#define GGML_MAX_NODES 4096
|
||||
#define GGML_MAX_PARAMS 16
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_OPT 4
|
||||
#define GGML_MAX_THREADS 64
|
||||
#define GGML_MAX_THREAD_POOLS 16
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
// we use the built-in 16-bit float type
|
||||
@ -731,8 +733,6 @@ int ggml_cpu_has_f16c(void);
|
||||
int ggml_cpu_has_fp16_va(void);
|
||||
int ggml_cpu_has_wasm_simd(void);
|
||||
int ggml_cpu_has_blas(void);
|
||||
int ggml_cpu_has_sse3(void);
|
||||
int ggml_cpu_has_vsx(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ def bytes_to_unicode():
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
|
@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
|
||||
goto :eof
|
||||
)
|
||||
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo Failed to download ggml model %model%
|
||||
|
181
whisper.cpp
181
whisper.cpp
@ -204,10 +204,6 @@ struct whisper_vocab {
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
|
||||
// used to avoid memory allocations during sampling
|
||||
// TODO: move to whisper_context in the future
|
||||
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
||||
|
||||
id token_eot = 50256;
|
||||
id token_sot = 50257;
|
||||
id token_prev = 50360;
|
||||
@ -412,8 +408,6 @@ struct whisper_context {
|
||||
std::vector<uint8_t> buf_compute;
|
||||
std::vector<uint8_t> buf_compute_layer;
|
||||
|
||||
ggml_type wtype; // weight type (FP32 or FP16)
|
||||
|
||||
whisper_model model;
|
||||
whisper_vocab vocab;
|
||||
|
||||
@ -437,8 +431,9 @@ struct whisper_context {
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
static void read_safe(std::ifstream& fin, T& dest) {
|
||||
fin.read((char*)& dest, sizeof(T));
|
||||
static void read_safe(std::ifstream& fin, T& dest)
|
||||
{
|
||||
fin.read((char*)& dest, sizeof(T));
|
||||
}
|
||||
|
||||
// load the model from a ggml file
|
||||
@ -556,9 +551,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
std::string word;
|
||||
std::vector<char> tmp;
|
||||
|
||||
tmp.reserve(128);
|
||||
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
read_safe(fin, len);
|
||||
@ -611,11 +603,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
||||
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
||||
|
||||
vocab.probs_id.reserve(n_vocab);
|
||||
}
|
||||
|
||||
{
|
||||
@ -631,9 +618,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats
|
||||
// in order to save memory and also to speed up the computation
|
||||
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
const ggml_type wtype = wctx.wtype;
|
||||
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
@ -654,6 +639,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// encoder
|
||||
{
|
||||
// TODO: F16 .. maybe not?
|
||||
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
|
||||
|
||||
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
|
||||
@ -668,6 +654,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
// decoder
|
||||
{
|
||||
// TODO: F16 .. maybe not?
|
||||
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
|
||||
|
||||
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
|
||||
@ -984,8 +971,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
const int n_mem = n_text_layer*n_text_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
}
|
||||
|
||||
// key/value memory for the cross-attention layer
|
||||
@ -995,8 +982,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
const int n_mem = n_text_layer*n_audio_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
}
|
||||
|
||||
const size_t memory_size =
|
||||
@ -1034,7 +1021,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
|
||||
std::string name;
|
||||
std::vector<char> tmp(length); // create a buffer
|
||||
fin.read(&tmp[0], tmp.size()); // read to buffer
|
||||
fin.read( &tmp[0], tmp.size() ); // read to buffer
|
||||
name.assign(&tmp[0], tmp.size());
|
||||
|
||||
if (model.tensors.find(name) == model.tensors.end()) {
|
||||
@ -1242,14 +1229,14 @@ static bool whisper_encode(
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
@ -1259,7 +1246,7 @@ static bool whisper_encode(
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
||||
@ -1275,7 +1262,7 @@ static bool whisper_encode(
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
@ -1293,7 +1280,7 @@ static bool whisper_encode(
|
||||
// ggml_permute(ctxL,
|
||||
// ggml_cpy(ctxL,
|
||||
// Vcur,
|
||||
// ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
||||
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
||||
// 1, 2, 0, 3);
|
||||
|
||||
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
@ -1305,7 +1292,7 @@ static bool whisper_encode(
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
0, 2, 1, 3),
|
||||
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
|
||||
@ -1350,7 +1337,7 @@ static bool whisper_encode(
|
||||
|
||||
#ifdef USE_FLASH_FF
|
||||
cur = ggml_flash_ff(ctxL,
|
||||
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
|
||||
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
|
||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||
#else
|
||||
// fully connected
|
||||
@ -1457,7 +1444,7 @@ static bool whisper_encode(
|
||||
layer.cross_attn_k_w,
|
||||
cur);
|
||||
|
||||
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_v_w,
|
||||
@ -1579,14 +1566,14 @@ static bool whisper_decode(
|
||||
Qcur),
|
||||
Qcur);
|
||||
|
||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
// note: no bias for Key
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
||||
layer.attn_k_w,
|
||||
cur);
|
||||
|
||||
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
||||
layer.attn_v_w,
|
||||
@ -1609,33 +1596,6 @@ static bool whisper_decode(
|
||||
|
||||
// ------
|
||||
|
||||
#ifdef USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctxL,
|
||||
ggml_reshape_3d(ctxL,
|
||||
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctxL,
|
||||
ggml_permute(ctxL,
|
||||
ggml_reshape_3d(ctxL,
|
||||
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
|
||||
#else
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
@ -1653,13 +1613,13 @@ static bool whisper_decode(
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
|
||||
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctxL,
|
||||
KQ,
|
||||
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
||||
);
|
||||
//struct ggml_tensor * KQ_scaled =
|
||||
// ggml_scale(ctxL,
|
||||
// KQ,
|
||||
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
||||
// );
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
|
||||
|
||||
@ -1671,7 +1631,6 @@ static bool whisper_decode(
|
||||
1, 2, 0, 3);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
#endif
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||
|
||||
@ -1717,7 +1676,7 @@ static bool whisper_decode(
|
||||
Qcur),
|
||||
Qcur);
|
||||
|
||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
// Kcross is already scaled
|
||||
struct ggml_tensor * Kcross =
|
||||
@ -1732,24 +1691,6 @@ static bool whisper_decode(
|
||||
|
||||
// ------
|
||||
|
||||
#ifdef USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctxL,
|
||||
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
||||
#else
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
@ -1776,7 +1717,6 @@ static bool whisper_decode(
|
||||
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
#endif
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||
|
||||
@ -1909,7 +1849,7 @@ static bool whisper_decode(
|
||||
|
||||
// the most basic sampling scheme - select the top token
|
||||
static whisper_token_data whisper_sample_best(
|
||||
whisper_vocab & vocab,
|
||||
const whisper_vocab & vocab,
|
||||
const float * probs,
|
||||
bool force_timestamp,
|
||||
bool is_initial) {
|
||||
@ -1917,11 +1857,11 @@ static whisper_token_data whisper_sample_best(
|
||||
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
||||
};
|
||||
|
||||
const int n_logits = vocab.n_vocab;
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
auto & probs_id = vocab.probs_id;
|
||||
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
||||
probs_id.reserve(n_logits);
|
||||
|
||||
probs_id.clear();
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
probs_id.emplace_back(probs[i], i);
|
||||
}
|
||||
@ -2061,9 +2001,6 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
||||
std::vector<float> even;
|
||||
std::vector<float> odd;
|
||||
|
||||
even.reserve(N/2);
|
||||
odd.reserve(N/2);
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (i % 2 == 0) {
|
||||
even.push_back(in[i]);
|
||||
@ -2497,7 +2434,7 @@ int whisper_lang_auto_detect(
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
for (const auto & kv : g_lang) {
|
||||
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
||||
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
|
||||
probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
|
||||
}
|
||||
|
||||
// sort descending
|
||||
@ -2521,12 +2458,12 @@ int whisper_lang_auto_detect(
|
||||
}
|
||||
|
||||
{
|
||||
for (const auto & prob : probs_id) {
|
||||
for (int i = 0; i < (int) probs_id.size(); i++) {
|
||||
if (lang_probs) {
|
||||
lang_probs[prob.second] = prob.first;
|
||||
lang_probs[probs_id[i].second] = probs_id[i].first;
|
||||
}
|
||||
|
||||
//printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
|
||||
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2545,10 +2482,6 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
|
||||
return ctx->model.hparams.n_text_ctx;
|
||||
}
|
||||
|
||||
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
||||
return ctx->model.hparams.n_audio_ctx;
|
||||
}
|
||||
|
||||
int whisper_is_multilingual(struct whisper_context * ctx) {
|
||||
return ctx->vocab.is_multilingual() ? 1 : 0;
|
||||
}
|
||||
@ -2629,8 +2562,6 @@ const char * whisper_print_system_info(void) {
|
||||
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
||||
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
||||
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
||||
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
||||
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
||||
|
||||
return s.c_str();
|
||||
}
|
||||
@ -2876,11 +2807,7 @@ int whisper_full(
|
||||
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
||||
}
|
||||
|
||||
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
||||
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
||||
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
||||
return -4;
|
||||
}
|
||||
// overwrite audio_ctx
|
||||
ctx->exp_n_audio_ctx = params.audio_ctx;
|
||||
|
||||
// these tokens determine the task that will be performed
|
||||
@ -3207,7 +3134,7 @@ int whisper_full_parallel(
|
||||
|
||||
// separate key + value memory for each processor
|
||||
{
|
||||
auto & mctx = model.ctx_mem;
|
||||
auto & ctx = model.ctx_mem;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@ -3220,8 +3147,8 @@ int whisper_full_parallel(
|
||||
const int n_mem = n_text_layer*n_text_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
}
|
||||
|
||||
// key/value memory for the cross-attention layer
|
||||
@ -3231,8 +3158,8 @@ int whisper_full_parallel(
|
||||
const int n_mem = n_text_layer*n_audio_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
|
||||
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3276,17 +3203,17 @@ int whisper_full_parallel(
|
||||
for (int i = 0; i < n_processors - 1; ++i) {
|
||||
auto & results_i = ctxs[i].result_all;
|
||||
|
||||
for (auto & result : results_i) {
|
||||
for (int j = 0; j < (int) results_i.size(); ++j) {
|
||||
// correct the segment timestamp taking into account the offset
|
||||
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
||||
|
||||
// make sure that segments are not overlapping
|
||||
if (!ctx->result_all.empty()) {
|
||||
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
|
||||
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
|
||||
}
|
||||
|
||||
ctx->result_all.push_back(std::move(result));
|
||||
ctx->result_all.push_back(std::move(results_i[j]));
|
||||
|
||||
// call the new_segment_callback for each segment
|
||||
if (params.new_segment_callback) {
|
||||
@ -3381,18 +3308,18 @@ static int64_t sample_to_timestamp(int i_sample) {
|
||||
static float voice_length(const std::string & text) {
|
||||
float res = 0.0f;
|
||||
|
||||
for (char c : text) {
|
||||
if (c == ' ') {
|
||||
for (size_t i = 0; i < text.size(); ++i) {
|
||||
if (text[i] == ' ') {
|
||||
res += 0.01f;
|
||||
} else if (c == ',') {
|
||||
} else if (text[i] == ',') {
|
||||
res += 2.00f;
|
||||
} else if (c == '.') {
|
||||
} else if (text[i] == '.') {
|
||||
res += 3.00f;
|
||||
} else if (c == '!') {
|
||||
} else if (text[i] == '!') {
|
||||
res += 3.00f;
|
||||
} else if (c == '?') {
|
||||
} else if (text[i] == '?') {
|
||||
res += 3.00f;
|
||||
} else if (c >= '0' && c <= '9') {
|
||||
} else if (text[i] >= '0' && text[i] <= '9') {
|
||||
res += 3.00f;
|
||||
} else {
|
||||
res += 1.00f;
|
||||
|
@ -148,7 +148,7 @@ extern "C" {
|
||||
struct whisper_context * ctx,
|
||||
const char * text,
|
||||
whisper_token * tokens,
|
||||
int n_max_tokens);
|
||||
int n_max_tokens);
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
WHISPER_API int whisper_lang_max_id();
|
||||
@ -177,7 +177,6 @@ extern "C" {
|
||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
||||
|
||||
// The probabilities for the next token
|
||||
|
Reference in New Issue
Block a user