Compare commits

..

1 Commits

Author SHA1 Message Date
59c997ca2d wip ignore 2023-02-15 19:11:12 +02:00
45 changed files with 2416 additions and 2225 deletions

View File

@ -1,22 +0,0 @@
name: Bindings Tests (Ruby)
on:
push:
paths:
- bindings/ruby/**
- whisper.h
pull_request:
paths:
- bindings/ruby/**
- whisper.h
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: ruby/setup-ruby@v1
with:
ruby-version: '3.0'
- uses: actions/checkout@v1
- run: |
cd bindings/ruby/ext
ruby extconf.rb && make

View File

@ -1,4 +1,4 @@
name: Bindings Tests (Go) name: Bindings Tests
on: on:
push: push:
paths: paths:

1
.gitignore vendored
View File

@ -10,7 +10,6 @@ build-em/
build-debug/ build-debug/
build-release/ build-release/
build-static/ build-static/
build-no-accel/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.0) cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.2.1) project(whisper.cpp VERSION 1.2.0)
# Add path to modules # Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")

View File

@ -30,8 +30,8 @@ endif
# Compile flags # Compile flags
# #
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC CFLAGS = -I. -O3 -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
LDFLAGS = LDFLAGS =
# OS specific # OS specific
@ -141,8 +141,6 @@ ifdef WHISPER_GPROF
CXXFLAGS += -pg CXXFLAGS += -pg
endif endif
ifneq ($(filter aarch64%,$(UNAME_M)),) ifneq ($(filter aarch64%,$(UNAME_M)),)
CFLAGS += -mcpu=native
CXXFLAGS += -mcpu=native
endif endif
ifneq ($(filter armv6%,$(UNAME_M)),) ifneq ($(filter armv6%,$(UNAME_M)),)
# Raspberry Pi 1, 2, 3 # Raspberry Pi 1, 2, 3
@ -199,21 +197,18 @@ clean:
CC_SDL=`sdl2-config --cflags --libs` CC_SDL=`sdl2-config --cflags --libs`
SRC_COMMON = examples/common.cpp main: examples/main/main.cpp ggml.o whisper.o
SRC_COMMON_SDL = examples/common-sdl.cpp $(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
./main -h ./main -h
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o stream: examples/stream/stream.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o command: examples/command/command.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o whisper.o bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)

View File

@ -4,7 +4,7 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) Stable: [v1.2.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -433,19 +433,6 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
--- ---
## Video comparison of different models
Use the [extra/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/extra/bench-wts.sh) script to generate a video in the following format:
```java
./extra/bench-wts.sh samples/jfk.wav
ffplay ./samples/jfk.wav.all.mp4
```
https://user-images.githubusercontent.com/1991296/223206245-2d36d903-cf8e-4f09-8c3b-eb9f9c39d6fc.mp4
---
## Benchmarks ## Benchmarks
In order to have an objective comparison of the performance of the inference across different system configurations, In order to have an objective comparison of the performance of the inference across different system configurations,
@ -477,14 +464,11 @@ in [models](models).
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310) - [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309) - [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312) - [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313) - [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422) - [X] .NET:
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net) - [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper) - [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9) - [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
## Examples ## Examples

View File

@ -94,7 +94,6 @@ func (model *model) NewContext() (Context, error) {
params.SetPrintRealtime(false) params.SetPrintRealtime(false)
params.SetPrintTimestamps(false) params.SetPrintTimestamps(false)
params.SetThreads(runtime.NumCPU()) params.SetThreads(runtime.NumCPU())
params.SetNoContext(true)
// Return new context // Return new context
return newContext(model, params) return newContext(model, params)

View File

@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data);
// Text segment callback // Text segment callback
// Called on every newly generated text segment // Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) { static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL) {
callNewSegment(user_data, n_new); callNewSegment(user_data, n_new);
} }
@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
// Encoder begin callback // Encoder begin callback
// If not NULL, called before the encoder starts // If not NULL, called before the encoder starts
// If it returns false, the computation is aborted // If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) { static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
if(user_data != NULL && ctx != NULL) { if(user_data != NULL && ctx != NULL) {
return callEncoderBegin(user_data); return callEncoderBegin(user_data);
} }

View File

@ -1,6 +1,6 @@
{ {
"name": "whisper.cpp", "name": "whisper.cpp",
"version": "1.2.1", "version": "1.2.0",
"description": "Whisper speech recognition", "description": "Whisper speech recognition",
"main": "whisper.js", "main": "whisper.js",
"scripts": { "scripts": {

File diff suppressed because one or more lines are too long

View File

@ -1,7 +0,0 @@
Makefile
ggml.c
ggml.h
whisper.bundle
whisper.cpp
whisper.h
dr_wav.h

View File

@ -1,21 +0,0 @@
require 'mkmf'
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
# need to use c++ compiler flags
$CXXFLAGS << ' -std=c++11'
# Set to true when building binary gems
if enable_config('static-stdlib', false)
$LDFLAGS << ' -static-libgcc -static-libstdc++'
end
if enable_config('march-tune-native', false)
$CFLAGS << ' -march=native -mtune=native'
$CXXFLAGS << ' -march=native -mtune=native'
end
create_makefile('whisper')

View File

@ -1,426 +0,0 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#ifdef __cplusplus
extern "C" {
#endif
#define BOOL_PARAMS_SETTER(self, prop, value) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (value == Qfalse || value == Qnil) { \
rwp->params.prop = false; \
} else { \
rwp->params.prop = true; \
} \
return value; \
#define BOOL_PARAMS_GETTER(self, prop) \
ruby_whisper_params *rwp; \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (rwp->params.prop) { \
return Qtrue; \
} else { \
return Qfalse; \
}
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
static void ruby_whisper_free(ruby_whisper *rw) {
if (rw->context) {
whisper_free(rw->context);
rw->context = NULL;
}
}
static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
}
void rb_whisper_mark(ruby_whisper *rw) {
// call rb_gc_mark on any ruby references in rw
}
void rb_whisper_free(ruby_whisper *rw) {
ruby_whisper_free(rw);
free(rw);
}
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
}
void rb_whisper_params_free(ruby_whisper_params *rwp) {
ruby_whisper_params_free(rwp);
free(rwp);
}
static VALUE ruby_whisper_allocate(VALUE klass) {
ruby_whisper *rw;
rw = ALLOC(ruby_whisper);
rw->context = NULL;
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
}
static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
VALUE whisper_model_file_path;
// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
if (rw->context == nullptr) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
return self;
}
/*
* transcribe a single file
* can emit to a block results
*
**/
static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params;
rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
Data_Get_Struct(self, ruby_whisper, rw);
Data_Get_Struct(params, ruby_whisper_params, rwp);
if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
}
std::string fname_inp = StringValueCStr(wave_file_path);
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input - this is directly from main.cpp example
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true) {
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return self;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return self;
}
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return self;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return self;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return self;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (rwp->diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}
const int n_segments = whisper_full_n_segments(rw->context);
VALUE output = rb_str_new2("");
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(rw->context, i);
output = rb_str_concat(output, rb_str_new2(text));
}
VALUE idCall = rb_intern("call");
if (blk != Qnil) {
rb_funcall(blk, idCall, 1, output);
}
return self;
}
/*
* params.language = "auto" | "en", etc...
*/
static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->params.language = "auto";
} else {
rwp->params.language = StringValueCStr(value);
}
return value;
}
static VALUE ruby_whisper_params_get_language(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->params.language) {
return rb_str_new2(rwp->params.language);
} else {
return rb_str_new2("auto");
}
}
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, translate, value)
}
static VALUE ruby_whisper_params_get_translate(VALUE self) {
BOOL_PARAMS_GETTER(self, translate)
}
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, no_context, value)
}
static VALUE ruby_whisper_params_get_no_context(VALUE self) {
BOOL_PARAMS_GETTER(self, no_context)
}
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, single_segment, value)
}
static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
BOOL_PARAMS_GETTER(self, single_segment)
}
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_special, value)
}
static VALUE ruby_whisper_params_get_print_special(VALUE self) {
BOOL_PARAMS_GETTER(self, print_special)
}
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_progress, value)
}
static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
BOOL_PARAMS_GETTER(self, print_progress)
}
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_realtime, value)
}
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
BOOL_PARAMS_GETTER(self, print_realtime)
}
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_timestamps, value)
}
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, print_timestamps)
}
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_blank, value)
}
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_blank)
}
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
}
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
}
static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, token_timestamps)
}
static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, token_timestamps, value)
}
static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
BOOL_PARAMS_GETTER(self, split_on_word)
}
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, split_on_word, value)
}
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
BOOL_PARAMS_GETTER(self, speed_up)
}
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, speed_up, value)
}
static VALUE ruby_whisper_params_get_diarize(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->diarize) {
return Qtrue;
} else {
return Qfalse;
}
}
static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->diarize = false;
} else {
rwp->diarize = true;
} \
return value;
}
static VALUE ruby_whisper_params_get_offset(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.offset_ms);
}
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.offset_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_duration(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.duration_ms);
}
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.duration_ms = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.n_max_text_ctx);
}
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.n_max_text_ctx = NUM2INT(value);
return value;
}
void Init_whisper() {
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
}
#ifdef __cplusplus
}
#endif

View File

@ -1,15 +0,0 @@
#ifndef __RUBY_WHISPER_H
#define __RUBY_WHISPER_H
#include "whisper.h"
typedef struct {
struct whisper_context *context;
} ruby_whisper;
typedef struct {
struct whisper_full_params params;
bool diarize;
} ruby_whisper_params;
#endif

View File

@ -1,138 +0,0 @@
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
EXTDIR = File.join(TOPDIR, 'ext')
#$LIBDIR = File.join(TOPDIR, 'lib')
#$:.unshift(LIBDIR)
$:.unshift(EXTDIR)
require 'whisper'
require 'test/unit'
class TestWhisper < Test::Unit::TestCase
def setup
@params = Whisper::Params.new
end
def test_language
@params.language = "en"
assert_equal @params.language, "en"
@params.language = "auto"
assert_equal @params.language, "auto"
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@params.token_timestamps = true
assert @params.token_timestamps
@params.token_timestamps = false
assert !@params.token_timestamps
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@params.split_on_word = false
assert !@params.split_on_word
end
def test_speed_up
@params.speed_up = true
assert @params.speed_up
@params.speed_up = false
assert !@params.speed_up
end
def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params) {|text|
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
}
end
end

View File

@ -14,37 +14,6 @@ if (WHISPER_SUPPORT_SDL2)
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}") message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
endif() endif()
# common
set(TARGET common)
add_library(${TARGET} STATIC
common.h
common.cpp
)
include(DefaultTargetOptions)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
if (WHISPER_SUPPORT_SDL2)
# common-sdl
set(TARGET common-sdl)
add_library(${TARGET} STATIC
common-sdl.h
common-sdl.cpp
)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PUBLIC ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES})
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# examples # examples
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR})

View File

@ -23,7 +23,7 @@ string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR}) target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
#================================================================== #==================================================================
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} common whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} ${CMAKE_JS_LIB} whisper ${CMAKE_THREAD_LIBS_INIT})
if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET) if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
# Generate node.lib # Generate node.lib

View File

@ -1,13 +1,15 @@
#include "napi.h" #include <cstdint>
#include "common.h"
#include "whisper.h"
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <cstdint>
#include "napi.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include "whisper.h"
struct whisper_params { struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -42,7 +44,7 @@ struct whisper_params {
std::string model = "../../ggml-large.bin"; std::string model = "../../ggml-large.bin";
std::vector<std::string> fname_inp = {}; std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {}; std::vector<std::string> fname_outp = {};
}; };
struct whisper_print_user_data { struct whisper_print_user_data {
@ -72,7 +74,7 @@ int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
} }
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -141,6 +143,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
} }
int run(whisper_params &params, std::vector<std::vector<std::string>> &result) { int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.fname_inp.empty()) { if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n"); fprintf(stderr, "error: no input files specified\n");
return 2; return 2;
@ -178,14 +181,91 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; const auto fname_outp = f < (int)params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { // WAV input
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); {
continue; drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "error: WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "error: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "error: WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "error: WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return 9;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
} }
// print system information // print system information
@ -260,7 +340,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
{ {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
return !is_aborted; return !is_aborted;
}; };

View File

@ -0,0 +1,10 @@
if (WHISPER_SUPPORT_SDL2)
# chess
set(TARGET chess)
add_executable(${TARGET} chess.cpp)
include(DefaultTargetOptions)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE common whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

634
examples/chess/chess.cpp Normal file
View File

@ -0,0 +1,634 @@
// Input chess moves via voice
//
#include "common.h"
#include "whisper.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cassert>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <fstream>
#include <mutex>
// 500 -> 00:05.000
// 6000 -> 01:00.000
std::string to_timestamp(int64_t t) {
int64_t sec = t/100;
int64_t msec = t - sec*100;
int64_t min = sec/60;
sec = sec - min*60;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
return std::string(buf);
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t step_ms = 3000;
int32_t length_ms = 10000;
int32_t keep_ms = 200;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool translate = false;
bool print_special = false;
bool no_context = true;
bool no_timestamps = false;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_inp;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); }
else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); }
else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", params.fname_inp.c_str());
fprintf(stderr, "\n");
}
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
params.keep_ms = std::min(params.keep_ms, params.step_ms);
params.length_ms = std::max(params.length_ms, params.step_ms);
const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE;
const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE;
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE;
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line
params.no_timestamps = !use_vad;
params.no_context |= use_vad;
params.max_tokens = 0;
// init audio
audio_async audio(params.length_ms);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
// whisper init
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
std::vector<whisper_token> prompt_tokens;
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec / keep = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
n_samples_step,
float(n_samples_step)/WHISPER_SAMPLE_RATE,
float(n_samples_len )/WHISPER_SAMPLE_RATE,
float(n_samples_keep)/WHISPER_SAMPLE_RATE,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
if (!use_vad) {
fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context);
} else {
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
}
fprintf(stderr, "\n");
}
int n_iter = 0;
bool is_running = true;
printf("[Start speaking]");
fflush(stdout);
auto t_last = std::chrono::high_resolution_clock::now();
const auto t_start = t_last;
// main audio loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
if (!is_running) {
break;
}
// process new audio
if (!use_vad) {
while (true) {
audio.get(params.step_ms, pcmf32_new);
if ((int) pcmf32_new.size() > 2*n_samples_step) {
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
audio.clear();
continue;
}
if ((int) pcmf32_new.size() >= n_samples_step) {
audio.clear();
break;
}
SDL_Delay(1);
}
const int n_samples_new = pcmf32_new.size();
// take up to params.length_ms audio from previous iteration
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
pcmf32.resize(n_samples_new + n_samples_take);
for (int i = 0; i < n_samples_take; i++) {
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
}
memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float));
pcmf32_old = pcmf32;
} else {
const auto t_now = std::chrono::high_resolution_clock::now();
const auto t_diff = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count();
if (t_diff < 2000) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
audio.get(2000, pcmf32_new);
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
audio.get(params.length_ms, pcmf32);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
t_last = t_now;
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = !use_vad;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 6;
}
// print result;
{
if (!use_vad) {
printf("\33[2K\r");
// print long empty line to clear the previous line
printf("%s", std::string(100, ' ').c_str());
printf("\33[2K\r");
} else {
const int64_t t1 = (t_last - t_start).count()/1000000;
const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
printf("\n");
printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1);
printf("\n");
}
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (params.no_timestamps) {
printf("%s", text);
fflush(stdout);
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
if (use_vad){
printf("\n");
printf("### Transcription %d END\n", n_iter);
}
}
++n_iter;
if (!use_vad && (n_iter % n_new_line) == 0) {
printf("\n");
// keep part of the audio for next iteration to try to mitigate word boundary issues
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
// Add tokens of the last full length segment as the prompt
if (!params.no_context) {
prompt_tokens.clear();
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const int token_count = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < token_count; ++j) {
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
}
}
}
}
}
}
audio.pause();
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

View File

@ -11,7 +11,6 @@ add_executable(${TARGET}
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE target_link_libraries(${TARGET} PRIVATE
common
whisper whisper
) )

View File

@ -1,5 +1,4 @@
#include "ggml.h" #include "ggml.h"
#include "common.h"
#include "whisper.h" #include "whisper.h"
#include <emscripten.h> #include <emscripten.h>
@ -28,6 +27,24 @@ std::string g_transcribed = "";
std::vector<float> g_pcmf32; std::vector<float> g_pcmf32;
static std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
// compute similarity between two strings using Levenshtein distance // compute similarity between two strings using Levenshtein distance
static float similarity(const std::string & s0, const std::string & s1) { static float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1; const size_t len0 = s0.size() + 1;
@ -58,6 +75,44 @@ void command_set_status(const std::string & status) {
g_status = status; g_status = status;
} }
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (size_t i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -100,7 +155,7 @@ void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
const int64_t n_samples = (ms * sample_rate) / 1000; const int64_t n_samples = (ms * sample_rate) / 1000;
int64_t n_take = 0; int64_t n_take = 0;
if (n_samples > (int) g_pcmf32.size()) { if (g_pcmf32.size() < n_samples) {
n_take = g_pcmf32.size(); n_take = g_pcmf32.size();
} else { } else {
n_take = n_samples; n_take = n_samples;
@ -132,6 +187,7 @@ void command_main(size_t index) {
printf("command: using %d threads\n", wparams.n_threads); printf("command: using %d threads\n", wparams.n_threads);
bool is_running = true;
bool have_prompt = false; bool have_prompt = false;
bool ask_prompt = true; bool ask_prompt = true;
bool print_energy = false; bool print_energy = false;
@ -177,7 +233,7 @@ void command_main(size_t index) {
{ {
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) { if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
command_set_status("Speech detected! Processing ..."); command_set_status("Speech detected! Processing ...");

View File

@ -5,5 +5,6 @@ if (WHISPER_SUPPORT_SDL2)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -6,10 +6,11 @@
// ref: https://github.com/ggerganov/whisper.cpp/issues/171 // ref: https://github.com/ggerganov/whisper.cpp/issues/171
// //
#include "common.h"
#include "common-sdl.h"
#include "whisper.h" #include "whisper.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <sstream> #include <sstream>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
@ -109,6 +110,309 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -198,7 +502,7 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
std::string line; std::string line;
while (std::getline(ifs, line)) { while (std::getline(ifs, line)) {
line = ::trim(line); line = trim(line);
if (line.empty()) { if (line.empty()) {
continue; continue;
} }
@ -222,6 +526,23 @@ std::vector<std::string> get_words(const std::string &txt) {
return words; return words;
} }
// returns true if no exit event was received
bool process_sdl_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
return false;
} break;
default:
break;
}
}
return true;
}
// command-list mode // command-list mode
// guide the transcription to match the most likely command from a provided list // guide the transcription to match the most likely command from a provided list
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) { int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
@ -313,14 +634,14 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = sdl_poll_events(); is_running = process_sdl_events();
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -454,7 +775,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = sdl_poll_events(); is_running = process_sdl_events();
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
@ -470,7 +791,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
{ {
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0; int64_t t_ms = 0;
@ -533,7 +854,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = sdl_poll_events(); is_running = process_sdl_events();
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
@ -549,7 +870,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
{ {
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) { if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
int64_t t_ms = 0; int64_t t_ms = 0;

View File

@ -1,226 +0,0 @@
#include "common-sdl.h"
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
bool sdl_poll_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
return false;
} break;
default:
break;
}
}
return true;
}

View File

@ -1,50 +0,0 @@
#pragma once
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cstdint>
#include <vector>
#include <mutex>
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
// Return false if need to quit
bool sdl_poll_events();

View File

@ -1,162 +0,0 @@
#include "common.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath>
#include <regex>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return false;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
return false;
}
if (stereo && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
return false;
}
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
return false;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
return false;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (stereo) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
return true;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}

View File

@ -1,40 +0,0 @@
#pragma once
// needs to match WHISPER_SAMPLE_RATE
#define COMMON_SAMPLE_RATE 16000
#include <vector>
#include <string>
std::string trim(const std::string & s);
std::string replace(
const std::string & s,
const std::string & from,
const std::string & to);
// Read WAV audio file and store the PCM data into pcmf32
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
bool read_wav(
const std::string & fname,
std::vector<float> & pcmf32,
std::vector<std::vector<float>> & pcmf32s,
bool stereo);
// Apply a high-pass frequency filter to PCM audio
// Suppresses frequencies below cutoff Hz
void high_pass_filter(
std::vector<float> & data,
float cutoff,
float sample_rate);
// Basic voice activity detection (VAD) using audio energy adaptive threshold
bool vad_simple(
std::vector<float> & pcmf32,
int sample_rate,
int last_ms,
float vad_thold,
float freq_thold,
bool verbose);

View File

@ -3,4 +3,4 @@ add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})

View File

@ -1,7 +1,10 @@
#include "common.h"
#include "whisper.h" #include "whisper.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <cstdio> #include <cstdio>
@ -80,11 +83,10 @@ struct whisper_params {
std::string language = "en"; std::string language = "en";
std::string prompt; std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin"; std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {}; std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {}; std::vector<std::string> fname_outp = {};
}; };
void whisper_print_usage(int argc, char ** argv, const whisper_params & params); void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -93,11 +95,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
std::string arg = argv[i]; std::string arg = argv[i];
if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-') { if (arg[0] != '-') {
params.fname_inp.push_back(arg); params.fname_inp.push_back(arg);
continue; continue;
@ -128,9 +125,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } else if (arg == "-of" || arg == "--output-file") { params.fname_outp.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
@ -176,7 +172,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
@ -196,7 +191,7 @@ struct whisper_print_user_data {
const std::vector<std::vector<float>> * pcmf32s; const std::vector<std::vector<float>> * pcmf32s;
}; };
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -355,7 +350,6 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,text\n";
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
@ -371,18 +365,13 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
std::ofstream fout(fname); std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
static const char * font = params.font_path.c_str(); // TODO: become parameter
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::ifstream fin(font);
if (!fin.is_open()) {
fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font);
return false;
}
fout << "#!/bin/bash" << "\n"; fout << "#!/bin/bash" << "\n";
fout << "\n"; fout << "\n";
@ -531,14 +520,91 @@ int main(int argc, char ** argv) {
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { // WAV input
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); {
continue; drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 6;
}
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 9;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
} }
// print system information // print system information
@ -616,7 +682,7 @@ int main(int argc, char ** argv) {
{ {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
return !is_aborted; return !is_aborted;
}; };
@ -635,33 +701,34 @@ int main(int argc, char ** argv) {
// output to text file // output to text file
if (params.output_txt) { if (params.output_txt) {
const auto fname_txt = fname_out + ".txt"; const auto fname_txt = fname_outp + ".txt";
output_txt(ctx, fname_txt.c_str()); output_txt(ctx, fname_txt.c_str());
} }
// output to VTT file // output to VTT file
if (params.output_vtt) { if (params.output_vtt) {
const auto fname_vtt = fname_out + ".vtt"; const auto fname_vtt = fname_outp + ".vtt";
output_vtt(ctx, fname_vtt.c_str()); output_vtt(ctx, fname_vtt.c_str());
} }
// output to SRT file // output to SRT file
if (params.output_srt) { if (params.output_srt) {
const auto fname_srt = fname_out + ".srt"; const auto fname_srt = fname_outp + ".srt";
output_srt(ctx, fname_srt.c_str(), params); output_srt(ctx, fname_srt.c_str(), params);
} }
// output to WTS file // output to WTS file
if (params.output_wts) { if (params.output_wts) {
const auto fname_wts = fname_out + ".wts"; const auto fname_wts = fname_outp + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
} }
// output to CSV file // output to CSV file
if (params.output_csv) { if (params.output_csv) {
const auto fname_csv = fname_out + ".csv"; const auto fname_csv = fname_outp + ".csv";
output_csv(ctx, fname_csv.c_str()); output_csv(ctx, fname_csv.c_str());
} }
} }
} }

View File

@ -5,5 +5,6 @@ if (WHISPER_SUPPORT_SDL2)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif () endif ()

View File

@ -3,16 +3,19 @@
// A very quick-n-dirty implementation serving mainly as a proof of concept. // A very quick-n-dirty implementation serving mainly as a proof of concept.
// //
#include "common.h"
#include "common-sdl.h"
#include "whisper.h" #include "whisper.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <fstream> #include <fstream>
#include <mutex>
// 500 -> 00:05.000 // 500 -> 00:05.000
// 6000 -> 01:00.000 // 6000 -> 01:00.000
@ -113,6 +116,306 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
std::atomic_bool m_running;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
m_running = false;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
whisper_params params; whisper_params params;
@ -123,10 +426,10 @@ int main(int argc, char ** argv) {
params.keep_ms = std::min(params.keep_ms, params.step_ms); params.keep_ms = std::min(params.keep_ms, params.step_ms);
params.length_ms = std::max(params.length_ms, params.step_ms); params.length_ms = std::max(params.length_ms, params.step_ms);
const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE; const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE; const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE; const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE; const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
@ -214,7 +517,23 @@ int main(int argc, char ** argv) {
// main audio loop // main audio loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = sdl_poll_events(); {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) {
break;
}
}
if (!is_running) { if (!is_running) {
break; break;
@ -237,7 +556,7 @@ int main(int argc, char ** argv) {
break; break;
} }
std::this_thread::sleep_for(std::chrono::milliseconds(1)); SDL_Delay(1);
} }
const int n_samples_new = pcmf32_new.size(); const int n_samples_new = pcmf32_new.size();
@ -268,7 +587,7 @@ int main(int argc, char ** argv) {
audio.get(2000, pcmf32_new); audio.get(2000, pcmf32_new);
if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) { if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
audio.get(params.length_ms, pcmf32); audio.get(params.length_ms, pcmf32);
} else { } else {
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
@ -288,6 +607,7 @@ int main(int argc, char ** argv) {
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate; wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = !use_vad; wparams.single_segment = !use_vad;
wparams.max_tokens = params.max_tokens; wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str(); wparams.language = params.language.c_str();

View File

@ -7,7 +7,7 @@ if (WHISPER_SUPPORT_SDL2)
# TODO: this is temporary # TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy .. # need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp) add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions) include(DefaultTargetOptions)

View File

@ -1,14 +1,16 @@
// Talk with AI // Talk with AI
// //
#include "common.h"
#include "common-sdl.h"
#include "whisper.h" #include "whisper.h"
#include "gpt-2.h" #include "gpt-2.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <mutex>
#include <regex> #include <regex>
#include <string> #include <string>
#include <thread> #include <thread>
@ -103,6 +105,320 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
fprintf(stderr, "\n");
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) { std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -241,11 +557,23 @@ int main(int argc, char ** argv) {
// main loop // main loop
while (is_running) { while (is_running) {
// handle Ctrl + C // handle Ctrl + C
is_running = sdl_poll_events(); {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
if (!is_running) { if (!is_running) {
break; break;
} }
}
// delay // delay
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
@ -255,7 +583,7 @@ int main(int argc, char ** argv) {
{ {
audio.get(2000, pcmf32_cur); audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) { if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur); audio.get(params.voice_ms, pcmf32_cur);

View File

@ -9,4 +9,4 @@ To use:
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device. 5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
[^1]: I recommend the tiny or base models for running on an Android device. [^1]: I recommend the tiny or base models for running on an Android device.
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1670775/221613663-a17bf770-27ef-45ab-9a46-a5f99ba65d2a.jpg"> <img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">

View File

@ -2,7 +2,6 @@ package com.whispercppdemo.ui.main
import androidx.compose.foundation.layout.* import androidx.compose.foundation.layout.*
import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.text.selection.SelectionContainer
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.* import androidx.compose.material3.*
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
@ -20,7 +19,6 @@ fun MainScreen(viewModel: MainScreenViewModel) {
canTranscribe = viewModel.canTranscribe, canTranscribe = viewModel.canTranscribe,
isRecording = viewModel.isRecording, isRecording = viewModel.isRecording,
messageLog = viewModel.dataLog, messageLog = viewModel.dataLog,
onBenchmarkTapped = viewModel::benchmark,
onTranscribeSampleTapped = viewModel::transcribeSample, onTranscribeSampleTapped = viewModel::transcribeSample,
onRecordTapped = viewModel::toggleRecord onRecordTapped = viewModel::toggleRecord
) )
@ -32,7 +30,6 @@ private fun MainScreen(
canTranscribe: Boolean, canTranscribe: Boolean,
isRecording: Boolean, isRecording: Boolean,
messageLog: String, messageLog: String,
onBenchmarkTapped: () -> Unit,
onTranscribeSampleTapped: () -> Unit, onTranscribeSampleTapped: () -> Unit,
onRecordTapped: () -> Unit onRecordTapped: () -> Unit
) { ) {
@ -48,11 +45,8 @@ private fun MainScreen(
.padding(innerPadding) .padding(innerPadding)
.padding(16.dp) .padding(16.dp)
) { ) {
Column(verticalArrangement = Arrangement.SpaceBetween) { Row(horizontalArrangement = Arrangement.SpaceBetween) {
Row(horizontalArrangement = Arrangement.SpaceBetween, modifier = Modifier.fillMaxWidth()) {
BenchmarkButton(enabled = canTranscribe, onClick = onBenchmarkTapped)
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped) TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
}
RecordButton( RecordButton(
enabled = canTranscribe, enabled = canTranscribe,
isRecording = isRecording, isRecording = isRecording,
@ -66,17 +60,8 @@ private fun MainScreen(
@Composable @Composable
private fun MessageLog(log: String) { private fun MessageLog(log: String) {
SelectionContainer() {
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log) Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
} }
}
@Composable
private fun BenchmarkButton(enabled: Boolean, onClick: () -> Unit) {
Button(onClick = onClick, enabled = enabled) {
Text("Benchmark")
}
}
@Composable @Composable
private fun TranscribeSampleButton(enabled: Boolean, onClick: () -> Unit) { private fun TranscribeSampleButton(enabled: Boolean, onClick: () -> Unit) {

View File

@ -41,15 +41,10 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
init { init {
viewModelScope.launch { viewModelScope.launch {
printSystemInfo()
loadData() loadData()
} }
} }
private suspend fun printSystemInfo() {
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
}
private suspend fun loadData() { private suspend fun loadData() {
printMessage("Loading data...\n") printMessage("Loading data...\n")
try { try {
@ -86,29 +81,10 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath) //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
} }
fun benchmark() = viewModelScope.launch {
runBenchmark(6)
}
fun transcribeSample() = viewModelScope.launch { fun transcribeSample() = viewModelScope.launch {
transcribeAudio(getFirstSample()) transcribeAudio(getFirstSample())
} }
private suspend fun runBenchmark(nthreads: Int) {
if (!canTranscribe) {
return
}
canTranscribe = false
printMessage("Running benchmark. This will take minutes...\n")
whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
printMessage("\n")
whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }
canTranscribe = true
}
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) { private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
samplesPath.listFiles()!!.first() samplesPath.listFiles()!!.first()
} }
@ -138,14 +114,11 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
canTranscribe = false canTranscribe = false
try { try {
printMessage("Reading wave samples... ") printMessage("Reading wave samples...\n")
val data = readAudioSamples(file) val data = readAudioSamples(file)
printMessage("${data.size / (16000 / 1000)} ms\n")
printMessage("Transcribing data...\n") printMessage("Transcribing data...\n")
val start = System.currentTimeMillis()
val text = whisperContext?.transcribeData(data) val text = whisperContext?.transcribeData(data)
val elapsed = System.currentTimeMillis() - start printMessage("Done: $text\n")
printMessage("Done ($elapsed ms): $text\n")
} catch (e: Exception) { } catch (e: Exception) {
Log.w(LOG_TAG, e) Log.w(LOG_TAG, e)
printMessage("${e.localizedMessage}\n") printMessage("${e.localizedMessage}\n")

View File

@ -27,14 +27,6 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
} }
suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchMemcpy(nthreads)
}
suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchGgmlMulMat(nthreads)
}
suspend fun release() = withContext(scope.coroutineContext) { suspend fun release() = withContext(scope.coroutineContext) {
if (ptr != 0L) { if (ptr != 0L) {
WhisperLib.freeContext(ptr) WhisperLib.freeContext(ptr)
@ -74,10 +66,6 @@ class WhisperContext private constructor(private var ptr: Long) {
} }
return WhisperContext(ptr) return WhisperContext(ptr)
} }
fun getSystemInfo(): String {
return WhisperLib.getSystemInfo()
}
} }
} }
@ -86,7 +74,6 @@ private class WhisperLib {
init { init {
Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}") Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}")
var loadVfpv4 = false var loadVfpv4 = false
var loadV8fp16 = false
if (isArmEabiV7a()) { if (isArmEabiV7a()) {
// armeabi-v7a needs runtime detection support // armeabi-v7a needs runtime detection support
val cpuInfo = cpuInfo() val cpuInfo = cpuInfo()
@ -97,24 +84,11 @@ private class WhisperLib {
loadVfpv4 = true loadVfpv4 = true
} }
} }
} else if (isArmEabiV8a()) {
// ARMv8.2a needs runtime detection support
val cpuInfo = cpuInfo()
cpuInfo?.let {
Log.d(LOG_TAG, "CPU info: $cpuInfo")
if (cpuInfo.contains("fphp")) {
Log.d(LOG_TAG, "CPU supports fp16 arithmetic")
loadV8fp16 = true
}
}
} }
if (loadVfpv4) { if (loadVfpv4) {
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so") Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so")
System.loadLibrary("whisper_vfpv4") System.loadLibrary("whisper_vfpv4")
} else if (loadV8fp16) {
Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so")
System.loadLibrary("whisper_v8fp16_va")
} else { } else {
Log.d(LOG_TAG, "Loading libwhisper.so") Log.d(LOG_TAG, "Loading libwhisper.so")
System.loadLibrary("whisper") System.loadLibrary("whisper")
@ -129,9 +103,6 @@ private class WhisperLib {
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray) external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
external fun benchMemcpy(nthread: Int): String
external fun benchGgmlMulMat(nthread: Int): String
} }
} }
@ -139,10 +110,6 @@ private fun isArmEabiV7a(): Boolean {
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a") return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a")
} }
private fun isArmEabiV8a(): Boolean {
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a")
}
private fun cpuInfo(): String? { private fun cpuInfo(): String? {
return try { return try {
File("/proc/cpuinfo").inputStream().bufferedReader().use { File("/proc/cpuinfo").inputStream().bufferedReader().use {

View File

@ -13,14 +13,3 @@ ifeq ($(TARGET_ARCH_ABI),armeabi-v7a)
LOCAL_CFLAGS += -mfpu=neon-vfpv4 LOCAL_CFLAGS += -mfpu=neon-vfpv4
include $(BUILD_SHARED_LIBRARY) include $(BUILD_SHARED_LIBRARY)
endif endif
ifeq ($(TARGET_ARCH_ABI),arm64-v8a)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper_v8fp16_va
include $(LOCAL_PATH)/Whisper.mk
# Allow building NEON FMA code.
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
LOCAL_CFLAGS += -march=armv8.2-a+fp16
include $(BUILD_SHARED_LIBRARY)
endif

View File

@ -6,7 +6,6 @@
#include <sys/sysinfo.h> #include <sys/sysinfo.h>
#include <string.h> #include <string.h>
#include "whisper.h" #include "whisper.h"
#include "ggml.h"
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define TAG "JNI" #define TAG "JNI"
@ -215,29 +214,3 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
jstring string = (*env)->NewStringUTF(env, text); jstring string = (*env)->NewStringUTF(env, text);
return string; return string;
} }
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getSystemInfo(
JNIEnv *env, jobject thiz
) {
UNUSED(thiz);
const char *sysinfo = whisper_print_system_info();
jstring string = (*env)->NewStringUTF(env, sysinfo);
return string;
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
}
JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
}

View File

@ -67,6 +67,23 @@ msg() {
echo >&2 -e "${1-}" echo >&2 -e "${1-}"
} }
################################################################################
# create a temporary directory to work in
# set the temp_dir and temp_filename variables
################################################################################
temp_dir="$(mktemp -d ${SCRIPT_DIR}/tmp.XXXXXX)";
temp_filename="${temp_dir}/yt-dlp-filename";
################################################################################
# for now we only take one argument
# TODO: a for loop
################################################################################
source_url="${1}"
title_name="";
cleanup() { cleanup() {
local -r clean_me="${1}"; local -r clean_me="${1}";
@ -128,20 +145,6 @@ fi
check_requirements; check_requirements;
################################################################################
# create a temporary directory to work in
# set the temp_dir and temp_filename variables
################################################################################
temp_dir="$(mktemp -d ${SCRIPT_DIR}/tmp.XXXXXX)";
temp_filename="${temp_dir}/yt-dlp-filename";
################################################################################
# for now we only take one argument
# TODO: a for loop
################################################################################
source_url="${1}"
title_name="";
msg "Downloading VOD..."; msg "Downloading VOD...";
################################################################################ ################################################################################
@ -196,6 +199,6 @@ ffmpeg -i "${temp_dir}/${title_name}.vod.mp4" \
-c:s mov_text \ -c:s mov_text \
-y "${title_name}-res.mp4"; -y "${title_name}-res.mp4";
#cleanup "${temp_dir}"; cleanup "${temp_dir}";
msg "Done! Your finished file is ready: ${title_name}-res.mp4"; msg "Done! Your finished file is ready: ${title_name}-res.mp4";

View File

@ -1,70 +0,0 @@
# Benchmark word-level timestamps for different models
#
# This script takes two arguments
# - an audio file
# - [optional] path to a font file
# I'm using "/usr/share/fonts/truetype/freefont/FreeMono.ttf" on Ubuntu
if [ -z "$1" ]; then
echo "Usage: $0 <audio file> [font file]"
exit 1
fi
#TODO: Make this a command line parameter
#models="base small large"
#models="tiny.en tiny base.en base small.en small medium.en medium large-v1 large"
models="tiny.en base.en small.en medium.en large"
DURATION=$(ffprobe -i $1 -show_entries format=duration -v quiet -of csv="p=0")
DURATION=$(printf "%.2f" $DURATION)
echo "Input file duration: ${DURATION}s"
for model in $models; do
echo "Running $model"
COMMAND="./main -m models/ggml-$model.bin -owts -f $1 -of $1.$model"
if [ ! -z "$2" ]; then
COMMAND="$COMMAND -fp $2"
fi
#TODO: Surface errors better
# TIMEFMT is for zsh, TIMEFORMAT is for bash
EXECTIME=$({ TIMEFMT="%E";TIMEFORMAT=%E; time $COMMAND >/dev/null 2>&1; } 2>&1)
# Slightly different formats between zsh and bash
if [ "${EXECTIME: -1}" == "s" ]; then
EXECTIME=${EXECTIME::-1}
fi
RATIO=$(echo "$DURATION / $EXECTIME" | bc -l)
RATIO=$(printf "%.2f" $RATIO)
echo "Execution time: ${EXECTIME}s (${RATIO}x realtime)"
# If the file already exists, delete it
if [ -f $1.mp4 ]; then
rm $1.mp4
fi
bash $1.$model.wts >/dev/null 2>&1
mv $1.mp4 $1.$model.mp4
ffmpeg -y -f lavfi -i color=c=black:s=1200x50:d=$DURATION -vf "drawtext=fontfile=$2:fontsize=36:x=10:y=(h-text_h)/2:text='ggml-$model - ${EXECTIME}s (${RATIO}x realtime)':fontcolor=lightgrey" $1.$model.info.mp4 >/dev/null 2>&1
done
COMMAND="ffmpeg -y"
for model in $models; do
COMMAND="$COMMAND -i $1.$model.info.mp4 -i $1.$model.mp4"
done
COMMAND="$COMMAND -filter_complex \""
COUNT=0
for model in $models; do
COMMAND="$COMMAND[${COUNT}:v][$(($COUNT+1)):v]"
COUNT=$((COUNT+2))
done
COMMAND="$COMMAND vstack=inputs=${COUNT}[v]\" -map \"[v]\" -map 1:a $1.all.mp4 >/dev/null 2>&1"
echo $COMMAND
# Run the command
eval $COMMAND

File diff suppressed because it is too large Load Diff

128
whisper.h
View File

@ -66,7 +66,6 @@ extern "C" {
// //
struct whisper_context; struct whisper_context;
struct whisper_state;
typedef int whisper_token; typedef int whisper_token;
@ -102,20 +101,11 @@ extern "C" {
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
// These are the same as the above, but the internal state of the context is not allocated automatically // Frees all memory allocated by the model.
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
// Frees all allocated memory
WHISPER_API void whisper_free(struct whisper_context * ctx); WHISPER_API void whisper_free(struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
// Convert RAW PCM audio to log mel spectrogram. // Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context. // The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_pcm_to_mel( WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx, struct whisper_context * ctx,
@ -123,15 +113,8 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
WHISPER_API int whisper_pcm_to_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the default state of the provided whisper context. // The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success // Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder( WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context* ctx, struct whisper_context* ctx,
@ -139,14 +122,8 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80 // n_mel must be 80
// Returns 0 on success // Returns 0 on success
@ -156,14 +133,7 @@ extern "C" {
int n_len, int n_len,
int n_mel); int n_mel);
WHISPER_API int whisper_set_mel_with_state( // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
struct whisper_context * ctx,
struct whisper_state * state,
const float * data,
int n_len,
int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram. // offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success // Returns 0 on success
@ -172,12 +142,6 @@ extern "C" {
int offset, int offset,
int n_threads); int n_threads);
WHISPER_API int whisper_encode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset,
int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token. // Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first. // Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder. // tokens + n_tokens is the provided context for the decoder.
@ -191,14 +155,6 @@ extern "C" {
int n_past, int n_past,
int n_threads); int n_threads);
WHISPER_API int whisper_decode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens. // Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens. // The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens // Returns the number of tokens on success, no more than n_max_tokens
@ -234,15 +190,7 @@ extern "C" {
int n_threads, int n_threads,
float * lang_probs); float * lang_probs);
WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
@ -253,7 +201,6 @@ extern "C" {
// Rows: n_tokens // Rows: n_tokens
// Cols: n_vocab // Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@ -271,7 +218,7 @@ extern "C" {
WHISPER_API whisper_token whisper_token_translate (void); WHISPER_API whisper_token whisper_token_translate (void);
WHISPER_API whisper_token whisper_token_transcribe(void); WHISPER_API whisper_token whisper_token_transcribe(void);
// Performance information from the default state. // Performance information
WHISPER_API void whisper_print_timings(struct whisper_context * ctx); WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
@ -289,23 +236,12 @@ extern "C" {
// Text segment callback // Text segment callback
// Called on every newly generated text segment // Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
// Encoder begin callback // Encoder begin callback
// If not NULL, called before the encoder starts // If not NULL, called before the encoder starts
// If it returns false, the computation is aborted // If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function // Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
@ -379,16 +315,11 @@ extern "C" {
// called each time before the encoder starts // called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback; whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data; void * encoder_begin_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
}; };
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Not thread safe for same context
// Uses the specified decoding strategy to obtain the text. // Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full( WHISPER_API int whisper_full(
struct whisper_context * ctx, struct whisper_context * ctx,
@ -396,16 +327,7 @@ extern "C" {
const float * samples, const float * samples,
int n_samples); int n_samples);
WHISPER_API int whisper_full_with_state( // Split the input audio in chunks and process each chunk separately using whisper_full()
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases. // It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel( WHISPER_API int whisper_full_parallel(
@ -415,56 +337,40 @@ extern "C" {
int n_samples, int n_samples,
int n_processors); int n_processors);
// Number of generated text segments // Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph. // A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
// Language id associated with the context's default state // Language id associated with the current context
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Language id associated with the provided state // Get the start and end time of the specified segment.
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
// Get the start and end time of the specified segment
WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
// Get the text of the specified segment // Get the text of the specified segment.
WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
// Get number of tokens in the specified segment // Get number of tokens in the specified segment.
WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
// Get the token text of the specified token in the specified segment // Get the token text of the specified token in the specified segment.
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment // Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc. // This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment // Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Temporary helpers needed for exposing ggml interface // Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
#ifdef __cplusplus #ifdef __cplusplus
} }