Compare commits

..

21 Commits

Author SHA1 Message Date
ec44ad0a75 diarization : try conv and self-attention embeddings 2023-02-19 13:00:12 +02:00
d11f35920e diarization : try to cluster embedings from last encoder layer 2023-02-19 10:33:03 +02:00
d5d7769fa7 diarization : more unsuccessful clustering experiments 2023-02-18 18:36:03 +02:00
c2f5be7c11 diarization : some unsuccessful experiments with audio embd clustering 2023-02-18 12:16:39 +02:00
f254e78737 yt-wsp.sh : print help on empty args 2023-02-18 09:42:31 +02:00
a94897bcde whisper : by default disable non-speech tokens suppression (#473)
This seems to be causing hallucinations in the end of the audio, e.g.:

"Thank you for listening"
"Amen"
..
2023-02-15 21:48:49 +02:00
2407ae8ef0 readme : add Ruby discussion + update .NET discussion 2023-02-15 19:51:54 +02:00
b623ca43b1 bindings : add Ruby (#500)
* adding ruby bindings

* avoid adding these they are copied in via extconf.rb

* ignore these files here

* add definitions for boolean params

* initial transcribe for ruby

* use en model and transcribe jfk with assertion

* possibly this works for building ruby binding

* ci : try to add ruby workflow

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-02-15 19:46:55 +02:00
69e6e4644a main : fix std in input (#503)
if we don't add this as an explicit check, then we get an "error: unknown argument: -" later on
2023-02-15 19:31:16 +02:00
09d7d2b68e examples : refactor in order to reuse code and reduce duplication (#482)
* examples : refactor common code into a library

* examples : refactor common SDL code into a library

* make : update Makefile to use common libs

* common : fix MSVC M_PI ..

* addon.node : link common lib
2023-02-15 19:28:10 +02:00
0336161b7d whisper : fix signedness compiler warning (#506) 2023-02-15 19:08:25 +02:00
459753342d yt-wsp.sh : add unique filename generation (#495)
Co-authored-by: genevera <genevera@noreply.users.github.com>
2023-02-14 20:12:51 +02:00
9764782bd9 readme : add another .NET repo (#303) 2023-02-14 20:04:03 +02:00
3b010f9bed readme : add .NET repo (#303) 2023-02-11 17:35:33 +02:00
113fcec513 cmake : install whisper.h header (#485)
Including the header file in the install bundle helps projects that ship binaries.
2023-02-11 09:13:32 +02:00
cfc06bf8df whisper : suppress non-speech-related token outputs (#473)
* add non-speech-token suppression

* add suppress non-speech_tokens param
2023-02-08 09:05:34 +02:00
2bfe0ebc0f whisper : fixed Beam Search Strategy and exposed whisper_pcm_to_mel_phase_vocoder (#474)
Co-authored-by: Sandro Hanea <sandrohanea@microsoft.com>
2023-02-08 09:01:47 +02:00
4dd7119deb whisper : only trim if split_on_word is true (#476) 2023-02-08 08:43:23 +02:00
ab1916fc59 ci : add node addon test and optimize compilation configuration (#468)
* addon: implement node addon call whisper through cpp

* addon: modify the license to MIT

* addon: remove iostream

* addon: rename dir

* addon: fix typo

* addon: configure cmake to build when cmake-js is used

* ci: add addon.node test ci

* addon: remove build WHISPER_BUILD_TESTS

* addon: update build command

* addon: add test

* addon: add test file

* addon: adapt to compile on Windows

* addon: fix typo

* addon: reuse jfk.wav

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* addon: reuse jfk.wav

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-02-05 15:02:08 +02:00
a1c1583cc7 whisper : add whisper_full_lang_id() for getting the context lang (#461) 2023-02-05 14:46:26 +02:00
d012b5c7e4 whisper : add "split_on_word" flag when using using "max_len" option (#455)
* Update whisper.cpp

* fix: trim function

* feat: added flag to split on word

* fix: arguments for main
2023-02-05 14:44:23 +02:00
37 changed files with 2075 additions and 1345 deletions

View File

@ -1,4 +1,4 @@
name: Bindings Tests name: Bindings Tests (Go)
on: on:
push: push:
paths: paths:

22
.github/workflows/bindings-ruby.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: Bindings Tests (Ruby)
on:
push:
paths:
- bindings/ruby/**
- whisper.h
pull_request:
paths:
- bindings/ruby/**
- whisper.h
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: ruby/setup-ruby@v1
with:
ruby-version: '3.0'
- uses: actions/checkout@v1
- run: |
cd bindings/ruby/ext
ruby extconf.rb && make

48
.github/workflows/examples.yml vendored Normal file
View File

@ -0,0 +1,48 @@
name: Examples Tests
on:
push:
paths:
- examples/addon.node/**
- whisper.h
pull_request:
paths:
- examples/addon.node/**
- whisper.h
jobs:
addon_node-ubuntu-latest:
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [ 16.x, 18.x ]
steps:
- name: Clone
uses: actions/checkout@v1
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v1
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
- name: Install package.json dependencies
working-directory: ./examples/addon.node
run: npm install
- name: Compile addon.node
run: npx cmake-js compile -T whisper-addon -B Release
- name: Download test model
run: |
bash ./models/download-ggml-model.sh base.en
- name: Test
run: |
cd examples/addon.node
npm run test

View File

@ -226,10 +226,13 @@ target_compile_definitions(${TARGET} PUBLIC
${WHISPER_EXTRA_FLAGS} ${WHISPER_EXTRA_FLAGS}
) )
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
install(TARGETS ${TARGET} install(TARGETS ${TARGET}
LIBRARY DESTINATION lib LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib/static ARCHIVE DESTINATION lib/static
RUNTIME DESTINATION bin RUNTIME DESTINATION bin
PUBLIC_HEADER DESTINATION include
) )
# #
@ -242,7 +245,7 @@ add_subdirectory(bindings)
# programs, examples and tests # programs, examples and tests
# #
if (WHISPER_BUILD_TESTS) if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
enable_testing() enable_testing()
add_subdirectory(tests) add_subdirectory(tests)
endif () endif ()

View File

@ -197,18 +197,21 @@ clean:
CC_SDL=`sdl2-config --cflags --libs` CC_SDL=`sdl2-config --cflags --libs`
main: examples/main/main.cpp ggml.o whisper.o SRC_COMMON = examples/common.cpp
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS) SRC_COMMON_SDL = examples/common-sdl.cpp
main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
./main -h ./main -h
stream: examples/stream/stream.cpp ggml.o whisper.o stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp ggml.o whisper.o command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o whisper.o bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)

View File

@ -464,7 +464,11 @@ in [models](models).
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310) - [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309) - [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312) - [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313) - [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9) - [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
## Examples ## Examples

File diff suppressed because one or more lines are too long

7
bindings/ruby/ext/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
Makefile
ggml.c
ggml.h
whisper.bundle
whisper.cpp
whisper.h
dr_wav.h

View File

@ -0,0 +1,21 @@
require 'mkmf'
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
# need to use c++ compiler flags
$CXXFLAGS << ' -std=c++11'
# Set to true when building binary gems
if enable_config('static-stdlib', false)
$LDFLAGS << ' -static-libgcc -static-libstdc++'
end
if enable_config('march-tune-native', false)
$CFLAGS << ' -march=native -mtune=native'
$CXXFLAGS << ' -march=native -mtune=native'
end
create_makefile('whisper')

View File

@ -0,0 +1,426 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#ifdef __cplusplus
extern "C" {
#endif
#define BOOL_PARAMS_SETTER(self, prop, value) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (value == Qfalse || value == Qnil) { \
rwp->params.prop = false; \
} else { \
rwp->params.prop = true; \
} \
return value; \
#define BOOL_PARAMS_GETTER(self, prop) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (rwp->params.prop) { \
return Qtrue; \
} else { \
return Qfalse; \
}
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
static void ruby_whisper_free(ruby_whisper *rw) {
if (rw->context) {
whisper_free(rw->context);
rw->context = NULL;
}
}
static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
}
void rb_whisper_mark(ruby_whisper *rw) {
// call rb_gc_mark on any ruby references in rw
}
void rb_whisper_free(ruby_whisper *rw) {
ruby_whisper_free(rw);
free(rw);
}
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
}
void rb_whisper_params_free(ruby_whisper_params *rwp) {
ruby_whisper_params_free(rwp);
free(rwp);
}
static VALUE ruby_whisper_allocate(VALUE klass) {
ruby_whisper *rw;
rw = ALLOC(ruby_whisper);
rw->context = NULL;
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
}
static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
VALUE whisper_model_file_path;
// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
if (rw->context == nullptr) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
return self;
}
/*
* transcribe a single file
* can emit to a block results
*
**/
static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params;
rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
Data_Get_Struct(self, ruby_whisper, rw);
Data_Get_Struct(params, ruby_whisper_params, rwp);
if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
}
std::string fname_inp = StringValueCStr(wave_file_path);
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input - this is directly from main.cpp example
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true) {
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return self;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return self;
}
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return self;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return self;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return self;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (rwp->diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}
const int n_segments = whisper_full_n_segments(rw->context);
VALUE output = rb_str_new2("");
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(rw->context, i);
output = rb_str_concat(output, rb_str_new2(text));
}
VALUE idCall = rb_intern("call");
if (blk != Qnil) {
rb_funcall(blk, idCall, 1, output);
}
return self;
}
/*
* params.language = "auto" | "en", etc...
*/
static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->params.language = "auto";
} else {
rwp->params.language = StringValueCStr(value);
}
return value;
}
static VALUE ruby_whisper_params_get_language(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->params.language) {
return rb_str_new2(rwp->params.language);
} else {
return rb_str_new2("auto");
}
}
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, translate, value)
}
static VALUE ruby_whisper_params_get_translate(VALUE self) {
BOOL_PARAMS_GETTER(self, translate)
}
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, no_context, value)
}
static VALUE ruby_whisper_params_get_no_context(VALUE self) {
BOOL_PARAMS_GETTER(self, no_context)
}
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, single_segment, value)
}
static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
BOOL_PARAMS_GETTER(self, single_segment)
}
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_special, value)
}
static VALUE ruby_whisper_params_get_print_special(VALUE self) {
BOOL_PARAMS_GETTER(self, print_special)
}
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_progress, value)
}
static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
BOOL_PARAMS_GETTER(self, print_progress)
}
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_realtime, value)
}
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
BOOL_PARAMS_GETTER(self, print_realtime)
}
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_timestamps, value)
}
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, print_timestamps)
}
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_blank, value)
}
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_blank)
}
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
}
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
}
static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, token_timestamps)
}
static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, token_timestamps, value)
}
static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
BOOL_PARAMS_GETTER(self, split_on_word)
}
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, split_on_word, value)
}
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
BOOL_PARAMS_GETTER(self, speed_up)
}
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, speed_up, value)
}
static VALUE ruby_whisper_params_get_diarize(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->diarize) {
return Qtrue;
} else {
return Qfalse;
}
}
static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->diarize = false;
} else {
rwp->diarize = true;
} \
return value;
}
static VALUE ruby_whisper_params_get_offset(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.offset_ms);
}
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.offset_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_duration(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.duration_ms);
}
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.duration_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.n_max_text_ctx);
}
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.n_max_text_ctx = NUM2INT(value);
return value;
}
void Init_whisper() {
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
}
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,15 @@
#ifndef __RUBY_WHISPER_H
#define __RUBY_WHISPER_H
#include "whisper.h"
typedef struct {
struct whisper_context *context;
} ruby_whisper;
typedef struct {
struct whisper_full_params params;
bool diarize;
} ruby_whisper_params;
#endif

View File

@ -0,0 +1,138 @@
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
EXTDIR = File.join(TOPDIR, 'ext')
#$LIBDIR = File.join(TOPDIR, 'lib')
#$:.unshift(LIBDIR)
$:.unshift(EXTDIR)
require 'whisper'
require 'test/unit'
class TestWhisper < Test::Unit::TestCase
def setup
@params = Whisper::Params.new
end
def test_language
@params.language = "en"
assert_equal @params.language, "en"
@params.language = "auto"
assert_equal @params.language, "auto"
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@params.token_timestamps = true
assert @params.token_timestamps
@params.token_timestamps = false
assert !@params.token_timestamps
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@params.split_on_word = false
assert !@params.split_on_word
end
def test_speed_up
@params.speed_up = true
assert @params.speed_up
@params.speed_up = false
assert !@params.speed_up
end
def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params) {|text|
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
}
end
end

View File

@ -14,6 +14,37 @@ if (WHISPER_SUPPORT_SDL2)
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}") message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
endif() endif()
# common
set(TARGET common)
add_library(${TARGET} STATIC
common.h
common.cpp
)
include(DefaultTargetOptions)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
if (WHISPER_SUPPORT_SDL2)
# common-sdl
set(TARGET common-sdl)
add_library(${TARGET} STATIC
common-sdl.h
common-sdl.cpp
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES})
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# examples # examples
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR})

View File

@ -23,4 +23,9 @@ string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR}) target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
#================================================================== #==================================================================
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} ${CMAKE_JS_LIB} common whisper ${CMAKE_THREAD_LIBS_INIT})
if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
# Generate node.lib
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
endif()

View File

@ -14,14 +14,14 @@ npm install
Make sure it is in the project root directory and compiled with make-js. Make sure it is in the project root directory and compiled with make-js.
```shell ```shell
npx cmake-js compile -T whisper-addon npx cmake-js compile -T whisper-addon -B Release
``` ```
For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes. For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes.
> Such as appointing special cmake path: > Such as appointing special cmake path:
> ```shell > ```shell
> npx cmake-js compile -c 'xxx/cmake' -T whisper-addon > npx cmake-js compile -c 'xxx/cmake' -T whisper-addon -B Release
> ``` > ```
## Run ## Run

View File

@ -0,0 +1,15 @@
const path = require('path');
const { whisper } = require(path.join(__dirname, '../../../build/Release/whisper-addon'));
const whisperParamsMock = {
language: 'en',
model: path.join(__dirname, '../../../models/ggml-base.en.bin'),
fname_inp: path.join(__dirname, '../../../samples/jfk.wav'),
};
describe("Run whisper.node", () => {
test("it should receive a non-empty value", () => {
expect(whisper(whisperParamsMock).length).toBeGreaterThan(0);
});
});

View File

@ -1,14 +1,13 @@
#include "napi.h"
#include "common.h"
#include "whisper.h"
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <cstdint>
#include "napi.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include "whisper.h"
struct whisper_params { struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -43,7 +42,7 @@ struct whisper_params {
std::string model = "../../ggml-large.bin"; std::string model = "../../ggml-large.bin";
std::vector<std::string> fname_inp = {}; std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_outp = {}; std::vector<std::string> fname_out = {};
}; };
struct whisper_print_user_data { struct whisper_print_user_data {
@ -142,7 +141,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
} }
int run(whisper_params &params, std::vector<std::vector<std::string>> &result) { int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.fname_inp.empty()) { if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n"); fprintf(stderr, "error: no input files specified\n");
return 2; return 2;
@ -180,91 +178,14 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
const auto fname_outp = f < (int)params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f]; const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
{ fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
drwav wav; continue;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "error: WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "error: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "error: WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "error: WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return 9;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
} }
// print system information // print system information
@ -398,9 +319,9 @@ Napi::Object whisper(const Napi::CallbackInfo& info) {
} }
Napi::Object res = Napi::Array::New(env, result.size()); Napi::Object res = Napi::Array::New(env, result.size());
for (u_int32_t i = 0; i < result.size(); ++i) { for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(env, 3); Napi::Object tmp = Napi::Array::New(env, 3);
for (u_int32_t j = 0; j < 3; ++j) { for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(env, result[i][j]); tmp[j] = Napi::String::New(env, result[i][j]);
} }
res[i] = tmp; res[i] = tmp;

View File

@ -5,8 +5,12 @@
"main": "index.js", "main": "index.js",
"author": "Qanhe Chen", "author": "Qanhe Chen",
"license": "MIT", "license": "MIT",
"scripts": {
"test": "jest"
},
"devDependencies": { "devDependencies": {
"cmake-js": "^7.1.1", "cmake-js": "^7.1.1",
"jest": "^29.4.0",
"node-addon-api": "^5.0.0" "node-addon-api": "^5.0.0"
} }
} }

View File

@ -11,6 +11,7 @@ add_executable(${TARGET}
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
common
whisper whisper
) )

View File

@ -1,4 +1,5 @@
#include "ggml.h" #include "ggml.h"
#include "common.h"
#include "whisper.h" #include "whisper.h"
#include <emscripten.h> #include <emscripten.h>
@ -27,24 +28,6 @@ std::string g_transcribed = "";
std::vector<float> g_pcmf32; std::vector<float> g_pcmf32;
static std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
// compute similarity between two strings using Levenshtein distance // compute similarity between two strings using Levenshtein distance
static float similarity(const std::string & s0, const std::string & s1) { static float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1; const size_t len0 = s0.size() + 1;
@ -75,44 +58,6 @@ void command_set_status(const std::string & status) {
g_status = status; g_status = status;
} }
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (size_t i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -155,7 +100,7 @@ void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
const int64_t n_samples = (ms * sample_rate) / 1000; const int64_t n_samples = (ms * sample_rate) / 1000;
int64_t n_take = 0; int64_t n_take = 0;
if (g_pcmf32.size() < n_samples) { if (n_samples > (int) g_pcmf32.size()) {
n_take = g_pcmf32.size(); n_take = g_pcmf32.size();
} else { } else {
n_take = n_samples; n_take = n_samples;
@ -187,7 +132,6 @@ void command_main(size_t index) {
printf("command: using %d threads\n", wparams.n_threads); printf("command: using %d threads\n", wparams.n_threads);
bool is_running = true;
bool have_prompt = false; bool have_prompt = false;
bool ask_prompt = true; bool ask_prompt = true;
bool print_energy = false; bool print_energy = false;
@ -233,7 +177,7 @@ void command_main(size_t index) {
{ {
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) { if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
command_set_status("Speech detected! Processing ..."); command_set_status("Speech detected! Processing ...");

View File

@ -5,6 +5,5 @@ if (WHISPER_SUPPORT_SDL2)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -6,11 +6,10 @@
// ref: https://github.com/ggerganov/whisper.cpp/issues/171 // ref: https://github.com/ggerganov/whisper.cpp/issues/171
// //
#include "common.h"
#include "common-sdl.h"
#include "whisper.h" #include "whisper.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <sstream> #include <sstream>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
@ -110,309 +109,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -502,7 +198,7 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
std::string line; std::string line;
while (std::getline(ifs, line)) { while (std::getline(ifs, line)) {
line = trim(line); line = ::trim(line);
if (line.empty()) { if (line.empty()) {
continue; continue;
} }
@ -526,23 +222,6 @@ std::vector<std::string> get_words(const std::string &txt) {
return words; return words;
} }
// returns true if no exit event was received
bool process_sdl_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
return false;
} break;
default:
break;
}
}
return true;
}
// command-list mode // command-list mode
// guide the transcription to match the most likely command from a provided list // guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) { int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
@ -634,14 +313,14 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = process_sdl_events(); is_running = sdl_poll_events();
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -775,7 +454,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = process_sdl_events(); is_running = sdl_poll_events();
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
@ -791,7 +470,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
{ {
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0; int64_t t_ms = 0;
@ -854,7 +533,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = process_sdl_events(); is_running = sdl_poll_events();
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
@ -870,7 +549,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
{ {
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0; int64_t t_ms = 0;

226
examples/common-sdl.cpp Normal file
View File

@ -0,0 +1,226 @@
#include "common-sdl.h"
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
bool sdl_poll_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
return false;
} break;
default:
break;
}
}
return true;
}

50
examples/common-sdl.h Normal file
View File

@ -0,0 +1,50 @@
#pragma once
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cstdint>
#include <vector>
#include <mutex>
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
// Return false if need to quit
bool sdl_poll_events();

162
examples/common.cpp Normal file
View File

@ -0,0 +1,162 @@
#include "common.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath>
#include <regex>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return false;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
return false;
}
if (stereo && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
return false;
}
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
return false;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
return false;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (stereo) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
return true;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}

40
examples/common.h Normal file
View File

@ -0,0 +1,40 @@
#pragma once
// needs to match WHISPER_SAMPLE_RATE
#define COMMON_SAMPLE_RATE 16000
#include <vector>
#include <string>
std::string trim(const std::string & s);
std::string replace(
const std::string & s,
const std::string & from,
const std::string & to);
// Read WAV audio file and store the PCM data into pcmf32
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
bool read_wav(
const std::string & fname,
std::vector<float> & pcmf32,
std::vector<std::vector<float>> & pcmf32s,
bool stereo);
// Apply a high-pass frequency filter to PCM audio
// Suppresses frequencies below cutoff Hz
void high_pass_filter(
std::vector<float> & data,
float cutoff,
float sample_rate);
// Basic voice activity detection (VAD) using audio energy adaptive threshold
bool vad_simple(
std::vector<float> & pcmf32,
int sample_rate,
int last_ms,
float vad_thold,
float freq_thold,
bool verbose);

View File

@ -3,4 +3,4 @@ add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,9 +1,6 @@
#include "whisper.h" #include "common.h"
// third-party utilities #include "whisper.h"
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
@ -69,6 +66,7 @@ struct whisper_params {
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool diarize = false; bool diarize = false;
bool split_on_word = false;
bool no_fallback = false; bool no_fallback = false;
bool output_txt = false; bool output_txt = false;
bool output_vtt = false; bool output_vtt = false;
@ -85,7 +83,7 @@ struct whisper_params {
std::string model = "models/ggml-base.en.bin"; std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {}; std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_outp = {}; std::vector<std::string> fname_out = {};
}; };
void whisper_print_usage(int argc, char ** argv, const whisper_params & params); void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -94,6 +92,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
std::string arg = argv[i]; std::string arg = argv[i];
if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-') { if (arg[0] != '-') {
params.fname_inp.push_back(arg); params.fname_inp.push_back(arg);
continue; continue;
@ -118,13 +121,14 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_outp.emplace_back(argv[++i]); } else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
@ -156,6 +160,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
@ -517,91 +522,14 @@ int main(int argc, char ** argv) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f]; const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
{ fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
drwav wav; continue;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 9;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
} }
// print system information // print system information
@ -651,6 +579,7 @@ int main(int argc, char ** argv) {
wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold; wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
@ -689,6 +618,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10; return 10;
} }
whisper_full_cluster_segments(ctx);
} }
// output stuff // output stuff
@ -697,34 +628,33 @@ int main(int argc, char ** argv) {
// output to text file // output to text file
if (params.output_txt) { if (params.output_txt) {
const auto fname_txt = fname_outp + ".txt"; const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str()); output_txt(ctx, fname_txt.c_str());
} }
// output to VTT file // output to VTT file
if (params.output_vtt) { if (params.output_vtt) {
const auto fname_vtt = fname_outp + ".vtt"; const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str()); output_vtt(ctx, fname_vtt.c_str());
} }
// output to SRT file // output to SRT file
if (params.output_srt) { if (params.output_srt) {
const auto fname_srt = fname_outp + ".srt"; const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params); output_srt(ctx, fname_srt.c_str(), params);
} }
// output to WTS file // output to WTS file
if (params.output_wts) { if (params.output_wts) {
const auto fname_wts = fname_outp + ".wts"; const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
} }
// output to CSV file // output to CSV file
if (params.output_csv) { if (params.output_csv) {
const auto fname_csv = fname_outp + ".csv"; const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str()); output_csv(ctx, fname_csv.c_str());
} }
} }
} }

View File

@ -5,6 +5,5 @@ if (WHISPER_SUPPORT_SDL2)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -3,19 +3,16 @@
// A very quick-n-dirty implementation serving mainly as a proof of concept. // A very quick-n-dirty implementation serving mainly as a proof of concept.
// //
#include "common.h"
#include "common-sdl.h"
#include "whisper.h" #include "whisper.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <fstream> #include <fstream>
#include <mutex>
// 500 -> 00:05.000 // 500 -> 00:05.000
// 6000 -> 01:00.000 // 6000 -> 01:00.000
@ -116,306 +113,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
whisper_params params; whisper_params params;
@ -426,10 +123,10 @@ int main(int argc, char ** argv) {
params.keep_ms = std::min(params.keep_ms, params.step_ms); params.keep_ms = std::min(params.keep_ms, params.step_ms);
params.length_ms = std::max(params.length_ms, params.step_ms); params.length_ms = std::max(params.length_ms, params.step_ms);
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE; const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE;
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE; const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE;
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE; const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE; const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE;
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
@ -517,23 +214,7 @@ int main(int argc, char ** argv) {
// main audio loop // main audio loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
{ is_running = sdl_poll_events();
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
if (!is_running) { if (!is_running) {
break; break;
@ -556,7 +237,7 @@ int main(int argc, char ** argv) {
break; break;
} }
SDL_Delay(1); std::this_thread::sleep_for(std::chrono::milliseconds(1));
} }
const int n_samples_new = pcmf32_new.size(); const int n_samples_new = pcmf32_new.size();
@ -587,7 +268,7 @@ int main(int argc, char ** argv) {
audio.get(2000, pcmf32_new); audio.get(2000, pcmf32_new);
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) { if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
audio.get(params.length_ms, pcmf32); audio.get(params.length_ms, pcmf32);
} else { } else {
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));

View File

@ -7,7 +7,7 @@ if (WHISPER_SUPPORT_SDL2)
# TODO: this is temporary # TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy .. # need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp) add_executable(${TARGET} talk.cpp gpt-2.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions) include(DefaultTargetOptions)

View File

@ -1,16 +1,14 @@
// Talk with AI // Talk with AI
// //
#include "common.h"
#include "common-sdl.h"
#include "whisper.h" #include "whisper.h"
#include "gpt-2.h" #include "gpt-2.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <mutex>
#include <regex> #include <regex>
#include <string> #include <string>
#include <thread> #include <thread>
@ -105,320 +103,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
fprintf(stderr, "\n");
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -557,22 +241,10 @@ int main(int argc, char ** argv) {
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
{ is_running = sdl_poll_events();
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) { if (!is_running) {
break; break;
}
} }
// delay // delay
@ -583,7 +255,7 @@ int main(int argc, char ** argv) {
{ {
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) { if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur); audio.get(params.voice_ms, pcmf32_cur);

View File

@ -1,20 +1,10 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# shellcheck disable=2086
# Small shell script to more easily automatically download and transcribe live stream VODs.
# This uses YT-DLP, ffmpeg and the CPP version of Whisper: https://github.com/ggerganov/whisper.cpp
# Use `./examples/yt-wsp.sh help` to print help info.
#
# Sample usage:
#
# git clone https://github.com/ggerganov/whisper.cpp
# cd whisper.cpp
# make
# ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890
#
# MIT License # MIT License
# Copyright (c) 2022 Daniils Petrovs # Copyright (c) 2022 Daniils Petrovs
# Copyright (c) 2023 Jennifer Capasso
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
@ -34,114 +24,178 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
# Small shell script to more easily automatically download and transcribe live stream VODs.
# This uses YT-DLP, ffmpeg and the CPP version of Whisper: https://github.com/ggerganov/whisper.cpp
# Use `./examples/yt-wsp.sh help` to print help info.
#
# Sample usage:
#
# git clone https://github.com/ggerganov/whisper.cpp
# cd whisper.cpp
# make
# ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890
#
set -Eeuo pipefail set -Eeuo pipefail
# You can find how to download models in the OG repo: https://github.com/ggerganov/whisper.cpp/#usage # get script file location
MODEL_PATH="${MODEL_PATH:-models/ggml-base.en.bin}" # Set to a multilingual model if you want to translate from foreign lang to en SCRIPT_PATH="$(realpath -e ${BASH_SOURCE[0]})";
WHISPER_EXECUTABLE="${WHISPER_EXECUTABLE:-whisper}" # Where to find the whisper.cpp executable SCRIPT_DIR="${SCRIPT_PATH%/*}"
WHISPER_LANG="${WHISPER_LANG:-en}" # Set to desired lang to translate from
################################################################################
# Documentation on downloading models can be found in the whisper.cpp repo:
# https://github.com/ggerganov/whisper.cpp/#usage
#
# note: unless a multilingual model is specified, WHISPER_LANG will be ignored
# and the video will be transcribed as if the audio were in the English language
################################################################################
MODEL_PATH="${MODEL_PATH:-${SCRIPT_DIR}/../models/ggml-base.en.bin}"
################################################################################
# Where to find the whisper.cpp executable. default to the examples directory
# which holds this script in source control
################################################################################
WHISPER_EXECUTABLE="${WHISPER_EXECUTABLE:-${SCRIPT_DIR}/../main}";
# Set to desired language to be translated into english
WHISPER_LANG="${WHISPER_LANG:-en}";
# Default to 4 threads (this was most performant on my 2020 M1 MBP)
WHISPER_THREAD_COUNT="${WHISPER_THREAD_COUNT:-4}";
msg() { msg() {
echo >&2 -e "${1-}" echo >&2 -e "${1-}"
} }
cleanup() { cleanup() {
msg "Cleaning up..." local -r clean_me="${1}";
rm -rf "${temp_dir}" "vod-resampled.wav" "vod-resampled.wav.srt"
if [ -d "${clean_me}" ]; then
msg "Cleaning up...";
rm -rf "${clean_me}";
else
msg "'${clean_me}' does not appear to be a directory!";
exit 1;
fi;
} }
print_help() { print_help() {
echo "################################################################################"
echo "Usage: ./examples/yt-wsp.sh <video_url>" echo "Usage: ./examples/yt-wsp.sh <video_url>"
echo "See configurable env variables in the script" echo "# See configurable env variables in the script; there are many!"
echo "This will produce an MP4 muxed file called res.mp4 in the working directory" echo "# This script will produce an MP4 muxed file in the working directory; it will"
echo "Requirements: ffmpeg yt-dlp whisper" echo "# be named for the title and id of the video."
echo "Whisper needs to be built into the main binary with make, then you can rename it to something like 'whisper' and add it to your PATH for convenience." echo "# passing in https://youtu.be/VYJtb2YXae8 produces a file named";
echo "E.g. in the root of Whisper.cpp, run: 'make && cp ./main /usr/local/bin/whisper'" echo "# 'Why_we_all_need_subtitles_now-VYJtb2YXae8-res.mp4'"
echo "# Requirements: ffmpeg yt-dlp whisper.cpp"
echo "################################################################################"
} }
check_requirements() { check_requirements() {
if ! command -v ffmpeg &>/dev/null; then if ! command -v ffmpeg &>/dev/null; then
echo "ffmpeg is required (https://ffmpeg.org)." echo "ffmpeg is required: https://ffmpeg.org";
exit 1 exit 1
fi fi;
if ! command -v yt-dlp &>/dev/null; then if ! command -v yt-dlp &>/dev/null; then
echo "yt-dlp is required (https://github.com/yt-dlp/yt-dlp)." echo "yt-dlp is required: https://github.com/yt-dlp/yt-dlp";
exit 1 exit 1;
fi fi;
if ! command -v "${WHISPER_EXECUTABLE}" &>/dev/null; then
echo "The C++ implementation of Whisper is required: https://github.com/ggerganov/whisper.cpp"
echo "Sample usage:";
echo "";
echo " git clone https://github.com/ggerganov/whisper.cpp";
echo " cd whisper.cpp";
echo " make";
echo " ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890";
echo "";
exit 1;
fi;
if ! command -v "$WHISPER_EXECUTABLE" &>/dev/null; then
WHISPER_EXECUTABLE="./main"
if ! command -v "$WHISPER_EXECUTABLE" &>/dev/null; then
echo "Whisper is required (https://github.com/ggerganov/whisper.cpp):"
echo "Sample usage:"
echo ""
echo " git clone https://github.com/ggerganov/whisper.cpp"
echo " cd whisper.cpp"
echo " make"
echo " ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890"
echo ""
exit 1
fi
fi
} }
if [[ $# -lt 1 ]]; then if [[ "${#}" -lt 1 ]]; then
print_help print_help;
exit 1 exit 1;
fi fi
if [[ "$1" == "help" ]]; then if [[ "${1##-*}" == "help" ]]; then
print_help print_help;
exit 0 exit 0;
fi fi
temp_dir="tmp" check_requirements;
source_url="$1"
check_requirements ################################################################################
# create a temporary directory to work in
# set the temp_dir and temp_filename variables
################################################################################
temp_dir="$(mktemp -d ${SCRIPT_DIR}/tmp.XXXXXX)";
temp_filename="${temp_dir}/yt-dlp-filename";
msg "Downloading VOD..." ################################################################################
# for now we only take one argument
# TODO: a for loop
################################################################################
source_url="${1}"
title_name="";
# Optionally add --cookies-from-browser BROWSER[+KEYRING][:PROFILE][::CONTAINER] for members only VODs msg "Downloading VOD...";
################################################################################
# Download the video, put the dynamic output filename into a variable.
# Optionally add --cookies-from-browser BROWSER[+KEYRING][:PROFILE][::CONTAINER]
# for videos only available to logged-in users.
################################################################################
yt-dlp \ yt-dlp \
-f "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best" \ -f "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best" \
-o "${temp_dir}/%(title)s-%(id)s.vod.mp4" \
--print-to-file "%(filename)s" "${temp_filename}" \
--no-simulate \
--no-write-auto-subs \
--restrict-filenames \
--embed-thumbnail \ --embed-thumbnail \
--embed-chapters \ --embed-chapters \
--xattrs \ --xattrs \
"${source_url}" -o "${temp_dir}/vod.mp4" "${source_url}";
msg "Extracting audio and resampling..." title_name="$(xargs basename -s .vod.mp4 < ${temp_filename})";
ffmpeg -i "${temp_dir}/vod.mp4" \ msg "Extracting audio and resampling...";
ffmpeg -i "${temp_dir}/${title_name}.vod.mp4" \
-hide_banner \ -hide_banner \
-vn \
-loglevel error \ -loglevel error \
-ar 16000 \ -ar 16000 \
-ac 1 \ -ac 1 \
-c:a \ -c:a pcm_s16le \
pcm_s16le -y "vod-resampled.wav" -y \
"${temp_dir}/${title_name}.vod-resampled.wav";
msg "Transcribing to subtitle file..." msg "Transcribing to subtitle file...";
msg "Whisper specified at: ${WHISPER_EXECUTABLE}" msg "Whisper specified at: '${WHISPER_EXECUTABLE}'";
$WHISPER_EXECUTABLE \ "${WHISPER_EXECUTABLE}" \
-m "${MODEL_PATH}" \ -m "${MODEL_PATH}" \
-l "${WHISPER_LANG}" \ -l "${WHISPER_LANG}" \
-f "vod-resampled.wav" \ -f "${temp_dir}/${title_name}.vod-resampled.wav" \
-t 8 \ -t "${WHISPER_THREAD_COUNT}" \
-osrt \ -osrt \
--translate --translate;
msg "Embedding subtitle track..." msg "Embedding subtitle track...";
ffmpeg -i "${temp_dir}/vod.mp4" \ ffmpeg -i "${temp_dir}/${title_name}.vod.mp4" \
-hide_banner \ -hide_banner \
-loglevel error \ -loglevel error \
-i "vod-resampled.wav.srt" \ -i "${temp_dir}/${title_name}.vod-resampled.wav.srt" \
-c copy \ -c copy \
-c:s mov_text \ -c:s mov_text \
-y res.mp4 -y "${title_name}-res.mp4";
cleanup #cleanup "${temp_dir}";
msg "Done! Your finished file is ready: res.mp4" msg "Done! Your finished file is ready: ${title_name}-res.mp4";

189
ggml.c
View File

@ -8517,6 +8517,195 @@ enum ggml_opt_result ggml_opt(
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void ggml_svd_reduce_dims(
int ne0,
int ne1,
float * a,
int nd) {
int n = ne1;
int m = ne0;
float * A = a;
float * A0 = (float *) malloc(n * m * sizeof(float));
// average vector
//float * M = (float *) malloc(m * sizeof(float));
//{
// for (int j = 0; j < m; ++j) {
// M[j] = 0.0f;
// }
// for (int i = 0; i < n; ++i) {
// for (int j = 0; j < m; ++j) {
// M[j] += A[i * m + j];
// }
// }
// for (int j = 0; j < m; ++j) {
// M[j] /= (float) n;
// }
//}
//// subtract average vector
//for (int i = 0; i < n; ++i) {
// for (int j = 0; j < m; ++j) {
// A[i * m + j] -= M[j];
// }
//}
//free(M);
memcpy(A0, A, n * m * sizeof(float));
// print A
//printf("A:\n");
//for (int i = 0; i < n; ++i) {
// printf("col %d : ", i);
// for (int j = 0; j < m; ++j) {
// printf("%9.5f ", A[i * m + j]);
// }
// printf("\n");
//}
//printf("\n");
// SVD
// A = U * S * V^T
float * U = (float *) malloc(n * m * sizeof(float));
float * S = (float *) malloc(n * sizeof(float));
float * V = (float *) malloc(n * n * sizeof(float));
int lda = m;
int ldu = m;
int ldvt = n;
float work_size;
int lwork = -1;
int info = 0;
sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, &work_size, &lwork, &info);
lwork = (int) work_size;
//printf("work_size = %f, info = %d, lwork = %d\n", work_size, info, lwork);
float * work = (float *) malloc(lwork * sizeof(float));
sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, work, &lwork, &info);
free(work);
// print U
//printf("U:\n");
//for (int i = 0; i < n; ++i) {
// printf("col %d : ", i);
// for (int j = 0; j < m; ++j) {
// printf("%9.5f ", U[i * m + j]);
// }
// printf("\n");
//}
//printf("\n");
// normalize S
{
double sum = 0.0;
for (int i = 0; i < n; ++i) {
sum += S[i];
}
sum *= sqrt((double) m);
for (int i = 0; i < n; ++i) {
S[i] /= sum;
}
}
// print S
printf("S:\n");
for (int i = 0; i < n; ++i) {
printf("- %d = %9.5f\n", i, S[i]);
}
printf("\n");
// print V
//printf("V:\n");
//for (int i = 0; i < n; ++i) {
// printf("col %d : ", i);
// for (int j = 0; j < n; ++j) {
// printf("%9.5f ", V[i * n + j]);
// }
// printf("\n");
//}
//printf("\n");
// print A
//printf("A:\n");
//for (int i = 0; i < n; ++i) {
// printf("col %d : ", i);
// for (int j = 0; j < m; ++j) {
// printf("%9.5f ", A[i * m + j]);
// }
// printf("\n");
//}
//printf("\n");
// compute singular vectors in U
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
U[i * m + j] *= S[i];
}
}
// normalize U
for (int i = 0; i < n; ++i) {
double sum = 0.0;
for (int j = 0; j < m; ++j) {
sum += U[i * m + j] * U[i * m + j];
}
sum = sqrt(sum);
for (int j = 0; j < m; ++j) {
U[i * m + j] /= sum*sqrt((double) m);
}
}
// print U
//printf("U:\n");
//for (int i = 0; i < n; ++i) {
// printf("col %d : ", i);
// for (int j = 0; j < m; ++j) {
// printf("%9.5f ", U[i * m + j]);
// }
// printf("\n");
//}
//printf("\n");
// project A0 onto U
for (int i = 0; i < n; ++i) {
for (int j = 0; j < nd; ++j) {
A[i * nd + j] = 0.0f;
//if (j == 0) continue;
for (int k = 0; k < m; ++k) {
A[i * nd + j] += A0[i * m + k] * U[j * m + k];
}
}
}
// print A
//printf("A:\n");
//for (int i = 0; i < n; ++i) {
// printf("col %d : ", i);
// for (int j = 0; j < n; ++j) {
// printf("%9.5f ", A[i * n + j]);
// }
// printf("\n");
//}
//printf("\n");
free(U);
free(S);
free(V);
free(A0);
}
////////////////////////////////////////////////////////////////////////////////
int ggml_cpu_has_avx(void) { int ggml_cpu_has_avx(void) {
#if defined(__AVX__) #if defined(__AVX__)
return 1; return 1;

10
ggml.h
View File

@ -726,6 +726,16 @@ enum ggml_opt_result ggml_opt(
struct ggml_opt_params params, struct ggml_opt_params params,
struct ggml_tensor * f); struct ggml_tensor * f);
//
// Temp stuff
//
void ggml_svd_reduce_dims(
int ne0,
int ne1,
float * a,
int nd);
// //
// system info // system info
// //

View File

@ -268,6 +268,14 @@ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
{ MODEL_LARGE, 71ull*MB }, { MODEL_LARGE, 71ull*MB },
}; };
static const std::map<e_model, size_t> MEM_REQ_KV_ENC_SELF = {
{ MODEL_TINY, 23ull*MB },
{ MODEL_BASE, 26ull*MB },
{ MODEL_SMALL, 216ull*MB },
{ MODEL_MEDIUM, 243ull*MB },
{ MODEL_LARGE, 271ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = { static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
{ MODEL_TINY, 9ull*MB }, { MODEL_TINY, 9ull*MB },
{ MODEL_BASE, 18ull*MB }, { MODEL_BASE, 18ull*MB },
@ -571,6 +579,7 @@ struct whisper_context {
// cross-attention KV cache for the decoders // cross-attention KV cache for the decoders
// shared between all decoders // shared between all decoders
whisper_kv_cache kv_cross; whisper_kv_cache kv_cross;
whisper_kv_cache kv_enc_self;
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
@ -592,6 +601,8 @@ struct whisper_context {
mutable std::mt19937 rng; // used for sampling at t > 0.0 mutable std::mt19937 rng; // used for sampling at t > 0.0
int lang_id;
// [EXPERIMENTAL] token-level timestamps data // [EXPERIMENTAL] token-level timestamps data
int64_t t_beg; int64_t t_beg;
int64_t t_last; int64_t t_last;
@ -601,6 +612,8 @@ struct whisper_context {
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default int32_t exp_n_audio_ctx; // 0 - use default
std::vector<float> audio_embd;
void use_buf(struct ggml_context * ctx, int i) { void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH) #if defined(WHISPER_USE_SCRATCH)
size_t last_size = 0; size_t last_size = 0;
@ -803,7 +816,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH3.at (model.type) + MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) + scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder // this is the memory required by one decoder
const size_t mem_required_decoder = const size_t mem_required_decoder =
@ -834,6 +847,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return false; return false;
} }
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_ENC_SELF.at(model.type), wctx.kv_enc_self, wctx.wtype, model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false;
}
{ {
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
@ -1356,7 +1374,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
static bool whisper_encode( static bool whisper_encode(
whisper_context & wctx, whisper_context & wctx,
const int mel_offset, const int mel_offset,
const int n_threads) { const int n_threads,
bool repeat = false) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model; const auto & model = wctx.model;
@ -1388,13 +1407,31 @@ static bool whisper_encode(
const int i0 = std::min(mel_offset, mel_inp.n_len); const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
for (int j = 0; j < mel_inp.n_mel; ++j) { if (repeat == false) {
for (int i = i0; i < i1; ++i) { for (int j = 0; j < mel_inp.n_mel; ++j) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}
} else {
for (int j = 0; j < mel_inp.n_mel; ++j) {
int k = 0;
while (k < 2*n_ctx) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + k] = mel_inp.data[j*mel_inp.n_len + i];
k++;
if (k >= 2*n_ctx) {
break;
}
}
}
} }
} }
} }
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * cur; struct ggml_tensor * cur;
// convolution + gelu // convolution + gelu
@ -1422,6 +1459,18 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
} }
//{
// //printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
// wctx.use_buf(ctx0, -1);
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(0*n_ctx));
// //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
// //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//}
wctx.use_buf(ctx0, 3); wctx.use_buf(ctx0, 3);
// =================================================================== // ===================================================================
@ -1502,6 +1551,18 @@ static bool whisper_encode(
Vcur), Vcur),
Vcur); Vcur);
//{
// //printf("Kcur: %d %d %d %d, size element = %d\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], Kcur->ne[3], ggml_element_size(Kcur));
// wctx.use_buf(ctx0, -1);
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
// struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//}
// ------ // ------
wctx.use_buf(ctx0, 0); wctx.use_buf(ctx0, 0);
@ -1586,6 +1647,18 @@ static bool whisper_encode(
cur = ggml_cpy(ctx0, cur = ggml_cpy(ctx0,
KQV_merged, KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
{
//printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
wctx.use_buf(ctx0, -1);
struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
//ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
} }
// projection // projection
@ -1695,8 +1768,6 @@ static bool whisper_encode(
// run the computation // run the computation
{ {
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur); ggml_build_forward_expand(&gf, cur);
ggml_graph_compute (ctx0, &gf); ggml_graph_compute (ctx0, &gf);
@ -1718,6 +1789,24 @@ static bool whisper_encode(
// printf("\n"); // printf("\n");
//} //}
{
//const int i0 = std::min(mel_offset, mel_inp.n_len);
//const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
const int i0 = 0;
const int i1 = cur->ne[1];
//printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
wctx.audio_embd.clear();
wctx.audio_embd.resize(cur->ne[0], 0.0f);
for (int j = 0; j < cur->ne[0]; ++j) {
for (int i = i0; i < i1; ++i) {
wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
}
wctx.audio_embd[j] /= (i1 - i0);
}
}
// pre-compute cross-attention memory // pre-compute cross-attention memory
{ {
struct ggml_cgraph gf = {}; struct ggml_cgraph gf = {};
@ -2903,7 +2992,7 @@ const char * whisper_print_system_info(void) {
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = { struct whisper_full_params result = {
/*.strategy =*/ WHISPER_SAMPLING_GREEDY, /*.strategy =*/ strategy,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384, /*.n_max_text_ctx =*/ 16384,
@ -2922,6 +3011,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f, /*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f, /*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0, /*.max_len =*/ 0,
/*.split_on_word =*/ false,
/*.max_tokens =*/ 0, /*.max_tokens =*/ 0,
/*.speed_up =*/ false, /*.speed_up =*/ false,
@ -2933,6 +3023,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en", /*.language =*/ "en",
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,
/*.temperature =*/ 0.0f, /*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
@ -2988,9 +3079,35 @@ static void whisper_exp_compute_token_level_timestamps(
float thold_pt, float thold_pt,
float thold_ptsum); float thold_ptsum);
// trim from start (in place)
static inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
return !std::isspace(ch);
}));
}
// trim from end (in place)
static inline void rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
return !std::isspace(ch);
}).base(), s.end());
}
// trim from both ends (in place)
static inline void trim(std::string &s) {
rtrim(s);
ltrim(s);
}
static inline bool should_split_on_word(const char * txt, bool split_on_word) {
if (!split_on_word) return true;
return txt[0] == ' ';
}
// wrap the last segment to max_len characters // wrap the last segment to max_len characters
// returns the number of new segments // returns the number of new segments
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
auto segment = ctx.result_all.back(); auto segment = ctx.result_all.back();
int res = 1; int res = 1;
@ -3005,11 +3122,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
} }
const auto txt = whisper_token_to_str(&ctx, token.id); const auto txt = whisper_token_to_str(&ctx, token.id);
const int cur = strlen(txt); const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) { if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
// split here // split here
if (split_on_word) {
trim(text);
}
ctx.result_all.back().text = std::move(text); ctx.result_all.back().text = std::move(text);
ctx.result_all.back().t1 = token.t0; ctx.result_all.back().t1 = token.t0;
ctx.result_all.back().tokens.resize(i); ctx.result_all.back().tokens.resize(i);
@ -3037,11 +3157,21 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
} }
} }
if (split_on_word) {
trim(text);
}
ctx.result_all.back().text = std::move(text); ctx.result_all.back().text = std::move(text);
return res; return res;
} }
static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
};
// process the logits for the selected decoder // process the logits for the selected decoder
// - applies logit filters // - applies logit filters
// - computes logprobs and probs // - computes logprobs and probs
@ -3102,6 +3232,28 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY; logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY;
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) {
for (const std::string & token : non_speech_tokens) {
const std::string suppress_tokens[] = {token, " " + token};
for (const std::string & suppress_token : suppress_tokens) {
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
}
}
}
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" -")] = -INFINITY;
}
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" '")] = -INFINITY;
}
}
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{ {
@ -3449,7 +3601,7 @@ int whisper_full(
fprintf(stderr, "%s: failed to auto-detect language\n", __func__); fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
return -3; return -3;
} }
ctx->lang_id = lang_id;
params.language = whisper_lang_str(lang_id); params.language = whisper_lang_str(lang_id);
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
@ -3546,6 +3698,7 @@ int whisper_full(
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) }; std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) { if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language); const int lang_id = whisper_lang_id(params.language);
ctx->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id)); prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) { if (params.translate) {
prompt_init.push_back(whisper_token_translate()); prompt_init.push_back(whisper_token_translate());
@ -3782,7 +3935,7 @@ int whisper_full(
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
}); });
int cur_c = 0; unsigned int cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) { for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j]; auto & decoder = ctx->decoders[j];
@ -3793,7 +3946,7 @@ int whisper_full(
auto & cur = beam_candidates[cur_c++]; auto & cur = beam_candidates[cur_c++];
while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
++cur_c; ++cur_c;
} }
@ -4069,7 +4222,7 @@ int whisper_full(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) { if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len); n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
} }
} }
if (params.new_segment_callback) { if (params.new_segment_callback) {
@ -4113,7 +4266,7 @@ int whisper_full(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) { if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len); n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
} }
} }
if (params.new_segment_callback) { if (params.new_segment_callback) {
@ -4266,6 +4419,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
return ctx->result_all.size(); return ctx->result_all.size();
} }
int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->lang_id;
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].t0; return ctx->result_all[i_segment].t0;
} }
@ -4736,3 +4893,258 @@ static void whisper_exp_compute_token_level_timestamps(
// } // }
//} //}
} }
//
// diarization stuff
//
void whisper_full_cluster_segments(struct whisper_context * ctx) {
const int n_segments = ctx->result_all.size();
printf("%s: clustering %d segments\n", __func__, n_segments);
const auto mel_len_save = ctx->mel.n_len;
printf("%s: mel_len_save = %d\n", __func__, mel_len_save);
const int n_ctx = ctx->model.hparams.n_audio_ctx;
const int n_state = ctx->model.hparams.n_audio_state;
const int n_layer = ctx->model.hparams.n_audio_layer;
#if 0
// use the last layer of the encoder
{
std::vector<float> embd(n_segments*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
for (int j = 0; j < n_state; ++j) {
embd[i*n_state + j] = ctx->audio_embd[j];
}
}
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_state, n_segments, embd.data(), n_features);
#elif 0
// use cross kv cache of various layers
for (int il = 0; il < n_layer; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_cross.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_cross.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#elif 0
// use conv embedding
for (int il = 0; il < 1; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_enc_self.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(3, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#else
// use enc self kv cache of various layers
for (int il = 0; il < n_layer; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_enc_self.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#endif
std::vector<std::vector<double>> features(n_segments);
for (int i = 0; i < n_segments; ++i) {
features[i].resize(n_features);
for (int j = 0; j < n_features; ++j) {
features[i][j] = embd[i*n_features + j];
}
}
// fuzzy c-means clustering
const int n_clusters = 2;
std::vector<std::vector<double>> centroids(n_clusters, std::vector<double>(n_features, 0.0));
std::vector<std::vector<double>> membership(n_segments, std::vector<double>(n_clusters, 0.0));
// initialize the centroids
for (int i = 0; i < n_clusters; ++i) {
for (int j = 0; j < n_features; ++j) {
centroids[i][j] = features[i][j];
}
}
// initialize the membership
for (int i = 0; i < n_segments; ++i) {
//membership[i][i % n_clusters] = 1.0;
//for (int j = 0; j < n_clusters; ++j) {
// membership[i][j] = rand() / (float) RAND_MAX;
//}
for (int j = 0; j < n_clusters; ++j) {
membership[i][j] = 1.0 / n_clusters;
}
}
const int niter = 10000;
// iterate
for (int i = 0; i < niter; ++i) {
// print the membership
if (i == niter - 1) {
//{
for (int i = 0; i < n_segments; ++i) {
#if 1
printf("%s: membership %3d: ", __func__, i);
for (int j = 0; j < n_clusters; ++j) {
printf("%.1f ", membership[i][j]);
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
#else
printf("%s: features : ", __func__);
for (int j = 0; j < n_features; ++j) {
printf("%8.3f ", features[i][j]);
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
#endif
}
printf("----------------\n");
// print the centroids
for (int i = 0; i < n_clusters; ++i) {
printf("%s: centroid %d: ", __func__, i);
for (int j = 0; j < n_features; ++j) {
printf("%f ", centroids[i][j]);
}
printf("\n");
}
}
// update the membership
for (int j = 0; j < n_segments; ++j) {
for (int k = 0; k < n_clusters; ++k) {
double sum = 0.0;
for (int l = 0; l < n_clusters; ++l) {
//sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
double d0 = 0.0;
double d1 = 0.0;
#if 1
// use the euclidean distance
{
for (int m = 0; m < n_features; ++m) {
d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
}
d0 = std::sqrt(d0);
for (int m = 0; m < n_features; ++m) {
d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
}
d1 = std::sqrt(d1);
}
#else
// use the cosine distance
{
double dot = 0.0;
double norm0 = 0.0;
double norm1 = 0.0;
for (int m = 0; m < n_features; ++m) {
dot += features[j][m]*centroids[k][m];
norm0 += std::pow(features[j][m], 2.0);
norm1 += std::pow(centroids[k][m], 2.0);
}
d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
dot = 0.0;
norm0 = 0.0;
norm1 = 0.0;
for (int m = 0; m < n_features; ++m) {
dot += features[j][m]*centroids[l][m];
norm0 += std::pow(features[j][m], 2.0);
norm1 += std::pow(centroids[l][m], 2.0);
}
d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
}
#endif
if (d1 > 0.0) {
sum += std::pow(d0/d1, 2.0/(1.20 - 1.0));
} else {
sum += 1.0;
}
}
membership[j][k] = sum == 0.0 ? 1.0 : 1.0/sum;
}
}
// update the centroids
for (int j = 0; j < n_clusters; ++j) {
for (int k = 0; k < n_features; ++k) {
double sum = 0.0;
double sum2 = 0.0;
for (int l = 0; l < n_segments; ++l) {
sum += membership[l][j]*features[l][k];
sum2 += membership[l][j];
}
centroids[j][k] = sum2 == 0.0 ? 0.0 : sum/sum2;
}
}
}
}
// restore the mel length
ctx->mel.n_len = mel_len_save;
}

View File

@ -113,6 +113,16 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context* ctx,
const float* samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the provided whisper context. // This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80 // n_mel must be 80
@ -257,6 +267,7 @@ extern "C" {
float thold_pt; // timestamp token probability threshold (~0.01) float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01) float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters int max_len; // max segment length in characters
bool split_on_word; // split on word rather than on token (when used with max_len)
int max_tokens; // max tokens per segment (0 = no limit) int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
@ -274,6 +285,7 @@ extern "C" {
// common decoding parameters: // common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
@ -329,6 +341,9 @@ extern "C" {
// A segment can be a few words, a sentence, or even a paragraph. // A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
// Language id associated with the current context
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Get the start and end time of the specified segment. // Get the start and end time of the specified segment.
WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
@ -357,6 +372,10 @@ extern "C" {
WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
// Temporary experimental API
WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif