Compare commits

..

4 Commits

70 changed files with 2383 additions and 10135 deletions

View File

@ -265,44 +265,3 @@ jobs:
popd popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make make
ios:
runs-on: macos-latest
strategy:
matrix:
build: [Release]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Configure
run: cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
- name: Build objc example
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
- name: Build swiftui example
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphonesimulator build
android:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v1
- name: Install Java
uses: actions/setup-java@v3
with:
distribution: zulu
java-version: 17
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: Build
run: |
cd examples/whisper.android
./gradlew assembleRelease --no-daemon

8
.gitignore vendored
View File

@ -1,8 +1,6 @@
*.o *.o
*.a *.a
.cache/ .cache/
.coreml/
.test/
.vs/ .vs/
.vscode/ .vscode/
.DS_Store .DS_Store
@ -12,7 +10,6 @@ build-em/
build-debug/ build-debug/
build-release/ build-release/
build-static/ build-static/
build-no-accel/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/
@ -20,7 +17,6 @@ build-sanitize-thread/
/stream /stream
/command /command
/talk /talk
/talk-llama
/bench /bench
arm_neon.h arm_neon.h
@ -35,7 +31,3 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
extra/bench-gg.txt extra/bench-gg.txt
models/*.mlmodel
models/*.mlmodelc
models/*.mlpackage

View File

@ -1,10 +1,6 @@
cmake_minimum_required (VERSION 3.0) cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.3.0) project(whisper.cpp VERSION 1.2.0)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
add_compile_options(/utf-8)
endif ()
# Add path to modules # Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@ -58,8 +54,6 @@ if (APPLE)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF) option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF) option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF) option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
else() else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF) option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif() endif()
@ -92,33 +86,16 @@ endif()
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
# on APPLE # on APPLE - include Accelerate framework
if (APPLE) if (APPLE AND NOT WHISPER_NO_ACCELERATE)
# include Accelerate framework find_library(ACCELERATE_FRAMEWORK Accelerate)
if (NOT WHISPER_NO_ACCELERATE) if (ACCELERATE_FRAMEWORK)
find_library(ACCELERATE_FRAMEWORK Accelerate) message(STATUS "Accelerate framework found")
if (ACCELERATE_FRAMEWORK) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
message(STATUS "Accelerate framework found") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) message(WARNING "Accelerate framework not found")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()
endif()
if (WHISPER_COREML)
find_library(FOUNDATION_FRAMEWORK Foundation)
find_library(COREML_FRAMEWORK CoreML)
if (COREML_FRAMEWORK)
message(STATUS "CoreML framework found")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
else()
message(WARNING "CoreML framework not found")
endif()
endif() endif()
endif() endif()
@ -195,9 +172,7 @@ else()
if(NOT WHISPER_NO_FMA) if(NOT WHISPER_NO_FMA)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif() endif()
if(NOT WHISPER_NO_F16C) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
endif() endif()
endif() endif()
endif() endif()
@ -206,33 +181,6 @@ if (WHISPER_PERF)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF) set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
endif() endif()
#
# whisper.coreml - Core ML support
#
if (WHISPER_COREML)
set(TARGET whisper.coreml)
add_library(${TARGET}
coreml/whisper-encoder.h
coreml/whisper-encoder.mm
coreml/whisper-encoder-impl.h
coreml/whisper-encoder-impl.m
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC
.
)
target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
set_target_properties(${TARGET} PROPERTIES
COMPILE_FLAGS "-fobjc-arc"
)
endif()
# #
# whisper - this is the main library of the project # whisper - this is the main library of the project
# #
@ -252,10 +200,6 @@ target_include_directories(${TARGET} PUBLIC
. .
) )
if (WHISPER_COREML)
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
endif()
if (MSVC) if (MSVC)
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2023 Georgi Gerganov Copyright (c) 2022 Georgi Gerganov
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -30,16 +30,10 @@ endif
# Compile flags # Compile flags
# #
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC CFLAGS = -I. -O3 -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
LDFLAGS = LDFLAGS =
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
ifneq ($(wildcard /usr/include/musl/*),)
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
endif
# OS specific # OS specific
# TODO: support Windows # TODO: support Windows
ifeq ($(UNAME_S),Linux) ifeq ($(UNAME_S),Linux)
@ -77,6 +71,10 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
CFLAGS += -mavx2 CFLAGS += -mavx2
endif endif
else ifeq ($(UNAME_S),Linux) else ifeq ($(UNAME_S),Linux)
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo) AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
ifneq (,$(findstring avx2,$(AVX2_M))) ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2 CFLAGS += -mavx2
@ -88,17 +86,16 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
F16C_M := $(shell grep "f16c " /proc/cpuinfo) F16C_M := $(shell grep "f16c " /proc/cpuinfo)
ifneq (,$(findstring f16c,$(F16C_M))) ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c CFLAGS += -mf16c
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif endif
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo) SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M))) ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3 CFLAGS += -msse3
endif endif
else ifeq ($(UNAME_S),Haiku) else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ") AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
ifneq (,$(findstring avx2,$(AVX2_M))) ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2 CFLAGS += -mavx2
@ -110,11 +107,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
F16C_M := $(shell sysinfo -cpu | grep "F16C ") F16C_M := $(shell sysinfo -cpu | grep "F16C ")
ifneq (,$(findstring f16c,$(F16C_M))) ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c CFLAGS += -mf16c
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif endif
else else
CFLAGS += -mfma -mf16c -mavx -mavx2 CFLAGS += -mfma -mf16c -mavx -mavx2
@ -140,10 +132,6 @@ ifndef WHISPER_NO_ACCELERATE
LDFLAGS += -framework Accelerate LDFLAGS += -framework Accelerate
endif endif
endif endif
ifdef WHISPER_COREML
CXXFLAGS += -DWHISPER_USE_COREML
LDFLAGS += -framework Foundation -framework CoreML
endif
ifdef WHISPER_OPENBLAS ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas LDFLAGS += -lopenblas
@ -153,19 +141,14 @@ ifdef WHISPER_GPROF
CXXFLAGS += -pg CXXFLAGS += -pg
endif endif
ifneq ($(filter aarch64%,$(UNAME_M)),) ifneq ($(filter aarch64%,$(UNAME_M)),)
CFLAGS += -mcpu=native
CXXFLAGS += -mcpu=native
endif endif
ifneq ($(filter armv6%,$(UNAME_M)),) ifneq ($(filter armv6%,$(UNAME_M)),)
# 32-bit Raspberry Pi 1, 2, 3 # Raspberry Pi 1, 2, 3
CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
endif endif
ifneq ($(filter armv7%,$(UNAME_M)),) ifneq ($(filter armv7%,$(UNAME_M)),)
# 32-bit ARM, for example on Armbian or possibly raspbian # Raspberry Pi 4
CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
# 64-bit ARM, use these (TODO: auto-detect 64-bit)
# CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
endif endif
ifneq ($(filter armv8%,$(UNAME_M)),) ifneq ($(filter armv8%,$(UNAME_M)),)
# Raspberry Pi 4 # Raspberry Pi 4
@ -187,7 +170,7 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV)) $(info I CXX: $(CXXV))
$(info ) $(info )
default: main bench default: main
# #
# Build library # Build library
@ -196,29 +179,17 @@ default: main bench
ggml.o: ggml.c ggml.h ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o $(CC) $(CFLAGS) -c ggml.c -o ggml.o
whisper.o: whisper.cpp whisper.h ggml.h whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o $(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
ifndef WHISPER_COREML libwhisper.a: ggml.o whisper.o
WHISPER_OBJ = whisper.o $(AR) rcs libwhisper.a ggml.o whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h libwhisper.so: ggml.o whisper.o
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o $(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
WHISPER_OBJ = whisper.o whisper-encoder.o whisper-encoder-impl.o
endif
libwhisper.a: ggml.o $(WHISPER_OBJ)
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
clean: clean:
rm -f *.o main stream command talk talk-llama bench libwhisper.a libwhisper.so rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
# #
# Examples # Examples
@ -229,24 +200,21 @@ CC_SDL=`sdl2-config --cflags --libs`
SRC_COMMON = examples/common.cpp SRC_COMMON = examples/common.cpp
SRC_COMMON_SDL = examples/common-sdl.cpp SRC_COMMON_SDL = examples/common-sdl.cpp
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
./main -h ./main -h
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
# #
# Audio samples # Audio samples

View File

@ -4,12 +4,12 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Beta: [v1.3.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.3.0) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) Stable: [v1.2.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
- Plain C/C++ implementation without dependencies - Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support) - Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- AVX intrinsics support for x86 architectures - AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures - VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision - Mixed F16 / F32 precision
@ -58,9 +58,7 @@ the Accelerate framework utilizes the special-purpose AMX coprocessor available
## Quick start ## Quick start
First clone the repository. First, download one of the Whisper models converted in [ggml format](models). For example:
Then, download one of the Whisper models converted in [ggml format](models). For example:
```bash ```bash
bash ./models/download-ggml-model.sh base.en bash ./models/download-ggml-model.sh base.en
@ -225,60 +223,6 @@ make large
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` | | medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` | | large | 2.9 GB | ~3.3 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
## Core ML support
On Apple Silicon devices, the Encoder inference can be executed on the Apple Neural Engine (ANE) via Core ML. This can result in significant
speed-up - more than x3 faster compared with CPU-only execution. Here are the instructions for generating a Core ML model and using it with `whisper.cpp`:
- Install Python dependencies needed for the creation of the Core ML model:
```bash
pip install ane_transformers
pip install openai-whisper
pip install coremltools
```
- Generate a Core ML model. For example, to generate a `base.en` model, use:
```bash
./models/generate-coreml-model.sh base.en
```
This will generate the folder `models/ggml-base.en-encoder.mlmodelc`
- Build `whisper.cpp` with Core ML support:
```bash
# using Makefile
make clean
WHISPER_COREML=1 make -j
# using CMake
cd build
cmake -DWHISPER_COREML=1 ..
```
- Run the examples as usual. For example:
```bash
./main -m models/ggml-base.en.bin -f samples/jfk.wav
...
whisper_init_state: loading Core ML model from 'models/ggml-base.en-encoder.mlmodelc'
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |
...
```
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster.
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## Limitations ## Limitations
- Inference only - Inference only
@ -369,7 +313,7 @@ whisper_print_timings: total time = 32733.52 ms
## Real-time audio input example ## Real-time audio input example
This is a naive example of performing real-time inference on audio from your microphone. This is a naive example of performing real-time inference on audio from your microphone.
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously. The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continously.
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10). More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```java ```java
@ -384,10 +328,6 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy Adding the `--print-colors` argument will print the transcribed text using an experimental color coding strategy
to highlight words with high or low confidence: to highlight words with high or low confidence:
```java
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
```
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png"> <img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
## Controlling the length of the generated text segments (experimental) ## Controlling the length of the generated text segments (experimental)
@ -493,19 +433,6 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
--- ---
## Video comparison of different models
Use the [extra/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/extra/bench-wts.sh) script to generate a video in the following format:
```java
./extra/bench-wts.sh samples/jfk.wav
ffplay ./samples/jfk.wav.all.mp4
```
https://user-images.githubusercontent.com/1991296/223206245-2d36d903-cf8e-4f09-8c3b-eb9f9c39d6fc.mp4
---
## Benchmarks ## Benchmarks
In order to have an objective comparison of the performance of the inference across different system configurations, In order to have an objective comparison of the performance of the inference across different system configurations,
@ -526,7 +453,7 @@ The original models are converted to a custom binary format. This allows to pack
You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script You can download the converted models using the [models/download-ggml-model.sh](models/download-ggml-model.sh) script
or manually from here: or manually from here:
- https://huggingface.co/ggerganov/whisper.cpp - https://huggingface.co/datasets/ggerganov/whisper.cpp
- https://ggml.ggerganov.com - https://ggml.ggerganov.com
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or the README
@ -536,19 +463,13 @@ in [models](models).
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310) - [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309) - [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312) - [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507) - [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313) - [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422) - [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net) - [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper) - [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9) - [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
- [X] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
- [X] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
## Examples ## Examples
@ -562,7 +483,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture | | [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic | | [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot | | [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp | | [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp | | [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp | | [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |

View File

@ -17,9 +17,9 @@ import (
// CONSTANTS // CONSTANTS
const ( const (
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models srcUrl = "https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main" // The location of the models
srcExt = ".bin" // Filename extension srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model bufSize = 1024 * 64 // Size of the buffer used for downloading the model
) )
var ( var (

View File

@ -105,10 +105,6 @@ func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n) p.max_len = C.int(n)
} }
func (p *Params) SetTokenTimestamps(b bool) {
p.token_timestamps = toBool(b)
}
// Set max tokens per segment (0 = no limit) // Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) { func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n) p.max_tokens = C.int(n)

View File

@ -111,11 +111,6 @@ func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n)) context.params.SetMaxSegmentLength(int(n))
} }
// Set token timestamps flag
func (context *context) SetTokenTimestamps(b bool) {
context.params.SetTokenTimestamps(b)
}
// Set max tokens per segment (0 = no limit) // Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) { func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n)) context.params.SetMaxTokensPerSegment(int(n))
@ -285,14 +280,10 @@ func toSegment(ctx *whisper.Context, n int) Segment {
func toTokens(ctx *whisper.Context, n int) []Token { func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n)) result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ { for i := 0; i < len(result); i++ {
data := ctx.Whisper_full_get_token_data(n, i)
result[i] = Token{ result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)), Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: ctx.Whisper_full_get_token_text(n, i), Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i), P: ctx.Whisper_full_get_token_p(n, i),
Start: time.Duration(data.T0()) * time.Millisecond * 10,
End: time.Duration(data.T1()) * time.Millisecond * 10,
} }
} }
return result return result

View File

@ -41,7 +41,6 @@ type Context interface {
SetTokenThreshold(float32) // Set timestamp token probability threshold SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors. // Process mono audio data and return any errors.
@ -86,8 +85,7 @@ type Segment struct {
// Token is a text or special token // Token is a text or special token
type Token struct { type Token struct {
Id int Id int
Text string Text string
P float32 P float32
Start, End time.Duration
} }

View File

@ -94,7 +94,6 @@ func (model *model) NewContext() (Context, error) {
params.SetPrintRealtime(false) params.SetPrintRealtime(false)
params.SetPrintTimestamps(false) params.SetPrintTimestamps(false)
params.SetThreads(runtime.NumCPU()) params.SetThreads(runtime.NumCPU())
params.SetNoContext(true)
// Return new context // Return new context
return newContext(model, params) return newContext(model, params)

View File

@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data);
// Text segment callback // Text segment callback
// Called on every newly generated text segment // Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) { static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL) {
callNewSegment(user_data, n_new); callNewSegment(user_data, n_new);
} }
@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
// Encoder begin callback // Encoder begin callback
// If not NULL, called before the encoder starts // If not NULL, called before the encoder starts
// If it returns false, the computation is aborted // If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) { static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL) {
return callEncoderBegin(user_data); return callEncoderBegin(user_data);
} }
@ -356,7 +356,7 @@ func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
// Get token data for the specified token in the specified segment. // Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc. // This contains probabilities, timestamps, etc.
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData { func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
} }
@ -407,11 +407,3 @@ func callEncoderBegin(user_data unsafe.Pointer) C.bool {
} }
return true return true
} }
func (t TokenData) T0() int64 {
return int64(t.t0)
}
func (t TokenData) T1() int64 {
return int64(t.t1)
}

View File

@ -1,6 +1,6 @@
{ {
"name": "whisper.cpp", "name": "whisper.cpp",
"version": "1.3.0", "version": "1.2.0",
"description": "Whisper speech recognition", "description": "Whisper speech recognition",
"main": "whisper.js", "main": "whisper.js",
"scripts": { "scripts": {

File diff suppressed because one or more lines are too long

View File

@ -199,7 +199,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
{ {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
return !is_aborted; return !is_aborted;
}; };

View File

@ -1,146 +0,0 @@
//
// whisper-decoder-impl.h
//
// This file was automatically generated and should not be edited.
//
#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
#include <stdint.h>
#include <os/log.h>
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
/// token_data as 1 by 1 matrix of 32-bit integers
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
/// var_1346 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle;
/**
Initialize whisper_decoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init;
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Construct whisper_decoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Make a prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the convenience interface
@param token_data as 1 by 1 matrix of 32-bit integers:
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Batch prediction
@param inputArray array of whisper_decoder_implInput instances to obtain predictions from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the predictions as NSArray<whisper_decoder_implOutput *>
*/
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,201 +0,0 @@
//
// whisper-decoder-impl.m
//
// This file was automatically generated and should not be edited.
//
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-decoder-impl.h"
@implementation whisper_decoder_implInput
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data {
self = [super init];
if (self) {
_token_data = token_data;
_audio_data = audio_data;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"token_data", @"audio_data"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"token_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.token_data];
}
if ([featureName isEqualToString:@"audio_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.audio_data];
}
return nil;
}
@end
@implementation whisper_decoder_implOutput
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
self = [super init];
if (self) {
_var_1346 = var_1346;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"var_1346"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"var_1346"]) {
return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
}
return nil;
}
@end
@implementation whisper_decoder_impl
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle {
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_decoder_impl" ofType:@"mlmodelc"];
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-decoder-impl.mlmodelc in the bundle resource"); return nil; }
return [NSURL fileURLWithPath:assetPath];
}
/**
Initialize whisper_decoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
return self;
}
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
}
/**
Initialize whisper_decoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
}
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Initialize whisper_decoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Construct whisper_decoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
configuration:configuration
completionHandler:handler];
}
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler {
[MLModel loadContentsOfURL:modelURL
configuration:configuration
completionHandler:^(MLModel *model, NSError *error) {
if (model != nil) {
whisper_decoder_impl *typedModel = [[whisper_decoder_impl alloc] initWithMLModel:model];
handler(typedModel, nil);
} else {
handler(nil, error);
}
}];
}
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
}
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
}
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_decoder_implInput *input_ = [[whisper_decoder_implInput alloc] initWithToken_data:token_data audio_data:audio_data];
return [self predictionFromFeatures:input_ error:error];
}
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
if (!outBatch) { return nil; }
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
[results addObject:result];
}
return results;
}
@end

View File

@ -1,142 +0,0 @@
//
// whisper-encoder-impl.h
//
// This file was automatically generated and should not be edited.
//
#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
#include <stdint.h>
#include <os/log.h>
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * logmel_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data NS_DESIGNATED_INITIALIZER;
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
/// output as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * output;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle;
/**
Initialize whisper_encoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init;
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Construct whisper_encoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
/**
Make a prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make a prediction using the convenience interface
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Batch prediction
@param inputArray array of whisper_encoder_implInput instances to obtain predictions from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the predictions as NSArray<whisper_encoder_implOutput *>
*/
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,197 +0,0 @@
//
// whisper-encoder-impl.m
//
// This file was automatically generated and should not be edited.
//
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-encoder-impl.h"
@implementation whisper_encoder_implInput
- (instancetype)initWithLogmel_data:(MLMultiArray *)logmel_data {
self = [super init];
if (self) {
_logmel_data = logmel_data;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"logmel_data"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"logmel_data"]) {
return [MLFeatureValue featureValueWithMultiArray:self.logmel_data];
}
return nil;
}
@end
@implementation whisper_encoder_implOutput
- (instancetype)initWithOutput:(MLMultiArray *)output {
self = [super init];
if (self) {
_output = output;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"output"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"output"]) {
return [MLFeatureValue featureValueWithMultiArray:self.output];
}
return nil;
}
@end
@implementation whisper_encoder_impl
/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle {
NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"whisper_encoder_impl" ofType:@"mlmodelc"];
if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load whisper-encoder-impl.mlmodelc in the bundle resource"); return nil; }
return [NSURL fileURLWithPath:assetPath];
}
/**
Initialize whisper_encoder_impl instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of whisper_encoder_impl.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
return self;
}
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
*/
- (nullable instancetype)init {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
}
/**
Initialize whisper_encoder_impl instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
}
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Initialize whisper_encoder_impl instance from the model URL.
@param modelURL URL to the .mlmodelc directory for whisper_encoder_impl.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
if (model == nil) { return nil; }
return [self initWithMLModel:model];
}
/**
Construct whisper_encoder_impl instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
[self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
configuration:configuration
completionHandler:handler];
}
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler {
[MLModel loadContentsOfURL:modelURL
configuration:configuration
completionHandler:^(MLModel *model, NSError *error) {
if (model != nil) {
whisper_encoder_impl *typedModel = [[whisper_encoder_impl alloc] initWithMLModel:model];
handler(typedModel, nil);
} else {
handler(nil, error);
}
}];
}
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
}
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
}
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
return [self predictionFromFeatures:input_ error:error];
}
- (nullable NSArray<whisper_encoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_encoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
if (!outBatch) { return nil; }
NSMutableArray<whisper_encoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_encoder_implOutput * result = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[resultProvider featureValueForName:@"output"].multiArrayValue];
[results addObject:result];
}
return results;
}
@end

View File

@ -1,22 +0,0 @@
// Wrapper of the Core ML Whisper Encoder model
//
// Code is derived from the work of Github user @wangchou
// ref: https://github.com/wangchou/callCoreMLFromCpp
#if __cplusplus
extern "C" {
#endif
struct whisper_coreml_context;
struct whisper_coreml_context * whisper_coreml_init(const char * path_model);
void whisper_coreml_free(struct whisper_coreml_context * ctx);
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
float * mel,
float * out);
#if __cplusplus
}
#endif

View File

@ -1,67 +0,0 @@
#import "coreml/whisper-encoder.h"
#import "coreml/whisper-encoder-impl.h"
#import <CoreML/CoreML.h>
#include <stdlib.h>
#if __cplusplus
extern "C" {
#endif
struct whisper_coreml_context {
const void * data;
};
struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
NSString * path_model_str = [[NSString alloc] initWithUTF8String:path_model];
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
if (data == NULL) {
return NULL;
}
whisper_coreml_context * ctx = new whisper_coreml_context;
ctx->data = data;
return ctx;
}
void whisper_coreml_free(struct whisper_coreml_context * ctx) {
CFRelease(ctx->data);
delete ctx;
}
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
float * mel,
float * out) {
MLMultiArray * inMultiArray = [
[MLMultiArray alloc] initWithDataPointer: mel
shape: @[@1, @80, @3000]
dataType: MLMultiArrayDataTypeFloat32
strides: @[@(240000), @(3000), @1]
deallocator: nil
error: nil
];
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
MLMultiArray * outMA = outCoreML.output;
//NSArray<NSNumber *> * shape = outMA.shape;
//NSArray<NSNumber *> * strides = outMA.strides;
//printf("shape: %ld %ld %ld %ld\n", [shape[0] longValue], [shape[1] longValue], [shape[2] longValue], [shape[3] longValue]);
//printf("strides: %ld %ld %ld %ld\n", [strides[0] longValue], [strides[1] longValue], [strides[2] longValue], [strides[3] longValue]);
memcpy(out, outMA.dataPointer, outMA.count * sizeof(float));
}
#if __cplusplus
}
#endif

View File

@ -63,5 +63,4 @@ else()
add_subdirectory(command) add_subdirectory(command)
add_subdirectory(bench) add_subdirectory(bench)
add_subdirectory(talk) add_subdirectory(talk)
add_subdirectory(talk-llama)
endif() endif()

View File

@ -1,22 +1,15 @@
const path = require("path"); const path = require('path');
const { whisper } = require(path.join( const { whisper } = require(path.join(__dirname, '../../../build/Release/whisper-addon'));
__dirname,
"../../../build/Release/whisper-addon"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
const whisperParamsMock = { const whisperParamsMock = {
language: "en", language: 'en',
model: path.join(__dirname, "../../../models/ggml-base.en.bin"), model: path.join(__dirname, '../../../models/ggml-base.en.bin'),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), fname_inp: path.join(__dirname, '../../../samples/jfk.wav'),
}; };
describe("Run whisper.node", () => { describe("Run whisper.node", () => {
test("it should receive a non-empty value", async () => {
let result = await whisperAsync(whisperParamsMock);
expect(result.length).toBeGreaterThan(0); test("it should receive a non-empty value", () => {
}); expect(whisper(whisperParamsMock).length).toBeGreaterThan(0);
});
}); });

View File

@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
} }
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -160,6 +160,22 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
return 3; return 3;
} }
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@ -227,7 +243,8 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
wparams.greedy.best_of = params.best_of; wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size; wparams.beam_search.beam_size = params.beam_size;
wparams.initial_prompt = params.prompt.c_str(); wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
whisper_print_user_data user_data = { &params, &pcmf32s }; whisper_print_user_data user_data = { &params, &pcmf32s };
@ -243,7 +260,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
{ {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
return !is_aborted; return !is_aborted;
}; };
@ -275,64 +292,51 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
return 0; return 0;
} }
class Worker : public Napi::AsyncWorker { Napi::Object whisper(const Napi::CallbackInfo& info) {
public: Napi::Env env = info.Env();
Worker(Napi::Function& callback, whisper_params params) if (info.Length() <= 0 || !info[0].IsObject()) {
: Napi::AsyncWorker(callback), params(params) {} Napi::TypeError::New(env, "object expected").ThrowAsJavaScriptException();
void Execute() override {
run(params, result);
}
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
res[i] = tmp;
} }
Callback().Call({Env().Null(), res}); whisper_params params;
} std::vector<std::vector<std::string>> result;
private: Napi::Object whisper_params = info[0].As<Napi::Object>();
whisper_params params; std::string language = whisper_params.Get("language").As<Napi::String>();
std::vector<std::vector<std::string>> result; std::string model = whisper_params.Get("model").As<Napi::String>();
}; std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
params.language = language;
params.model = model;
params.fname_inp.emplace_back(input);
// run model
run(params, result);
Napi::Value whisper(const Napi::CallbackInfo& info) { fprintf(stderr, "RESULT:\n");
Napi::Env env = info.Env(); for (auto sentence:result) {
if (info.Length() <= 0 || !info[0].IsObject()) { fprintf(stderr, "t0: %s, t1: %s, content: %s \n",
Napi::TypeError::New(env, "object expected").ThrowAsJavaScriptException(); sentence[0].c_str(), sentence[1].c_str(), sentence[2].c_str());
} }
whisper_params params;
Napi::Object whisper_params = info[0].As<Napi::Object>(); Napi::Object res = Napi::Array::New(env, result.size());
std::string language = whisper_params.Get("language").As<Napi::String>(); for (uint64_t i = 0; i < result.size(); ++i) {
std::string model = whisper_params.Get("model").As<Napi::String>(); Napi::Object tmp = Napi::Array::New(env, 3);
std::string input = whisper_params.Get("fname_inp").As<Napi::String>(); for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(env, result[i][j]);
}
res[i] = tmp;
}
params.language = language; return res;
params.model = model;
params.fname_inp.emplace_back(input);
Napi::Function callback = info[1].As<Napi::Function>();
Worker* worker = new Worker(callback, params);
worker->Queue();
return env.Undefined();
} }
Napi::Object Init(Napi::Env env, Napi::Object exports) { Napi::Object Init(Napi::Env env, Napi::Object exports) {
exports.Set( exports.Set(
Napi::String::New(env, "whisper"), Napi::String::New(env, "whisper"),
Napi::Function::New(env, whisper) Napi::Function::New(env, whisper)
); );
return exports; return exports;
} }
NODE_API_MODULE(whisper, Init); NODE_API_MODULE(whisper, Init);

View File

@ -1,36 +1,27 @@
const path = require("path"); const path = require('path');
const { whisper } = require(path.join( const { whisper } = require(path.join(__dirname, '../../build/Release/whisper-addon'));
__dirname,
"../../build/Release/whisper-addon"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
const whisperParams = { const whisperParams = {
language: "en", language: 'en',
model: path.join(__dirname, "../../models/ggml-base.en.bin"), model: path.join(__dirname, '../../models/ggml-base.en.bin'),
fname_inp: "../../samples/jfk.wav", fname_inp: '',
}; };
const arguments = process.argv.slice(2); const arguments = process.argv.slice(2);
const params = Object.fromEntries( const params = Object.fromEntries(
arguments.reduce((pre, item) => { arguments.reduce((pre, item) => {
if (item.startsWith("--")) { if (item.startsWith("--")) {
return [...pre, item.slice(2).split("=")]; return [...pre, item.slice(2).split("=")];
} }
return pre; return pre;
}, []) }, []),
); );
for (const key in params) { for (const key in params) {
if (whisperParams.hasOwnProperty(key)) { if (whisperParams.hasOwnProperty(key)) {
whisperParams[key] = params[key]; whisperParams[key] = params[key];
} }
} }
console.log("whisperParams =", whisperParams); console.log('whisperParams =', whisperParams);
console.log(whisper(whisperParams));
whisperAsync(whisperParams).then((result) => {
console.log(`Result from whisper: ${result}`);
});

View File

@ -31,7 +31,6 @@ options:
-osrt, --output-srt [false ] output result in a srt file -osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video -owts, --output-words [false ] output script for generating karaoke video
-ocsv, --output-csv [false ] output result in a CSV file -ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension) -of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens -ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors -pc, --print-colors [false ] print colors

View File

@ -8,7 +8,6 @@
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <cstring>
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green. // Lowest is red, middle is yellow, highest is green.
@ -57,7 +56,7 @@ struct whisper_params {
int32_t duration_ms = 0; int32_t duration_ms = 0;
int32_t max_context = -1; int32_t max_context = -1;
int32_t max_len = 0; int32_t max_len = 0;
int32_t best_of = 2; int32_t best_of = 5;
int32_t beam_size = -1; int32_t beam_size = -1;
float word_thold = 0.01f; float word_thold = 0.01f;
@ -74,8 +73,6 @@ struct whisper_params {
bool output_srt = false; bool output_srt = false;
bool output_wts = false; bool output_wts = false;
bool output_csv = false; bool output_csv = false;
bool output_jsn = false;
bool output_lrc = false;
bool print_special = false; bool print_special = false;
bool print_colors = false; bool print_colors = false;
bool print_progress = false; bool print_progress = false;
@ -83,7 +80,6 @@ struct whisper_params {
std::string language = "en"; std::string language = "en";
std::string prompt; std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin"; std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {}; std::vector<std::string> fname_inp = {};
@ -131,10 +127,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
@ -180,11 +173,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
@ -203,7 +193,7 @@ struct whisper_print_user_data {
const std::vector<std::vector<float>> * pcmf32s; const std::vector<std::vector<float>> * pcmf32s;
}; };
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -211,8 +201,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
std::string speaker = ""; std::string speaker = "";
int64_t t0 = 0; int64_t t0;
int64_t t1 = 0; int64_t t1;
// print the last n_new segments // print the last n_new segments
const int s0 = n_segments - n_new; const int s0 = n_segments - n_new;
@ -362,192 +352,28 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,text\n";
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text << "\"\n"; fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
} }
return true; return true;
} }
char *escape_double_quotes(const char *str) {
if (str == NULL) {
return NULL;
}
size_t escaped_length = strlen(str) + 1;
for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"') {
escaped_length++;
}
}
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
if (escaped == NULL) {
return NULL;
}
size_t pos = 0;
for (size_t i = 0; str[i] != '\0'; i++) {
if (str[i] == '"') {
escaped[pos++] = '\\';
escaped[pos++] = '"';
} else {
escaped[pos++] = str[i];
}
}
// no need to set zero due to calloc() being used prior
return escaped;
}
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
std::ofstream fout(fname);
int indent = 0;
auto doindent = [&]() {
for (int i = 0; i < indent; i++) fout << "\t";
};
auto start_arr = [&](const char *name) {
doindent();
fout << "\"" << name << "\": [\n";
indent++;
};
auto end_arr = [&](bool end = false) {
indent--;
doindent();
fout << (end ? "]\n" : "},\n");
};
auto start_obj = [&](const char *name = nullptr) {
doindent();
if (name) {
fout << "\"" << name << "\": {\n";
} else {
fout << "{\n";
}
indent++;
};
auto end_obj = [&](bool end = false) {
indent--;
doindent();
fout << (end ? "}\n" : "},\n");
};
auto start_value = [&](const char *name) {
doindent();
fout << "\"" << name << "\": ";
};
auto value_s = [&](const char *name, const char *val, bool end = false) {
start_value(name);
char * val_escaped = escape_double_quotes(val);
fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
free(val_escaped);
};
auto end_value = [&](bool end = false) {
fout << (end ? "\n" : ",\n");
};
auto value_i = [&](const char *name, const int64_t val, bool end = false) {
start_value(name);
fout << val;
end_value(end);
};
auto value_b = [&](const char *name, const bool val, bool end = false) {
start_value(name);
fout << (val ? "true" : "false");
end_value(end);
};
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj();
value_s("systeminfo", whisper_print_system_info());
start_obj("model");
value_s("type", whisper_model_type_readable(ctx));
value_b("multilingual", whisper_is_multilingual(ctx));
value_i("vocab", whisper_model_n_vocab(ctx));
start_obj("audio");
value_i("ctx", whisper_model_n_audio_ctx(ctx));
value_i("state", whisper_model_n_audio_state(ctx));
value_i("head", whisper_model_n_audio_head(ctx));
value_i("layer", whisper_model_n_audio_layer(ctx), true);
end_obj();
start_obj("text");
value_i("ctx", whisper_model_n_text_ctx(ctx));
value_i("state", whisper_model_n_text_state(ctx));
value_i("head", whisper_model_n_text_head(ctx));
value_i("layer", whisper_model_n_text_layer(ctx), true);
end_obj();
value_i("mels", whisper_model_n_mels(ctx));
value_i("f16", whisper_model_f16(ctx), true);
end_obj();
start_obj("params");
value_s("model", params.model.c_str());
value_s("language", params.language.c_str());
value_b("translate", params.translate, true);
end_obj();
start_obj("result");
value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
end_obj();
start_arr("transcription");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
start_obj();
start_obj("timestamps");
value_s("from", to_timestamp(t0, true).c_str());
value_s("to", to_timestamp(t1, true).c_str(), true);
end_obj();
start_obj("offsets");
value_i("from", t0 * 10);
value_i("to", t1 * 10, true);
end_obj();
value_s("text", text, true);
end_obj(i == (n_segments - 1));
}
end_arr(true);
end_obj(true);
return true;
}
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
std::ofstream fout(fname); std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
static const char * font = params.font_path.c_str(); // TODO: become parameter
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::ifstream fin(font);
if (!fin.is_open()) {
fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font);
return false;
}
fout << "#!/bin/bash" << "\n"; fout << "#!/bin/bash" << "\n";
fout << "\n"; fout << "\n";
@ -650,39 +476,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
return true; return true;
} }
bool output_lrc(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "[by:whisper.cpp]\n";
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t = whisper_full_get_segment_t0(ctx, i);
int64_t msec = t * 10;
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[16];
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
std::string timestamp_lrc = std::string(buf);
fout << '[' << timestamp_lrc << ']' << text << "\n";
}
return true;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
whisper_params params; whisper_params params;
@ -711,6 +504,22 @@ int main(int argc, char ** argv) {
return 3; return 3;
} }
// initial prompt
std::vector<whisper_token> prompt_tokens;
if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@ -774,7 +583,8 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
wparams.initial_prompt = params.prompt.c_str(); wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.greedy.best_of = params.best_of; wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size; wparams.beam_search.beam_size = params.beam_size;
@ -797,7 +607,7 @@ int main(int argc, char ** argv) {
{ {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
return !is_aborted; return !is_aborted;
}; };
@ -808,6 +618,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10; return 10;
} }
whisper_full_cluster_segments(ctx);
} }
// output stuff // output stuff
@ -843,18 +655,6 @@ int main(int argc, char ** argv) {
const auto fname_csv = fname_out + ".csv"; const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str()); output_csv(ctx, fname_csv.c_str());
} }
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str());
}
} }
} }

View File

@ -43,7 +43,6 @@ struct whisper_params {
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool no_fallback = false;
bool print_special = false; bool print_special = false;
bool no_context = true; bool no_context = true;
bool no_timestamps = false; bool no_timestamps = false;
@ -74,7 +73,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@ -96,23 +94,22 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms); fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms); fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms); fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -151,7 +148,7 @@ int main(int argc, char ** argv) {
// whisper init // whisper init
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){ if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
exit(0); exit(0);
@ -291,6 +288,7 @@ int main(int argc, char ** argv) {
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate; wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = !use_vad; wparams.single_segment = !use_vad;
wparams.max_tokens = params.max_tokens; wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str(); wparams.language = params.language.c_str();
@ -300,8 +298,7 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
// disable temperature fallback // disable temperature fallback
//wparams.temperature_inc = -1.0f; wparams.temperature_inc = -1.0f;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

View File

@ -1 +0,0 @@
audio.mp3

View File

@ -1,16 +0,0 @@
if (WHISPER_SUPPORT_SDL2)
# talk-llama
set(TARGET talk-llama)
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions)
endif ()

View File

@ -1,36 +0,0 @@
# talk-llama
Talk with an LLaMA AI in your terminal
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
## Building
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2 on Linux
sudo apt-get install libsdl2-dev
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk-llama" executable
make talk-llama
# Run it
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
## Discussion
If you have any feedback, please let "us" know in the following discussion: https://github.com/ggerganov/whisper.cpp/discussions/672?converting=1

View File

@ -1,23 +0,0 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk-llama/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
# Get a Voice object, by name or UUID
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = voice.generate(str(sys.argv[2:]))
# Save the TTS to a file
audio.save("audio")

File diff suppressed because it is too large Load Diff

View File

@ -1,173 +0,0 @@
#ifndef LLAMA_H
#define LLAMA_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef LLAMA_BUILD
# define LLAMA_API __declspec(dllexport)
# else
# define LLAMA_API __declspec(dllimport)
# endif
# else
# define LLAMA_API __attribute__ ((visibility ("default")))
# endif
#else
# define LLAMA_API
#endif
#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
#ifdef __cplusplus
extern "C" {
#endif
//
// C interface
//
// TODO: show sample usage
//
struct llama_context;
typedef int llama_token;
typedef struct llama_token_data {
llama_token id; // token id
float p; // probability of the token
float plog; // log probability of the token
} llama_token_data;
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
int n_parts; // -1 for default
int seed; // RNG seed, 0 for random
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
};
LLAMA_API struct llama_context_params llama_context_default_params();
LLAMA_API bool llama_mmap_supported();
LLAMA_API bool llama_mlock_supported();
// Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
LLAMA_API struct llama_context * llama_init_from_file(
const char * path_model,
struct llama_context_params params);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
// TODO: not great API - very likely to change
// Returns 0 on success
LLAMA_API int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
int itype);
// Returns the KV cache that will contain the context for the
// ongoing prediction with the model.
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
// Returns the size of the KV cache
LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
// Sets the KV cache containing the current context for the model
LLAMA_API void llama_set_kv_cache(
struct llama_context * ctx,
const uint8_t * kv_cache,
size_t n_size,
int n_token_count);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
// Returns 0 on success
LLAMA_API int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
LLAMA_API int llama_n_embd (struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token
// Rows: n_tokens
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();
// TODO: improve the last_n_tokens interface ?
LLAMA_API llama_token llama_sample_top_p_top_k(
struct llama_context * ctx,
const llama_token * last_n_tokens_data,
int last_n_tokens_size,
int top_k,
float top_p,
float temp,
float repeat_penalty);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
// Print system information
LLAMA_API const char * llama_print_system_info(void);
#ifdef __cplusplus
}
#endif
#endif // LLAMA_H

View File

@ -1,12 +0,0 @@
// Internal header to be included by llama.cpp and tests/benchmarks only.
#ifndef LLAMA_INTERNAL_H
#define LLAMA_INTERNAL_H
#include <vector>
#include <string>
struct ggml_tensor;
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
#endif // LLAMA_INTERNAL_H

View File

@ -1,383 +0,0 @@
// Internal header to be included only by llama.cpp.
// Contains wrappers around OS interfaces.
#ifndef LLAMA_UTIL_H
#define LLAMA_UTIL_H
#include <cstdio>
#include <cstdint>
#include <cerrno>
#include <cstring>
#include <cstdarg>
#include <cstdlib>
#include <climits>
#include <string>
#include <vector>
#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#if defined(_POSIX_MAPPED_FILES)
#include <sys/mman.h>
#endif
#endif
#endif
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
#endif
#define LLAMA_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#ifdef __GNUC__
__attribute__((format(printf, 1, 2)))
#endif
static std::string format(const char * fmt, ...) {
va_list ap, ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
LLAMA_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
};
struct llama_file {
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
size_t size;
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw format("failed to open %s: %s", fname, std::strerror(errno));
}
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
size_t tell() const {
#ifdef _WIN32
__int64 ret = _ftelli64(fp);
#else
long ret = std::ftell(fp);
#endif
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
return (size_t) ret;
}
void seek(size_t offset, int whence) {
#ifdef _WIN32
int ret = _fseeki64(fp, (__int64) offset, whence);
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
LLAMA_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, size, 1, fp);
if (ferror(fp)) {
throw format("read error: %s", strerror(errno));
}
if (ret != 1) {
throw std::string("unexpectedly reached end of file");
}
}
std::uint32_t read_u32() {
std::uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
std::string read_string(std::uint32_t len) {
std::vector<char> chars(len);
read_raw(chars.data(), len);
return std::string(chars.data(), len);
}
void write_raw(const void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, size, 1, fp);
if (ret != 1) {
throw format("write error: %s", strerror(errno));
}
}
void write_u32(std::uint32_t val) {
write_raw(&val, sizeof(val));
}
~llama_file() {
if (fp) {
std::fclose(fp);
}
}
};
#if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) {
LPSTR buf;
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
if (!size) {
return "FormatMessageA failed";
}
std::string ret(buf, size);
LocalFree(buf);
return ret;
}
#endif
struct llama_mmap {
void * addr;
size_t size;
llama_mmap(const llama_mmap &) = delete;
#ifdef _POSIX_MAPPED_FILES
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file) {
size = file->size;
int fd = fileno(file->fp);
int flags = MAP_SHARED;
#ifdef __linux__
flags |= MAP_POPULATE;
#endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
close(fd);
if (addr == MAP_FAILED) {
throw format("mmap failed: %s", strerror(errno));
}
// Advise the kernel to preload the mapped memory
if (madvise(addr, file->size, MADV_WILLNEED)) {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
~llama_mmap() {
munmap(addr, size);
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file) {
size = file->size;
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
DWORD error = GetLastError();
CloseHandle(hFile);
if (hMapping == NULL) {
throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str());
}
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
error = GetLastError();
CloseHandle(hMapping);
if (addr == NULL) {
throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str());
}
// Advise the kernel to preload the mapped memory
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = (SIZE_T)size;
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
~llama_mmap() {
if (!UnmapViewOfFile(addr)) {
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
llama_mmap(struct llama_file *) {
throw std::string("mmap not supported");
}
#endif
};
// Represents some region of memory being locked using mlock or VirtualLock;
// will automatically unlock on destruction.
struct llama_mlock {
void * addr = NULL;
size_t size = 0;
bool failed_already = false;
llama_mlock() {}
llama_mlock(const llama_mlock &) = delete;
~llama_mlock() {
if (size) {
raw_unlock(addr, size);
}
}
void init(void * addr) {
LLAMA_ASSERT(this->addr == NULL && this->size == 0);
this->addr = addr;
}
void grow_to(size_t target_size) {
LLAMA_ASSERT(addr);
if (failed_already) {
return;
}
size_t granularity = lock_granularity();
target_size = (target_size + granularity - 1) & ~(granularity - 1);
if (target_size > size) {
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
size = target_size;
} else {
failed_already = true;
}
}
}
#ifdef _POSIX_MEMLOCK_RANGE
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
return (size_t) sysconf(_SC_PAGESIZE);
}
#ifdef __APPLE__
#define MLOCK_SUGGESTION \
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
#else
#define MLOCK_SUGGESTION \
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
#endif
bool raw_lock(const void * addr, size_t size) {
if (!mlock(addr, size)) {
return true;
} else {
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n" MLOCK_SUGGESTION,
size, this->size, std::strerror(errno));
return false;
}
}
#undef MLOCK_SUGGESTION
void raw_unlock(void * addr, size_t size) {
if (munlock(addr, size)) {
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
}
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
SYSTEM_INFO si;
GetSystemInfo(&si);
return (size_t) si.dwPageSize;
}
bool raw_lock(void * addr, size_t size) {
for (int tries = 1; ; tries++) {
if (VirtualLock(addr, size)) {
return true;
}
if (tries == 2) {
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
size, this->size, llama_format_win_err(GetLastError()).c_str());
return false;
}
// It failed but this was only the first try; increase the working
// set size and try again.
SIZE_T min_ws_size, max_ws_size;
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
// Per MSDN: "The maximum number of pages that a process can lock
// is equal to the number of pages in its minimum working set minus
// a small overhead."
// Hopefully a megabyte is enough overhead:
size_t increment = size + 1048576;
// The minimum must be <= the maximum, so we need to increase both:
min_ws_size += size;
max_ws_size += size;
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
}
}
void raw_unlock(void * addr, size_t size) {
if (!VirtualUnlock(addr, size)) {
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
void raw_lock(const void * addr, size_t size) {
fprintf(stderr, "warning: mlock not supported on this system\n");
}
void raw_unlock(const void * addr, size_t size) {}
#endif
};
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
struct llama_buffer {
uint8_t * addr = NULL;
size_t size = 0;
void resize(size_t size) {
delete[] addr;
addr = new uint8_t[size];
this->size = size;
}
~llama_buffer() {
delete[] addr;
}
};
#endif

View File

@ -1,23 +0,0 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
### Response:
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4}

View File

@ -1,21 +0,0 @@
#!/bin/bash
# Usage:
# speak.sh <voice_id> <text-to-speak>
# espeak
# Mac OS: brew install espeak
# Linux: apt-get install espeak
#
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
# for Mac
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk-llama/eleven-labs.py
#
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2" >/dev/null 2>&1
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1

View File

@ -1,550 +0,0 @@
// Talk with AI
//
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "llama.h"
#include <cassert>
#include <cstdio>
#include <fstream>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
std::vector<llama_token> res(text.size() + (int)add_bos);
int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
assert(n >= 0);
res.resize(n);
return res;
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t voice_ms = 10000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
int32_t n_parts_llama = -1;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool verbose_prompt = false;
std::string person = "Georgi";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk-llama/speak.sh";
std::string prompt = "";
std::string fname_out;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
}
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
std::string transcribe(
whisper_context * ctx,
const whisper_params & params,
const std::vector<float> & pcmf32,
const std::string prompt_text,
float & prob,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
std::vector<whisper_token> prompt_tokens;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4})";
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// llama init
auto lparams = llama_context_default_params();
// tune these to your liking
lparams.n_ctx = 2048;
lparams.seed = 1;
lparams.f16_kv = true;
lparams.n_parts = params.n_parts_llama;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx_wsp)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
float prob0 = 0.0f;
const std::string chat_symb = ":";
const std::string bot_name = "LLaMA";
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
// construct the initial prompt for LLaMA inference
std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
// need to have leading ' '
prompt_llama.insert(0, 1, ' ');
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
{
// get time string
std::string time_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%H:%M", now);
time_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{2}", time_str);
}
{
// get year string
std::string year_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%Y", now);
year_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{3}", year_str);
}
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
// evaluate the initial prompt
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
printf("\n");
printf("%s : initializing - please wait ...\n", __func__);
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
if (params.verbose_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s", prompt_llama.c_str());
fflush(stdout);
}
printf("%s : done! start speaking in the microphone\n", __func__);
printf("\n");
printf("%s%s", params.person.c_str(), chat_symb.c_str());
fflush(stdout);
// clear audio buffer
audio.clear();
// text inference variables
const int voice_id = 2;
const int n_keep = embd_inp.size();
const int n_ctx = llama_n_ctx(ctx_llama);
int n_past = n_keep;
int n_prev = 64; // TODO arg
std::vector<llama_token> embd;
// reverse prompts for detecting when it's time to stop speaking
std::vector<std::string> antiprompts = {
params.person + chat_symb,
};
// main loop
while (is_running) {
// handle Ctrl + C
is_running = sdl_poll_events();
if (!is_running) {
break;
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int64_t t_ms = 0;
{
audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
//fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
}
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
if (text_heard.empty() || tokens.empty() || force_speak) {
//fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
audio.clear();
continue;
}
force_speak = false;
text_heard.insert(0, 1, ' ');
text_heard += "\n" + bot_name + chat_symb;
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
fflush(stdout);
embd = ::llama_tokenize(ctx_llama, text_heard, false);
// text inference
bool done = false;
std::string text_to_speak;
while (true) {
// predict
if (embd.size() > 0) {
if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep;
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
}
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}
//printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size();
embd.clear();
if (done) break;
{
// out of user input, sample next token
const float top_k = 5;
const float top_p = 0.80f;
const float temp = 0.30f;
const float repeat_penalty = 1.1764f;
const int repeat_last_n = 256;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx_llama);
logits[llama_token_eos()] = 0;
id = llama_sample_top_p_top_k(ctx_llama,
embd_inp.data() + std::max(0, n_past - repeat_last_n),
repeat_last_n, top_k, top_p, temp, repeat_penalty);
}
if (id != llama_token_eos()) {
// add it to the context
embd.push_back(id);
text_to_speak += llama_token_to_str(ctx_llama, id);
printf("%s", llama_token_to_str(ctx_llama, id));
}
}
{
std::string last_output;
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
}
last_output += llama_token_to_str(ctx_llama, embd[0]);
for (std::string & antiprompt : antiprompts) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
done = true;
text_to_speak = ::replace(text_to_speak, antiprompt, "");
fflush(stdout);
break;
}
}
}
is_running = sdl_poll_events();
if (!is_running) {
break;
}
}
text_to_speak = ::replace(text_to_speak, "\"", "");
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
audio.clear();
++n_iter;
}
}
}
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_free(ctx_wsp);
llama_print_timings(ctx_llama);
llama_free(ctx_llama);
return 0;
}

View File

@ -841,7 +841,6 @@ struct gpt2_context * gpt2_init(const char * path_model) {
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) { if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin"); fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
delete ctx;
return nullptr; return nullptr;
} }

View File

@ -1 +1 @@
audio.mp3 eleven-labs.py

View File

@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this: Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
``` ```
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
``` ```
## TTS ## TTS

View File

@ -1,23 +0,0 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
# Get a Voice object, by name or UUID
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = voice.generate(str(sys.argv[2:]))
# Save the TTS to a file
audio.save("audio")

View File

@ -7,13 +7,9 @@
# Mac OS: brew install espeak # Mac OS: brew install espeak
# Linux: apt-get install espeak # Linux: apt-get install espeak
# #
#espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2" espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
# Mac OS "say" command
say "$2"
# Eleven Labs # Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk/eleven-labs.py
# #
#wd=$(dirname $0) #wd=$(dirname $0)
#script=$wd/eleven-labs.py #script=$wd/eleven-labs.py

View File

@ -9,4 +9,4 @@ To use:
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device. 5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
[^1]: I recommend the tiny or base models for running on an Android device. [^1]: I recommend the tiny or base models for running on an Android device.
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1670775/221613663-a17bf770-27ef-45ab-9a46-a5f99ba65d2a.jpg"> <img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">

View File

@ -2,7 +2,6 @@ package com.whispercppdemo.ui.main
import androidx.compose.foundation.layout.* import androidx.compose.foundation.layout.*
import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.text.selection.SelectionContainer
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.* import androidx.compose.material3.*
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
@ -20,7 +19,6 @@ fun MainScreen(viewModel: MainScreenViewModel) {
canTranscribe = viewModel.canTranscribe, canTranscribe = viewModel.canTranscribe,
isRecording = viewModel.isRecording, isRecording = viewModel.isRecording,
messageLog = viewModel.dataLog, messageLog = viewModel.dataLog,
onBenchmarkTapped = viewModel::benchmark,
onTranscribeSampleTapped = viewModel::transcribeSample, onTranscribeSampleTapped = viewModel::transcribeSample,
onRecordTapped = viewModel::toggleRecord onRecordTapped = viewModel::toggleRecord
) )
@ -32,7 +30,6 @@ private fun MainScreen(
canTranscribe: Boolean, canTranscribe: Boolean,
isRecording: Boolean, isRecording: Boolean,
messageLog: String, messageLog: String,
onBenchmarkTapped: () -> Unit,
onTranscribeSampleTapped: () -> Unit, onTranscribeSampleTapped: () -> Unit,
onRecordTapped: () -> Unit onRecordTapped: () -> Unit
) { ) {
@ -48,11 +45,8 @@ private fun MainScreen(
.padding(innerPadding) .padding(innerPadding)
.padding(16.dp) .padding(16.dp)
) { ) {
Column(verticalArrangement = Arrangement.SpaceBetween) { Row(horizontalArrangement = Arrangement.SpaceBetween) {
Row(horizontalArrangement = Arrangement.SpaceBetween, modifier = Modifier.fillMaxWidth()) { TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
BenchmarkButton(enabled = canTranscribe, onClick = onBenchmarkTapped)
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
}
RecordButton( RecordButton(
enabled = canTranscribe, enabled = canTranscribe,
isRecording = isRecording, isRecording = isRecording,
@ -66,16 +60,7 @@ private fun MainScreen(
@Composable @Composable
private fun MessageLog(log: String) { private fun MessageLog(log: String) {
SelectionContainer() { Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
}
}
@Composable
private fun BenchmarkButton(enabled: Boolean, onClick: () -> Unit) {
Button(onClick = onClick, enabled = enabled) {
Text("Benchmark")
}
} }
@Composable @Composable

View File

@ -41,15 +41,10 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
init { init {
viewModelScope.launch { viewModelScope.launch {
printSystemInfo()
loadData() loadData()
} }
} }
private suspend fun printSystemInfo() {
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
}
private suspend fun loadData() { private suspend fun loadData() {
printMessage("Loading data...\n") printMessage("Loading data...\n")
try { try {
@ -86,29 +81,10 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath) //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
} }
fun benchmark() = viewModelScope.launch {
runBenchmark(6)
}
fun transcribeSample() = viewModelScope.launch { fun transcribeSample() = viewModelScope.launch {
transcribeAudio(getFirstSample()) transcribeAudio(getFirstSample())
} }
private suspend fun runBenchmark(nthreads: Int) {
if (!canTranscribe) {
return
}
canTranscribe = false
printMessage("Running benchmark. This will take minutes...\n")
whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
printMessage("\n")
whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }
canTranscribe = true
}
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) { private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
samplesPath.listFiles()!!.first() samplesPath.listFiles()!!.first()
} }
@ -138,14 +114,11 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
canTranscribe = false canTranscribe = false
try { try {
printMessage("Reading wave samples... ") printMessage("Reading wave samples...\n")
val data = readAudioSamples(file) val data = readAudioSamples(file)
printMessage("${data.size / (16000 / 1000)} ms\n")
printMessage("Transcribing data...\n") printMessage("Transcribing data...\n")
val start = System.currentTimeMillis()
val text = whisperContext?.transcribeData(data) val text = whisperContext?.transcribeData(data)
val elapsed = System.currentTimeMillis() - start printMessage("Done: $text\n")
printMessage("Done ($elapsed ms): $text\n")
} catch (e: Exception) { } catch (e: Exception) {
Log.w(LOG_TAG, e) Log.w(LOG_TAG, e)
printMessage("${e.localizedMessage}\n") printMessage("${e.localizedMessage}\n")

View File

@ -27,14 +27,6 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
} }
suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchMemcpy(nthreads)
}
suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchGgmlMulMat(nthreads)
}
suspend fun release() = withContext(scope.coroutineContext) { suspend fun release() = withContext(scope.coroutineContext) {
if (ptr != 0L) { if (ptr != 0L) {
WhisperLib.freeContext(ptr) WhisperLib.freeContext(ptr)
@ -74,10 +66,6 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
return WhisperContext(ptr) return WhisperContext(ptr)
} }
fun getSystemInfo(): String {
return WhisperLib.getSystemInfo()
}
} }
} }
@ -86,7 +74,6 @@ private class WhisperLib {
init { init {
Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}") Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}")
var loadVfpv4 = false var loadVfpv4 = false
var loadV8fp16 = false
if (isArmEabiV7a()) { if (isArmEabiV7a()) {
// armeabi-v7a needs runtime detection support // armeabi-v7a needs runtime detection support
val cpuInfo = cpuInfo() val cpuInfo = cpuInfo()
@ -97,24 +84,11 @@ private class WhisperLib {
loadVfpv4 = true loadVfpv4 = true
} }
} }
} else if (isArmEabiV8a()) {
// ARMv8.2a needs runtime detection support
val cpuInfo = cpuInfo()
cpuInfo?.let {
Log.d(LOG_TAG, "CPU info: $cpuInfo")
if (cpuInfo.contains("fphp")) {
Log.d(LOG_TAG, "CPU supports fp16 arithmetic")
loadV8fp16 = true
}
}
} }
if (loadVfpv4) { if (loadVfpv4) {
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so") Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so")
System.loadLibrary("whisper_vfpv4") System.loadLibrary("whisper_vfpv4")
} else if (loadV8fp16) {
Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so")
System.loadLibrary("whisper_v8fp16_va")
} else { } else {
Log.d(LOG_TAG, "Loading libwhisper.so") Log.d(LOG_TAG, "Loading libwhisper.so")
System.loadLibrary("whisper") System.loadLibrary("whisper")
@ -129,9 +103,6 @@ private class WhisperLib {
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray) external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
external fun benchMemcpy(nthread: Int): String
external fun benchGgmlMulMat(nthread: Int): String
} }
} }
@ -139,10 +110,6 @@ private fun isArmEabiV7a(): Boolean {
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a") return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a")
} }
private fun isArmEabiV8a(): Boolean {
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a")
}
private fun cpuInfo(): String? { private fun cpuInfo(): String? {
return try { return try {
File("/proc/cpuinfo").inputStream().bufferedReader().use { File("/proc/cpuinfo").inputStream().bufferedReader().use {

View File

@ -13,14 +13,3 @@ ifeq ($(TARGET_ARCH_ABI),armeabi-v7a)
LOCAL_CFLAGS += -mfpu=neon-vfpv4 LOCAL_CFLAGS += -mfpu=neon-vfpv4
include $(BUILD_SHARED_LIBRARY) include $(BUILD_SHARED_LIBRARY)
endif endif
ifeq ($(TARGET_ARCH_ABI),arm64-v8a)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper_v8fp16_va
include $(LOCAL_PATH)/Whisper.mk
# Allow building NEON FMA code.
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
LOCAL_CFLAGS += -march=armv8.2-a+fp16
include $(BUILD_SHARED_LIBRARY)
endif

View File

@ -6,7 +6,6 @@
#include <sys/sysinfo.h> #include <sys/sysinfo.h>
#include <string.h> #include <string.h>
#include "whisper.h" #include "whisper.h"
#include "ggml.h"
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define TAG "JNI" #define TAG "JNI"
@ -215,29 +214,3 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
jstring string = (*env)->NewStringUTF(env, text); jstring string = (*env)->NewStringUTF(env, text);
return string; return string;
} }
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getSystemInfo(
JNIEnv *env, jobject thiz
) {
UNUSED(thiz);
const char *sysinfo = whisper_print_system_info();
jstring string = (*env)->NewStringUTF(env, sysinfo);
return string;
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
}

View File

@ -24,5 +24,3 @@ Also, don't forget to add the `-DGGML_USE_ACCELERATE` compiler flag in Build Pha
This can significantly improve the performance of the transcription: This can significantly improve the performance of the transcription:
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png"> <img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.

View File

@ -296,10 +296,6 @@
IPHONEOS_DEPLOYMENT_TARGET = 16.0; IPHONEOS_DEPLOYMENT_TARGET = 16.0;
MTL_ENABLE_DEBUG_INFO = NO; MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES; MTL_FAST_MATH = YES;
OTHER_CFLAGS = (
"-O3",
"-DNDEBUG",
);
SDKROOT = iphoneos; SDKROOT = iphoneos;
VALIDATE_PRODUCT = YES; VALIDATE_PRODUCT = YES;
}; };

View File

@ -1,18 +1,14 @@
A sample SwiftUI app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions. A sample SwiftUI app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
See also: [whisper.objc](https://github.com/ggerganov/whisper.cpp/tree/master/examples/whisper.objc). See also: [whisper.objc](https://github.com/ggerganov/whisper.cpp/tree/master/examples/whisper.objc).
**Usage**: To use:
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1] 1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
2. Add the model to `whisper.swiftui.demo/Resources/models` **via Xcode**. 2. Add the model to "whisper.swiftui.demo/Resources/models" via Xcode.
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)). 3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
4. Add the sample audio file to `whisper.swiftui.demo/Resources/samples` **via Xcode**. 4. Add the model to "whisper.swiftui.demo/Resources/samples" via Xcode.
5. Select the "Release" [^2] build configuration under "Run", then deploy and run to your device. 5. Select the "release" build configuration under "Run", then deploy and run to your device.
**Note:** Pay attention to the folder path: `whisper.swiftui.demo/Resources/models` is the appropriate directory to place resources whilst `whisper.swiftui.demo/Models` is related to actual code.
[^1]: I recommend the tiny, base or small models for running on an iOS device. [^1]: I recommend the tiny, base or small models for running on an iOS device.
[^2]: The `Release` build can boost performance of transcription. In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
![image](https://user-images.githubusercontent.com/1991296/212539216-0aef65e4-f882-480a-8358-0f816838fd52.png) ![image](https://user-images.githubusercontent.com/1991296/212539216-0aef65e4-f882-480a-8358-0f816838fd52.png)

View File

@ -430,10 +430,6 @@
LLVM_LTO = YES; LLVM_LTO = YES;
MACOSX_DEPLOYMENT_TARGET = 13.0; MACOSX_DEPLOYMENT_TARGET = 13.0;
MARKETING_VERSION = 1.0; MARKETING_VERSION = 1.0;
OTHER_CFLAGS = (
"-O3",
"-DNDEBUG",
);
PRODUCT_BUNDLE_IDENTIFIER = com.whispercppdemo.WhisperCppDemo; PRODUCT_BUNDLE_IDENTIFIER = com.whispercppdemo.WhisperCppDemo;
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto; SDKROOT = auto;

View File

@ -64,10 +64,6 @@ for model in "${models[@]}"; do
config="$config BLAS" config="$config BLAS"
fi fi
if [[ $system_info == *"COREML = 1"* ]]; then
config="$config COREML"
fi
commit=$(git rev-parse --short HEAD) commit=$(git rev-parse --short HEAD)
printf "| <todo> | <todo> | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n" printf "| <todo> | <todo> | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n"

View File

@ -1,70 +0,0 @@
# Benchmark word-level timestamps for different models
#
# This script takes two arguments
# - an audio file
# - [optional] path to a font file
# I'm using "/usr/share/fonts/truetype/freefont/FreeMono.ttf" on Ubuntu
if [ -z "$1" ]; then
echo "Usage: $0 <audio file> [font file]"
exit 1
fi
#TODO: Make this a command line parameter
#models="base small large"
#models="tiny.en tiny base.en base small.en small medium.en medium large-v1 large"
models="tiny.en base.en small.en medium.en large"
DURATION=$(ffprobe -i $1 -show_entries format=duration -v quiet -of csv="p=0")
DURATION=$(printf "%.2f" $DURATION)
echo "Input file duration: ${DURATION}s"
for model in $models; do
echo "Running $model"
COMMAND="./main -m models/ggml-$model.bin -owts -f $1 -of $1.$model"
if [ ! -z "$2" ]; then
COMMAND="$COMMAND -fp $2"
fi
#TODO: Surface errors better
# TIMEFMT is for zsh, TIMEFORMAT is for bash
EXECTIME=$({ TIMEFMT="%E";TIMEFORMAT=%E; time $COMMAND >/dev/null 2>&1; } 2>&1)
# Slightly different formats between zsh and bash
if [ "${EXECTIME: -1}" == "s" ]; then
EXECTIME=${EXECTIME::-1}
fi
RATIO=$(echo "$DURATION / $EXECTIME" | bc -l)
RATIO=$(printf "%.2f" $RATIO)
echo "Execution time: ${EXECTIME}s (${RATIO}x realtime)"
# If the file already exists, delete it
if [ -f $1.mp4 ]; then
rm $1.mp4
fi
bash $1.$model.wts >/dev/null 2>&1
mv $1.mp4 $1.$model.mp4
ffmpeg -y -f lavfi -i color=c=black:s=1200x50:d=$DURATION -vf "drawtext=fontfile=$2:fontsize=36:x=10:y=(h-text_h)/2:text='ggml-$model - ${EXECTIME}s (${RATIO}x realtime)':fontcolor=lightgrey" $1.$model.info.mp4 >/dev/null 2>&1
done
COMMAND="ffmpeg -y"
for model in $models; do
COMMAND="$COMMAND -i $1.$model.info.mp4 -i $1.$model.mp4"
done
COMMAND="$COMMAND -filter_complex \""
COUNT=0
for model in $models; do
COMMAND="$COMMAND[${COUNT}:v][$(($COUNT+1)):v]"
COUNT=$((COUNT+2))
done
COMMAND="$COMMAND vstack=inputs=${COUNT}[v]\" -map \"[v]\" -map 1:a $1.all.mp4 >/dev/null 2>&1"
echo $COMMAND
# Run the command
eval $COMMAND

4725
ggml.c

File diff suppressed because it is too large Load Diff

166
ggml.h
View File

@ -177,12 +177,11 @@ extern "C" {
#include <stddef.h> #include <stddef.h>
#include <stdbool.h> #include <stdbool.h>
#define GGML_MAX_DIMS 4 #define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096 #define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16 #define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4 #define GGML_MAX_OPT 4
#define GGML_DEFAULT_N_THREADS 4
#ifdef __ARM_NEON #ifdef __ARM_NEON
// we use the built-in 16-bit float type // we use the built-in 16-bit float type
@ -199,14 +198,11 @@ struct ggml_object;
struct ggml_context; struct ggml_context;
enum ggml_type { enum ggml_type {
// explicitly numbered values are used in llama.cpp files
GGML_TYPE_F32 = 0,
GGML_TYPE_F16 = 1,
GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3,
GGML_TYPE_I8, GGML_TYPE_I8,
GGML_TYPE_I16, GGML_TYPE_I16,
GGML_TYPE_I32, GGML_TYPE_I32,
GGML_TYPE_F16,
GGML_TYPE_F32,
GGML_TYPE_COUNT, GGML_TYPE_COUNT,
}; };
@ -230,15 +226,12 @@ enum ggml_op {
GGML_OP_STEP, GGML_OP_STEP,
GGML_OP_RELU, GGML_OP_RELU,
GGML_OP_GELU, GGML_OP_GELU,
GGML_OP_SILU,
GGML_OP_NORM, // normalize GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
GGML_OP_MUL_MAT, GGML_OP_MUL_MAT,
GGML_OP_SCALE, GGML_OP_SCALE,
GGML_OP_CPY, GGML_OP_CPY,
GGML_OP_CONT,
GGML_OP_RESHAPE, GGML_OP_RESHAPE,
GGML_OP_VIEW, GGML_OP_VIEW,
GGML_OP_PERMUTE, GGML_OP_PERMUTE,
@ -253,35 +246,19 @@ enum ggml_op {
GGML_OP_FLASH_ATTN, GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF, GGML_OP_FLASH_FF,
GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY,
GGML_OP_COUNT, GGML_OP_COUNT,
}; };
// ggml object
struct ggml_object {
size_t offs;
size_t size;
struct ggml_object * next;
char padding[8];
};
static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
// n-dimensional tensor // n-dimensional tensor
struct ggml_tensor { struct ggml_tensor {
enum ggml_type type; enum ggml_type type;
int n_dims; int n_dims;
int64_t ne[GGML_MAX_DIMS]; // number of elements int ne[GGML_MAX_DIMS]; // number of elements
size_t nb[GGML_MAX_DIMS]; // stride in bytes: size_t nb[GGML_MAX_DIMS]; // stride in bytes:
// nb[0] = sizeof(type) // nb[0] = sizeof(type)
// nb[1] = nb[0] * ne[0] + padding // nb[1] = nb[0] * ne[0] + padding
// nb[i] = nb[i-1] * ne[i-1] // nb[i] = nb[i-1] * ne[i-1]
// compute data // compute data
enum ggml_op op; enum ggml_op op;
@ -335,7 +312,6 @@ struct ggml_init_params {
// memory pool // memory pool
size_t mem_size; // bytes size_t mem_size; // bytes
void * mem_buffer; // if NULL, memory will be allocated internally void * mem_buffer; // if NULL, memory will be allocated internally
bool no_alloc; // don't allocate memory for the tensor data
}; };
void ggml_time_init(void); // call this once at the beginning of the program void ggml_time_init(void); // call this once at the beginning of the program
@ -347,13 +323,10 @@ int64_t ggml_cycles_per_ms(void);
void ggml_print_object (const struct ggml_object * obj); void ggml_print_object (const struct ggml_object * obj);
void ggml_print_objects(const struct ggml_context * ctx); void ggml_print_objects(const struct ggml_context * ctx);
int64_t ggml_nelements(const struct ggml_tensor * tensor); int ggml_nelements(const struct ggml_tensor * tensor);
size_t ggml_nbytes (const struct ggml_tensor * tensor); size_t ggml_nbytes (const struct ggml_tensor * tensor);
int ggml_blck_size (enum ggml_type type);
size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
size_t ggml_type_size (enum ggml_type type);
size_t ggml_element_size(const struct ggml_tensor * tensor); size_t ggml_element_size(const struct ggml_tensor * tensor);
struct ggml_context * ggml_init(struct ggml_init_params params); struct ggml_context * ggml_init(struct ggml_init_params params);
@ -367,33 +340,33 @@ struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx, struct ggml_context * ctx,
enum ggml_type type, enum ggml_type type,
int n_dims, int n_dims,
const int64_t *ne); const int *ne);
struct ggml_tensor * ggml_new_tensor_1d( struct ggml_tensor * ggml_new_tensor_1d(
struct ggml_context * ctx, struct ggml_context * ctx,
enum ggml_type type, enum ggml_type type,
int64_t ne0); int ne0);
struct ggml_tensor * ggml_new_tensor_2d( struct ggml_tensor * ggml_new_tensor_2d(
struct ggml_context * ctx, struct ggml_context * ctx,
enum ggml_type type, enum ggml_type type,
int64_t ne0, int ne0,
int64_t ne1); int ne1);
struct ggml_tensor * ggml_new_tensor_3d( struct ggml_tensor * ggml_new_tensor_3d(
struct ggml_context * ctx, struct ggml_context * ctx,
enum ggml_type type, enum ggml_type type,
int64_t ne0, int ne0,
int64_t ne1, int ne1,
int64_t ne2); int ne2);
struct ggml_tensor * ggml_new_tensor_4d( struct ggml_tensor * ggml_new_tensor_4d(
struct ggml_context * ctx, struct ggml_context * ctx,
enum ggml_type type, enum ggml_type type,
int64_t ne0, int ne0,
int64_t ne1, int ne1,
int64_t ne2, int ne2,
int64_t ne3); int ne3);
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
@ -493,20 +466,12 @@ struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_silu(
struct ggml_context * ctx,
struct ggml_tensor * a);
// normalize along rows // normalize along rows
// TODO: eps is hardcoded to 1e-5 for now // TODO: eps is hardcoded to 1e-5 for now
struct ggml_tensor * ggml_norm( struct ggml_tensor * ggml_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: m rows, n columns // A: m rows, n columns
// B: p rows, n columns (i.e. we transpose it internally) // B: p rows, n columns (i.e. we transpose it internally)
// result is m columns, p rows // result is m columns, p rows
@ -531,11 +496,6 @@ struct ggml_tensor * ggml_cpy(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// make contiguous
struct ggml_tensor * ggml_cont(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return view(a), b specifies the new shape // return view(a), b specifies the new shape
// TODO: when we start computing gradient, make a copy instead of view // TODO: when we start computing gradient, make a copy instead of view
struct ggml_tensor * ggml_reshape( struct ggml_tensor * ggml_reshape(
@ -548,43 +508,33 @@ struct ggml_tensor * ggml_reshape(
struct ggml_tensor * ggml_reshape_2d( struct ggml_tensor * ggml_reshape_2d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int64_t ne0, int ne0,
int64_t ne1); int ne1);
// return view(a) // return view(a)
// TODO: when we start computing gradient, make a copy instead of view // TODO: when we start computing gradient, make a copy instead of view
struct ggml_tensor * ggml_reshape_3d( struct ggml_tensor * ggml_reshape_3d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int64_t ne0, int ne0,
int64_t ne1, int ne1,
int64_t ne2); int ne2);
// offset in bytes // offset in bytes
struct ggml_tensor * ggml_view_1d( struct ggml_tensor * ggml_view_1d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int64_t ne0, int ne0,
size_t offset); size_t offset);
struct ggml_tensor * ggml_view_2d( struct ggml_tensor * ggml_view_2d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int64_t ne0, int ne0,
int64_t ne1, int ne1,
size_t nb1, // row stride in bytes size_t nb1, // row stride in bytes
size_t offset); size_t offset);
struct ggml_tensor * ggml_view_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
size_t nb1, // row stride in bytes
size_t nb2, // slice stride in bytes
size_t offset);
struct ggml_tensor * ggml_permute( struct ggml_tensor * ggml_permute(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -655,21 +605,6 @@ struct ggml_tensor * ggml_flash_ff(
struct ggml_tensor * c0, struct ggml_tensor * c0,
struct ggml_tensor * c1); struct ggml_tensor * c1);
// Mapping operations
typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
struct ggml_tensor * ggml_map_unary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_unary_op_f32_t fun);
struct ggml_tensor * ggml_map_binary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
const ggml_binary_op_f32_t fun);
// //
// automatic differentiation // automatic differentiation
// //
@ -792,11 +727,14 @@ enum ggml_opt_result ggml_opt(
struct ggml_tensor * f); struct ggml_tensor * f);
// //
// quantization // Temp stuff
// //
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); void ggml_svd_reduce_dims(
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); int ne0,
int ne1,
float * a,
int nd);
// //
// system info // system info
@ -815,30 +753,6 @@ int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void); int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void); int ggml_cpu_has_vsx(void);
//
// Internal types and functions exposed for tests and benchmarks
//
#ifdef __cplusplus
// restrict not standard in C++
#define GGML_RESTRICT
#else
#define GGML_RESTRICT restrict
#endif
typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
typedef void (*quantize_row_q_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
typedef void (*vec_dot_q_t)(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
typedef struct {
dequantize_row_q_t dequantize_row_q;
quantize_row_q_t quantize_row_q;
quantize_row_q_t quantize_row_q_reference;
vec_dot_q_t vec_dot_q;
} quantize_fns_t;
quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -6,7 +6,7 @@ using the [convert-pt-to-ggml.py](convert-pt-to-ggml.py) script. You can either
the `ggml` files yourself using the conversion script, or you can use the [download-ggml-model.sh](download-ggml-model.sh) the `ggml` files yourself using the conversion script, or you can use the [download-ggml-model.sh](download-ggml-model.sh)
script to download the already converted models. Currently, they are hosted on the following locations: script to download the already converted models. Currently, they are hosted on the following locations:
- https://huggingface.co/ggerganov/whisper.cpp - https://huggingface.co/datasets/ggerganov/whisper.cpp
- https://ggml.ggerganov.com - https://ggml.ggerganov.com
Sample usage: Sample usage:
@ -23,7 +23,7 @@ You can now use it like this:
A third option to obtain the model files is to download them from Hugging Face: A third option to obtain the model files is to download them from Hugging Face:
https://huggingface.co/ggerganov/whisper.cpp/tree/main https://huggingface.co/datasets/ggerganov/whisper.cpp/tree/main
## Available models ## Available models

View File

@ -79,11 +79,11 @@ dir_model = sys.argv[1]
dir_whisper = sys.argv[2] dir_whisper = sys.argv[2]
dir_out = sys.argv[3] dir_out = sys.argv[3]
with open(dir_model + "/vocab.json", "r", encoding="utf8") as f: with open(dir_model + "/vocab.json", "r") as f:
encoder = json.load(f) encoder = json.load(f)
with open(dir_model + "/added_tokens.json", "r", encoding="utf8") as f: with open(dir_model + "/added_tokens.json", "r") as f:
encoder_added = json.load(f) encoder_added = json.load(f)
with open(dir_model + "/config.json", "r", encoding="utf8") as f: with open(dir_model + "/config.json", "r") as f:
hparams = json.load(f) hparams = json.load(f)
model = WhisperForConditionalGeneration.from_pretrained(dir_model) model = WhisperForConditionalGeneration.from_pretrained(dir_model)

View File

@ -39,7 +39,6 @@ import json
import code import code
import torch import torch
import numpy as np import numpy as np
import base64
#from transformers import GPTJForCausalLM #from transformers import GPTJForCausalLM
#from transformers import GPT2TokenizerFast #from transformers import GPT2TokenizerFast
@ -225,14 +224,18 @@ with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as
#code.interact(local=locals()) #code.interact(local=locals())
multilingual = hparams["n_vocab"] == 51865 multilingual = hparams["n_vocab"] == 51865
tokenizer = os.path.join(dir_whisper, "whisper/assets", multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") dir_tokenizer = os.path.join(dir_whisper, "whisper/assets", multilingual and "multilingual" or "gpt2")
#tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2")
#print(tokenizer)
#print(tokenizer.name_or_path)
#print(len(tokenizer.additional_special_tokens))
# output in the same directory as the model # output in the same directory as the model
fname_out = dir_out + "/ggml-model.bin" fname_out = dir_out + "/ggml-model.bin"
with open(tokenizer, "rb") as f: with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f:
contents = f.read() tokens = json.load(f)
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
# use 16-bit or 32-bit floats # use 16-bit or 32-bit floats
use_f16 = True use_f16 = True
@ -268,8 +271,9 @@ byte_decoder = {v:k for k, v in byte_encoder.items()}
fout.write(struct.pack("i", len(tokens))) fout.write(struct.pack("i", len(tokens)))
for key in tokens: for key in tokens:
fout.write(struct.pack("i", len(key))) text = bytearray([byte_decoder[c] for c in key])
fout.write(key) fout.write(struct.pack("i", len(text)))
fout.write(text)
for name in list_vars.keys(): for name in list_vars.keys():
data = list_vars[name].squeeze().numpy() data = list_vars[name].squeeze().numpy()

View File

@ -1,334 +0,0 @@
import argparse
import torch
import torch.nn.functional as F
import coremltools as ct
from torch import Tensor
from torch import nn
from typing import Dict
from typing import Optional
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
from coremltools.models.neural_network.quantization_utils import quantize_weights
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from whisper import load_model
# Use for changing dim of input in encoder and decoder embeddings
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""
Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_attention = all(substr in k for substr in ['attn', '.weight'])
is_mlp = any([k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']])
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs):
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
return state_dict
class LayerNormANE(LayerNormANEBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_load_state_dict_pre_hook(
correct_for_bias_scale_order_inversion)
class MultiHeadAttentionANE(MultiHeadAttention):
def __init__(self, n_state: int, n_head: int):
super().__init__(n_state, n_head)
setattr(self, 'query', nn.Conv2d(n_state, n_state, kernel_size=1))
setattr(self, 'key', nn.Conv2d(n_state, n_state, kernel_size=1, bias=False))
setattr(self, 'value', nn.Conv2d(n_state, n_state, kernel_size=1))
setattr(self, 'out', nn.Conv2d(n_state, n_state, kernel_size=1))
def forward(self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention_ane(q, k, v, mask)
return self.out(wv), qk
def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
_, dim, _, seqlen = q.size()
dim_per_head = dim // self.n_head
scale = float(dim_per_head)**-0.5
q = q * scale
mh_q = q.split(dim_per_head, dim=1)
mh_k = k.transpose(1,3).split(dim_per_head, dim=3)
mh_v = v.split(dim_per_head, dim=1)
mh_qk = [
torch.einsum('bchq,bkhc->bkhq', [qi, ki])
for qi, ki in zip(mh_q, mh_k)
] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
if mask is not None:
for head_idx in range(self.n_head):
mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen]
attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (batch_size, dim_per_head, 1, max_seq_length) * n_heads
attn = torch.cat(attn, dim=1) # (batch_size, dim, 1, max_seq_length)
return attn, torch.cat(mh_qk, dim=1).float().detach()
class ResidualAttentionBlockANE(ResidualAttentionBlock):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__(n_state, n_head, cross_attention)
setattr(self, 'attn', MultiHeadAttentionANE(n_state, n_head))
setattr(self, 'attn_ln', LayerNormANE(n_state))
setattr(self, 'cross_attn', MultiHeadAttentionANE(n_state, n_head) if cross_attention else None)
setattr(self, 'cross_attn_ln', LayerNormANE(n_state) if cross_attention else None)
n_mlp = n_state * 4
setattr(self, 'mlp', nn.Sequential(
nn.Conv2d(n_state, n_mlp, kernel_size=1),
nn.GELU(),
nn.Conv2d(n_mlp, n_state, kernel_size=1)
))
setattr(self, 'mlp_ln', LayerNormANE(n_state))
class AudioEncoderANE(AudioEncoder):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
setattr(self, 'blocks', nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
))
setattr(self, 'ln_post', LayerNormANE(n_state))
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape"
# Add positional embedding and add dummy dim for ANE
x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
# """
# TODO:
# I think we need to transpose the result here to make it fit whisper.cpp memory order.
# However, even doing this, the results are still wrong. Kind of less wrong compared to
# not transposing, but still wrong.
# Also, I don't know why the original OpenAI implementation does not need to transpose
# transpose to (batch_size, n_ctx, n_state)
# x : torch.Tensor, shape = (batch_size, n_state, 1, n_ctx)
# """
# x = x.transpose(1,3)
return x
class TextDecoderANE(TextDecoder):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
setattr(self, 'blocks', nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
))
setattr(self, 'ln', LayerNormANE(n_state))
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
# Reformat for ANE
mask = self.mask[None, None, :, :].permute(0,3,1,2)
x = x.transpose(1,2).unsqueeze(2)
for block in self.blocks:
x = block(x, xa, mask=mask, kv_cache=kv_cache)
x = self.ln(x)
# Reformat back from ANE
x = x.permute(0,2,3,1).squeeze(0)
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
if self.token_embedding.weight.shape[0] == 51865:
# split in 11 chunks - 4715 each
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
else:
# split in 12 chunks - 4322 each
assert(self.token_embedding.weight.shape[0] == 51864)
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//12, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
return logits
class WhisperANE(Whisper):
def __init__(self, dims: ModelDimensions):
super().__init__(dims)
setattr(self, 'encoder', AudioEncoderANE(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
))
setattr(self, 'decoder', TextDecoderANE(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
))
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=3).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttentionANE):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
def convert_encoder(hparams, model, quantize=False):
model.eval()
input_shape = (1, 80, 3000)
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)
model = ct.convert(
traced_model,
convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
outputs=[ct.TensorType(name="output")],
compute_units=ct.ComputeUnit.ALL
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
def convert_decoder(hparams, model, quantize=False):
model.eval()
tokens_shape = (1, 1)
audio_shape = (1, hparams.n_audio_state, 1, 1500)
audio_data = torch.randn(audio_shape)
token_data = torch.randint(50257, tokens_shape).long()
traced_model = torch.jit.trace(model, (token_data, audio_data))
model = ct.convert(
traced_model,
convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why
inputs=[
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
ct.TensorType(name="audio_data", shape=audio_shape)
]
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()
hparams = whisper.dims
print(hparams)
if args.optimize_ane:
whisperANE = WhisperANE(hparams).eval()
whisperANE.load_state_dict(whisper.state_dict())
encoder = whisperANE.encoder
decoder = whisperANE.decoder
else:
encoder = whisper.encoder
decoder = whisper.decoder
# Convert encoder
encoder = convert_encoder(hparams, encoder, quantize=args.quantize)
encoder.save(f"models/coreml-encoder-{args.model}.mlpackage")
if args.encoder_only is False:
# Convert decoder
decoder = convert_decoder(hparams, decoder, quantize=args.quantize)
decoder.save(f"models/coreml-decoder-{args.model}.mlpackage")
print("done converting")

View File

@ -1,82 +0,0 @@
#!/bin/bash
# This script downloads Whisper model files that have already been converted to Core ML format.
# This way you don't have to convert them yourself.
src="https://huggingface.co/datasets/ggerganov/whisper.cpp-coreml"
pfx="resolve/main/ggml"
# get the path of this script
function get_script_path() {
if [ -x "$(command -v realpath)" ]; then
echo "$(dirname $(realpath $0))"
else
local ret="$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P)"
echo "$ret"
fi
}
models_path="$(get_script_path)"
# Whisper models
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
# list available models
function list_models {
printf "\n"
printf " Available models:"
for model in "${models[@]}"; do
printf " $model"
done
printf "\n\n"
}
if [ "$#" -ne 1 ]; then
printf "Usage: $0 <model>\n"
list_models
exit 1
fi
model=$1
if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
printf "Invalid model: $model\n"
list_models
exit 1
fi
# download Core ML model
printf "Downloading Core ML model $model from '$src' ...\n"
cd $models_path
if [ -f "ggml-$model.mlmodel" ]; then
printf "Model $model already exists. Skipping download.\n"
exit 0
fi
if [ -x "$(command -v wget)" ]; then
wget --quiet --show-progress -O ggml-$model.mlmodel $src/$pfx-$model.mlmodel
elif [ -x "$(command -v curl)" ]; then
curl -L --output ggml-$model.mlmodel $src/$pfx-$model.mlmodel
else
printf "Either wget or curl is required to download models.\n"
exit 1
fi
if [ $? -ne 0 ]; then
printf "Failed to download Core ML model $model \n"
printf "Please try again later or download the original Whisper model files and convert them yourself.\n"
exit 1
fi
printf "Done! Model '$model' saved in 'models/ggml-$model.mlmodel'\n"
printf "Run the following command to compile it:\n\n"
printf " $ xcrun coremlc compile ./models/ggml-$model.mlmodel ./models\n\n"
printf "You can now use it like this:\n\n"
printf " $ ./main -m models/ggml-$model.bin -f samples/jfk.wav\n"
printf "\n"

View File

@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
goto :eof goto :eof
) )
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin" PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 ( if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model% echo Failed to download ggml model %model%

View File

@ -6,13 +6,13 @@
#src="https://ggml.ggerganov.com" #src="https://ggml.ggerganov.com"
#pfx="ggml-model-whisper" #pfx="ggml-model-whisper"
src="https://huggingface.co/ggerganov/whisper.cpp" src="https://huggingface.co/datasets/ggerganov/whisper.cpp"
pfx="resolve/main/ggml" pfx="resolve/main/ggml"
# get the path of this script # get the path of this script
function get_script_path() { function get_script_path() {
if [ -x "$(command -v realpath)" ]; then if [ -x "$(command -v realpath)" ]; then
echo "$(dirname "$(realpath "$0")")" echo "$(dirname $(realpath $0))"
else else
local ret="$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P)" local ret="$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P)"
echo "$ret" echo "$ret"

View File

@ -1,29 +0,0 @@
#!/bin/bash
#
# This generates:
# - coreml/whisper-encoder-impl.h and coreml/whisper-encoder-impl.m
# - coreml/whisper-decoder-impl.h and coreml/whisper-decoder-impl.m
#
wd=$(dirname "$0")
cd "$wd/../"
python3 models/convert-whisper-to-coreml.py --model tiny.en
mv -v models/coreml-encoder-tiny.en.mlpackage models/whisper-encoder-impl.mlpackage
xcrun coremlc generate models/whisper-encoder-impl.mlpackage coreml/
mv coreml/whisper_encoder_impl.h coreml/whisper-encoder-impl.h
mv coreml/whisper_encoder_impl.m coreml/whisper-encoder-impl.m
sed -i '' 's/whisper_encoder_impl\.h/whisper-encoder-impl.h/g' coreml/whisper-encoder-impl.m
sed -i '' 's/whisper_encoder_impl\.m/whisper-encoder-impl.m/g' coreml/whisper-encoder-impl.m
sed -i '' 's/whisper_encoder_impl\.h/whisper-encoder-impl.h/g' coreml/whisper-encoder-impl.h
mv -v models/coreml-decoder-tiny.en.mlpackage models/whisper-decoder-impl.mlpackage
xcrun coremlc generate models/whisper-decoder-impl.mlpackage coreml/
mv coreml/whisper_decoder_impl.h coreml/whisper-decoder-impl.h
mv coreml/whisper_decoder_impl.m coreml/whisper-decoder-impl.m
sed -i '' 's/whisper_decoder_impl\.h/whisper-decoder-impl.h/g' coreml/whisper-decoder-impl.m
sed -i '' 's/whisper_decoder_impl\.m/whisper-decoder-impl.m/g' coreml/whisper-decoder-impl.m
sed -i '' 's/whisper_decoder_impl\.h/whisper-decoder-impl.h/g' coreml/whisper-decoder-impl.h
rm -rfv models/whisper-encoder-impl.mlpackage models/whisper-decoder-impl.mlpackage

View File

@ -1,25 +0,0 @@
#!/bin/bash
# Usage: ./generate-coreml-model.sh <model-name>
if [ $# -eq 0 ]
then
echo "No model name supplied"
echo "Usage: ./generate-coreml-model.sh <model-name>"
exit 1
fi
mname="$1"
wd=$(dirname "$0")
cd "$wd/../"
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
rm -rf models/ggml-${mname}-encoder.mlmodelc
mv -v models/coreml-encoder-${mname}.mlmodelc models/ggml-${mname}-encoder.mlmodelc
# TODO: decoder (sometime in the future maybe)
#xcrun coremlc compile models/whisper-decoder-${mname}.mlpackage models/
#rm -rf models/ggml-${mname}-decoder.mlmodelc
#mv -v models/coreml_decoder_${mname}.mlmodelc models/ggml-${mname}-decoder.mlmodelc

File diff suppressed because it is too large Load Diff

199
whisper.h
View File

@ -66,7 +66,6 @@ extern "C" {
// //
struct whisper_context; struct whisper_context;
struct whisper_state;
typedef int whisper_token; typedef int whisper_token;
@ -102,20 +101,11 @@ extern "C" {
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
// These are the same as the above, but the internal state of the context is not allocated automatically // Frees all memory allocated by the model.
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) WHISPER_API void whisper_free(struct whisper_context * ctx);
WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
// Frees all allocated memory
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
// Convert RAW PCM audio to log mel spectrogram. // Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context. // The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_pcm_to_mel( WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx, struct whisper_context * ctx,
@ -123,30 +113,17 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
WHISPER_API int whisper_pcm_to_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the default state of the provided whisper context. // The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder( WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context * ctx, struct whisper_context* ctx,
const float * samples, const float* samples,
int n_samples, int n_samples,
int n_threads); int n_threads);
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80 // n_mel must be 80
// Returns 0 on success // Returns 0 on success
@ -156,14 +133,7 @@ extern "C" {
int n_len, int n_len,
int n_mel); int n_mel);
WHISPER_API int whisper_set_mel_with_state( // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
struct whisper_context * ctx,
struct whisper_state * state,
const float * data,
int n_len,
int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram. // offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success // Returns 0 on success
@ -172,12 +142,6 @@ extern "C" {
int offset, int offset,
int n_threads); int n_threads);
WHISPER_API int whisper_encode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset,
int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token. // Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first. // Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder. // tokens + n_tokens is the provided context for the decoder.
@ -191,14 +155,6 @@ extern "C" {
int n_past, int n_past,
int n_threads); int n_threads);
WHISPER_API int whisper_decode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens. // Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens. // The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens // Returns the number of tokens on success, no more than n_max_tokens
@ -226,7 +182,7 @@ extern "C" {
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure // Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages // If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whisper_lang_max_id() + 1 in size // The array must be whispe_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect( WHISPER_API int whisper_lang_auto_detect(
struct whisper_context * ctx, struct whisper_context * ctx,
@ -234,44 +190,20 @@ extern "C" {
int n_threads, int n_threads,
float * lang_probs); float * lang_probs);
WHISPER_API int whisper_lang_auto_detect_with_state( WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
struct whisper_context * ctx, WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
struct whisper_state * state, WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
int offset_ms, WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
int n_threads, WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx);
WHISPER_API int whisper_model_f16 (struct whisper_context * ctx);
WHISPER_API int whisper_model_type (struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode() // Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row // The logits for the last token are stored in the last row
// Rows: n_tokens // Rows: n_tokens
// Cols: n_vocab // Cols: n_vocab
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx); WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
// Special tokens // Special tokens
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
@ -286,7 +218,7 @@ extern "C" {
WHISPER_API whisper_token whisper_token_translate (void); WHISPER_API whisper_token whisper_token_translate (void);
WHISPER_API whisper_token whisper_token_transcribe(void); WHISPER_API whisper_token whisper_token_transcribe(void);
// Performance information from the default state. // Performance information
WHISPER_API void whisper_print_timings(struct whisper_context * ctx); WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
@ -297,33 +229,19 @@ extern "C" {
// Available sampling strategies // Available sampling strategies
enum whisper_sampling_strategy { enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreedyDecoder WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
}; };
// Text segment callback // Text segment callback
// Called on every newly generated text segment // Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
// Progress callback
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
// Encoder begin callback // Encoder begin callback
// If not NULL, called before the encoder starts // If not NULL, called before the encoder starts
// If it returns false, the computation is aborted // If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function // Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
@ -359,7 +277,6 @@ extern "C" {
// tokens to provide to the whisper decoder as initial prompt // tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call // these are prepended to any existing text context from a previous call
const char * initial_prompt;
const whisper_token * prompt_tokens; const whisper_token * prompt_tokens;
int prompt_n_tokens; int prompt_n_tokens;
@ -395,23 +312,14 @@ extern "C" {
whisper_new_segment_callback new_segment_callback; whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data; void * new_segment_callback_user_data;
// called on each progress update
whisper_progress_callback progress_callback;
void * progress_callback_user_data;
// called each time before the encoder starts // called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback; whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data; void * encoder_begin_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
}; };
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Not thread safe for same context
// Uses the specified decoding strategy to obtain the text. // Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full( WHISPER_API int whisper_full(
struct whisper_context * ctx, struct whisper_context * ctx,
@ -419,16 +327,7 @@ extern "C" {
const float * samples, const float * samples,
int n_samples); int n_samples);
WHISPER_API int whisper_full_with_state( // Split the input audio in chunks and process each chunk separately using whisper_full()
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases. // It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel( WHISPER_API int whisper_full_parallel(
@ -438,56 +337,44 @@ extern "C" {
int n_samples, int n_samples,
int n_processors); int n_processors);
// Number of generated text segments // Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph. // A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
// Language id associated with the context's default state // Language id associated with the current context
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Language id associated with the provided state // Get the start and end time of the specified segment.
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
// Get the start and end time of the specified segment // Get the text of the specified segment.
WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); // Get number of tokens in the specified segment.
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
// Get the text of the specified segment // Get the token text of the specified token in the specified segment.
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
// Get number of tokens in the specified segment // Get token data for the specified token in the specified segment.
WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment);
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
// Get the token text of the specified token in the specified segment
WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment
// This contains probabilities, timestamps, etc. // This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment // Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Temporary helpers needed for exposing ggml interface // Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
// Temporary experimental API
WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx);
#ifdef __cplusplus #ifdef __cplusplus
} }