Compare commits

...

30 Commits

Author SHA1 Message Date
5d895d60b6 Merge branch 'master' into avx512 2022-11-06 09:09:50 +02:00
b71d45beff ggml : fix AVX 512-bit kernels 2022-11-06 08:50:57 +02:00
c4350356de Update ggml.c 2022-11-05 22:56:56 +02:00
a09e9123ca Update README.md 2022-11-05 08:44:41 +02:00
d42cf6d0df Update README.md 2022-11-04 22:26:08 +02:00
ef47d77492 main : fix generated bash script 2022-11-04 18:30:38 +02:00
75171c2b79 ggml : multi-thread the ggml_add operator 2022-11-03 20:53:44 +02:00
a2eeb941f6 cmake : fix passing GGML_PERF compile option 2022-11-03 20:19:06 +02:00
0e689f83d8 Update README.md 2022-11-02 22:03:27 +02:00
d5afebd37c whisper : token-level timestamp refactoring (#49, #120)
This turned out pretty good overall. The algorithm has been moved from
main.cpp to whisper.cpp and can be reused for all subtitles types. This
means that now you can specify the maximum length of the generated
lines. Simply provide the "-ml" argument specifying the max length in
number of characters
2022-11-02 21:45:54 +02:00
4b1c32e8ea Update README.md 2022-11-02 18:33:29 +02:00
b5dde365e9 extra : compute SHA of all models files 2022-11-02 18:31:55 +02:00
02dfd5b8c3 whisper : fix extra memory usage after recent processor changes
Had increased the memory buffer to the size of the model and forgot to
bring it down.
2022-11-02 18:31:18 +02:00
c63ce24834 Allow building with Accelerate for x86_64 Macs (#123)
* Cross compile windows

* set env properly

* rm log

* fix review

* Add back space

* Don't force architecture

* Allow building x86_64 with accelerate
2022-11-02 18:00:19 +02:00
137321915f ggml : fix the check for NEON support (#7)
Was using the wrong preprocessor macro
2022-11-02 17:52:24 +02:00
24cd12f647 Cross compilation (#121)
* Cross compile windows

* set env properly

* rm log

* fix review

* Add back space
2022-11-02 08:46:49 +02:00
e46bc56e71 Update README.md 2022-11-01 22:47:58 +02:00
6fb98370ba main : add some comments for the word-level timestamp algorithm 2022-11-01 22:35:21 +02:00
0729da9a3b main : fix some edge cases for word-level timestamps 2022-11-01 22:09:25 +02:00
5dc74e3aff Update README.md 2022-10-31 22:06:05 +02:00
ac8ef34039 Update README.md 2022-10-31 20:19:41 +02:00
b26345cc7b Added for Windows implemenated script download-ggml-model.cmd 2022-10-31 19:38:20 +02:00
8dac3c6e10 Fixed sched_yield 2022-10-30 21:38:18 +02:00
6417e59aad Implemenated sched_yield function for Windows 2022-10-30 21:38:18 +02:00
dc12994603 Update README.md 2022-10-30 17:11:37 +02:00
b0f2aa0ea6 Update README.md 2022-10-30 17:10:46 +02:00
361395187d Merge remote-tracking branch 'origin/master' into avx512 2022-10-27 17:46:04 +03:00
7fc52fa7ef Another shot at AVX-512 support 2022-10-27 17:45:38 +03:00
01e037c6c6 Merge branch 'master' into avx512 2022-10-27 17:31:55 +03:00
95f4fc70ca Try to add AVX 512-bit support 2022-10-26 18:48:54 +03:00
12 changed files with 1089 additions and 532 deletions

View File

@ -47,7 +47,7 @@ else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
option(WHISPER_PERF "whisper: enable perf timings" OFF)
option(WHISPER_PERF "whisper: enable perf timings" OFF)
# sanitizers
@ -151,6 +151,10 @@ else()
endif()
endif()
if (WHISPER_PERF)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
endif()
#
# whisper - this is the main library of the project
#

View File

@ -1,6 +1,14 @@
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
ifndef UNAME_P
UNAME_P := $(shell uname -p)
endif
ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
@ -8,8 +16,8 @@ ifeq ($(UNAME_S),Darwin)
ifneq ($(UNAME_P),arm)
SYSCTL_M := $(shell sysctl -n hw.optional.arm64)
ifeq ($(SYSCTL_M),1)
UNAME_P := arm
UNAME_M := arm64
# UNAME_P := arm
# UNAME_M := arm64
warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
endif
endif
@ -42,12 +50,16 @@ endif
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),x86_64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
# AVX 512
CFLAGS += -mavx512f -mfma -mf16c
# AVX 256
#CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifneq ($(filter arm%,$(UNAME_M)),)
ifndef WHISPER_NO_ACCELERATE
# Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin)
CFLAGS += -DGGML_USE_ACCELERATE
@ -78,13 +90,13 @@ main: examples/main/main.cpp ggml.o whisper.o
./main -h
ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
libwhisper.a: ggml.o whisper.o
ar rcs libwhisper.a ggml.o whisper.o
$(AR) rcs libwhisper.a ggml.o whisper.o
clean:
rm -f *.o main stream bench libwhisper.a

187
README.md
View File

@ -26,14 +26,41 @@ Supported platforms:
The entire implementation of the model is contained in 2 source files:
- [ggml.h](ggml.h) / [ggml.c](ggml.c)
- [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device:
https://user-images.githubusercontent.com/1991296/197385372-962a6dea-bca1-4d50-bf96-1d8c27b98c81.mp4
## Implementation details
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
- The transformer model and the high-level C-style API are implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Limitations
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Quick start
First, download one of the Whisper models converted in [ggml format](models). For example:
@ -59,8 +86,8 @@ For a quick demo, simply run `make base.en`:
```java
$ make base.en
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp -o whisper.o
c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o ggml.o -o main -framework Accelerate
./main -h
@ -70,13 +97,18 @@ options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
@ -86,7 +118,7 @@ options:
bash ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ...
ggml-base.en.bin 100%[========================>] 141.11M 6.34MB/s in 24s
ggml-base.en.bin 100%[========================>] 141.11M 6.34MB/s in 24s
Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
You can now use it like this:
@ -114,23 +146,26 @@ whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 2
whisper_model_load: mem_required = 505.00 MB
whisper_model_load: mem_required = 670.00 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: ggml ctx size = 163.43 MB
whisper_model_load: ggml ctx size = 140.60 MB
whisper_model_load: memory size = 22.83 MB
whisper_model_load: model size = 140.54 MB
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, lang = en, task = transcribe, timestamps = 1 ...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
[00:00.000 --> 00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
whisper_print_timings: load time = 87.21 ms
whisper_print_timings: mel time = 24.26 ms
whisper_print_timings: sample time = 3.87 ms
whisper_print_timings: encode time = 323.67 ms / 53.94 ms per layer
whisper_print_timings: decode time = 83.25 ms / 13.87 ms per layer
whisper_print_timings: total time = 522.66 ms
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: load time = 105.91 ms
whisper_print_timings: mel time = 24.62 ms
whisper_print_timings: sample time = 3.63 ms
whisper_print_timings: encode time = 324.71 ms / 54.12 ms per layer
whisper_print_timings: decode time = 83.58 ms / 13.93 ms per layer
whisper_print_timings: total time = 542.81 ms
```
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@ -172,8 +207,8 @@ make large
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~280 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~430 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
@ -185,7 +220,7 @@ in about half a minute on a MacBook M1 Pro, using `medium.en` model:
<details>
<summary>Expand to see the result</summary>
```java
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
@ -273,32 +308,108 @@ to highlight words with high or low confidence:
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
## Implementation details
## Controlling the length of the generated text segments (experimental)
- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
- The high-level C-style API is implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
## Limitations
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
```
whisper --best_of None --beam_size None ...
```
[00:00:00.000 --> 00:00:00.850] And so my
[00:00:00.850 --> 00:00:01.590] fellow
[00:00:01.590 --> 00:00:04.140] Americans, ask
[00:00:04.140 --> 00:00:05.660] not what your
[00:00:05.660 --> 00:00:06.840] country can do
[00:00:06.840 --> 00:00:08.430] for you, ask
[00:00:08.430 --> 00:00:09.440] what you can do
[00:00:09.440 --> 00:00:10.020] for your
[00:00:10.020 --> 00:00:11.000] country.
```
In the future, `whisper.cpp` will support more sampling strategies.
## Word-level timestamp
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 |
main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:00.320]
[00:00:00.320 --> 00:00:00.370] And
[00:00:00.370 --> 00:00:00.690] so
[00:00:00.690 --> 00:00:00.850] my
[00:00:00.850 --> 00:00:01.590] fellow
[00:00:01.590 --> 00:00:02.850] Americans
[00:00:02.850 --> 00:00:03.300] ,
[00:00:03.300 --> 00:00:04.140] ask
[00:00:04.140 --> 00:00:04.990] not
[00:00:04.990 --> 00:00:05.410] what
[00:00:05.410 --> 00:00:05.660] your
[00:00:05.660 --> 00:00:06.260] country
[00:00:06.260 --> 00:00:06.600] can
[00:00:06.600 --> 00:00:06.840] do
[00:00:06.840 --> 00:00:07.010] for
[00:00:07.010 --> 00:00:08.170] you
[00:00:08.170 --> 00:00:08.190] ,
[00:00:08.190 --> 00:00:08.430] ask
[00:00:08.430 --> 00:00:08.910] what
[00:00:08.910 --> 00:00:09.040] you
[00:00:09.040 --> 00:00:09.320] can
[00:00:09.320 --> 00:00:09.440] do
[00:00:09.440 --> 00:00:09.760] for
[00:00:09.760 --> 00:00:10.020] your
[00:00:10.020 --> 00:00:10.510] country
[00:00:10.510 --> 00:00:11.000] .
```
## Karaoke-style movie generation (experimental)
The [main](examples/main) example provides support for output of karaoke-style movies, where the
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed.
Here are a few *"typical"* examples:
```java
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
source ./samples/jfk.wav.wts
ffplay ./samples/jfk.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337465-dbee4b5e-9aeb-48a3-b1c6-323ac4db5b2c.mp4
---
```java
./main -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
source ./samples/mm0.wav.wts
ffplay ./samples/mm0.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337504-cc8fd233-0cb7-4920-95f9-4227de3570aa.mp4
---
```java
./main -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
source ./samples/gb0.wav.wts
ffplay ./samples/gb0.wav.mp4
```
https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a0cd-f28a317987ba.mp4
---
## Benchmarks

View File

@ -6,21 +6,29 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
```
./main -h
usage: ./main [options] file0.wav file1.wav ...
usage: ./bin/main [options] file0.wav file1.wav ...
options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-o N, --offset N offset in milliseconds (default: 0)
-p N, --processors N number of processors to use during computation (default: 1)
-ot N, --offset-t N time offset in milliseconds (default: 0)
-on N, --offset-n N segment index offset (default: 0)
-mc N, --max-context N maximum number of text context tokens to store (default: max)
-ml N, --max-len N maximum segment length in characters (default: 0)
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
-v, --verbose verbose output
--translate translate from source language to english
-otxt, --output-txt output result in a text file
-ovtt, --output-vtt output result in a vtt file
-osrt, --output-srt output result in a srt file
-owts, --output-words output script for generating karaoke video
-ps, --print_special print special tokens
-pc, --print_colors print colors
-nt, --no_timestamps do not print timestamps
-l LANG, --language LANG spoken language (default: en)
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
-f FNAME, --file FNAME input WAV file path
-h, --help show this help message and exit
```

View File

@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf);
}
// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
pos = s.find(search, pos);
@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
}
}
// a cost-function that is high for text that takes longer to pronounce
float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// command-line parameters
struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
@ -78,6 +54,7 @@ struct whisper_params {
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t max_context = -1;
int32_t max_len = 0;
float word_thold = 0.01f;
@ -120,6 +97,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.offset_n = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]);
} else if (arg == "-ml" || arg == "--max-len") {
params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
@ -176,13 +155,14 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n");
fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -pc, --print_colors print colors\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
@ -192,65 +172,67 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "\n");
}
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx);
// print the last segment
const int i = n_segments - 1;
if (i == 0) {
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
for (int i = s0; i < n_segments; i++) {
if (params.no_timestamps) {
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
}
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
}
@ -320,364 +302,118 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
return true;
}
// word-level timestamps (experimental)
// TODO: probably still has bugs, needs refactoring, etc..
// TODO: auto threshold
// TODO: extra pass to detect unused speech and assign to tokens
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
if (params.output_wts) {
std::vector<float> pcm_avg(pcmf32.size(), 0);
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
std::ofstream fout(fname);
// average the fabs of the signal
{
const int hw = 32;
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
for (int i = 0; i < pcmf32.size(); i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < pcmf32.size()) {
sum += fabs(pcmf32[i + j]);
}
}
pcm_avg[i] = sum/(2*hw + 1);
}
// TODO: become parameter
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
fout << "#!/bin/bash" << "\n";
fout << "\n";
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
const int n = whisper_full_n_tokens(ctx, i);
std::vector<whisper_token_data> tokens(n);
for (int j = 0; j < n; ++j) {
tokens[j] = whisper_full_get_token_data(ctx, i, j);
}
struct token_info {
int64_t t0 = -1;
int64_t t1 = -1;
if (i > 0) {
fout << ",";
}
int64_t tt0 = -1;
int64_t tt1 = -1;
whisper_token id;
whisper_token tid;
float p = 0.0f;
float pt = 0.0f;
float ptsum = 0.0f;
std::string text;
float vlen = 0.0f; // voice length of this token
};
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last = 0;
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "!/bin/bash" << "\n";
fout << "\n";
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \"";
// background text
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
bool is_first = true;
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
const char *text = whisper_full_get_segment_text(ctx, i);
const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
const int n = whisper_full_n_tokens(ctx, i);
std::vector<token_info> tokens(n);
if (n <= 1) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
for (int j = 0; j < n; ++j) {
struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
std::string txt_bg;
std::string txt_fg; // highlight token
std::string txt_ul; // underline
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].text = whisper_token_to_str(ctx, token.id);
//tokens[j].vlen = tokens[j].pt;
tokens[j].vlen = voice_length(tokens[j].text);
if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
//const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1);
//const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1);
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
tokens[j].tt0 = tokens[j].t0;
tokens[j].tt1 = tokens[j].t1;
}
// VAD
{
const int hw = WHISPER_SAMPLE_RATE; // take one second of audio around the token
int ncnt = 0;
for (int k = 0; k < n; ++k) {
const auto & token2 = tokens[k];
for (int j = 0; j < n; j++) {
const int64_t t0 = tokens[j].t0;
const int64_t t1 = tokens[j].t1;
int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100));
const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
const int n = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += pcm_avg[k];
if (tokens[k].id >= whisper_token_eot(ctx)) {
continue;
}
const float avg = sum/n;
const std::string txt = whisper_token_to_str(ctx, token2.id);
const float thold = 0.5*avg;
txt_bg += txt;
{
int k = s0;
if (pcm_avg[k] > thold && j > 0) {
while (k > 0 && pcm_avg[k] > thold) {
k--;
}
tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (pcm_avg[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
if (k == j) {
for (int l = 0; l < (int) txt.size(); ++l) {
txt_fg += txt[l];
txt_ul += "_";
}
txt_fg += "|";
} else {
for (int l = 0; l < (int) txt.size(); ++l) {
txt_fg += "\\ ";
txt_ul += "\\ ";
}
}
{
int k = s1;
if (pcm_avg[k] > thold) {
while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
k++;
}
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (pcm_avg[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
}
}
ncnt += txt.size();
}
::replace_all(txt_bg, "'", "");
::replace_all(txt_bg, "\"", "\\\"");
::replace_all(txt_fg, "'", "");
::replace_all(txt_fg, "\"", "\\\"");
}
const int t_expand = 0;
for (int j = 0; j < n; j++) {
if (j > 0) {
tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
}
if (j < n - 1) {
tokens[j].t1 = tokens[j].t1 + t_expand;
}
}
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
//printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
//fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
}
static const int line_wrap = 60;
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
if (!is_first) {
fout << ",";
}
// background text
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
is_first = false;
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
std::string txt_bg;
std::string txt_fg; // highlight token
std::string txt_ul; // underline
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
{
int ncnt = 0;
for (int k = 0; k < n; ++k) {
const auto & token2 = tokens[k];
if (tokens[k].id >= whisper_token_eot(ctx)) {
continue;
}
const std::string txt = whisper_token_to_str(ctx, token2.id);
txt_bg += txt;
if (k == j) {
for (int l = 0; l < (int) txt.size(); ++l) {
txt_fg += txt[l];
txt_ul += "_";
}
txt_fg += "|";
} else {
for (int l = 0; l < (int) txt.size(); ++l) {
txt_fg += "\\ ";
txt_ul += "\\ ";
}
}
ncnt += txt.size();
if (ncnt > line_wrap) {
if (k < j) {
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
ncnt = 0;
} else {
break;
}
}
}
::replace_all(txt_bg, "'", "");
::replace_all(txt_bg, "\"", "\\\"");
::replace_all(txt_fg, "'", "");
::replace_all(txt_fg, "\"", "\\\"");
}
if (is_first) {
// background text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'";
// foreground text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
// underline
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
is_first = false;
}
// foreground text
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
// underline
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
}
fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
fout << "\n\n";
fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
fout << "\n";
fout << "echo \" ffplay " << fname_inp << ".mp4\"\n";
fout << "\n";
fout.close();
fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
}
fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
fout << "\n\n";
fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
fout << "\n";
fout << "echo \" ffplay " << fname_inp << ".mp4\"\n";
fout << "\n";
fout.close();
fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
return true;
}
@ -797,6 +533,10 @@ int main(int argc, char ** argv) {
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
@ -834,7 +574,7 @@ int main(int argc, char ** argv) {
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_inp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
}
}

View File

@ -78,6 +78,14 @@ There are a lot of ways to improve this idea and I don't have much experience wi
*"optimize by sorting the data first"*
The plugin would then make an appropriate query using the selected text and code context to Copilot or GPT-3 and return the result.
Here is a proof-of-concept:
https://user-images.githubusercontent.com/1991296/199078847-0278fcde-5667-4748-ba0d-7d55381d6047.mp4
https://user-images.githubusercontent.com/1991296/200067939-f98d2ac2-7519-438a-85f9-79db0841ba4f.mp4
For explanation how this works see: https://twitter.com/ggerganov/status/1587168771789258756
## Discussion

7
extra/sha-all.sh Executable file
View File

@ -0,0 +1,7 @@
#!/bin/bash
# Compute the SHA1 of all model files in ./models/ggml-*.bin
for f in ./models/ggml-*.bin; do
shasum "$f" -a 1
done

258
ggml.c
View File

@ -14,7 +14,7 @@
#include <stdint.h>
#include <stdio.h>
#if defined _MSC_VER
#if defined _MSC_VER || defined(__MINGW32__)
#include <Windows.h>
typedef volatile LONG atomic_int;
@ -44,6 +44,11 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
static int pthread_join(pthread_t thread, void* unused) {
return (int) WaitForSingleObject(thread, INFINITE);
}
static int sched_yield (void) {
Sleep (0);
return 0;
}
#else
#include <pthread.h>
#include <stdatomic.h>
@ -193,7 +198,7 @@ static ggml_fp16_t table_exp_f16[1 << 16];
// timing
//
#if defined(_MSC_VER)
#if defined(_MSC_VER) || defined(__MINGW32__)
static int64_t timer_freq;
void ggml_time_init(void) {
LARGE_INTEGER frequency;
@ -322,6 +327,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
for (int i = n16; i < n; ++i) {
sumf += x[i]*y[i];
}
#elif defined(__AVX512F__)
const int n64 = (n & ~63);
__m512 sum0 = _mm512_setzero_ps();
__m512 sum1 = _mm512_setzero_ps();
__m512 sum2 = _mm512_setzero_ps();
__m512 sum3 = _mm512_setzero_ps();
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
for (int i = 0; i < n64; i += 64) {
x0 = _mm512_loadu_ps(x + i + 0);
x1 = _mm512_loadu_ps(x + i + 16);
x2 = _mm512_loadu_ps(x + i + 32);
x3 = _mm512_loadu_ps(x + i + 48);
y0 = _mm512_loadu_ps(y + i + 0);
y1 = _mm512_loadu_ps(y + i + 16);
y2 = _mm512_loadu_ps(y + i + 32);
y3 = _mm512_loadu_ps(y + i + 48);
sum0 = _mm512_fmadd_ps(x0, y0, sum0);
sum1 = _mm512_fmadd_ps(x1, y1, sum1);
sum2 = _mm512_fmadd_ps(x2, y2, sum2);
sum3 = _mm512_fmadd_ps(x3, y3, sum3);
}
sum0 = _mm512_add_ps(sum0, sum1);
sum2 = _mm512_add_ps(sum2, sum3);
sum0 = _mm512_add_ps(sum0, sum2);
sumf = sum0[0] + sum0[1] + sum0[2] + sum0[3] + sum0[4] + sum0[5] + sum0[6] + sum0[7] +
sum0[8] + sum0[9] + sum0[10] + sum0[11] + sum0[12] + sum0[13] + sum0[14] + sum0[15];
// leftovers
for (int i = n64; i < n; ++i) {
sumf += x[i]*y[i];
}
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
@ -519,6 +563,47 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
for (int i = n32; i < n; ++i) {
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
}
#elif defined(__AVX512F__)
// AVX 512-bit
const int n64 = (n & ~63);
__m512 sum0 = _mm512_setzero_ps();
__m512 sum1 = _mm512_setzero_ps();
__m512 sum2 = _mm512_setzero_ps();
__m512 sum3 = _mm512_setzero_ps();
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
for (int i = 0; i < n64; i += 64) {
x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 )));
x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16)));
x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32)));
x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48)));
y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 )));
y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16)));
y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32)));
y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48)));
sum0 = _mm512_fmadd_ps(x0, y0, sum0);
sum1 = _mm512_fmadd_ps(x1, y1, sum1);
sum2 = _mm512_fmadd_ps(x2, y2, sum2);
sum3 = _mm512_fmadd_ps(x3, y3, sum3);
}
const __m512 sum01 = _mm512_add_ps(sum0, sum1);
const __m512 sum23 = _mm512_add_ps(sum2, sum3);
const __m512 sum0123 = _mm512_add_ps(sum01, sum23);
sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7] +
sum0123[8] + sum0123[9] + sum0123[10] + sum0123[11] + sum0123[12] + sum0123[13] + sum0123[14] + sum0123[15];
// leftovers
for (int i = n64; i < n; ++i) {
//GGML_ASSERT(false);
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
}
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
@ -625,7 +710,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
// NEON 128-bit
const int n16 = (n & ~15);
const float32x4_t v4 = vdupq_n_f32(v);
const float32x4_t v0 = vdupq_n_f32(v);
float32x4_t x0, x1, x2, x3;
float32x4_t y0, y1, y2, y3;
@ -641,14 +726,14 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
y2 = vld1q_f32(y + i + 8);
y3 = vld1q_f32(y + i + 12);
y0 = vfmaq_f32(y0, x0, v4);
y1 = vfmaq_f32(y1, x1, v4);
y2 = vfmaq_f32(y2, x2, v4);
y3 = vfmaq_f32(y3, x3, v4);
y0 = vfmaq_f32(y0, x0, v0);
y1 = vfmaq_f32(y1, x1, v0);
y2 = vfmaq_f32(y2, x2, v0);
y3 = vfmaq_f32(y3, x3, v0);
vst1q_f32(y + i + 0, y0);
vst1q_f32(y + i + 4, y1);
vst1q_f32(y + i + 8, y2);
vst1q_f32(y + i + 0, y0);
vst1q_f32(y + i + 4, y1);
vst1q_f32(y + i + 8, y2);
vst1q_f32(y + i + 12, y3);
}
@ -656,11 +741,46 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
for (int i = n16; i < n; ++i) {
y[i] += x[i]*v;
}
#elif defined(__AVX512F__)
// AVX512 512-bit
const int n64 = (n & ~63);
const __m512 v0 = _mm512_set1_ps(v);
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
for (int i = 0; i < n64; i += 64) {
x0 = _mm512_loadu_ps(x + i + 0);
x1 = _mm512_loadu_ps(x + i + 16);
x2 = _mm512_loadu_ps(x + i + 32);
x3 = _mm512_loadu_ps(x + i + 48);
y0 = _mm512_loadu_ps(y + i + 0);
y1 = _mm512_loadu_ps(y + i + 16);
y2 = _mm512_loadu_ps(y + i + 32);
y3 = _mm512_loadu_ps(y + i + 48);
y0 = _mm512_fmadd_ps(x0, v0, y0);
y1 = _mm512_fmadd_ps(x1, v0, y1);
y2 = _mm512_fmadd_ps(x2, v0, y2);
y3 = _mm512_fmadd_ps(x3, v0, y3);
_mm512_storeu_ps(y + i + 0, y0);
_mm512_storeu_ps(y + i + 16, y1);
_mm512_storeu_ps(y + i + 32, y2);
_mm512_storeu_ps(y + i + 48, y3);
}
// leftovers
for (int i = n64; i < n; ++i) {
y[i] += x[i]*v;
}
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
const __m256 v4 = _mm256_set1_ps(v);
const __m256 v0 = _mm256_set1_ps(v);
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
@ -676,13 +796,13 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);
y0 = _mm256_fmadd_ps(x0, v4, y0);
y1 = _mm256_fmadd_ps(x1, v4, y1);
y2 = _mm256_fmadd_ps(x2, v4, y2);
y3 = _mm256_fmadd_ps(x3, v4, y3);
y0 = _mm256_fmadd_ps(x0, v0, y0);
y1 = _mm256_fmadd_ps(x1, v0, y1);
y2 = _mm256_fmadd_ps(x2, v0, y2);
y3 = _mm256_fmadd_ps(x3, v0, y3);
_mm256_storeu_ps(y + i + 0, y0);
_mm256_storeu_ps(y + i + 8, y1);
_mm256_storeu_ps(y + i + 0, y0);
_mm256_storeu_ps(y + i + 8, y1);
_mm256_storeu_ps(y + i + 16, y2);
_mm256_storeu_ps(y + i + 24, y3);
}
@ -695,7 +815,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
// WASM SIMD 128-bit
const int n16 = (n & ~15);
const v128_t v4 = wasm_f32x4_splat(v);
const v128_t v0 = wasm_f32x4_splat(v);
v128_t x0, x1, x2, x3;
v128_t y0, y1, y2, y3;
@ -711,10 +831,10 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
y2 = wasm_v128_load(y + i + 8);
y3 = wasm_v128_load(y + i + 12);
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v0));
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v0));
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v0));
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v0));
wasm_v128_store(y + i + 0, y0);
wasm_v128_store(y + i + 4, y1);
@ -740,7 +860,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
const int n32 = (n & ~31);
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
const float16x8_t v8 = vdupq_n_f16(v);
const float16x8_t v0 = vdupq_n_f16(v);
float16x8_t x0, x1, x2, x3;
float16x8_t y0, y1, y2, y3;
@ -756,10 +876,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
x2 = vld1q_f16(x + i + 16);
x3 = vld1q_f16(x + i + 24);
y0 = vfmaq_f16(y0, x0, v8);
y1 = vfmaq_f16(y1, x1, v8);
y2 = vfmaq_f16(y2, x2, v8);
y3 = vfmaq_f16(y3, x3, v8);
y0 = vfmaq_f16(y0, x0, v0);
y1 = vfmaq_f16(y1, x1, v0);
y2 = vfmaq_f16(y2, x2, v0);
y3 = vfmaq_f16(y3, x3, v0);
vst1q_f16(y + i + 0 , y0);
vst1q_f16(y + i + 8 , y1);
@ -767,8 +887,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
vst1q_f16(y + i + 24, y3);
}
#else
const float32x4_t v40 = vdupq_n_f32(v);
const float32x4_t v41 = vdupq_n_f32(v);
const float32x4_t v0 = vdupq_n_f32(v);
float32x4_t x0, x1, x2, x3, x4, x5, x6, x7;
float32x4_t y0, y1, y2, y3, y4, y5, y6, y7;
@ -792,14 +911,14 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
x6 = vcvt_f32_f16(vld1_f16(x + i + 24));
x7 = vcvt_f32_f16(vld1_f16(x + i + 28));
y0 = vfmaq_f32(y0, x0, v40);
y1 = vfmaq_f32(y1, x1, v40);
y2 = vfmaq_f32(y2, x2, v40);
y3 = vfmaq_f32(y3, x3, v40);
y4 = vfmaq_f32(y4, x4, v41);
y5 = vfmaq_f32(y5, x5, v41);
y6 = vfmaq_f32(y6, x6, v41);
y7 = vfmaq_f32(y7, x7, v41);
y0 = vfmaq_f32(y0, x0, v0);
y1 = vfmaq_f32(y1, x1, v0);
y2 = vfmaq_f32(y2, x2, v0);
y3 = vfmaq_f32(y3, x3, v0);
y4 = vfmaq_f32(y4, x4, v0);
y5 = vfmaq_f32(y5, x5, v0);
y6 = vfmaq_f32(y6, x6, v0);
y7 = vfmaq_f32(y7, x7, v0);
vst1_f16(y + i + 0 , vcvt_f16_f32(y0));
vst1_f16(y + i + 4 , vcvt_f16_f32(y1));
@ -817,11 +936,47 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
}
#elif defined(__AVX512F__)
// AVX 512-bit
const int n64 = (n & ~63);
const __m512 v0 = _mm512_set1_ps(v);
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
for (int i = 0; i < n64; i += 64) {
x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 )));
x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16)));
x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32)));
x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48)));
y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 )));
y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16)));
y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32)));
y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48)));
y0 = _mm512_fmadd_ps(x0, v0, y0);
y1 = _mm512_fmadd_ps(x1, v0, y1);
y2 = _mm512_fmadd_ps(x2, v0, y2);
y3 = _mm512_fmadd_ps(x3, v0, y3);
_mm256_storeu_si256((__m256i*)(y + i + 0 ), _mm512_cvtps_ph(y0, 0));
_mm256_storeu_si256((__m256i*)(y + i + 16), _mm512_cvtps_ph(y1, 0));
_mm256_storeu_si256((__m256i*)(y + i + 32), _mm512_cvtps_ph(y2, 0));
_mm256_storeu_si256((__m256i*)(y + i + 48), _mm512_cvtps_ph(y3, 0));
}
// leftovers
for (int i = n64; i < n; ++i) {
GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
}
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
const __m256 v8 = _mm256_set1_ps(v);
const __m256 v0 = _mm256_set1_ps(v);
__m256 x0, x1, x2, x3;
__m256 y0, y1, y2, y3;
@ -837,10 +992,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
y0 = _mm256_fmadd_ps(x0, v8, y0);
y1 = _mm256_fmadd_ps(x1, v8, y1);
y2 = _mm256_fmadd_ps(x2, v8, y2);
y3 = _mm256_fmadd_ps(x3, v8, y3);
y0 = _mm256_fmadd_ps(x0, v0, y0);
y1 = _mm256_fmadd_ps(x1, v0, y1);
y2 = _mm256_fmadd_ps(x2, v0, y2);
y3 = _mm256_fmadd_ps(x3, v0, y3);
_mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
_mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
@ -857,7 +1012,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
// WASM SIMD 128-bit
const int n16 = (n & ~15);
const v128_t v4 = wasm_f32x4_splat(v);
const v128_t v0 = wasm_f32x4_splat(v);
v128_t x0, x1, x2, x3;
v128_t y0, y1, y2, y3;
@ -881,10 +1036,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
y2 = wasm_v128_load(ty + 8);
y3 = wasm_v128_load(ty + 12);
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v0));
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v0));
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v0));
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v0));
wasm_v128_store(ty + 0, y0);
wasm_v128_store(ty + 4, y1);
@ -3145,7 +3300,10 @@ void ggml_compute_forward_add_f32(
GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
for (int j = ith; j < n; j += nth) {
const int j0 = (n/nth)*ith;
const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
for (int j = j0; j < j1; j++) {
ggml_vec_add_f32(nc,
(float *) ((char *) dst->data + j*nb1),
(float *) ((char *) src0->data + j*nb01),
@ -6852,7 +7010,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_ADD:
{
node->n_tasks = 1;
node->n_tasks = n_threads;
} break;
case GGML_OP_SUB:
case GGML_OP_MUL:
@ -8084,7 +8242,7 @@ int ggml_cpu_has_avx512(void) {
}
int ggml_cpu_has_neon(void) {
#if defined(__ARM_NEON__)
#if defined(__ARM_NEON)
return 1;
#else
return 0;

View File

@ -22,6 +22,20 @@ A third option to obtain the model files is to download them from Hugging Face:
https://huggingface.co/datasets/ggerganov/whisper.cpp/tree/main
## Available models
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| tiny.en | 75 MB | ~390 MB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| base.en | 142 MB | ~500 MB | `137c40403d78fd54d454da0f9bd998f78703390c` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| small.en | 466 MB | ~1.0 GB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| medium.en | 1.5 GB | ~2.6 GB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
| large | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
## Model files for testing purposes
The model files pefixed with `for-tests-` are empty (i.e. do not contain any weights) and are used by the CI for testing purposes.

View File

@ -0,0 +1,63 @@
@echo off
pushd %~dp0
set models_path=%CD%
popd
set argc=0
for %%x in (%*) do set /A argc+=1
set models=tiny.en tiny base.en base small.en small medium.en medium large
if %argc% neq 1 (
echo.
echo Usage: download-ggml-model.cmd model
CALL :list_models
goto :eof
)
set model=%1
for %%b in (%models%) do (
if "%%b"=="%model%" (
CALL :download_model
goto :eof
)
)
echo Invalid model: %model%
CALL :list_models
goto :eof
:download_model
echo Downloading ggml model %model%...
cd %models_path%
if exist "ggml-%model%.bin" (
echo Model %model% already exists. Skipping download.
goto :eof
)
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://ggml.ggerganov.com/ggml-model-whisper-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model%
echo Please try again later or download the original Whisper model files and convert them yourself.
goto :eof
)
echo Done! Model %model% saved in %models_path%\models\ggml-%model%.bin
echo You can now use it like this:
echo main.exe -m %models_path%\models\ggml-%model%.bin -f %models_path%\samples\jfk.wav
goto :eof
:list_models
echo.
echo Available models:
(for %%a in (%models%) do (
echo %%a
))
echo.
exit /b

View File

@ -133,11 +133,19 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 86ull*MB },
{ MODEL_BASE, 165ull*MB },
{ MODEL_SMALL, 540ull*MB },
{ MODEL_MEDIUM, 1650ull*MB },
{ MODEL_LARGE, 3260ull*MB },
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -410,6 +418,12 @@ struct whisper_context {
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
};
// load the model from a ggml file
@ -423,7 +437,7 @@ struct whisper_context {
//
// see the convert-pt-to-ggml.py script for details
//
bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
auto & model = wctx.model;
@ -498,7 +512,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
@ -722,20 +736,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
auto & ctx = model.ctx;
@ -932,6 +932,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// key + value memory
{
auto & ctx = model.ctx_mem;
@ -1054,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
bool whisper_encode(
static bool whisper_encode(
whisper_context & wctx,
const int n_threads,
const int mel_offset) {
@ -1440,7 +1454,7 @@ bool whisper_encode(
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
bool whisper_decode(
static bool whisper_decode(
whisper_context & wctx,
const int n_threads,
const whisper_token * tokens,
@ -1803,10 +1817,12 @@ bool whisper_decode(
}
// the most basic sampling scheme - select the top token
whisper_token_data whisper_sample_best(
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
whisper_token_data result;
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab.id_to_token.size();
@ -1879,7 +1895,7 @@ whisper_token_data whisper_sample_best(
}
// samples only from the timestamps tokens
whisper_vocab::id whisper_sample_timestamp(
static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
@ -1931,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
void dft(const std::vector<float> & in, std::vector<float> & out) {
static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
@ -1955,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
void fft(const std::vector<float> & in, std::vector<float> & out) {
static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
@ -2006,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
bool log_mel_spectrogram(
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int sample_rate,
@ -2331,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2363,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@ -2384,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
return result;
}
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum);
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
auto segment = ctx->result_all.back();
int res = 1;
int acc = 0;
std::string text;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(ctx)) {
continue;
}
const auto txt = whisper_token_to_str(ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) {
// split here
ctx->result_all.back().text = std::move(text);
ctx->result_all.back().t1 = token.t0;
ctx->result_all.back().tokens.resize(i);
ctx->result_all.push_back({});
ctx->result_all.back().t0 = token.t0;
ctx->result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
ctx->result_all.back().tokens.insert(
ctx->result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
acc = 0;
text = "";
segment = ctx->result_all.back();
i = -1;
res++;
} else {
acc += cur;
text += txt;
}
}
ctx->result_all.back().text = std::move(text);
return res;
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
@ -2400,6 +2488,13 @@ int whisper_full(
return -1;
}
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
ctx->tid_last = 0;
ctx->energy = get_signal_energy(samples, n_samples, 32);
}
const int seek_start = params.offset_ms/10;
// if length of spectrogram is less than 1s (100 samples), then return
@ -2549,6 +2644,7 @@ int whisper_full(
}
}
// shrink down to result_len
tokens_cur.resize(result_len);
for (const auto & r : tokens_cur) {
@ -2587,8 +2683,19 @@ int whisper_full(
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
text = "";
@ -2617,8 +2724,19 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
}
@ -2752,7 +2870,7 @@ int whisper_full_parallel(
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
}
}
@ -2828,3 +2946,304 @@ const char * whisper_print_system_info() {
return s.c_str();
}
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
static int64_t sample_to_timestamp(int i_sample) {
return (100*i_sample)/WHISPER_SAMPLE_RATE;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// average the fabs of the signal
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
const int hw = n_samples_per_half_window;
std::vector<float> result(n_samples);
for (int i = 0; i < n_samples; i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < n_samples) {
sum += fabs(signal[i + j]);
}
}
result[i] = sum/(2*hw + 1);
}
return result;
}
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum) {
auto & segment = ctx->result_all[i_segment];
auto & tokens = segment.tokens;
const int n_samples = ctx->energy.size();
if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__);
return;
}
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size();
if (n == 0) {
return;
}
if (n == 1) {
tokens[0].t0 = t0;
tokens[0].t1 = t1;
return;
}
auto & t_beg = ctx->t_beg;
auto & t_last = ctx->t_last;
auto & tid_last = ctx->tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
// split the time proportionally to the voice length
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
// fix up (just in case)
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
const int ss0 = std::max(s0 - hw, 0);
const int ss1 = std::min(s1 + hw, n_samples);
const int ns = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += ctx->energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
if (ctx->energy[k] > thold && j > 0) {
while (k > 0 && ctx->energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (ctx->energy[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = sample_to_timestamp(k);
}
}
{
int k = s1;
if (ctx->energy[k] > thold) {
while (k < n_samples - 1 && ctx->energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (ctx->energy[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = sample_to_timestamp(k);
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}

View File

@ -68,14 +68,21 @@ extern "C" {
typedef int whisper_token;
struct whisper_token_data {
typedef struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
};
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token
int64_t t1; // end time of the token
float vlen; // voice length of the token
} whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
@ -129,7 +136,7 @@ extern "C" {
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found
@ -172,7 +179,7 @@ extern "C" {
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
@ -188,6 +195,12 @@ extern "C" {
bool print_realtime;
bool print_timestamps;
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
const char * language;
struct {
@ -244,7 +257,7 @@ extern "C" {
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);