Compare commits

..

1 Commits

Author SHA1 Message Date
210a6fb83c wip : some unsuccessful experiments using DP 2022-11-01 21:28:30 +02:00
18 changed files with 729 additions and 1271 deletions

3
.gitmodules vendored
View File

@ -1,3 +0,0 @@
[submodule "bindings/ios"]
path = bindings/ios
url = https://github.com/ggerganov/whisper.spm

View File

@ -9,11 +9,6 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON)
include(cmake/GitVars.cmake)
include(cmake/BuildTypes.cmake)
# configure project version
if (EXISTS "${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl")
configure_file(${CMAKE_SOURCE_DIR}/bindings/ios/Makefile-tmpl ${CMAKE_SOURCE_DIR}/bindings/ios/Makefile @ONLY)
endif()
else()
set(WHISPER_STANDALONE OFF)
endif()
@ -52,7 +47,7 @@ else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
option(WHISPER_PERF "whisper: enable perf timings" OFF)
option(WHISPER_PERF "whisper: enable perf timings" OFF)
# sanitizers
@ -94,17 +89,6 @@ if (APPLE AND NOT WHISPER_NO_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK})
endif()
if (WHISPER_SUPPORT_OPENBLAS)
@ -167,10 +151,6 @@ else()
endif()
endif()
if (WHISPER_PERF)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
endif()
#
# whisper - this is the main library of the project
#
@ -179,7 +159,6 @@ set(TARGET whisper)
add_library(${TARGET}
ggml.c
ggml-mtl.m
whisper.cpp
)

View File

@ -1,14 +1,6 @@
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
ifndef UNAME_P
UNAME_P := $(shell uname -p)
endif
ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
@ -16,8 +8,8 @@ ifeq ($(UNAME_S),Darwin)
ifneq ($(UNAME_P),arm)
SYSCTL_M := $(shell sysctl -n hw.optional.arm64)
ifeq ($(SYSCTL_M),1)
# UNAME_P := arm
# UNAME_M := arm64
UNAME_P := arm
UNAME_M := arm64
warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
endif
endif
@ -55,11 +47,11 @@ endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifndef WHISPER_NO_ACCELERATE
ifneq ($(filter arm%,$(UNAME_M)),)
# Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin)
CFLAGS += -DGGML_USE_ACCELERATE -DGGML_PERF
LDFLAGS += -framework Foundation -framework Accelerate -framework Metal -framework MetalKit -framework MetalPerformanceShaders
CFLAGS += -DGGML_USE_ACCELERATE
LDFLAGS += -framework Accelerate
endif
endif
ifneq ($(filter aarch64%,$(UNAME_M)),)
@ -81,21 +73,18 @@ endif
# Build library + main
#
main: examples/main/main.cpp ggml.o ggml-mtl.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp whisper.o ggml.o ggml-mtl.o -o main $(LDFLAGS)
main: examples/main/main.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp whisper.o ggml.o -o main $(LDFLAGS)
./main -h
ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
ggml-mtl.o: ggml-mtl.m ggml-mtl.h
$(CC) $(CFLAGS) -c ggml-mtl.m -o ggml-mtl.o
$(CC) $(CFLAGS) -c ggml.c
whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
$(CXX) $(CXXFLAGS) -c whisper.cpp
libwhisper.a: ggml.o ggml-mtl.o whisper.o
$(AR) rcs libwhisper.a ggml.o ggml-mtl.o whisper.o
libwhisper.a: ggml.o whisper.o
ar rcs libwhisper.a ggml.o whisper.o
clean:
rm -f *.o main stream bench libwhisper.a

190
README.md
View File

@ -26,41 +26,14 @@ Supported platforms:
The entire implementation of the model is contained in 2 source files:
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
- [ggml.h](ggml.h) / [ggml.c](ggml.c)
- [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device:
https://user-images.githubusercontent.com/1991296/197385372-962a6dea-bca1-4d50-bf96-1d8c27b98c81.mp4
## Implementation details
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
- The transformer model and the high-level C-style API are implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Limitations
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Quick start
First, download one of the Whisper models converted in [ggml format](models). For example:
@ -86,8 +59,8 @@ For a quick demo, simply run `make base.en`:
```java
$ make base.en
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp -o whisper.o
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp
c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o ggml.o -o main -framework Accelerate
./main -h
@ -97,18 +70,13 @@ options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
@ -118,7 +86,7 @@ options:
bash ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ...
ggml-base.en.bin 100%[========================>] 141.11M 6.34MB/s in 24s
ggml-base.en.bin 100%[========================>] 141.11M 6.34MB/s in 24s
Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
You can now use it like this:
@ -146,26 +114,23 @@ whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 2
whisper_model_load: mem_required = 670.00 MB
whisper_model_load: mem_required = 505.00 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: ggml ctx size = 140.60 MB
whisper_model_load: ggml ctx size = 163.43 MB
whisper_model_load: memory size = 22.83 MB
whisper_model_load: model size = 140.54 MB
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, lang = en, task = transcribe, timestamps = 1 ...
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00.000 --> 00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: load time = 105.91 ms
whisper_print_timings: mel time = 24.62 ms
whisper_print_timings: sample time = 3.63 ms
whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
whisper_print_timings: total time = 542.81 ms
whisper_print_timings: load time = 87.21 ms
whisper_print_timings: mel time = 24.26 ms
whisper_print_timings: sample time = 3.87 ms
whisper_print_timings: encode time = 323.67 ms / 53.94 ms per layer
whisper_print_timings: decode time = 83.25 ms / 13.87 ms per layer
whisper_print_timings: total time = 522.66 ms
```
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@ -207,8 +172,8 @@ make large
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| tiny | 75 MB | ~280 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~430 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
@ -220,7 +185,7 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
<details>
<summary>Expand to see the result</summary>
```java
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
@ -308,108 +273,32 @@ to highlight words with high or low confidence:
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
## Controlling the length of the generated text segments (experimental)
## Implementation details
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
- The high-level C-style API is implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
## Limitations
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
[00:00:00.000 --> 00:00:00.850] And so my
[00:00:00.850 --> 00:00:01.590] fellow
[00:00:01.590 --> 00:00:04.140] Americans, ask
[00:00:04.140 --> 00:00:05.660] not what your
[00:00:05.660 --> 00:00:06.840] country can do
[00:00:06.840 --> 00:00:08.430] for you, ask
[00:00:08.430 --> 00:00:09.440] what you can do
[00:00:09.440 --> 00:00:10.020] for your
[00:00:10.020 --> 00:00:11.000] country.
```
```
whisper --best_of None --beam_size None ...
```
## Word-level timestamp
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:00.320]
[00:00:00.320 --> 00:00:00.370] And
[00:00:00.370 --> 00:00:00.690] so
[00:00:00.690 --> 00:00:00.850] my
[00:00:00.850 --> 00:00:01.590] fellow
[00:00:01.590 --> 00:00:02.850] Americans
[00:00:02.850 --> 00:00:03.300] ,
[00:00:03.300 --> 00:00:04.140] ask
[00:00:04.140 --> 00:00:04.990] not
[00:00:04.990 --> 00:00:05.410] what
[00:00:05.410 --> 00:00:05.660] your
[00:00:05.660 --> 00:00:06.260] country
[00:00:06.260 --> 00:00:06.600] can
[00:00:06.600 --> 00:00:06.840] do
[00:00:06.840 --> 00:00:07.010] for
[00:00:07.010 --> 00:00:08.170] you
[00:00:08.170 --> 00:00:08.190] ,
[00:00:08.190 --> 00:00:08.430] ask
[00:00:08.430 --> 00:00:08.910] what
[00:00:08.910 --> 00:00:09.040] you
[00:00:09.040 --> 00:00:09.320] can
[00:00:09.320 --> 00:00:09.440] do
[00:00:09.440 --> 00:00:09.760] for
[00:00:09.760 --> 00:00:10.020] your
[00:00:10.020 --> 00:00:10.510] country
[00:00:10.510 --> 00:00:11.000] .
```
## Karaoke-style movie generation (experimental)
The [main](examples/main) example provides support for output of karaoke-style movies, where the
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed.
Here are a few *"typical"* examples:
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
source ./samples/jfk.wav.wts
ffplay ./samples/jfk.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337465-dbee4b5e-9aeb-48a3-b1c6-323ac4db5b2c.mp4
---
```java
./main -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
source ./samples/mm0.wav.wts
ffplay ./samples/mm0.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337504-cc8fd233-0cb7-4920-95f9-4227de3570aa.mp4
---
```java
./main -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
source ./samples/gb0.wav.wts
ffplay ./samples/gb0.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a0cd-f28a317987ba.mp4
---
In the future, `whisper.cpp` will support more sampling strategies.
## Benchmarks
@ -437,12 +326,9 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model
## Bindings
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm)
- [ ] Python:
- [ ] Java:
## Examples
There are various examples of using the library for different projects in the [examples](examples) folder. Check them out!
## [Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126)

Submodule bindings/ios deleted from 4bda8e9d80

View File

@ -1,49 +0,0 @@
#!/bin/bash
executable="./main"
model="base.en"
model_path="models/ggml-$model.bin"
# require sox and ffmpeg to be installed
if ! command -v sox &> /dev/null
then
echo "sox could not be found"
exit 1
fi
if ! command -v ffmpeg &> /dev/null
then
echo "ffmpeg could not be found"
exit 2
fi
if [ ! -f "$executable" ]; then
echo "'$executable' does not exist. Please build it first."
exit 3
fi
if [ ! -f "$model_path" ]; then
echo "'$model_path' does not exist. Please download it first."
exit 4
fi
# record some raw audio
sox -d rec.wav
# resample to 16kHz
ffmpeg -y -i ./rec.wav -ar 16000 -ac 1 -c:a pcm_s16le ./rec16.wav > /dev/null 2>&1
# run Whisper
echo "Processing ..."
./main -m models/ggml-base.en.bin rec16.wav -owts > /dev/null 2>&1
# generate Karaoke video
echo "Generating video ..."
source rec16.wav.wts > /dev/null 2>&1
# play the video
echo "Playing ./rec16.wav.mp4 ..."
ffplay -loglevel 0 -autoexit ./rec16.wav.mp4
echo "Done"
exit 0

View File

@ -6,29 +6,21 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
```
./main -h
usage: ./bin/main [options] file0.wav file1.wav ...
usage: ./main [options] file0.wav file1.wav ...
options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-o N, --offset N offset in milliseconds (default: 0)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
-l LANG, --language LANG spoken language (default: en)
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
-f FNAME, --file FNAME input WAV file path
-h, --help show this help message and exit
```

View File

@ -36,7 +36,6 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf);
}
// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
pos = s.find(search, pos);
@ -46,6 +45,29 @@ void replace_all(std::string & s, const std::string & search, const std::string
}
}
// a cost-function that is high for text that takes longer to pronounce
float voice_length(const std::string & text) {
float res = 0.0f;
// letters - add 1
// digits - add 3
// else - 0
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] >= '0' && text[i] <= '9') {
res += 3.0f;
} else if (text[i] >= 'a' && text[i] <= 'z') {
res += 1.0f;
} else if (text[i] >= 'A' && text[i] <= 'Z') {
res += 1.0f;
} else {
res += 0.01f;
}
// TODO: support unicode
}
return res;
}
// command-line parameters
struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
@ -53,11 +75,9 @@ struct whisper_params {
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
float word_thold = 0.01f;
float word_thold = 0.1f;
bool verbose = false;
bool translate = false;
@ -96,12 +116,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.offset_t_ms = std::stoi(argv[++i]);
} else if (arg == "-on" || arg == "--offset-n") {
params.offset_n = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--duration") {
params.duration_ms = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]);
} else if (arg == "-ml" || arg == "--max-len") {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
@ -157,16 +173,14 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -pc, --print_colors print colors\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
@ -176,67 +190,65 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "\n");
}
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx);
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
// print the last segment
const int i = n_segments - 1;
if (i == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
}
@ -306,41 +318,566 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
struct Interval {
int x0;
int x1;
int type;
};
struct IntervalArray : public std::vector<Interval> {
int F = -1;
};
std::vector<IntervalArray> fit_text_to_audio(const IntervalArray & input, int N, float alpha = 2.0f) {
const int x_max = input.back().x1;
std::vector<int> ls;
std::vector<int> rs;
std::vector<int> xs;
std::vector<int> gs;
int G_max = 0;
for (const auto & ii : input) {
if (ii.type == 0) {
continue;
}
ls.push_back(ii.x0);
rs.push_back(ii.x1);
xs.push_back(ii.x0);
xs.push_back(ii.x1);
gs.push_back(G_max);
G_max += ii.x1 - ii.x0;
gs.push_back(G_max);
}
const int inf = 100*G_max;
struct Cell {
int fval = -1;
int xprev = -1;
int w = -1;
};
// Function F + initial conditions
std::vector<std::vector<Cell>> F(xs.size());
for (auto & Fx : F) {
Fx.resize(N + 1);
for (auto & f : Fx) f.fval = inf;
Fx[0].fval = alpha*G_max;
}
// DP core
for (int n = 1; n <= N; ++n) {
for (int ix = 0; ix < (int) xs.size(); ++ix) {
const int x = xs[ix];
int best_fval = inf;
int best_xprev = -1;
int best_w = -1;
for (int il = 0; il < (int) ls.size(); ++il) {
const int l = ls[il];
if (l < n) continue;
if (l >= x) break;
for (int ir = il; ir < (int) rs.size(); ++ir) {
const int r = rs[ir];
if (r < l + 1) continue;
if (r > x) break;
const int cur_fval = F[2*il][n - 1].fval + (r - l) - (alpha + 1)*(gs[2*ir + 1] - gs[2*il]);
if (cur_fval < best_fval) {
best_fval = cur_fval;
best_xprev = 2*il;
best_w = r - l;
}
}
}
F[ix][n].fval = best_fval;
F[ix][n].xprev = best_xprev;
F[ix][n].w = best_w;
}
}
// generate output
std::vector<IntervalArray> res(N + 1);
{
for (int i = 1; i <= N; ++i) {
IntervalArray resCur;
int n = i;
std::vector<int> grid(x_max + 1, 0); // initally, everything is background
int best_ix = 0;
int best_fval = F[0][n].fval;
for (int ix = 1; ix < (int) xs.size(); ++ix) {
if (F[ix][n].fval < best_fval) {
best_ix = ix;
best_fval = F[ix][n].fval;
}
}
resCur.F = (i == input.size()/2) ? 0 : best_fval;
while (true) {
const int ix = F[best_ix][n].xprev;
const int w = F[best_ix][n].w;
for (int x = xs[ix]; x < xs[ix] + w; ++x) {
grid[x] = 1; // i.e. green
}
best_ix = F[best_ix][n].xprev;
if (--n == 0) break;
}
int x0 = 0;
int type = grid[0];
for (int x1 = 1; x1 <= x_max; ++x1) {
if (grid[x1] != grid[x1 - 1] || x1 == x_max) {
if (type == 1) {
resCur.push_back({x0, x1, 1});
}
x0 = x1;
type = grid[x1];
}
}
res[i] = std::move(resCur);
}
}
return res;
}
// word-level timestamps (experimental)
// TODO: probably still has bugs, needs refactoring, etc..
// TODO: auto threshold
// TODO: extra pass to detect unused speech and assign to tokens
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
std::vector<float> pcm_avg(pcmf32.size(), 0);
// average the fabs of the signal
{
const int hw = 32;
for (int i = 0; i < pcmf32.size(); i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < pcmf32.size()) {
sum += fabs(pcmf32[i + j]);
}
}
pcm_avg[i] = sum/(2*hw + 1);
}
}
struct token_info {
int64_t t0 = -1;
int64_t t1 = -1;
int64_t tt0 = -1;
int64_t tt1 = -1;
whisper_token id;
whisper_token tid;
float p = 0.0f;
float pt = 0.0f;
float ptsum = 0.0f;
std::string text;
float vlen = 0.0f; // voice length of this token
void calc_vlen(struct whisper_context * ctx) {
if (id >= whisper_token_eot(ctx)) {
vlen = 0.1f;
return;
}
vlen = voice_length(text);
}
bool is_voice() const {
return vlen > 0.5f;
}
};
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last = 0;
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
// TODO: become parameter
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
fout << "#!/bin/bash" << "\n";
fout << "!/bin/bash" << "\n";
fout << "\n";
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcm_avg.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \"";
bool is_first = true;
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const char *text = whisper_full_get_segment_text(ctx, i);
const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
const int s1 = std::min((int) pcm_avg.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
const int n = whisper_full_n_tokens(ctx, i);
std::vector<whisper_token_data> tokens(n);
for (int j = 0; j < n; ++j) {
tokens[j] = whisper_full_get_token_data(ctx, i, j);
std::vector<token_info> tokens(n);
if (n <= 1) {
continue;
}
if (i > 0) {
for (int j = 0; j < n; ++j) {
struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].text = whisper_token_to_str(ctx, token.id);
tokens[j].calc_vlen(ctx);
if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
{
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
IntervalArray arr;
const int s0 = std::max(0, (int) (tokens[p0].t0*WHISPER_SAMPLE_RATE/100));
const int s1 = std::min((int) pcm_avg.size() - 1, (int) (tokens[p1].t1*WHISPER_SAMPLE_RATE/100));
const int ns = s1 - s0;
float sum = 0.0f;
for (int k = s0; k < s1; k++) {
sum += pcm_avg[k];
}
const float thold = sum/ns;
printf("segment %4d: s0 = %6d, s1 = %6d, ns = %6d, thold = %f\n", i, s0, s1, ns, thold);
{
int last_s = -1;
int last_type = -1;
for (int k = s0; k < s1; k++) {
const int type = pcm_avg[k] > thold ? 1 : 0;
if (type != last_type) {
if (last_type != -1) {
arr.push_back({ last_s, k, last_type });
}
last_s = k;
last_type = type;
}
}
}
//for (int k = 0; k < arr.size(); k++) {
// printf(" %4d: %6d, %6d, %d\n", k, arr[k].x0, arr[k].x1, arr[k].type);
//}
int n_voice = 0;
for (int j = p0; j <= p1; ++j) {
if (tokens[j].is_voice()) {
n_voice++;
}
}
if (n_voice > 0 && arr.size() > n_voice) {
printf("xxxxxxxx n = %d, n_voice = %d, arr.size() = %d\n", n, n_voice, (int) arr.size());
auto res = fit_text_to_audio(arr, n_voice, 2.0f);
printf("done fit_text_to_audio, F = %d\n", res[n_voice].F);
{
int tid = p0;
for (int k = 0; k < (int) res[n_voice].size(); ++k) {
while (!tokens[tid].is_voice() && tid <= p1) {
//if (tokens[tid].t0 < 0) {
// tokens[tid].t0 = (int64_t) (100*res[n_voice][k].x0/WHISPER_SAMPLE_RATE);
//}
//if (tokens[tid].t1 < 0) {
// tokens[tid].t1 = tokens[tid].t0;
//}
tid++;
}
if (tid > p1) {
break;
}
if (tokens[tid].t0 < 0) {
tokens[tid].t0 = (int64_t) (100*res[n_voice][k].x0/WHISPER_SAMPLE_RATE);
if (tid > 0) {
tokens[tid - 1].t1 = tokens[tid].t0;
}
}
if (tokens[tid].t1 < 0) {
tokens[tid].t1 = (int64_t) (100*res[n_voice][k].x1/WHISPER_SAMPLE_RATE);
}
tid++;
}
printf("xxxxxxxx n = %d, tid = %d\n", n, tid);
}
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
//const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1);
//const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1);
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
for (int j = 0; j < n - 1; j++) {
if (tokens[j + 1].t0 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
tokens[j].tt0 = tokens[j].t0;
tokens[j].tt1 = tokens[j].t1;
if (j < n - 2) {
tokens[j].tt1 = std::max(tokens[j].tt1, tokens[j + 1].t0);
}
}
// VAD
{
const int hw = WHISPER_SAMPLE_RATE; // take one second of audio around the token
for (int j = 0; j < n; j++) {
const int64_t t0 = tokens[j].t0;
const int64_t t1 = tokens[j].t1;
int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
int s1 = std::min((int) pcm_avg.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100));
const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
const int ss1 = std::min((int) pcm_avg.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
const int n = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += pcm_avg[k];
}
const float avg = sum/n;
const float thold = 0.5*avg;
{
int k = s0;
if (pcm_avg[k] > thold && j > 0) {
while (k > 0 && pcm_avg[k] > thold) {
k--;
}
tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (pcm_avg[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
}
}
{
int k = s1;
if (pcm_avg[k] > thold) {
while (k < (int) pcm_avg.size() - 1 && pcm_avg[k] > thold) {
k++;
}
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (pcm_avg[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
}
}
}
}
const int t_expand = 0;
for (int j = 0; j < n; j++) {
if (j > 0) {
tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
}
if (j < n - 1) {
tokens[j].t1 = tokens[j].t1 + t_expand;
}
}
}
{
const std::string fname_tokens = "tokens-" + std::to_string(i) + ".txt";
std::ofstream fout(fname_tokens);
int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
int s1 = std::min((int) pcm_avg.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100));
for (int j = s0; j < s1; j++) {
int k = -1;
for (int r = 0; r < n; r++) {
if (j >= (int) (tokens[r].t0*WHISPER_SAMPLE_RATE/100) && j < (int) (tokens[r].t1*WHISPER_SAMPLE_RATE/100)) {
k = r;
break;
}
}
fout << j << " " << pcm_avg[j] << " " << float(k%3 + 1)/30.0 << std::endl;
}
fout.close();
}
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
//printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
//fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
}
static const int line_wrap = 60;
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
if (!is_first) {
fout << ",";
}
// background text
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
bool is_first = true;
is_first = false;
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
@ -349,6 +886,10 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
continue;
}
//if (!tokens[j].is_voice()) {
// continue;
//}
std::string txt_bg;
std::string txt_fg; // highlight token
std::string txt_ul; // underline
@ -384,6 +925,17 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
}
ncnt += txt.size();
if (ncnt > line_wrap) {
if (k < j) {
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
ncnt = 0;
} else {
break;
}
}
}
::replace_all(txt_bg, "'", "");
@ -392,11 +944,8 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
::replace_all(txt_fg, "\"", "\\\"");
}
if (is_first) {
// background text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
is_first = false;
}
// background text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'";
// foreground text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
@ -536,11 +1085,6 @@ int main(int argc, char ** argv) {
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
// this callback is called on each new segment
if (!wparams.print_realtime) {
@ -579,7 +1123,7 @@ int main(int argc, char ** argv) {
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
}
}
}

View File

@ -78,14 +78,6 @@ There are a lot of ways to improve this idea and I don't have much experience wi
*"optimize by sorting the data first"*
The plugin would then make an appropriate query using the selected text and code context to Copilot or GPT-3 and return the result.
Here is a proof-of-concept:
https://user-images.githubusercontent.com/1991296/199078847-0278fcde-5667-4748-ba0d-7d55381d6047.mp4
https://user-images.githubusercontent.com/1991296/200067939-f98d2ac2-7519-438a-85f9-79db0841ba4f.mp4
For explanation how this works see: https://twitter.com/ggerganov/status/1587168771789258756
## Discussion

View File

@ -1,7 +0,0 @@
#!/bin/bash
# Compute the SHA1 of all model files in ./models/ggml-*.bin
for f in ./models/ggml-*.bin; do
shasum "$f" -a 1
done

View File

@ -1,38 +0,0 @@
#pragma once
#include <stdint.h>
#include <stddef.h>
// TODO: this will hold dynamic context data in the future
// currently unused
struct ggml_mtl_context {
void * dummy;
};
struct ggml_mtl_object {
int32_t id;
void * data;
};
struct ggml_mtl_context * ggml_mtl_init(void);
struct ggml_mtl_object ggml_mtl_alloc(size_t size);
// multiply matrix by vector
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // vector f16
float * dst, // vector f32
int nrows,
int ncols);
// multiply matrix by matrix
void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // matrix f16
float * dst, // matrix f32
int nrows0,
int nrows1,
int ncols);

View File

@ -1,162 +0,0 @@
#import "ggml-mtl.h"
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#define GGML_MTL_MAX_BUFFERS 256
// global static storage for Metal buffers
// TODO: move this into a dynamic context
static id<MTLBuffer> g_buffers[GGML_MTL_MAX_BUFFERS];
// global MTL context
// TODO: move this into a dynamic context
static id<MTLDevice> g_device;
static id<MTLCommandQueue> g_command_queue;
struct ggml_mtl_context * ggml_mtl_init() {
// TODO: implement properly
// for now, init the global MTL context and MTL buffers
g_device = MTLCreateSystemDefaultDevice();
g_command_queue = [g_device newCommandQueue];
if (g_command_queue == nil)
{
NSLog(@"Failed to find the command queue.");
return nil;
}
return nil;
}
// search for unallocated buffer slot and use it
struct ggml_mtl_object ggml_mtl_alloc(size_t size) {
// TODO: temporarily making sure that the buffers are nil at the start
static bool first = true;
if (first) {
for (int i = 0; i < GGML_MTL_MAX_BUFFERS; ++i) {
assert(g_buffers[i] == nil);
}
first = false;
}
struct ggml_mtl_object obj = { -1, nil };
for (int i = 0; i < GGML_MTL_MAX_BUFFERS; i++) {
if (g_buffers[i] == nil) {
g_buffers[i] = [g_device newBufferWithLength:size options:MTLResourceStorageModeManaged];
// lunk the MTL buffer to the ggml object
obj.id = i;
obj.data = [g_buffers[i] contents];
break;
}
}
return obj;
}
struct params_mul_mat_vec {
int N; // rows
int M; // cols
};
// multiply matrix with a vector using MPSMatrixVectorMultiplication
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows,
int ncols) {
(void) ctx; // unused
// Create a command buffer to hold commands.
id<MTLCommandBuffer> commandBuffer = [g_command_queue commandBuffer];
assert(commandBuffer != nil);
// make managed device buffer to store src1
id<MTLBuffer> src1_buffer = [g_device newBufferWithBytes:src1 length:ncols*sizeof(__fp16) options:MTLResourceStorageModeManaged];
id<MTLBuffer> dst_buffer = [g_device newBufferWithLength:nrows*sizeof(float) options:MTLResourceStorageModeManaged];
// MPSMatrixDescriptor
MPSMatrixDescriptor *src0_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSVectorDescriptor *src1_desc = [MPSVectorDescriptor vectorDescriptorWithLength:ncols dataType:MPSDataTypeFloat16];
MPSVectorDescriptor *dst_desc = [MPSVectorDescriptor vectorDescriptorWithLength:nrows dataType:MPSDataTypeFloat32];
// MPSMatrix
MPSMatrix *src0_mat = [[MPSMatrix alloc] initWithBuffer:g_buffers[src0.id] descriptor:src0_desc];
MPSVector *src1_vec = [[MPSVector alloc] initWithBuffer:src1_buffer descriptor:src1_desc];
MPSVector *dst_vec = [[MPSVector alloc] initWithBuffer:dst_buffer descriptor:dst_desc];
// MPSMatrixVectorMultiplication
MPSMatrixVectorMultiplication *mul_mat_vec = [[MPSMatrixVectorMultiplication alloc] initWithDevice:g_device transpose:NO rows:nrows columns:ncols alpha:1.0 beta:0.0];
// encode
[mul_mat_vec encodeToCommandBuffer:commandBuffer
inputMatrix:src0_mat
inputVector:src1_vec
resultVector:dst_vec];
[commandBuffer commit];
[commandBuffer waitUntilCompleted];
// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows*sizeof(float));
}
// multiply matrix with a matrix using MPSMatrixMultiplication
void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows0,
int nrows1,
int ncols) {
(void) ctx; // unused
// Create a command buffer to hold commands.
id<MTLCommandBuffer> commandBuffer = [g_command_queue commandBuffer];
assert(commandBuffer != nil);
// make managed device buffer to store src1
id<MTLBuffer> src1_buffer = [g_device newBufferWithBytes:src1 length:ncols*nrows1*sizeof(__fp16) options:MTLResourceStorageModeManaged];
id<MTLBuffer> dst_buffer = [g_device newBufferWithLength:nrows0*nrows1*sizeof(float) options:MTLResourceStorageModeManaged];
// MPSMatrixDescriptor
MPSMatrixDescriptor *src0_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows0 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *src1_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *dst_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:nrows0 rowBytes:nrows0*sizeof(float) dataType:MPSDataTypeFloat32];
// MPSMatrix
MPSMatrix *src0_mat = [[MPSMatrix alloc] initWithBuffer:g_buffers[src0.id] descriptor:src0_desc];
MPSMatrix *src1_mat = [[MPSMatrix alloc] initWithBuffer:src1_buffer descriptor:src1_desc];
MPSMatrix *dst_mat = [[MPSMatrix alloc] initWithBuffer:dst_buffer descriptor:dst_desc];
//// MPSMatrixMultiplication z = x * yT
//MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows resultColumns:nrows interiorColumns:ncols alpha:1.0 beta:0.0];
//// encode
//[mul_mat encodeToCommandBuffer:commandBuffer
// leftMatrix:src0_mat
// rightMatrix:src1_mat
// resultMatrix:dst_mat];
// MPSMatrixMultiplication zT = xT * y
MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols alpha:1.0 beta:0.0];
// encode
[mul_mat encodeToCommandBuffer:commandBuffer
leftMatrix:src1_mat
rightMatrix:src0_mat
resultMatrix:dst_mat];
[commandBuffer commit];
[commandBuffer waitUntilCompleted];
// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows0*nrows1*sizeof(float));
}

156
ggml.c
View File

@ -1,7 +1,5 @@
#include "ggml.h"
#include "ggml-mtl.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__)
@ -16,7 +14,7 @@
#include <stdint.h>
#include <stdio.h>
#if defined _MSC_VER || defined(__MINGW32__)
#if defined _MSC_VER
#include <Windows.h>
typedef volatile LONG atomic_int;
@ -46,11 +44,6 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
static int pthread_join(pthread_t thread, void* unused) {
return (int) WaitForSingleObject(thread, INFINITE);
}
static int sched_yield (void) {
Sleep (0);
return 0;
}
#else
#include <pthread.h>
#include <stdatomic.h>
@ -200,7 +193,7 @@ static ggml_fp16_t table_exp_f16[1 << 16];
// timing
//
#if defined(_MSC_VER) || defined(__MINGW32__)
#if defined(_MSC_VER)
static int64_t timer_freq;
void ggml_time_init(void) {
LARGE_INTEGER frequency;
@ -1309,8 +1302,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
static bool first_time = true;
if (first_time) {
ggml_mtl_init(); // TODO: fix this
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
g_state.contexts[i].used = false;
}
@ -1466,104 +1457,6 @@ struct ggml_tensor * ggml_new_tensor_impl(
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
/*.data =*/ data == NULL ? (void *)(result + 1) : data,
/*.id =*/ -1,
/*.pad =*/ { 0 },
};
ggml_assert_aligned(result->data);
for (int i = 0; i < n_dims; i++) {
result->ne[i] = ne[i];
}
result->nb[0] = GGML_TYPE_SIZE[type];
for (int i = 1; i < GGML_MAX_DIMS; i++) {
result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
}
ctx->n_objects++;
return result;
}
struct ggml_tensor * ggml_new_tensor_mtl_impl(
struct ggml_context * ctx,
enum ggml_type type,
int n_dims,
const int* ne,
void* data) {
// always insert objects at the end of the context's memory pool
struct ggml_object * obj_cur = ctx->objects_end;
const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
const size_t cur_end = cur_offset + cur_size;
struct ggml_mtl_object obj_mtl;
{
assert(data == NULL); // TODO: in-place metal buffer, need page aligned memory
size_t size_needed_mtl = 0;
if (data == NULL) {
size_needed_mtl += GGML_TYPE_SIZE[type];
for (int i = 0; i < n_dims; i++) {
size_needed_mtl *= ne[i];
}
}
obj_mtl = ggml_mtl_alloc(size_needed_mtl);
}
size_t size_needed = 0;
size_needed += sizeof(struct ggml_tensor);
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
assert(false);
return NULL;
}
char * const mem_buffer = ctx->mem_buffer;
struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
*obj_new = (struct ggml_object) {
.offset = cur_end + GGML_OBJECT_SIZE,
.size = size_needed,
.next = NULL,
};
if (obj_cur != NULL) {
obj_cur->next = obj_new;
} else {
// this is the first object in this context
ctx->objects_begin = obj_new;
}
ctx->objects_end = obj_new;
//GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
ggml_assert_aligned(result);
*result = (struct ggml_tensor) {
/*.type =*/ type,
/*.n_dims =*/ n_dims,
/*.ne =*/ { 1, 1, 1, 1 },
/*.nb =*/ { 0, 0, 0, 0 },
/*.op =*/ GGML_OP_NONE,
/*.is_param =*/ false,
/*.grad =*/ NULL,
/*.src0 =*/ NULL,
/*.src1 =*/ NULL,
/*.opt =*/ { NULL },
/*.n_tasks =*/ 0,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
/*.data =*/ obj_mtl.data,
/*.id =*/ obj_mtl.id,
/*.pad =*/ { 0 },
};
@ -1591,14 +1484,6 @@ struct ggml_tensor * ggml_new_tensor(
return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
}
struct ggml_tensor * ggml_new_tensor_mtl(
struct ggml_context * ctx,
enum ggml_type type,
int n_dims,
const int* ne) {
return ggml_new_tensor_mtl_impl(ctx, type, n_dims, ne, NULL);
}
struct ggml_tensor * ggml_new_tensor_1d(
struct ggml_context * ctx,
enum ggml_type type,
@ -1615,15 +1500,6 @@ struct ggml_tensor * ggml_new_tensor_2d(
return ggml_new_tensor(ctx, type, 2, ne);
}
struct ggml_tensor * ggml_new_tensor_2d_mtl(
struct ggml_context * ctx,
enum ggml_type type,
int ne0,
int ne1) {
const int ne[2] = { ne0, ne1 };
return ggml_new_tensor_mtl(ctx, type, 2, ne);
}
struct ggml_tensor * ggml_new_tensor_3d(
struct ggml_context * ctx,
enum ggml_type type,
@ -3269,10 +3145,7 @@ void ggml_compute_forward_add_f32(
GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
const int j0 = (n/nth)*ith;
const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
for (int j = j0; j < j1; j++) {
for (int j = ith; j < n; j += nth) {
ggml_vec_add_f32(nc,
(float *) ((char *) dst->data + j*nb1),
(float *) ((char *) src0->data + j*nb01),
@ -4462,11 +4335,8 @@ void ggml_compute_forward_mul_mat_f16_f32(
// nb00 < nb01 - src0 is transposed
// compute by src0 columns
// are we using Metal?
const bool is_mtl = src0->id >= 0;
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst) && !is_mtl) {
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
GGML_ASSERT(nb10 == sizeof(float));
if (params->ith != 0) return;
@ -4594,20 +4464,6 @@ void ggml_compute_forward_mul_mat_f16_f32(
// parallelize by src0 rows using ggml_vec_dot_f32
if (is_mtl) {
assert(ne02 == 1);
assert(ne03 == 1);
if (params->ith == 0) {
printf("XXXXXXXXXXX src0->ne[0] = %d, src0->ne[1] = %d\n", src0->ne[0], src0->ne[1]);
printf("XXXXXXXXXXX src1->ne[0] = %d, src1->ne[1] = %d\n", src1->ne[0], src1->ne[1]);
struct ggml_mtl_object src0_mtl = { src0->id, src0->data };
ggml_fp16_t * src1_fp16 = params->wdata;
ggml_mtl_mul_mat_f16(NULL, src0_mtl, src1_fp16, dst->data, ne01, ne11, ne00);
}
return;
}
// total rows in src0
const int nr = ne01*ne02*ne03;
@ -6996,7 +6852,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_ADD:
{
node->n_tasks = n_threads;
node->n_tasks = 1;
} break;
case GGML_OP_SUB:
case GGML_OP_MUL:
@ -8228,7 +8084,7 @@ int ggml_cpu_has_avx512(void) {
}
int ggml_cpu_has_neon(void) {
#if defined(__ARM_NEON)
#if defined(__ARM_NEON__)
return 1;
#else
return 0;

9
ggml.h
View File

@ -108,8 +108,7 @@ struct ggml_tensor {
int64_t perf_time_us;
void * data;
int32_t id; // TODO: mtl buffer id
char pad[4];
char padding[8];
};
// computation graph
@ -174,12 +173,6 @@ struct ggml_tensor * ggml_new_tensor_2d(
int ne0,
int ne1);
struct ggml_tensor * ggml_new_tensor_2d_mtl(
struct ggml_context * ctx,
enum ggml_type type,
int ne0,
int ne1);
struct ggml_tensor * ggml_new_tensor_3d(
struct ggml_context * ctx,
enum ggml_type type,

View File

@ -22,20 +22,6 @@ A third option to obtain the model files is to download them from Hugging Face:
https://huggingface.co/datasets/ggerganov/whisper.cpp/tree/main
## Available models
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| tiny.en | 75 MB | ~390 MB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| base.en | 142 MB | ~500 MB | `137c40403d78fd54d454da0f9bd998f78703390c` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| small.en | 466 MB | ~1.0 GB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| medium.en | 1.5 GB | ~2.6 GB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
| large | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
## Model files for testing purposes
The model files pefixed with `for-tests-` are empty (i.e. do not contain any weights) and are used by the CI for testing purposes.

View File

@ -1,63 +0,0 @@
@echo off
pushd %~dp0
set models_path=%CD%
popd
set argc=0
for %%x in (%*) do set /A argc+=1
set models=tiny.en tiny base.en base small.en small medium.en medium large
if %argc% neq 1 (
echo.
echo Usage: download-ggml-model.cmd model
CALL :list_models
goto :eof
)
set model=%1
for %%b in (%models%) do (
if "%%b"=="%model%" (
CALL :download_model
goto :eof
)
)
echo Invalid model: %model%
CALL :list_models
goto :eof
:download_model
echo Downloading ggml model %model%...
cd %models_path%
if exist "ggml-%model%.bin" (
echo Model %model% already exists. Skipping download.
goto :eof
)
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://ggml.ggerganov.com/ggml-model-whisper-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model%
echo Please try again later or download the original Whisper model files and convert them yourself.
goto :eof
)
echo Done! Model %model% saved in %models_path%\models\ggml-%model%.bin
echo You can now use it like this:
echo main.exe -m %models_path%\models\ggml-%model%.bin -f %models_path%\samples\jfk.wav
goto :eof
:list_models
echo.
echo Available models:
(for %%a in (%models%) do (
echo %%a
))
echo.
exit /b

View File

@ -133,19 +133,11 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
{ MODEL_TINY, 86ull*MB },
{ MODEL_BASE, 165ull*MB },
{ MODEL_SMALL, 540ull*MB },
{ MODEL_MEDIUM, 1650ull*MB },
{ MODEL_LARGE, 3260ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -418,12 +410,6 @@ struct whisper_context {
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
};
// load the model from a ggml file
@ -437,7 +423,7 @@ struct whisper_context {
//
// see the convert-pt-to-ggml.py script for details
//
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
auto & model = wctx.model;
@ -512,7 +498,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
@ -736,6 +722,20 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
auto & ctx = model.ctx;
@ -788,10 +788,10 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.mlp_0_w = ggml_new_tensor_2d_mtl(ctx, wtype, n_audio_state, 4*n_audio_state); // offload to GPU
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
layer.mlp_1_w = ggml_new_tensor_2d_mtl(ctx, wtype, 4*n_audio_state, n_audio_state); // offload to GPU
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@ -932,20 +932,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// key + value memory
{
auto & ctx = model.ctx_mem;
@ -1068,7 +1054,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
static bool whisper_encode(
bool whisper_encode(
whisper_context & wctx,
const int n_threads,
const int mel_offset) {
@ -1342,7 +1328,7 @@ static bool whisper_encode(
ggml_build_forward_expand(&gf, inpO);
ggml_graph_compute (ctxL, &gf);
ggml_graph_print(&gf);
//ggml_graph_print(&gf);
}
// TODO: this is a hack to have per-layer computation graphs - need to come up with something better
@ -1454,7 +1440,7 @@ static bool whisper_encode(
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
static bool whisper_decode(
bool whisper_decode(
whisper_context & wctx,
const int n_threads,
const whisper_token * tokens,
@ -1817,12 +1803,10 @@ static bool whisper_decode(
}
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
whisper_token_data result;
int n_logits = vocab.id_to_token.size();
@ -1895,7 +1879,7 @@ static whisper_token_data whisper_sample_best(
}
// samples only from the timestamps tokens
static whisper_vocab::id whisper_sample_timestamp(
whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
@ -1947,7 +1931,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const std::vector<float> & in, std::vector<float> & out) {
void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
@ -1971,7 +1955,7 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(const std::vector<float> & in, std::vector<float> & out) {
void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
@ -2022,7 +2006,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
static bool log_mel_spectrogram(
bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int sample_rate,
@ -2339,7 +2323,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
@ -2348,11 +2331,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2377,7 +2355,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
@ -2386,11 +2363,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2412,68 +2384,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
return result;
}
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum);
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
auto segment = ctx->result_all.back();
int res = 1;
int acc = 0;
std::string text;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(ctx)) {
continue;
}
const auto txt = whisper_token_to_str(ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) {
// split here
ctx->result_all.back().text = std::move(text);
ctx->result_all.back().t1 = token.t0;
ctx->result_all.back().tokens.resize(i);
ctx->result_all.push_back({});
ctx->result_all.back().t0 = token.t0;
ctx->result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
ctx->result_all.back().tokens.insert(
ctx->result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
acc = 0;
text = "";
segment = ctx->result_all.back();
i = -1;
res++;
} else {
acc += cur;
text += txt;
}
}
ctx->result_all.back().text = std::move(text);
return res;
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
@ -2490,20 +2400,12 @@ int whisper_full(
return -1;
}
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
ctx->tid_last = 0;
ctx->energy = get_signal_energy(samples, n_samples, 32);
}
const int seek_start = params.offset_ms/10;
const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
// if length of spectrogram is less than 1s (100 samples), then return
// basically don't process anything that is less than 1s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < 100 + seek_start) {
if (whisper_n_len(ctx) < 100 + seek_start) {
return 0;
}
@ -2536,7 +2438,7 @@ int whisper_full(
// main loop
int seek = seek_start;
while (true) {
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
int progress_cur = (100*seek)/whisper_n_len(ctx);
while (progress_cur >= progress_prev + progress_step) {
progress_prev += progress_step;
if (params.print_progress) {
@ -2544,7 +2446,7 @@ int whisper_full(
}
}
if (seek + 100 >= seek_end) {
if (seek + 100 >= whisper_n_len(ctx)) {
break;
}
@ -2625,7 +2527,7 @@ int whisper_full(
// end of text token
if (token.id == whisper_token_eot(ctx)) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= seek_end) {
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
result_len = i + 1;
} else {
// TODO: figure out how to resolve this
@ -2647,7 +2549,6 @@ int whisper_full(
}
}
// shrink down to result_len
tokens_cur.resize(result_len);
for (const auto & r : tokens_cur) {
@ -2686,19 +2587,8 @@ int whisper_full(
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
}
text = "";
@ -2727,19 +2617,8 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
}
}
@ -2873,7 +2752,7 @@ int whisper_full_parallel(
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
}
@ -2949,304 +2828,3 @@ const char * whisper_print_system_info() {
return s.c_str();
}
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
static int64_t sample_to_timestamp(int i_sample) {
return (100*i_sample)/WHISPER_SAMPLE_RATE;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// average the fabs of the signal
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
const int hw = n_samples_per_half_window;
std::vector<float> result(n_samples);
for (int i = 0; i < n_samples; i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < n_samples) {
sum += fabs(signal[i + j]);
}
}
result[i] = sum/(2*hw + 1);
}
return result;
}
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum) {
auto & segment = ctx->result_all[i_segment];
auto & tokens = segment.tokens;
const int n_samples = ctx->energy.size();
if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__);
return;
}
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size();
if (n == 0) {
return;
}
if (n == 1) {
tokens[0].t0 = t0;
tokens[0].t1 = t1;
return;
}
auto & t_beg = ctx->t_beg;
auto & t_last = ctx->t_last;
auto & tid_last = ctx->tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
// split the time proportionally to the voice length
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
// fix up (just in case)
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
const int ss0 = std::max(s0 - hw, 0);
const int ss1 = std::min(s1 + hw, n_samples);
const int ns = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += ctx->energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
if (ctx->energy[k] > thold && j > 0) {
while (k > 0 && ctx->energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (ctx->energy[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = sample_to_timestamp(k);
}
}
{
int k = s1;
if (ctx->energy[k] > thold) {
while (k < n_samples - 1 && ctx->energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (ctx->energy[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = sample_to_timestamp(k);
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}

View File

@ -68,21 +68,14 @@ extern "C" {
typedef int whisper_token;
typedef struct whisper_token_data {
struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token
int64_t t1; // end time of the token
float vlen; // voice length of the token
} whisper_token_data;
};
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
@ -136,7 +129,7 @@ extern "C" {
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found
@ -179,15 +172,14 @@ extern "C" {
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
int offset_ms;
bool translate;
bool no_context;
@ -196,12 +188,6 @@ extern "C" {
bool print_realtime;
bool print_timestamps;
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
const char * language;
struct {
@ -258,7 +244,7 @@ extern "C" {
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);