mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-07-04 16:30:58 +02:00
Compare commits
14 Commits
Author | SHA1 | Date | |
---|---|---|---|
ad1389003d | |||
f420de1322 | |||
d176160f6f | |||
ca21f7ab16 | |||
373043cabe | |||
fb4d0d470f | |||
0d229163bb | |||
f254e78737 | |||
a94897bcde | |||
2407ae8ef0 | |||
b623ca43b1 | |||
69e6e4644a | |||
09d7d2b68e | |||
0336161b7d |
@ -1,4 +1,4 @@
|
|||||||
name: Bindings Tests
|
name: Bindings Tests (Go)
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
paths:
|
paths:
|
22
.github/workflows/bindings-ruby.yml
vendored
Normal file
22
.github/workflows/bindings-ruby.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
name: Bindings Tests (Ruby)
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- bindings/ruby/**
|
||||||
|
- whisper.h
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- bindings/ruby/**
|
||||||
|
- whisper.h
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ubuntu-latest:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: ruby/setup-ruby@v1
|
||||||
|
with:
|
||||||
|
ruby-version: '3.0'
|
||||||
|
- uses: actions/checkout@v1
|
||||||
|
- run: |
|
||||||
|
cd bindings/ruby/ext
|
||||||
|
ruby extconf.rb && make
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,6 +10,7 @@ build-em/
|
|||||||
build-debug/
|
build-debug/
|
||||||
build-release/
|
build-release/
|
||||||
build-static/
|
build-static/
|
||||||
|
build-no-accel/
|
||||||
build-sanitize-addr/
|
build-sanitize-addr/
|
||||||
build-sanitize-thread/
|
build-sanitize-thread/
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
cmake_minimum_required (VERSION 3.0)
|
cmake_minimum_required (VERSION 3.0)
|
||||||
|
|
||||||
project(whisper.cpp VERSION 1.2.0)
|
project(whisper.cpp VERSION 1.2.1)
|
||||||
|
|
||||||
# Add path to modules
|
# Add path to modules
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||||
|
21
Makefile
21
Makefile
@ -141,6 +141,8 @@ ifdef WHISPER_GPROF
|
|||||||
CXXFLAGS += -pg
|
CXXFLAGS += -pg
|
||||||
endif
|
endif
|
||||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||||
|
CFLAGS += -mcpu=native
|
||||||
|
CXXFLAGS += -mcpu=native
|
||||||
endif
|
endif
|
||||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||||
# Raspberry Pi 1, 2, 3
|
# Raspberry Pi 1, 2, 3
|
||||||
@ -197,18 +199,21 @@ clean:
|
|||||||
|
|
||||||
CC_SDL=`sdl2-config --cflags --libs`
|
CC_SDL=`sdl2-config --cflags --libs`
|
||||||
|
|
||||||
main: examples/main/main.cpp ggml.o whisper.o
|
SRC_COMMON = examples/common.cpp
|
||||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
|
SRC_COMMON_SDL = examples/common-sdl.cpp
|
||||||
|
|
||||||
|
main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
|
||||||
|
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
|
||||||
./main -h
|
./main -h
|
||||||
|
|
||||||
stream: examples/stream/stream.cpp ggml.o whisper.o
|
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
||||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
|
||||||
|
|
||||||
command: examples/command/command.cpp ggml.o whisper.o
|
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
||||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
|
||||||
|
|
||||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
|
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
||||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
||||||
|
|
||||||
bench: examples/bench/bench.cpp ggml.o whisper.o
|
bench: examples/bench/bench.cpp ggml.o whisper.o
|
||||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||||
|
|
||||||
Stable: [v1.2.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||||
|
|
||||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||||
|
|
||||||
@ -464,11 +464,14 @@ in [models](models).
|
|||||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||||
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||||
|
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
|
||||||
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
||||||
- [X] .NET:
|
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
|
||||||
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
|
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
|
||||||
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
|
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
|
||||||
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
|
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
|
||||||
|
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
|
||||||
|
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
Submodule bindings/ios updated: d5c6d5c8a3...92d4c5c9a0
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "whisper.cpp",
|
"name": "whisper.cpp",
|
||||||
"version": "1.2.0",
|
"version": "1.2.1",
|
||||||
"description": "Whisper speech recognition",
|
"description": "Whisper speech recognition",
|
||||||
"main": "whisper.js",
|
"main": "whisper.js",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
File diff suppressed because one or more lines are too long
7
bindings/ruby/ext/.gitignore
vendored
Normal file
7
bindings/ruby/ext/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
Makefile
|
||||||
|
ggml.c
|
||||||
|
ggml.h
|
||||||
|
whisper.bundle
|
||||||
|
whisper.cpp
|
||||||
|
whisper.h
|
||||||
|
dr_wav.h
|
21
bindings/ruby/ext/extconf.rb
Normal file
21
bindings/ruby/ext/extconf.rb
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
require 'mkmf'
|
||||||
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
||||||
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
||||||
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
||||||
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
||||||
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
|
||||||
|
|
||||||
|
|
||||||
|
# need to use c++ compiler flags
|
||||||
|
$CXXFLAGS << ' -std=c++11'
|
||||||
|
# Set to true when building binary gems
|
||||||
|
if enable_config('static-stdlib', false)
|
||||||
|
$LDFLAGS << ' -static-libgcc -static-libstdc++'
|
||||||
|
end
|
||||||
|
|
||||||
|
if enable_config('march-tune-native', false)
|
||||||
|
$CFLAGS << ' -march=native -mtune=native'
|
||||||
|
$CXXFLAGS << ' -march=native -mtune=native'
|
||||||
|
end
|
||||||
|
|
||||||
|
create_makefile('whisper')
|
426
bindings/ruby/ext/ruby_whisper.cpp
Normal file
426
bindings/ruby/ext/ruby_whisper.cpp
Normal file
@ -0,0 +1,426 @@
|
|||||||
|
#include <ruby.h>
|
||||||
|
#include "ruby_whisper.h"
|
||||||
|
#define DR_WAV_IMPLEMENTATION
|
||||||
|
#include "dr_wav.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <fstream>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define BOOL_PARAMS_SETTER(self, prop, value) \
|
||||||
|
ruby_whisper_params *rwp; \
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||||
|
if (value == Qfalse || value == Qnil) { \
|
||||||
|
rwp->params.prop = false; \
|
||||||
|
} else { \
|
||||||
|
rwp->params.prop = true; \
|
||||||
|
} \
|
||||||
|
return value; \
|
||||||
|
|
||||||
|
#define BOOL_PARAMS_GETTER(self, prop) \
|
||||||
|
ruby_whisper_params *rwp; \
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||||
|
if (rwp->params.prop) { \
|
||||||
|
return Qtrue; \
|
||||||
|
} else { \
|
||||||
|
return Qfalse; \
|
||||||
|
}
|
||||||
|
|
||||||
|
VALUE mWhisper;
|
||||||
|
VALUE cContext;
|
||||||
|
VALUE cParams;
|
||||||
|
|
||||||
|
static void ruby_whisper_free(ruby_whisper *rw) {
|
||||||
|
if (rw->context) {
|
||||||
|
whisper_free(rw->context);
|
||||||
|
rw->context = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void rb_whisper_mark(ruby_whisper *rw) {
|
||||||
|
// call rb_gc_mark on any ruby references in rw
|
||||||
|
}
|
||||||
|
|
||||||
|
void rb_whisper_free(ruby_whisper *rw) {
|
||||||
|
ruby_whisper_free(rw);
|
||||||
|
free(rw);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void rb_whisper_params_free(ruby_whisper_params *rwp) {
|
||||||
|
ruby_whisper_params_free(rwp);
|
||||||
|
free(rwp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static VALUE ruby_whisper_allocate(VALUE klass) {
|
||||||
|
ruby_whisper *rw;
|
||||||
|
rw = ALLOC(ruby_whisper);
|
||||||
|
rw->context = NULL;
|
||||||
|
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
|
||||||
|
}
|
||||||
|
|
||||||
|
static VALUE ruby_whisper_params_allocate(VALUE klass) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
rwp = ALLOC(ruby_whisper_params);
|
||||||
|
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
|
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
|
||||||
|
ruby_whisper *rw;
|
||||||
|
VALUE whisper_model_file_path;
|
||||||
|
|
||||||
|
// TODO: we can support init from buffer here too maybe another ruby object to expose
|
||||||
|
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
|
||||||
|
Data_Get_Struct(self, ruby_whisper, rw);
|
||||||
|
|
||||||
|
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
|
||||||
|
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
|
||||||
|
}
|
||||||
|
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
|
||||||
|
if (rw->context == nullptr) {
|
||||||
|
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* transcribe a single file
|
||||||
|
* can emit to a block results
|
||||||
|
*
|
||||||
|
**/
|
||||||
|
static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
||||||
|
ruby_whisper *rw;
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
VALUE wave_file_path, blk, params;
|
||||||
|
|
||||||
|
rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk);
|
||||||
|
Data_Get_Struct(self, ruby_whisper, rw);
|
||||||
|
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||||
|
|
||||||
|
if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
|
||||||
|
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string fname_inp = StringValueCStr(wave_file_path);
|
||||||
|
|
||||||
|
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||||
|
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||||
|
|
||||||
|
// WAV input - this is directly from main.cpp example
|
||||||
|
{
|
||||||
|
drwav wav;
|
||||||
|
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||||
|
|
||||||
|
if (fname_inp == "-") {
|
||||||
|
{
|
||||||
|
uint8_t buf[1024];
|
||||||
|
while (true) {
|
||||||
|
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||||
|
if (n == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||||
|
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||||
|
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||||
|
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wav.channels != 1 && wav.channels != 2) {
|
||||||
|
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
|
||||||
|
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||||
|
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wav.bitsPerSample != 16) {
|
||||||
|
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||||
|
|
||||||
|
std::vector<int16_t> pcm16;
|
||||||
|
pcm16.resize(n*wav.channels);
|
||||||
|
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||||
|
drwav_uninit(&wav);
|
||||||
|
|
||||||
|
// convert to mono, float
|
||||||
|
pcmf32.resize(n);
|
||||||
|
if (wav.channels == 1) {
|
||||||
|
for (uint64_t i = 0; i < n; i++) {
|
||||||
|
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t i = 0; i < n; i++) {
|
||||||
|
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rwp->diarize) {
|
||||||
|
// convert to stereo, float
|
||||||
|
pcmf32s.resize(2);
|
||||||
|
|
||||||
|
pcmf32s[0].resize(n);
|
||||||
|
pcmf32s[1].resize(n);
|
||||||
|
for (uint64_t i = 0; i < n; i++) {
|
||||||
|
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||||
|
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||||
|
|
||||||
|
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
|
||||||
|
bool is_aborted = *(bool*)user_data;
|
||||||
|
return !is_aborted;
|
||||||
|
};
|
||||||
|
rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
||||||
|
fprintf(stderr, "failed to process audio\n");
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
const int n_segments = whisper_full_n_segments(rw->context);
|
||||||
|
VALUE output = rb_str_new2("");
|
||||||
|
for (int i = 0; i < n_segments; ++i) {
|
||||||
|
const char * text = whisper_full_get_segment_text(rw->context, i);
|
||||||
|
output = rb_str_concat(output, rb_str_new2(text));
|
||||||
|
}
|
||||||
|
VALUE idCall = rb_intern("call");
|
||||||
|
if (blk != Qnil) {
|
||||||
|
rb_funcall(blk, idCall, 1, output);
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* params.language = "auto" | "en", etc...
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
if (value == Qfalse || value == Qnil) {
|
||||||
|
rwp->params.language = "auto";
|
||||||
|
} else {
|
||||||
|
rwp->params.language = StringValueCStr(value);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_language(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
if (rwp->params.language) {
|
||||||
|
return rb_str_new2(rwp->params.language);
|
||||||
|
} else {
|
||||||
|
return rb_str_new2("auto");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, translate, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_translate(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, translate)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, no_context, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_no_context(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, no_context)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, single_segment, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, single_segment)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, print_special, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_print_special(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, print_special)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, print_progress, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, print_progress)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, print_realtime, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, print_realtime)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, print_timestamps, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, print_timestamps)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, suppress_blank, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, suppress_blank)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, token_timestamps)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, token_timestamps, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, split_on_word)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, split_on_word, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
|
||||||
|
BOOL_PARAMS_GETTER(self, speed_up)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
|
||||||
|
BOOL_PARAMS_SETTER(self, speed_up, value)
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_diarize(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
if (rwp->diarize) {
|
||||||
|
return Qtrue;
|
||||||
|
} else {
|
||||||
|
return Qfalse;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
if (value == Qfalse || value == Qnil) {
|
||||||
|
rwp->diarize = false;
|
||||||
|
} else {
|
||||||
|
rwp->diarize = true;
|
||||||
|
} \
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static VALUE ruby_whisper_params_get_offset(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return INT2NUM(rwp->params.offset_ms);
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.offset_ms = NUM2INT(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_get_duration(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return INT2NUM(rwp->params.duration_ms);
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.duration_ms = NUM2INT(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return INT2NUM(rwp->params.n_max_text_ctx);
|
||||||
|
}
|
||||||
|
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.n_max_text_ctx = NUM2INT(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Init_whisper() {
|
||||||
|
mWhisper = rb_define_module("Whisper");
|
||||||
|
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
|
||||||
|
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
|
||||||
|
|
||||||
|
rb_define_alloc_func(cContext, ruby_whisper_allocate);
|
||||||
|
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
|
||||||
|
|
||||||
|
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
|
||||||
|
|
||||||
|
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
|
||||||
|
|
||||||
|
rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
|
||||||
|
rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
|
||||||
|
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);
|
||||||
|
rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0);
|
||||||
|
rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1);
|
||||||
|
rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0);
|
||||||
|
rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1);
|
||||||
|
rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0);
|
||||||
|
rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0);
|
||||||
|
rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1);
|
||||||
|
rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0);
|
||||||
|
rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1);
|
||||||
|
rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0);
|
||||||
|
rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1);
|
||||||
|
rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0);
|
||||||
|
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
|
||||||
|
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
|
||||||
|
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
|
||||||
|
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
|
||||||
|
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
|
||||||
|
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
|
||||||
|
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
||||||
|
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
||||||
|
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
|
||||||
|
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
|
||||||
|
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
|
||||||
|
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
|
||||||
|
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
|
||||||
|
|
||||||
|
rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0);
|
||||||
|
rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1);
|
||||||
|
rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0);
|
||||||
|
rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1);
|
||||||
|
|
||||||
|
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
|
||||||
|
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
|
||||||
|
}
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
15
bindings/ruby/ext/ruby_whisper.h
Normal file
15
bindings/ruby/ext/ruby_whisper.h
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#ifndef __RUBY_WHISPER_H
|
||||||
|
#define __RUBY_WHISPER_H
|
||||||
|
|
||||||
|
#include "whisper.h"
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
struct whisper_context *context;
|
||||||
|
} ruby_whisper;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
struct whisper_full_params params;
|
||||||
|
bool diarize;
|
||||||
|
} ruby_whisper_params;
|
||||||
|
|
||||||
|
#endif
|
138
bindings/ruby/tests/test_whisper.rb
Normal file
138
bindings/ruby/tests/test_whisper.rb
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
|
||||||
|
EXTDIR = File.join(TOPDIR, 'ext')
|
||||||
|
#$LIBDIR = File.join(TOPDIR, 'lib')
|
||||||
|
#$:.unshift(LIBDIR)
|
||||||
|
$:.unshift(EXTDIR)
|
||||||
|
|
||||||
|
require 'whisper'
|
||||||
|
require 'test/unit'
|
||||||
|
|
||||||
|
class TestWhisper < Test::Unit::TestCase
|
||||||
|
def setup
|
||||||
|
@params = Whisper::Params.new
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_language
|
||||||
|
@params.language = "en"
|
||||||
|
assert_equal @params.language, "en"
|
||||||
|
@params.language = "auto"
|
||||||
|
assert_equal @params.language, "auto"
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_offset
|
||||||
|
@params.offset = 10_000
|
||||||
|
assert_equal @params.offset, 10_000
|
||||||
|
@params.offset = 0
|
||||||
|
assert_equal @params.offset, 0
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_duration
|
||||||
|
@params.duration = 60_000
|
||||||
|
assert_equal @params.duration, 60_000
|
||||||
|
@params.duration = 0
|
||||||
|
assert_equal @params.duration, 0
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_max_text_tokens
|
||||||
|
@params.max_text_tokens = 300
|
||||||
|
assert_equal @params.max_text_tokens, 300
|
||||||
|
@params.max_text_tokens = 0
|
||||||
|
assert_equal @params.max_text_tokens, 0
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_translate
|
||||||
|
@params.translate = true
|
||||||
|
assert @params.translate
|
||||||
|
@params.translate = false
|
||||||
|
assert !@params.translate
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_no_context
|
||||||
|
@params.no_context = true
|
||||||
|
assert @params.no_context
|
||||||
|
@params.no_context = false
|
||||||
|
assert !@params.no_context
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_single_segment
|
||||||
|
@params.single_segment = true
|
||||||
|
assert @params.single_segment
|
||||||
|
@params.single_segment = false
|
||||||
|
assert !@params.single_segment
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_print_special
|
||||||
|
@params.print_special = true
|
||||||
|
assert @params.print_special
|
||||||
|
@params.print_special = false
|
||||||
|
assert !@params.print_special
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_print_progress
|
||||||
|
@params.print_progress = true
|
||||||
|
assert @params.print_progress
|
||||||
|
@params.print_progress = false
|
||||||
|
assert !@params.print_progress
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_print_realtime
|
||||||
|
@params.print_realtime = true
|
||||||
|
assert @params.print_realtime
|
||||||
|
@params.print_realtime = false
|
||||||
|
assert !@params.print_realtime
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_print_timestamps
|
||||||
|
@params.print_timestamps = true
|
||||||
|
assert @params.print_timestamps
|
||||||
|
@params.print_timestamps = false
|
||||||
|
assert !@params.print_timestamps
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_suppress_blank
|
||||||
|
@params.suppress_blank = true
|
||||||
|
assert @params.suppress_blank
|
||||||
|
@params.suppress_blank = false
|
||||||
|
assert !@params.suppress_blank
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_suppress_non_speech_tokens
|
||||||
|
@params.suppress_non_speech_tokens = true
|
||||||
|
assert @params.suppress_non_speech_tokens
|
||||||
|
@params.suppress_non_speech_tokens = false
|
||||||
|
assert !@params.suppress_non_speech_tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_token_timestamps
|
||||||
|
@params.token_timestamps = true
|
||||||
|
assert @params.token_timestamps
|
||||||
|
@params.token_timestamps = false
|
||||||
|
assert !@params.token_timestamps
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_split_on_word
|
||||||
|
@params.split_on_word = true
|
||||||
|
assert @params.split_on_word
|
||||||
|
@params.split_on_word = false
|
||||||
|
assert !@params.split_on_word
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_speed_up
|
||||||
|
@params.speed_up = true
|
||||||
|
assert @params.speed_up
|
||||||
|
@params.speed_up = false
|
||||||
|
assert !@params.speed_up
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_whisper
|
||||||
|
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
|
||||||
|
params = Whisper::Params.new
|
||||||
|
params.print_timestamps = false
|
||||||
|
|
||||||
|
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
|
||||||
|
@whisper.transcribe(jfk, params) {|text|
|
||||||
|
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
@ -14,6 +14,37 @@ if (WHISPER_SUPPORT_SDL2)
|
|||||||
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# common
|
||||||
|
|
||||||
|
set(TARGET common)
|
||||||
|
|
||||||
|
add_library(${TARGET} STATIC
|
||||||
|
common.h
|
||||||
|
common.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
|
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
if (WHISPER_SUPPORT_SDL2)
|
||||||
|
# common-sdl
|
||||||
|
|
||||||
|
set(TARGET common-sdl)
|
||||||
|
|
||||||
|
add_library(${TARGET} STATIC
|
||||||
|
common-sdl.h
|
||||||
|
common-sdl.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
|
target_include_directories(${TARGET} PUBLIC ${SDL2_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES})
|
||||||
|
|
||||||
|
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
# examples
|
# examples
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
@ -23,7 +23,7 @@ string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
|
|||||||
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
|
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
|
||||||
#==================================================================
|
#==================================================================
|
||||||
|
|
||||||
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} whisper ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
|
if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
|
||||||
# Generate node.lib
|
# Generate node.lib
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
#include <cstdint>
|
#include "napi.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
#include "whisper.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
#include "napi.h"
|
|
||||||
|
|
||||||
#define DR_WAV_IMPLEMENTATION
|
|
||||||
#include "dr_wav.h"
|
|
||||||
|
|
||||||
#include "whisper.h"
|
|
||||||
|
|
||||||
struct whisper_params {
|
struct whisper_params {
|
||||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
@ -44,7 +42,7 @@ struct whisper_params {
|
|||||||
std::string model = "../../ggml-large.bin";
|
std::string model = "../../ggml-large.bin";
|
||||||
|
|
||||||
std::vector<std::string> fname_inp = {};
|
std::vector<std::string> fname_inp = {};
|
||||||
std::vector<std::string> fname_outp = {};
|
std::vector<std::string> fname_out = {};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct whisper_print_user_data {
|
struct whisper_print_user_data {
|
||||||
@ -143,7 +141,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
|||||||
}
|
}
|
||||||
|
|
||||||
int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
||||||
|
|
||||||
if (params.fname_inp.empty()) {
|
if (params.fname_inp.empty()) {
|
||||||
fprintf(stderr, "error: no input files specified\n");
|
fprintf(stderr, "error: no input files specified\n");
|
||||||
return 2;
|
return 2;
|
||||||
@ -181,91 +178,14 @@ int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
|||||||
|
|
||||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||||
const auto fname_inp = params.fname_inp[f];
|
const auto fname_inp = params.fname_inp[f];
|
||||||
const auto fname_outp = f < (int)params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
|
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||||
|
|
||||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||||
|
|
||||||
// WAV input
|
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
|
||||||
{
|
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||||
drwav wav;
|
continue;
|
||||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
|
||||||
|
|
||||||
if (fname_inp == "-") {
|
|
||||||
{
|
|
||||||
uint8_t buf[1024];
|
|
||||||
while (true)
|
|
||||||
{
|
|
||||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
|
||||||
if (n == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
|
||||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
|
||||||
return 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
|
||||||
}
|
|
||||||
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
|
||||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
|
||||||
return 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wav.channels != 1 && wav.channels != 2) {
|
|
||||||
fprintf(stderr, "error: WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
|
|
||||||
return 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
|
|
||||||
fprintf(stderr, "error: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
|
|
||||||
return 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
|
||||||
fprintf(stderr, "error: WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
|
||||||
return 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wav.bitsPerSample != 16) {
|
|
||||||
fprintf(stderr, "error: WAV file '%s' must be 16-bit\n", fname_inp.c_str());
|
|
||||||
return 9;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
|
||||||
|
|
||||||
std::vector<int16_t> pcm16;
|
|
||||||
pcm16.resize(n*wav.channels);
|
|
||||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
|
||||||
drwav_uninit(&wav);
|
|
||||||
|
|
||||||
// convert to mono, float
|
|
||||||
pcmf32.resize(n);
|
|
||||||
if (wav.channels == 1) {
|
|
||||||
for (uint64_t i = 0; i < n; i++) {
|
|
||||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (uint64_t i = 0; i < n; i++) {
|
|
||||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.diarize) {
|
|
||||||
// convert to stereo, float
|
|
||||||
pcmf32s.resize(2);
|
|
||||||
|
|
||||||
pcmf32s[0].resize(n);
|
|
||||||
pcmf32s[1].resize(n);
|
|
||||||
for (uint64_t i = 0; i < n; i++) {
|
|
||||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
|
||||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// print system information
|
// print system information
|
||||||
|
@ -11,6 +11,7 @@ add_executable(${TARGET}
|
|||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE
|
target_link_libraries(${TARGET} PRIVATE
|
||||||
|
common
|
||||||
whisper
|
whisper
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "common.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
|
||||||
#include <emscripten.h>
|
#include <emscripten.h>
|
||||||
@ -27,24 +28,6 @@ std::string g_transcribed = "";
|
|||||||
|
|
||||||
std::vector<float> g_pcmf32;
|
std::vector<float> g_pcmf32;
|
||||||
|
|
||||||
static std::string trim(const std::string & s) {
|
|
||||||
std::regex e("^\\s+|\\s+$");
|
|
||||||
return std::regex_replace(s, e, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
|
||||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
|
||||||
const float dt = 1.0f / sample_rate;
|
|
||||||
const float alpha = dt / (rc + dt);
|
|
||||||
|
|
||||||
float y = data[0];
|
|
||||||
|
|
||||||
for (size_t i = 1; i < data.size(); i++) {
|
|
||||||
y = alpha * (y + data[i] - data[i - 1]);
|
|
||||||
data[i] = y;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute similarity between two strings using Levenshtein distance
|
// compute similarity between two strings using Levenshtein distance
|
||||||
static float similarity(const std::string & s0, const std::string & s1) {
|
static float similarity(const std::string & s0, const std::string & s1) {
|
||||||
const size_t len0 = s0.size() + 1;
|
const size_t len0 = s0.size() + 1;
|
||||||
@ -75,44 +58,6 @@ void command_set_status(const std::string & status) {
|
|||||||
g_status = status;
|
g_status = status;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
|
||||||
const int n_samples = pcmf32.size();
|
|
||||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
|
||||||
|
|
||||||
if (n_samples_last >= n_samples) {
|
|
||||||
// not enough samples - assume no speech
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (freq_thold > 0.0f) {
|
|
||||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
|
||||||
}
|
|
||||||
|
|
||||||
float energy_all = 0.0f;
|
|
||||||
float energy_last = 0.0f;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < n_samples; i++) {
|
|
||||||
energy_all += fabsf(pcmf32[i]);
|
|
||||||
|
|
||||||
if (i >= n_samples - n_samples_last) {
|
|
||||||
energy_last += fabsf(pcmf32[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
energy_all /= n_samples;
|
|
||||||
energy_last /= n_samples_last;
|
|
||||||
|
|
||||||
if (verbose) {
|
|
||||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (energy_last > vad_thold*energy_all) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
@ -155,7 +100,7 @@ void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
|
|||||||
const int64_t n_samples = (ms * sample_rate) / 1000;
|
const int64_t n_samples = (ms * sample_rate) / 1000;
|
||||||
|
|
||||||
int64_t n_take = 0;
|
int64_t n_take = 0;
|
||||||
if (g_pcmf32.size() < n_samples) {
|
if (n_samples > (int) g_pcmf32.size()) {
|
||||||
n_take = g_pcmf32.size();
|
n_take = g_pcmf32.size();
|
||||||
} else {
|
} else {
|
||||||
n_take = n_samples;
|
n_take = n_samples;
|
||||||
@ -187,7 +132,6 @@ void command_main(size_t index) {
|
|||||||
|
|
||||||
printf("command: using %d threads\n", wparams.n_threads);
|
printf("command: using %d threads\n", wparams.n_threads);
|
||||||
|
|
||||||
bool is_running = true;
|
|
||||||
bool have_prompt = false;
|
bool have_prompt = false;
|
||||||
bool ask_prompt = true;
|
bool ask_prompt = true;
|
||||||
bool print_energy = false;
|
bool print_energy = false;
|
||||||
@ -233,7 +177,7 @@ void command_main(size_t index) {
|
|||||||
{
|
{
|
||||||
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
|
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);
|
||||||
|
|
||||||
if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
|
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
|
||||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||||
command_set_status("Speech detected! Processing ...");
|
command_set_status("Speech detected! Processing ...");
|
||||||
|
|
||||||
|
@ -5,6 +5,5 @@ if (WHISPER_SUPPORT_SDL2)
|
|||||||
|
|
||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -6,11 +6,10 @@
|
|||||||
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
|
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "common-sdl.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
|
||||||
#include <SDL.h>
|
|
||||||
#include <SDL_audio.h>
|
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -110,309 +109,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// SDL Audio capture
|
|
||||||
//
|
|
||||||
|
|
||||||
class audio_async {
|
|
||||||
public:
|
|
||||||
audio_async(int len_ms);
|
|
||||||
~audio_async();
|
|
||||||
|
|
||||||
bool init(int capture_id, int sample_rate);
|
|
||||||
|
|
||||||
// start capturing audio via the provided SDL callback
|
|
||||||
// keep last len_ms seconds of audio in a circular buffer
|
|
||||||
bool resume();
|
|
||||||
bool pause();
|
|
||||||
bool clear();
|
|
||||||
|
|
||||||
// callback to be called by SDL
|
|
||||||
void callback(uint8_t * stream, int len);
|
|
||||||
|
|
||||||
// get audio data from the circular buffer
|
|
||||||
void get(int ms, std::vector<float> & audio);
|
|
||||||
|
|
||||||
private:
|
|
||||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
|
||||||
|
|
||||||
int m_len_ms = 0;
|
|
||||||
int m_sample_rate = 0;
|
|
||||||
|
|
||||||
bool m_running = false;
|
|
||||||
std::mutex m_mutex;
|
|
||||||
|
|
||||||
std::vector<float> m_audio;
|
|
||||||
std::vector<float> m_audio_new;
|
|
||||||
size_t m_audio_pos = 0;
|
|
||||||
size_t m_audio_len = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
audio_async::audio_async(int len_ms) {
|
|
||||||
m_len_ms = len_ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
audio_async::~audio_async() {
|
|
||||||
if (m_dev_id_in) {
|
|
||||||
SDL_CloseAudioDevice(m_dev_id_in);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::init(int capture_id, int sample_rate) {
|
|
||||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
|
||||||
|
|
||||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
|
||||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
|
||||||
|
|
||||||
{
|
|
||||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
|
||||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
|
||||||
for (int i = 0; i < nDevices; i++) {
|
|
||||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_AudioSpec capture_spec_requested;
|
|
||||||
SDL_AudioSpec capture_spec_obtained;
|
|
||||||
|
|
||||||
SDL_zero(capture_spec_requested);
|
|
||||||
SDL_zero(capture_spec_obtained);
|
|
||||||
|
|
||||||
capture_spec_requested.freq = sample_rate;
|
|
||||||
capture_spec_requested.format = AUDIO_F32;
|
|
||||||
capture_spec_requested.channels = 1;
|
|
||||||
capture_spec_requested.samples = 1024;
|
|
||||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
|
||||||
audio_async * audio = (audio_async *) userdata;
|
|
||||||
audio->callback(stream, len);
|
|
||||||
};
|
|
||||||
capture_spec_requested.userdata = this;
|
|
||||||
|
|
||||||
if (capture_id >= 0) {
|
|
||||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
|
||||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
|
||||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
|
||||||
m_dev_id_in = 0;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
|
||||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
|
||||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
|
||||||
capture_spec_requested.format);
|
|
||||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
|
||||||
capture_spec_requested.channels);
|
|
||||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
|
||||||
}
|
|
||||||
|
|
||||||
m_sample_rate = capture_spec_obtained.freq;
|
|
||||||
|
|
||||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::resume() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (m_running) {
|
|
||||||
fprintf(stderr, "%s: already running!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
|
||||||
|
|
||||||
m_running = true;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::pause() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
|
||||||
|
|
||||||
m_running = false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::clear() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: not running!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
m_audio_pos = 0;
|
|
||||||
m_audio_len = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// callback to be called by SDL
|
|
||||||
void audio_async::callback(uint8_t * stream, int len) {
|
|
||||||
if (!m_running) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t n_samples = len / sizeof(float);
|
|
||||||
|
|
||||||
m_audio_new.resize(n_samples);
|
|
||||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
|
||||||
|
|
||||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
|
||||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
|
||||||
|
|
||||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
|
||||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
|
||||||
|
|
||||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
|
||||||
m_audio_len = m_audio.size();
|
|
||||||
} else {
|
|
||||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
|
||||||
|
|
||||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
|
||||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void audio_async::get(int ms, std::vector<float> & result) {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: not running!\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.clear();
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
if (ms <= 0) {
|
|
||||||
ms = m_len_ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
|
||||||
if (n_samples > m_audio_len) {
|
|
||||||
n_samples = m_audio_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.resize(n_samples);
|
|
||||||
|
|
||||||
int s0 = m_audio_pos - n_samples;
|
|
||||||
if (s0 < 0) {
|
|
||||||
s0 += m_audio.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s0 + n_samples > m_audio.size()) {
|
|
||||||
const size_t n0 = m_audio.size() - s0;
|
|
||||||
|
|
||||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
|
||||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
|
||||||
} else {
|
|
||||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////
|
|
||||||
|
|
||||||
std::string trim(const std::string & s) {
|
|
||||||
std::regex e("^\\s+|\\s+$");
|
|
||||||
return std::regex_replace(s, e, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
|
||||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
|
||||||
const float dt = 1.0f / sample_rate;
|
|
||||||
const float alpha = dt / (rc + dt);
|
|
||||||
|
|
||||||
float y = data[0];
|
|
||||||
|
|
||||||
for (size_t i = 1; i < data.size(); i++) {
|
|
||||||
y = alpha * (y + data[i] - data[i - 1]);
|
|
||||||
data[i] = y;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
|
||||||
const int n_samples = pcmf32.size();
|
|
||||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
|
||||||
|
|
||||||
if (n_samples_last >= n_samples) {
|
|
||||||
// not enough samples - assume no speech
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (freq_thold > 0.0f) {
|
|
||||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
|
||||||
}
|
|
||||||
|
|
||||||
float energy_all = 0.0f;
|
|
||||||
float energy_last = 0.0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < n_samples; i++) {
|
|
||||||
energy_all += fabsf(pcmf32[i]);
|
|
||||||
|
|
||||||
if (i >= n_samples - n_samples_last) {
|
|
||||||
energy_last += fabsf(pcmf32[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
energy_all /= n_samples;
|
|
||||||
energy_last /= n_samples_last;
|
|
||||||
|
|
||||||
if (verbose) {
|
|
||||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (energy_last > vad_thold*energy_all) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
@ -502,7 +198,7 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
|||||||
|
|
||||||
std::string line;
|
std::string line;
|
||||||
while (std::getline(ifs, line)) {
|
while (std::getline(ifs, line)) {
|
||||||
line = trim(line);
|
line = ::trim(line);
|
||||||
if (line.empty()) {
|
if (line.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -526,23 +222,6 @@ std::vector<std::string> get_words(const std::string &txt) {
|
|||||||
return words;
|
return words;
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns true if no exit event was received
|
|
||||||
bool process_sdl_events() {
|
|
||||||
SDL_Event event;
|
|
||||||
while (SDL_PollEvent(&event)) {
|
|
||||||
switch (event.type) {
|
|
||||||
case SDL_QUIT:
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
} break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// command-list mode
|
// command-list mode
|
||||||
// guide the transcription to match the most likely command from a provided list
|
// guide the transcription to match the most likely command from a provided list
|
||||||
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||||
@ -634,14 +313,14 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
|||||||
// main loop
|
// main loop
|
||||||
while (is_running) {
|
while (is_running) {
|
||||||
// handle Ctrl + C
|
// handle Ctrl + C
|
||||||
is_running = process_sdl_events();
|
is_running = sdl_poll_events();
|
||||||
|
|
||||||
// delay
|
// delay
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
|
||||||
audio.get(2000, pcmf32_cur);
|
audio.get(2000, pcmf32_cur);
|
||||||
|
|
||||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
@ -775,7 +454,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|||||||
// main loop
|
// main loop
|
||||||
while (is_running) {
|
while (is_running) {
|
||||||
// handle Ctrl + C
|
// handle Ctrl + C
|
||||||
is_running = process_sdl_events();
|
is_running = sdl_poll_events();
|
||||||
|
|
||||||
// delay
|
// delay
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
@ -791,7 +470,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|||||||
{
|
{
|
||||||
audio.get(2000, pcmf32_cur);
|
audio.get(2000, pcmf32_cur);
|
||||||
|
|
||||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||||
|
|
||||||
int64_t t_ms = 0;
|
int64_t t_ms = 0;
|
||||||
@ -854,7 +533,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
|||||||
// main loop
|
// main loop
|
||||||
while (is_running) {
|
while (is_running) {
|
||||||
// handle Ctrl + C
|
// handle Ctrl + C
|
||||||
is_running = process_sdl_events();
|
is_running = sdl_poll_events();
|
||||||
|
|
||||||
// delay
|
// delay
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
@ -870,7 +549,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
|||||||
{
|
{
|
||||||
audio.get(2000, pcmf32_cur);
|
audio.get(2000, pcmf32_cur);
|
||||||
|
|
||||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||||
|
|
||||||
int64_t t_ms = 0;
|
int64_t t_ms = 0;
|
||||||
|
226
examples/common-sdl.cpp
Normal file
226
examples/common-sdl.cpp
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
#include "common-sdl.h"
|
||||||
|
|
||||||
|
audio_async::audio_async(int len_ms) {
|
||||||
|
m_len_ms = len_ms;
|
||||||
|
|
||||||
|
m_running = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
audio_async::~audio_async() {
|
||||||
|
if (m_dev_id_in) {
|
||||||
|
SDL_CloseAudioDevice(m_dev_id_in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool audio_async::init(int capture_id, int sample_rate) {
|
||||||
|
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||||
|
|
||||||
|
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||||
|
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||||
|
|
||||||
|
{
|
||||||
|
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
||||||
|
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
||||||
|
for (int i = 0; i < nDevices; i++) {
|
||||||
|
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SDL_AudioSpec capture_spec_requested;
|
||||||
|
SDL_AudioSpec capture_spec_obtained;
|
||||||
|
|
||||||
|
SDL_zero(capture_spec_requested);
|
||||||
|
SDL_zero(capture_spec_obtained);
|
||||||
|
|
||||||
|
capture_spec_requested.freq = sample_rate;
|
||||||
|
capture_spec_requested.format = AUDIO_F32;
|
||||||
|
capture_spec_requested.channels = 1;
|
||||||
|
capture_spec_requested.samples = 1024;
|
||||||
|
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||||
|
audio_async * audio = (audio_async *) userdata;
|
||||||
|
audio->callback(stream, len);
|
||||||
|
};
|
||||||
|
capture_spec_requested.userdata = this;
|
||||||
|
|
||||||
|
if (capture_id >= 0) {
|
||||||
|
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||||
|
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||||
|
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!m_dev_id_in) {
|
||||||
|
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||||
|
m_dev_id_in = 0;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||||
|
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||||
|
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||||
|
capture_spec_requested.format);
|
||||||
|
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||||
|
capture_spec_requested.channels);
|
||||||
|
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||||
|
}
|
||||||
|
|
||||||
|
m_sample_rate = capture_spec_obtained.freq;
|
||||||
|
|
||||||
|
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool audio_async::resume() {
|
||||||
|
if (!m_dev_id_in) {
|
||||||
|
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m_running) {
|
||||||
|
fprintf(stderr, "%s: already running!\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||||
|
|
||||||
|
m_running = true;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool audio_async::pause() {
|
||||||
|
if (!m_dev_id_in) {
|
||||||
|
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!m_running) {
|
||||||
|
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||||
|
|
||||||
|
m_running = false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool audio_async::clear() {
|
||||||
|
if (!m_dev_id_in) {
|
||||||
|
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!m_running) {
|
||||||
|
fprintf(stderr, "%s: not running!\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(m_mutex);
|
||||||
|
|
||||||
|
m_audio_pos = 0;
|
||||||
|
m_audio_len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// callback to be called by SDL
|
||||||
|
void audio_async::callback(uint8_t * stream, int len) {
|
||||||
|
if (!m_running) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t n_samples = len / sizeof(float);
|
||||||
|
|
||||||
|
m_audio_new.resize(n_samples);
|
||||||
|
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||||
|
|
||||||
|
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(m_mutex);
|
||||||
|
|
||||||
|
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||||
|
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||||
|
|
||||||
|
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||||
|
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||||
|
|
||||||
|
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||||
|
m_audio_len = m_audio.size();
|
||||||
|
} else {
|
||||||
|
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||||
|
|
||||||
|
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||||
|
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void audio_async::get(int ms, std::vector<float> & result) {
|
||||||
|
if (!m_dev_id_in) {
|
||||||
|
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!m_running) {
|
||||||
|
fprintf(stderr, "%s: not running!\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.clear();
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(m_mutex);
|
||||||
|
|
||||||
|
if (ms <= 0) {
|
||||||
|
ms = m_len_ms;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||||
|
if (n_samples > m_audio_len) {
|
||||||
|
n_samples = m_audio_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.resize(n_samples);
|
||||||
|
|
||||||
|
int s0 = m_audio_pos - n_samples;
|
||||||
|
if (s0 < 0) {
|
||||||
|
s0 += m_audio.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (s0 + n_samples > m_audio.size()) {
|
||||||
|
const size_t n0 = m_audio.size() - s0;
|
||||||
|
|
||||||
|
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||||
|
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||||
|
} else {
|
||||||
|
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool sdl_poll_events() {
|
||||||
|
SDL_Event event;
|
||||||
|
while (SDL_PollEvent(&event)) {
|
||||||
|
switch (event.type) {
|
||||||
|
case SDL_QUIT:
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
50
examples/common-sdl.h
Normal file
50
examples/common-sdl.h
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <SDL.h>
|
||||||
|
#include <SDL_audio.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
//
|
||||||
|
// SDL Audio capture
|
||||||
|
//
|
||||||
|
|
||||||
|
class audio_async {
|
||||||
|
public:
|
||||||
|
audio_async(int len_ms);
|
||||||
|
~audio_async();
|
||||||
|
|
||||||
|
bool init(int capture_id, int sample_rate);
|
||||||
|
|
||||||
|
// start capturing audio via the provided SDL callback
|
||||||
|
// keep last len_ms seconds of audio in a circular buffer
|
||||||
|
bool resume();
|
||||||
|
bool pause();
|
||||||
|
bool clear();
|
||||||
|
|
||||||
|
// callback to be called by SDL
|
||||||
|
void callback(uint8_t * stream, int len);
|
||||||
|
|
||||||
|
// get audio data from the circular buffer
|
||||||
|
void get(int ms, std::vector<float> & audio);
|
||||||
|
|
||||||
|
private:
|
||||||
|
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||||
|
|
||||||
|
int m_len_ms = 0;
|
||||||
|
int m_sample_rate = 0;
|
||||||
|
|
||||||
|
std::atomic_bool m_running;
|
||||||
|
std::mutex m_mutex;
|
||||||
|
|
||||||
|
std::vector<float> m_audio;
|
||||||
|
std::vector<float> m_audio_new;
|
||||||
|
size_t m_audio_pos = 0;
|
||||||
|
size_t m_audio_len = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return false if need to quit
|
||||||
|
bool sdl_poll_events();
|
162
examples/common.cpp
Normal file
162
examples/common.cpp
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
// third-party utilities
|
||||||
|
// use your favorite implementations
|
||||||
|
#define DR_WAV_IMPLEMENTATION
|
||||||
|
#include "dr_wav.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
#ifndef M_PI
|
||||||
|
#define M_PI 3.14159265358979323846
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::string trim(const std::string & s) {
|
||||||
|
std::regex e("^\\s+|\\s+$");
|
||||||
|
return std::regex_replace(s, e, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
||||||
|
std::string result = s;
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = result.find(from, pos)) != std::string::npos) {
|
||||||
|
result.replace(pos, from.length(), to);
|
||||||
|
pos += to.length();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
|
||||||
|
drwav wav;
|
||||||
|
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||||
|
|
||||||
|
if (fname == "-") {
|
||||||
|
{
|
||||||
|
uint8_t buf[1024];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||||
|
if (n == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||||
|
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||||
|
}
|
||||||
|
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
|
||||||
|
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wav.channels != 1 && wav.channels != 2) {
|
||||||
|
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stereo && wav.channels != 2) {
|
||||||
|
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
|
||||||
|
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wav.bitsPerSample != 16) {
|
||||||
|
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||||
|
|
||||||
|
std::vector<int16_t> pcm16;
|
||||||
|
pcm16.resize(n*wav.channels);
|
||||||
|
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||||
|
drwav_uninit(&wav);
|
||||||
|
|
||||||
|
// convert to mono, float
|
||||||
|
pcmf32.resize(n);
|
||||||
|
if (wav.channels == 1) {
|
||||||
|
for (uint64_t i = 0; i < n; i++) {
|
||||||
|
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t i = 0; i < n; i++) {
|
||||||
|
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stereo) {
|
||||||
|
// convert to stereo, float
|
||||||
|
pcmf32s.resize(2);
|
||||||
|
|
||||||
|
pcmf32s[0].resize(n);
|
||||||
|
pcmf32s[1].resize(n);
|
||||||
|
for (uint64_t i = 0; i < n; i++) {
|
||||||
|
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||||
|
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||||
|
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||||
|
const float dt = 1.0f / sample_rate;
|
||||||
|
const float alpha = dt / (rc + dt);
|
||||||
|
|
||||||
|
float y = data[0];
|
||||||
|
|
||||||
|
for (size_t i = 1; i < data.size(); i++) {
|
||||||
|
y = alpha * (y + data[i] - data[i - 1]);
|
||||||
|
data[i] = y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||||
|
const int n_samples = pcmf32.size();
|
||||||
|
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||||
|
|
||||||
|
if (n_samples_last >= n_samples) {
|
||||||
|
// not enough samples - assume no speech
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (freq_thold > 0.0f) {
|
||||||
|
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||||
|
}
|
||||||
|
|
||||||
|
float energy_all = 0.0f;
|
||||||
|
float energy_last = 0.0f;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
energy_all += fabsf(pcmf32[i]);
|
||||||
|
|
||||||
|
if (i >= n_samples - n_samples_last) {
|
||||||
|
energy_last += fabsf(pcmf32[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
energy_all /= n_samples;
|
||||||
|
energy_last /= n_samples_last;
|
||||||
|
|
||||||
|
if (verbose) {
|
||||||
|
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (energy_last > vad_thold*energy_all) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
40
examples/common.h
Normal file
40
examples/common.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
// needs to match WHISPER_SAMPLE_RATE
|
||||||
|
#define COMMON_SAMPLE_RATE 16000
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
std::string trim(const std::string & s);
|
||||||
|
|
||||||
|
std::string replace(
|
||||||
|
const std::string & s,
|
||||||
|
const std::string & from,
|
||||||
|
const std::string & to);
|
||||||
|
|
||||||
|
// Read WAV audio file and store the PCM data into pcmf32
|
||||||
|
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
|
||||||
|
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
|
||||||
|
bool read_wav(
|
||||||
|
const std::string & fname,
|
||||||
|
std::vector<float> & pcmf32,
|
||||||
|
std::vector<std::vector<float>> & pcmf32s,
|
||||||
|
bool stereo);
|
||||||
|
|
||||||
|
// Apply a high-pass frequency filter to PCM audio
|
||||||
|
// Suppresses frequencies below cutoff Hz
|
||||||
|
void high_pass_filter(
|
||||||
|
std::vector<float> & data,
|
||||||
|
float cutoff,
|
||||||
|
float sample_rate);
|
||||||
|
|
||||||
|
// Basic voice activity detection (VAD) using audio energy adaptive threshold
|
||||||
|
bool vad_simple(
|
||||||
|
std::vector<float> & pcmf32,
|
||||||
|
int sample_rate,
|
||||||
|
int last_ms,
|
||||||
|
float vad_thold,
|
||||||
|
float freq_thold,
|
||||||
|
bool verbose);
|
||||||
|
|
@ -3,4 +3,4 @@ add_executable(${TARGET} main.cpp)
|
|||||||
|
|
||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
#include "whisper.h"
|
#include "common.h"
|
||||||
|
|
||||||
// third-party utilities
|
#include "whisper.h"
|
||||||
// use your favorite implementations
|
|
||||||
#define DR_WAV_IMPLEMENTATION
|
|
||||||
#include "dr_wav.h"
|
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -86,7 +83,7 @@ struct whisper_params {
|
|||||||
std::string model = "models/ggml-base.en.bin";
|
std::string model = "models/ggml-base.en.bin";
|
||||||
|
|
||||||
std::vector<std::string> fname_inp = {};
|
std::vector<std::string> fname_inp = {};
|
||||||
std::vector<std::string> fname_outp = {};
|
std::vector<std::string> fname_out = {};
|
||||||
};
|
};
|
||||||
|
|
||||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||||
@ -95,6 +92,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
for (int i = 1; i < argc; i++) {
|
for (int i = 1; i < argc; i++) {
|
||||||
std::string arg = argv[i];
|
std::string arg = argv[i];
|
||||||
|
|
||||||
|
if (arg == "-"){
|
||||||
|
params.fname_inp.push_back(arg);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (arg[0] != '-') {
|
if (arg[0] != '-') {
|
||||||
params.fname_inp.push_back(arg);
|
params.fname_inp.push_back(arg);
|
||||||
continue;
|
continue;
|
||||||
@ -126,7 +128,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||||
else if (arg == "-of" || arg == "--output-file") { params.fname_outp.emplace_back(argv[++i]); }
|
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||||
@ -520,91 +522,14 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||||
const auto fname_inp = params.fname_inp[f];
|
const auto fname_inp = params.fname_inp[f];
|
||||||
const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
|
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||||
|
|
||||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||||
|
|
||||||
// WAV input
|
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
|
||||||
{
|
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||||
drwav wav;
|
continue;
|
||||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
|
||||||
|
|
||||||
if (fname_inp == "-") {
|
|
||||||
{
|
|
||||||
uint8_t buf[1024];
|
|
||||||
while (true)
|
|
||||||
{
|
|
||||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
|
||||||
if (n == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
|
||||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
|
||||||
return 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
|
||||||
}
|
|
||||||
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
|
||||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
|
||||||
return 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wav.channels != 1 && wav.channels != 2) {
|
|
||||||
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
|
|
||||||
return 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
|
|
||||||
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
|
|
||||||
return 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
|
||||||
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
|
||||||
return 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wav.bitsPerSample != 16) {
|
|
||||||
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
|
|
||||||
return 9;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
|
||||||
|
|
||||||
std::vector<int16_t> pcm16;
|
|
||||||
pcm16.resize(n*wav.channels);
|
|
||||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
|
||||||
drwav_uninit(&wav);
|
|
||||||
|
|
||||||
// convert to mono, float
|
|
||||||
pcmf32.resize(n);
|
|
||||||
if (wav.channels == 1) {
|
|
||||||
for (uint64_t i = 0; i < n; i++) {
|
|
||||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (uint64_t i = 0; i < n; i++) {
|
|
||||||
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.diarize) {
|
|
||||||
// convert to stereo, float
|
|
||||||
pcmf32s.resize(2);
|
|
||||||
|
|
||||||
pcmf32s[0].resize(n);
|
|
||||||
pcmf32s[1].resize(n);
|
|
||||||
for (uint64_t i = 0; i < n; i++) {
|
|
||||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
|
||||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// print system information
|
// print system information
|
||||||
@ -701,34 +626,33 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// output to text file
|
// output to text file
|
||||||
if (params.output_txt) {
|
if (params.output_txt) {
|
||||||
const auto fname_txt = fname_outp + ".txt";
|
const auto fname_txt = fname_out + ".txt";
|
||||||
output_txt(ctx, fname_txt.c_str());
|
output_txt(ctx, fname_txt.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// output to VTT file
|
// output to VTT file
|
||||||
if (params.output_vtt) {
|
if (params.output_vtt) {
|
||||||
const auto fname_vtt = fname_outp + ".vtt";
|
const auto fname_vtt = fname_out + ".vtt";
|
||||||
output_vtt(ctx, fname_vtt.c_str());
|
output_vtt(ctx, fname_vtt.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// output to SRT file
|
// output to SRT file
|
||||||
if (params.output_srt) {
|
if (params.output_srt) {
|
||||||
const auto fname_srt = fname_outp + ".srt";
|
const auto fname_srt = fname_out + ".srt";
|
||||||
output_srt(ctx, fname_srt.c_str(), params);
|
output_srt(ctx, fname_srt.c_str(), params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// output to WTS file
|
// output to WTS file
|
||||||
if (params.output_wts) {
|
if (params.output_wts) {
|
||||||
const auto fname_wts = fname_outp + ".wts";
|
const auto fname_wts = fname_out + ".wts";
|
||||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// output to CSV file
|
// output to CSV file
|
||||||
if (params.output_csv) {
|
if (params.output_csv) {
|
||||||
const auto fname_csv = fname_outp + ".csv";
|
const auto fname_csv = fname_out + ".csv";
|
||||||
output_csv(ctx, fname_csv.c_str());
|
output_csv(ctx, fname_csv.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,5 @@ if (WHISPER_SUPPORT_SDL2)
|
|||||||
|
|
||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -3,19 +3,16 @@
|
|||||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "common-sdl.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
|
||||||
#include <SDL.h>
|
|
||||||
#include <SDL_audio.h>
|
|
||||||
|
|
||||||
#include <atomic>
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <mutex>
|
|
||||||
|
|
||||||
// 500 -> 00:05.000
|
// 500 -> 00:05.000
|
||||||
// 6000 -> 01:00.000
|
// 6000 -> 01:00.000
|
||||||
@ -116,306 +113,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// SDL Audio capture
|
|
||||||
//
|
|
||||||
|
|
||||||
class audio_async {
|
|
||||||
public:
|
|
||||||
audio_async(int len_ms);
|
|
||||||
~audio_async();
|
|
||||||
|
|
||||||
bool init(int capture_id, int sample_rate);
|
|
||||||
|
|
||||||
// start capturing audio via the provided SDL callback
|
|
||||||
// keep last len_ms seconds of audio in a circular buffer
|
|
||||||
bool resume();
|
|
||||||
bool pause();
|
|
||||||
bool clear();
|
|
||||||
|
|
||||||
// callback to be called by SDL
|
|
||||||
void callback(uint8_t * stream, int len);
|
|
||||||
|
|
||||||
// get audio data from the circular buffer
|
|
||||||
void get(int ms, std::vector<float> & audio);
|
|
||||||
|
|
||||||
private:
|
|
||||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
|
||||||
|
|
||||||
int m_len_ms = 0;
|
|
||||||
int m_sample_rate = 0;
|
|
||||||
|
|
||||||
std::atomic_bool m_running;
|
|
||||||
std::mutex m_mutex;
|
|
||||||
|
|
||||||
std::vector<float> m_audio;
|
|
||||||
std::vector<float> m_audio_new;
|
|
||||||
size_t m_audio_pos = 0;
|
|
||||||
size_t m_audio_len = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
audio_async::audio_async(int len_ms) {
|
|
||||||
m_len_ms = len_ms;
|
|
||||||
|
|
||||||
m_running = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
audio_async::~audio_async() {
|
|
||||||
if (m_dev_id_in) {
|
|
||||||
SDL_CloseAudioDevice(m_dev_id_in);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::init(int capture_id, int sample_rate) {
|
|
||||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
|
||||||
|
|
||||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
|
||||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
|
||||||
|
|
||||||
{
|
|
||||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
|
||||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
|
||||||
for (int i = 0; i < nDevices; i++) {
|
|
||||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_AudioSpec capture_spec_requested;
|
|
||||||
SDL_AudioSpec capture_spec_obtained;
|
|
||||||
|
|
||||||
SDL_zero(capture_spec_requested);
|
|
||||||
SDL_zero(capture_spec_obtained);
|
|
||||||
|
|
||||||
capture_spec_requested.freq = sample_rate;
|
|
||||||
capture_spec_requested.format = AUDIO_F32;
|
|
||||||
capture_spec_requested.channels = 1;
|
|
||||||
capture_spec_requested.samples = 1024;
|
|
||||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
|
||||||
audio_async * audio = (audio_async *) userdata;
|
|
||||||
audio->callback(stream, len);
|
|
||||||
};
|
|
||||||
capture_spec_requested.userdata = this;
|
|
||||||
|
|
||||||
if (capture_id >= 0) {
|
|
||||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
|
||||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
|
||||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
|
||||||
m_dev_id_in = 0;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
|
||||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
|
||||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
|
||||||
capture_spec_requested.format);
|
|
||||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
|
||||||
capture_spec_requested.channels);
|
|
||||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
|
||||||
}
|
|
||||||
|
|
||||||
m_sample_rate = capture_spec_obtained.freq;
|
|
||||||
|
|
||||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::resume() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (m_running) {
|
|
||||||
fprintf(stderr, "%s: already running!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
|
||||||
|
|
||||||
m_running = true;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::pause() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
|
||||||
|
|
||||||
m_running = false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::clear() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: not running!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
m_audio_pos = 0;
|
|
||||||
m_audio_len = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// callback to be called by SDL
|
|
||||||
void audio_async::callback(uint8_t * stream, int len) {
|
|
||||||
if (!m_running) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t n_samples = len / sizeof(float);
|
|
||||||
|
|
||||||
m_audio_new.resize(n_samples);
|
|
||||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
|
||||||
|
|
||||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
|
||||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
|
||||||
|
|
||||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
|
||||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
|
||||||
|
|
||||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
|
||||||
m_audio_len = m_audio.size();
|
|
||||||
} else {
|
|
||||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
|
||||||
|
|
||||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
|
||||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void audio_async::get(int ms, std::vector<float> & result) {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: not running!\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.clear();
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
if (ms <= 0) {
|
|
||||||
ms = m_len_ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
|
||||||
if (n_samples > m_audio_len) {
|
|
||||||
n_samples = m_audio_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.resize(n_samples);
|
|
||||||
|
|
||||||
int s0 = m_audio_pos - n_samples;
|
|
||||||
if (s0 < 0) {
|
|
||||||
s0 += m_audio.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s0 + n_samples > m_audio.size()) {
|
|
||||||
const size_t n0 = m_audio.size() - s0;
|
|
||||||
|
|
||||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
|
||||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
|
||||||
} else {
|
|
||||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////
|
|
||||||
|
|
||||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
|
||||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
|
||||||
const float dt = 1.0f / sample_rate;
|
|
||||||
const float alpha = dt / (rc + dt);
|
|
||||||
|
|
||||||
float y = data[0];
|
|
||||||
|
|
||||||
for (size_t i = 1; i < data.size(); i++) {
|
|
||||||
y = alpha * (y + data[i] - data[i - 1]);
|
|
||||||
data[i] = y;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
|
||||||
const int n_samples = pcmf32.size();
|
|
||||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
|
||||||
|
|
||||||
if (n_samples_last >= n_samples) {
|
|
||||||
// not enough samples - assume no speech
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (freq_thold > 0.0f) {
|
|
||||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
|
||||||
}
|
|
||||||
|
|
||||||
float energy_all = 0.0f;
|
|
||||||
float energy_last = 0.0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < n_samples; i++) {
|
|
||||||
energy_all += fabsf(pcmf32[i]);
|
|
||||||
|
|
||||||
if (i >= n_samples - n_samples_last) {
|
|
||||||
energy_last += fabsf(pcmf32[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
energy_all /= n_samples;
|
|
||||||
energy_last /= n_samples_last;
|
|
||||||
|
|
||||||
if (verbose) {
|
|
||||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (energy_last > vad_thold*energy_all) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
whisper_params params;
|
whisper_params params;
|
||||||
|
|
||||||
@ -426,10 +123,10 @@ int main(int argc, char ** argv) {
|
|||||||
params.keep_ms = std::min(params.keep_ms, params.step_ms);
|
params.keep_ms = std::min(params.keep_ms, params.step_ms);
|
||||||
params.length_ms = std::max(params.length_ms, params.step_ms);
|
params.length_ms = std::max(params.length_ms, params.step_ms);
|
||||||
|
|
||||||
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE;
|
||||||
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
|
const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE;
|
||||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE;
|
||||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE;
|
||||||
|
|
||||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||||
|
|
||||||
@ -517,23 +214,7 @@ int main(int argc, char ** argv) {
|
|||||||
// main audio loop
|
// main audio loop
|
||||||
while (is_running) {
|
while (is_running) {
|
||||||
// handle Ctrl + C
|
// handle Ctrl + C
|
||||||
{
|
is_running = sdl_poll_events();
|
||||||
SDL_Event event;
|
|
||||||
while (SDL_PollEvent(&event)) {
|
|
||||||
switch (event.type) {
|
|
||||||
case SDL_QUIT:
|
|
||||||
{
|
|
||||||
is_running = false;
|
|
||||||
} break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!is_running) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!is_running) {
|
if (!is_running) {
|
||||||
break;
|
break;
|
||||||
@ -556,7 +237,7 @@ int main(int argc, char ** argv) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDL_Delay(1);
|
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_samples_new = pcmf32_new.size();
|
const int n_samples_new = pcmf32_new.size();
|
||||||
@ -587,7 +268,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
audio.get(2000, pcmf32_new);
|
audio.get(2000, pcmf32_new);
|
||||||
|
|
||||||
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||||
audio.get(params.length_ms, pcmf32);
|
audio.get(params.length_ms, pcmf32);
|
||||||
} else {
|
} else {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
@ -7,7 +7,7 @@ if (WHISPER_SUPPORT_SDL2)
|
|||||||
|
|
||||||
# TODO: this is temporary
|
# TODO: this is temporary
|
||||||
# need to export ggml symbols for MSVC, but too lazy ..
|
# need to export ggml symbols for MSVC, but too lazy ..
|
||||||
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
|
add_executable(${TARGET} talk.cpp gpt-2.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
// Talk with AI
|
// Talk with AI
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "common-sdl.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
#include "gpt-2.h"
|
#include "gpt-2.h"
|
||||||
|
|
||||||
#include <SDL.h>
|
|
||||||
#include <SDL_audio.h>
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <mutex>
|
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
@ -105,320 +103,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// SDL Audio capture
|
|
||||||
//
|
|
||||||
|
|
||||||
class audio_async {
|
|
||||||
public:
|
|
||||||
audio_async(int len_ms);
|
|
||||||
~audio_async();
|
|
||||||
|
|
||||||
bool init(int capture_id, int sample_rate);
|
|
||||||
|
|
||||||
// start capturing audio via the provided SDL callback
|
|
||||||
// keep last len_ms seconds of audio in a circular buffer
|
|
||||||
bool resume();
|
|
||||||
bool pause();
|
|
||||||
bool clear();
|
|
||||||
|
|
||||||
// callback to be called by SDL
|
|
||||||
void callback(uint8_t * stream, int len);
|
|
||||||
|
|
||||||
// get audio data from the circular buffer
|
|
||||||
void get(int ms, std::vector<float> & audio);
|
|
||||||
|
|
||||||
private:
|
|
||||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
|
||||||
|
|
||||||
int m_len_ms = 0;
|
|
||||||
int m_sample_rate = 0;
|
|
||||||
|
|
||||||
bool m_running = false;
|
|
||||||
std::mutex m_mutex;
|
|
||||||
|
|
||||||
std::vector<float> m_audio;
|
|
||||||
std::vector<float> m_audio_new;
|
|
||||||
size_t m_audio_pos = 0;
|
|
||||||
size_t m_audio_len = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
audio_async::audio_async(int len_ms) {
|
|
||||||
m_len_ms = len_ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
audio_async::~audio_async() {
|
|
||||||
if (m_dev_id_in) {
|
|
||||||
SDL_CloseAudioDevice(m_dev_id_in);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::init(int capture_id, int sample_rate) {
|
|
||||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
|
||||||
|
|
||||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
|
||||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
|
||||||
|
|
||||||
{
|
|
||||||
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
|
|
||||||
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
|
|
||||||
for (int i = 0; i < nDevices; i++) {
|
|
||||||
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_AudioSpec capture_spec_requested;
|
|
||||||
SDL_AudioSpec capture_spec_obtained;
|
|
||||||
|
|
||||||
SDL_zero(capture_spec_requested);
|
|
||||||
SDL_zero(capture_spec_obtained);
|
|
||||||
|
|
||||||
capture_spec_requested.freq = sample_rate;
|
|
||||||
capture_spec_requested.format = AUDIO_F32;
|
|
||||||
capture_spec_requested.channels = 1;
|
|
||||||
capture_spec_requested.samples = 1024;
|
|
||||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
|
||||||
audio_async * audio = (audio_async *) userdata;
|
|
||||||
audio->callback(stream, len);
|
|
||||||
};
|
|
||||||
capture_spec_requested.userdata = this;
|
|
||||||
|
|
||||||
if (capture_id >= 0) {
|
|
||||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
|
||||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
|
||||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
|
||||||
m_dev_id_in = 0;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
|
||||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
|
||||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
|
||||||
capture_spec_requested.format);
|
|
||||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
|
||||||
capture_spec_requested.channels);
|
|
||||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
m_sample_rate = capture_spec_obtained.freq;
|
|
||||||
|
|
||||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::resume() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (m_running) {
|
|
||||||
fprintf(stderr, "%s: already running!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
|
||||||
|
|
||||||
m_running = true;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::pause() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
|
||||||
|
|
||||||
m_running = false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool audio_async::clear() {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: not running!\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
m_audio_pos = 0;
|
|
||||||
m_audio_len = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// callback to be called by SDL
|
|
||||||
void audio_async::callback(uint8_t * stream, int len) {
|
|
||||||
if (!m_running) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t n_samples = len / sizeof(float);
|
|
||||||
|
|
||||||
m_audio_new.resize(n_samples);
|
|
||||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
|
||||||
|
|
||||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
|
||||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
|
||||||
|
|
||||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
|
||||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
|
||||||
|
|
||||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
|
||||||
m_audio_len = m_audio.size();
|
|
||||||
} else {
|
|
||||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
|
||||||
|
|
||||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
|
||||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void audio_async::get(int ms, std::vector<float> & result) {
|
|
||||||
if (!m_dev_id_in) {
|
|
||||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_running) {
|
|
||||||
fprintf(stderr, "%s: not running!\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.clear();
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
if (ms <= 0) {
|
|
||||||
ms = m_len_ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
|
||||||
if (n_samples > m_audio_len) {
|
|
||||||
n_samples = m_audio_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.resize(n_samples);
|
|
||||||
|
|
||||||
int s0 = m_audio_pos - n_samples;
|
|
||||||
if (s0 < 0) {
|
|
||||||
s0 += m_audio.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s0 + n_samples > m_audio.size()) {
|
|
||||||
const size_t n0 = m_audio.size() - s0;
|
|
||||||
|
|
||||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
|
||||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
|
||||||
} else {
|
|
||||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////
|
|
||||||
|
|
||||||
std::string trim(const std::string & s) {
|
|
||||||
std::regex e("^\\s+|\\s+$");
|
|
||||||
return std::regex_replace(s, e, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
|
||||||
std::string result = s;
|
|
||||||
size_t pos = 0;
|
|
||||||
while ((pos = result.find(from, pos)) != std::string::npos) {
|
|
||||||
result.replace(pos, from.length(), to);
|
|
||||||
pos += to.length();
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
|
||||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
|
||||||
const float dt = 1.0f / sample_rate;
|
|
||||||
const float alpha = dt / (rc + dt);
|
|
||||||
|
|
||||||
float y = data[0];
|
|
||||||
|
|
||||||
for (size_t i = 1; i < data.size(); i++) {
|
|
||||||
y = alpha * (y + data[i] - data[i - 1]);
|
|
||||||
data[i] = y;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
|
||||||
const int n_samples = pcmf32.size();
|
|
||||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
|
||||||
|
|
||||||
if (n_samples_last >= n_samples) {
|
|
||||||
// not enough samples - assume no speech
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (freq_thold > 0.0f) {
|
|
||||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
|
||||||
}
|
|
||||||
|
|
||||||
float energy_all = 0.0f;
|
|
||||||
float energy_last = 0.0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < n_samples; i++) {
|
|
||||||
energy_all += fabsf(pcmf32[i]);
|
|
||||||
|
|
||||||
if (i >= n_samples - n_samples_last) {
|
|
||||||
energy_last += fabsf(pcmf32[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
energy_all /= n_samples;
|
|
||||||
energy_last /= n_samples_last;
|
|
||||||
|
|
||||||
if (verbose) {
|
|
||||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (energy_last > vad_thold*energy_all) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
@ -557,22 +241,10 @@ int main(int argc, char ** argv) {
|
|||||||
// main loop
|
// main loop
|
||||||
while (is_running) {
|
while (is_running) {
|
||||||
// handle Ctrl + C
|
// handle Ctrl + C
|
||||||
{
|
is_running = sdl_poll_events();
|
||||||
SDL_Event event;
|
|
||||||
while (SDL_PollEvent(&event)) {
|
|
||||||
switch (event.type) {
|
|
||||||
case SDL_QUIT:
|
|
||||||
{
|
|
||||||
is_running = false;
|
|
||||||
} break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!is_running) {
|
if (!is_running) {
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// delay
|
// delay
|
||||||
@ -583,7 +255,7 @@ int main(int argc, char ** argv) {
|
|||||||
{
|
{
|
||||||
audio.get(2000, pcmf32_cur);
|
audio.get(2000, pcmf32_cur);
|
||||||
|
|
||||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
||||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||||
|
|
||||||
audio.get(params.voice_ms, pcmf32_cur);
|
audio.get(params.voice_ms, pcmf32_cur);
|
||||||
|
@ -67,23 +67,6 @@ msg() {
|
|||||||
echo >&2 -e "${1-}"
|
echo >&2 -e "${1-}"
|
||||||
}
|
}
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# create a temporary directory to work in
|
|
||||||
# set the temp_dir and temp_filename variables
|
|
||||||
################################################################################
|
|
||||||
temp_dir="$(mktemp -d ${SCRIPT_DIR}/tmp.XXXXXX)";
|
|
||||||
temp_filename="${temp_dir}/yt-dlp-filename";
|
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# for now we only take one argument
|
|
||||||
# TODO: a for loop
|
|
||||||
################################################################################
|
|
||||||
source_url="${1}"
|
|
||||||
|
|
||||||
|
|
||||||
title_name="";
|
|
||||||
|
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
local -r clean_me="${1}";
|
local -r clean_me="${1}";
|
||||||
|
|
||||||
@ -145,6 +128,20 @@ fi
|
|||||||
|
|
||||||
check_requirements;
|
check_requirements;
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# create a temporary directory to work in
|
||||||
|
# set the temp_dir and temp_filename variables
|
||||||
|
################################################################################
|
||||||
|
temp_dir="$(mktemp -d ${SCRIPT_DIR}/tmp.XXXXXX)";
|
||||||
|
temp_filename="${temp_dir}/yt-dlp-filename";
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# for now we only take one argument
|
||||||
|
# TODO: a for loop
|
||||||
|
################################################################################
|
||||||
|
source_url="${1}"
|
||||||
|
title_name="";
|
||||||
|
|
||||||
msg "Downloading VOD...";
|
msg "Downloading VOD...";
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
@ -199,6 +196,6 @@ ffmpeg -i "${temp_dir}/${title_name}.vod.mp4" \
|
|||||||
-c:s mov_text \
|
-c:s mov_text \
|
||||||
-y "${title_name}-res.mp4";
|
-y "${title_name}-res.mp4";
|
||||||
|
|
||||||
cleanup "${temp_dir}";
|
#cleanup "${temp_dir}";
|
||||||
|
|
||||||
msg "Done! Your finished file is ready: ${title_name}-res.mp4";
|
msg "Done! Your finished file is ready: ${title_name}-res.mp4";
|
||||||
|
48
whisper.cpp
48
whisper.cpp
@ -592,16 +592,16 @@ struct whisper_context {
|
|||||||
|
|
||||||
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
||||||
|
|
||||||
int lang_id;
|
int lang_id = 0; // english by default
|
||||||
|
|
||||||
// [EXPERIMENTAL] token-level timestamps data
|
// [EXPERIMENTAL] token-level timestamps data
|
||||||
int64_t t_beg;
|
int64_t t_beg = 0;
|
||||||
int64_t t_last;
|
int64_t t_last = 0;
|
||||||
whisper_token tid_last;
|
whisper_token tid_last;
|
||||||
std::vector<float> energy; // PCM signal energy
|
std::vector<float> energy; // PCM signal energy
|
||||||
|
|
||||||
// [EXPERIMENTAL] speed-up techniques
|
// [EXPERIMENTAL] speed-up techniques
|
||||||
int32_t exp_n_audio_ctx; // 0 - use default
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
||||||
|
|
||||||
void use_buf(struct ggml_context * ctx, int i) {
|
void use_buf(struct ggml_context * ctx, int i) {
|
||||||
#if defined(WHISPER_USE_SCRATCH)
|
#if defined(WHISPER_USE_SCRATCH)
|
||||||
@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
MEM_REQ_SCRATCH3.at (model.type) +
|
MEM_REQ_SCRATCH3.at (model.type) +
|
||||||
scale*MEM_REQ_MODEL.at (model.type) +
|
scale*MEM_REQ_MODEL.at (model.type) +
|
||||||
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
||||||
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
||||||
|
|
||||||
// this is the memory required by one decoder
|
// this is the memory required by one decoder
|
||||||
const size_t mem_required_decoder =
|
const size_t mem_required_decoder =
|
||||||
@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.language =*/ "en",
|
/*.language =*/ "en",
|
||||||
|
|
||||||
/*.suppress_blank =*/ true,
|
/*.suppress_blank =*/ true,
|
||||||
/*.suppress_non_speech_tokens =*/true,
|
/*.suppress_non_speech_tokens =*/ false,
|
||||||
|
|
||||||
/*.temperature =*/ 0.0f,
|
/*.temperature =*/ 0.0f,
|
||||||
/*.max_initial_ts =*/ 1.0f,
|
/*.max_initial_ts =*/ 1.0f,
|
||||||
@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
|
|
||||||
/*.encoder_begin_callback =*/ nullptr,
|
/*.encoder_begin_callback =*/ nullptr,
|
||||||
/*.encoder_begin_callback_user_data =*/ nullptr,
|
/*.encoder_begin_callback_user_data =*/ nullptr,
|
||||||
|
|
||||||
|
/*.logits_filter_callback =*/ nullptr,
|
||||||
|
/*.logits_filter_callback_user_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
switch (strategy) {
|
switch (strategy) {
|
||||||
@ -3078,8 +3081,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const std::vector<std::string> non_speech_tokens
|
static const std::vector<std::string> non_speech_tokens = {
|
||||||
{
|
|
||||||
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
|
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
|
||||||
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
||||||
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
||||||
@ -3090,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens
|
|||||||
// - applies logit filters
|
// - applies logit filters
|
||||||
// - computes logprobs and probs
|
// - computes logprobs and probs
|
||||||
static void whisper_process_logits(
|
static void whisper_process_logits(
|
||||||
const struct whisper_context & ctx,
|
struct whisper_context & ctx,
|
||||||
const struct whisper_full_params params,
|
const struct whisper_full_params params,
|
||||||
struct whisper_decoder & decoder,
|
struct whisper_decoder & decoder,
|
||||||
float temperature) {
|
float temperature) {
|
||||||
@ -3146,29 +3148,27 @@ static void whisper_process_logits(
|
|||||||
logits[vocab.token_translate] = -INFINITY;
|
logits[vocab.token_translate] = -INFINITY;
|
||||||
logits[vocab.token_transcribe] = -INFINITY;
|
logits[vocab.token_transcribe] = -INFINITY;
|
||||||
|
|
||||||
|
if (params.logits_filter_callback) {
|
||||||
|
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
||||||
|
}
|
||||||
|
|
||||||
// suppress non-speech tokens
|
// suppress non-speech tokens
|
||||||
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||||
if (params.suppress_non_speech_tokens)
|
if (params.suppress_non_speech_tokens) {
|
||||||
{
|
for (const std::string & token : non_speech_tokens) {
|
||||||
for (const std::string &token : non_speech_tokens)
|
const std::string suppress_tokens[] = {token, " " + token};
|
||||||
{
|
for (const std::string & suppress_token : suppress_tokens) {
|
||||||
std::string suppress_tokens[] = {token, " " + token};
|
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
|
||||||
for (const std::string &suppress_token : suppress_tokens)
|
|
||||||
{
|
|
||||||
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
|
|
||||||
{
|
|
||||||
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
|
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
|
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
|
||||||
{
|
|
||||||
logits[vocab.token_to_id.at(" -")] = -INFINITY;
|
logits[vocab.token_to_id.at(" -")] = -INFINITY;
|
||||||
}
|
}
|
||||||
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
|
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
|
||||||
{
|
|
||||||
logits[vocab.token_to_id.at(" '")] = -INFINITY;
|
logits[vocab.token_to_id.at(" '")] = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3854,7 +3854,7 @@ int whisper_full(
|
|||||||
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
||||||
});
|
});
|
||||||
|
|
||||||
int cur_c = 0;
|
uint32_t cur_c = 0;
|
||||||
|
|
||||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||||
auto & decoder = ctx->decoders[j];
|
auto & decoder = ctx->decoders[j];
|
||||||
@ -4339,7 +4339,7 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int whisper_full_lang_id(struct whisper_context * ctx) {
|
int whisper_full_lang_id(struct whisper_context * ctx) {
|
||||||
return ctx->lang_id;
|
return ctx->lang_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
||||||
|
14
whisper.h
14
whisper.h
@ -243,6 +243,16 @@ extern "C" {
|
|||||||
// If it returns false, the computation is aborted
|
// If it returns false, the computation is aborted
|
||||||
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
|
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
|
||||||
|
|
||||||
|
// Logits filter callback
|
||||||
|
// Can be used to modify the logits before sampling
|
||||||
|
// If not NULL, called after applying temperature to logits
|
||||||
|
typedef void (*whisper_logits_filter_callback)(
|
||||||
|
struct whisper_context * ctx,
|
||||||
|
const whisper_token_data * tokens,
|
||||||
|
int n_tokens,
|
||||||
|
float * logits,
|
||||||
|
void * user_data);
|
||||||
|
|
||||||
// Parameters for the whisper_full() function
|
// Parameters for the whisper_full() function
|
||||||
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
|
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
|
||||||
// whisper_full_default_params()
|
// whisper_full_default_params()
|
||||||
@ -315,6 +325,10 @@ extern "C" {
|
|||||||
// called each time before the encoder starts
|
// called each time before the encoder starts
|
||||||
whisper_encoder_begin_callback encoder_begin_callback;
|
whisper_encoder_begin_callback encoder_begin_callback;
|
||||||
void * encoder_begin_callback_user_data;
|
void * encoder_begin_callback_user_data;
|
||||||
|
|
||||||
|
// called by each decoder to filter obtained logits
|
||||||
|
whisper_logits_filter_callback logits_filter_callback;
|
||||||
|
void * logits_filter_callback_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
||||||
|
Reference in New Issue
Block a user