ruby : support new-segment callback (#2506)

* Add Params#new_segment_callback= method

* Add tests for Params#new_segment_callback=

* Group tests for #transcribe

* Don't use static for thread-safety

* Set new_segment_callback only when necessary

* Remove redundant check

* [skip ci] Add Ruby version README

* Revert "Group tests for #transcribe"

This reverts commit 71b65b00cc.

* Revert "Add tests for Params#new_segment_callback="

This reverts commit 81e6df3bab.

* Add test for Context#full_n_segments

* Add Context#full_n_segments

* Add tests for lang API

* Add lang API

* Add tests for Context#full_lang_id API

* Add Context#full_lang_id

* Add abnormal test cases for lang

* Raise appropriate errors from lang APIs

* Add tests for Context#full_get_segment_t{0,1} API

* Add Context#full_get_segment_t{0,1}

* Add tests for Context#full_get_segment_speaker_turn_next API

* Add Context#full_get_segment_speaker_turn_next

* Add tests for Context#full_get_segment_text

* Add Context#full_get_setgment_text

* Add tests for Params#new_segment_callback=

* Run new segment callback

* Split tests to multiple files

* Use container struct for new segment callback

* Add tests for Params#new_segment_callback_user_data=

* Add Whisper::Params#new_user_callback_user_data=

* Add GC-related test for new segment callback

* Protect new segment callback related structs from GC

* Add meaningful test for build

* Rename: new_segment_callback_user_data -> new_segment_callback_container

* Add tests for Whisper::Segment

* Add Whisper::Segment and Whisper::Context#each_segment

* Extract c_ruby_whisper_callback_container_allocate()

* Add test for Whisper::Params#on_new_segment

* Add Whisper::Params#on_new_egment

* Assign symbol IDs to variables

* Make extsources.yaml simpler

* Update README

* Add document comments

* Add test for calling Whisper::Params#on_new_segment multiple times

* Add file dependencies to GitHub actions config and .gitignore

* Add more files to ext/.gitignore
This commit is contained in:
KITAITI Makoto 2024-10-28 22:43:27 +09:00 committed by GitHub
parent c0ea41f6b2
commit fc49ee4479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1117 additions and 175 deletions

View File

@ -16,6 +16,9 @@ on:
- ggml/src/ggml-quants.h - ggml/src/ggml-quants.h
- ggml/src/ggml-quants.c - ggml/src/ggml-quants.c
- ggml/src/ggml-cpu-impl.h - ggml/src/ggml-cpu-impl.h
- ggml/src/ggml-metal.m
- ggml/src/ggml-metal.metal
- ggml/src/ggml-blas.cpp
- ggml/include/ggml.h - ggml/include/ggml.h
- ggml/include/ggml-alloc.h - ggml/include/ggml-alloc.h
- ggml/include/ggml-backend.h - ggml/include/ggml-backend.h
@ -24,6 +27,8 @@ on:
- ggml/include/ggml-metal.h - ggml/include/ggml-metal.h
- ggml/include/ggml-sycl.h - ggml/include/ggml-sycl.h
- ggml/include/ggml-vulkan.h - ggml/include/ggml-vulkan.h
- ggml/include/ggml-blas.h
- scripts/get-flags.mk
- examples/dr_wav.h - examples/dr_wav.h
pull_request: pull_request:
paths: paths:
@ -41,6 +46,9 @@ on:
- ggml/src/ggml-quants.h - ggml/src/ggml-quants.h
- ggml/src/ggml-quants.c - ggml/src/ggml-quants.c
- ggml/src/ggml-cpu-impl.h - ggml/src/ggml-cpu-impl.h
- ggml/src/ggml-metal.m
- ggml/src/ggml-metal.metal
- ggml/src/ggml-blas.cpp
- ggml/include/ggml.h - ggml/include/ggml.h
- ggml/include/ggml-alloc.h - ggml/include/ggml-alloc.h
- ggml/include/ggml-backend.h - ggml/include/ggml-backend.h
@ -49,6 +57,8 @@ on:
- ggml/include/ggml-metal.h - ggml/include/ggml-metal.h
- ggml/include/ggml-sycl.h - ggml/include/ggml-sycl.h
- ggml/include/ggml-vulkan.h - ggml/include/ggml-vulkan.h
- ggml/include/ggml-blas.h
- scripts/get-flags.mk
- examples/dr_wav.h - examples/dr_wav.h
jobs: jobs:

View File

@ -1,4 +1,3 @@
README.md
LICENSE LICENSE
pkg/ pkg/
lib/whisper.* lib/whisper.*

110
bindings/ruby/README.md Normal file
View File

@ -0,0 +1,110 @@
whispercpp
==========
![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)
Ruby bindings for [whisper.cpp][], an interface of automatic speech recognition model.
Installation
------------
Install the gem and add to the application's Gemfile by executing:
$ bundle add whispercpp
If bundler is not being used to manage dependencies, install the gem by executing:
$ gem install whispercpp
Usage
-----
```ruby
require "whisper"
whisper = Whisper::Context.new("path/to/model.bin")
params = Whisper::Params.new
params.language = "en"
params.offset = 10_000
params.duration = 60_000
params.max_text_tokens = 300
params.translate = true
params.print_timestamps = false
whisper.transcribe("path/to/audio.wav", params) do |whole_text|
puts whole_text
end
```
### Preparing model ###
Use script to download model file(s):
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
sh ./models/download-ggml-model.sh base.en
```
There are some types of models. See [models][] page for details.
### Preparing audio file ###
Currently, whisper.cpp accepts only 16-bit WAV files.
### API ###
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
```ruby
def format_time(time_ms)
sec, decimal_part = time_ms.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
end
whisper.transcribe("path/to/audio.wav", params)
whisper.each_segment.with_index do |segment, index|
line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
nth: index + 1,
st: format_time(segment.start_time),
ed: format_time(segment.end_time),
text: segment.text
}
line << " (speaker turned)" if segment.speaker_next_turn?
puts line
end
```
You can also add hook to params called on new segment:
```ruby
def format_time(time_ms)
sec, decimal_part = time_ms.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
end
# Add hook before calling #transcribe
params.on_new_segment do |segment|
line = "[%{st} --> %{ed}] %{text}" % {
st: format_time(segment.start_time),
ed: format_time(segment.end_time),
text: segment.text
}
line << " (speaker turned)" if segment.speaker_next_turn?
puts line
end
whisper.transcribe("path/to/audio.wav", params)
```
[whisper.cpp]: https://github.com/ggerganov/whisper.cpp
[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models

View File

@ -5,17 +5,16 @@ require "yaml"
require "rake/testtask" require "rake/testtask"
extsources = YAML.load_file("extsources.yaml") extsources = YAML.load_file("extsources.yaml")
extsources.each_pair do |src_dir, dests| SOURCES = FileList[]
dests.each do |dest| extsources.each do |src|
src = Pathname(src_dir)/File.basename(dest) basename = src.pathmap("%f")
dest = basename == "LICENSE" ? basename : basename.pathmap("ext/%f")
file src file src
file dest => src do |t| file dest => src do |t|
cp t.source, t.name cp t.source, t.name
end
end end
SOURCES.include dest
end end
SOURCES = extsources.values.flatten
CLEAN.include SOURCES CLEAN.include SOURCES
CLEAN.include FileList[ CLEAN.include FileList[
"ext/*.o", "ext/*.o",

View File

@ -11,6 +11,10 @@ ggml-backend.c
ggml-backend.h ggml-backend.h
ggml-common.h ggml-common.h
ggml-cpu-impl.h ggml-cpu-impl.h
ggml-metal.m
ggml-metal.metal
ggml-metal-embed.metal
ggml-blas.cpp
ggml-cuda.h ggml-cuda.h
ggml-impl.h ggml-impl.h
ggml-kompute.h ggml-kompute.h
@ -20,9 +24,12 @@ ggml-quants.c
ggml-quants.h ggml-quants.h
ggml-sycl.h ggml-sycl.h
ggml-vulkan.h ggml-vulkan.h
ggml-blas.h
get-flags.mk
whisper.cpp whisper.cpp
whisper.h whisper.h
dr_wav.h dr_wav.h
depend
whisper.bundle whisper.bundle
whisper.so whisper.so
whisper.dll whisper.dll

View File

@ -36,12 +36,65 @@ VALUE mWhisper;
VALUE cContext; VALUE cContext;
VALUE cParams; VALUE cParams;
static ID id_to_s;
static ID id_call;
static ID id___method__;
static ID id_to_enum;
/*
* call-seq:
* lang_max_id -> Integer
*/
static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
return INT2NUM(whisper_lang_max_id());
}
/*
* call-seq:
* lang_id(lang_name) -> Integer
*/
static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
const char * lang_str = StringValueCStr(lang);
const int id = whisper_lang_id(lang_str);
if (-1 == id) {
rb_raise(rb_eArgError, "language not found: %s", lang_str);
}
return INT2NUM(id);
}
/*
* call-seq:
* lang_str(lang_id) -> String
*/
static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
const int lang_id = NUM2INT(id);
const char * str = whisper_lang_str(lang_id);
if (nullptr == str) {
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
}
return rb_str_new2(str);
}
/*
* call-seq:
* lang_str(lang_id) -> String
*/
static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
const int lang_id = NUM2INT(id);
const char * str_full = whisper_lang_str_full(lang_id);
if (nullptr == str_full) {
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
}
return rb_str_new2(str_full);
}
static void ruby_whisper_free(ruby_whisper *rw) { static void ruby_whisper_free(ruby_whisper *rw) {
if (rw->context) { if (rw->context) {
whisper_free(rw->context); whisper_free(rw->context);
rw->context = NULL; rw->context = NULL;
} }
} }
static void ruby_whisper_params_free(ruby_whisper_params *rwp) { static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
} }
@ -55,9 +108,13 @@ void rb_whisper_free(ruby_whisper *rw) {
} }
void rb_whisper_params_mark(ruby_whisper_params *rwp) { void rb_whisper_params_mark(ruby_whisper_params *rwp) {
rb_gc_mark(rwp->new_segment_callback_container->user_data);
rb_gc_mark(rwp->new_segment_callback_container->callback);
rb_gc_mark(rwp->new_segment_callback_container->callbacks);
} }
void rb_whisper_params_free(ruby_whisper_params *rwp) { void rb_whisper_params_free(ruby_whisper_params *rwp) {
// How to free user_data and callback only when not referred to by others?
ruby_whisper_params_free(rwp); ruby_whisper_params_free(rwp);
free(rwp); free(rwp);
} }
@ -69,13 +126,28 @@ static VALUE ruby_whisper_allocate(VALUE klass) {
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
} }
static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() {
ruby_whisper_callback_container *container;
container = ALLOC(ruby_whisper_callback_container);
container->context = nullptr;
container->user_data = Qnil;
container->callback = Qnil;
container->callbacks = rb_ary_new();
return container;
}
static VALUE ruby_whisper_params_allocate(VALUE klass) { static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params); rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
} }
/*
* call-seq:
* new("path/to/model.bin") -> Whisper::Context
*/
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw; ruby_whisper *rw;
VALUE whisper_model_file_path; VALUE whisper_model_file_path;
@ -84,7 +156,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
rb_scan_args(argc, argv, "01", &whisper_model_file_path); rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw); Data_Get_Struct(self, ruby_whisper, rw);
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) { if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
} }
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
@ -94,10 +166,21 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
return self; return self;
} }
// High level API
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
/* /*
* transcribe a single file * transcribe a single file
* can emit to a block results * can emit to a block results
* *
* params = Whisper::Params.new
* params.duration = 60_000
* whisper.transcribe "path/to/audio.wav", params do |text|
* puts text
* end
*
* call-seq:
* transcribe(path_to_audio, params) {|text| ...}
**/ **/
static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw; ruby_whisper *rw;
@ -108,7 +191,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
Data_Get_Struct(self, ruby_whisper, rw); Data_Get_Struct(self, ruby_whisper, rw);
Data_Get_Struct(params, ruby_whisper_params, rwp); Data_Get_Struct(params, ruby_whisper_params, rwp);
if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) { if (!rb_respond_to(wave_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to wave file"); rb_raise(rb_eRuntimeError, "Expected file path to wave file");
} }
@ -206,6 +289,33 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
rwp->params.encoder_begin_callback_user_data = &is_aborted; rwp->params.encoder_begin_callback_user_data = &is_aborted;
} }
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
};
rwp->new_segment_callback_container->context = &self;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n"); fprintf(stderr, "failed to process audio\n");
return self; return self;
@ -216,15 +326,114 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
const char * text = whisper_full_get_segment_text(rw->context, i); const char * text = whisper_full_get_segment_text(rw->context, i);
output = rb_str_concat(output, rb_str_new2(text)); output = rb_str_concat(output, rb_str_new2(text));
} }
VALUE idCall = rb_intern("call"); VALUE idCall = id_call;
if (blk != Qnil) { if (blk != Qnil) {
rb_funcall(blk, idCall, 1, output); rb_funcall(blk, idCall, 1, output);
} }
return self; return self;
} }
/*
* Number of segments.
*
* call-seq:
* full_n_segments -> Integer
*/
static VALUE ruby_whisper_full_n_segments(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_n_segments(rw->context));
}
/*
* Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
*
* call-seq:
* full_lang_id -> Integer
*/
static VALUE ruby_whisper_full_lang_id(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_lang_id(rw->context));
}
static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) {
const int c_i_segment = NUM2INT(i_segment);
if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
}
return c_i_segment;
}
/*
* Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
*
* full_get_segment_t0(3) # => 1668 (16680 ms)
*
* call-seq:
* full_get_segment_t0(segment_index) -> Integer
*/
static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
return INT2NUM(t0);
}
/*
* End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
*
* full_get_segment_t1(3) # => 1668 (16680 ms)
*
* call-seq:
* full_get_segment_t1(segment_index) -> Integer
*/
static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
return INT2NUM(t1);
}
/*
* Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
*
* full_get_segment_speacker_turn_next(3) # => true
*
* call-seq:
* full_get_segment_speacker_turn_next(segment_index) -> bool
*/
static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
return speaker_turn_next ? Qtrue : Qfalse;
}
/*
* Text of a segment indexed by +segment_index+.
*
* full_get_segment_text(3) # => "ask not what your country can do for you, ..."
*
* call-seq:
* full_get_segment_text(segment_index) -> String
*/
static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
return rb_str_new2(text);
}
/* /*
* params.language = "auto" | "en", etc... * params.language = "auto" | "en", etc...
*
* call-seq:
* language = lang_name -> lang_name
*/ */
static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
@ -236,6 +445,10 @@ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
} }
return value; return value;
} }
/*
* call-seq:
* language -> String
*/
static VALUE ruby_whisper_params_get_language(VALUE self) { static VALUE ruby_whisper_params_get_language(VALUE self) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
@ -245,72 +458,185 @@ static VALUE ruby_whisper_params_get_language(VALUE self) {
return rb_str_new2("auto"); return rb_str_new2("auto");
} }
} }
/*
* call-seq:
* translate = do_translate -> do_translate
*/
static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, translate, value) BOOL_PARAMS_SETTER(self, translate, value)
} }
/*
* call-seq:
* translate -> bool
*/
static VALUE ruby_whisper_params_get_translate(VALUE self) { static VALUE ruby_whisper_params_get_translate(VALUE self) {
BOOL_PARAMS_GETTER(self, translate) BOOL_PARAMS_GETTER(self, translate)
} }
/*
* call-seq:
* no_context = dont_use_context -> dont_use_context
*/
static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, no_context, value) BOOL_PARAMS_SETTER(self, no_context, value)
} }
/*
* If true, does not use past transcription (if any) as initial prompt for the decoder.
*
* call-seq:
* no_context -> bool
*/
static VALUE ruby_whisper_params_get_no_context(VALUE self) { static VALUE ruby_whisper_params_get_no_context(VALUE self) {
BOOL_PARAMS_GETTER(self, no_context) BOOL_PARAMS_GETTER(self, no_context)
} }
/*
* call-seq:
* single_segment = force_single -> force_single
*/
static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, single_segment, value) BOOL_PARAMS_SETTER(self, single_segment, value)
} }
/*
* If true, forces single segment output (useful for streaming).
*
* call-seq:
* single_segment -> bool
*/
static VALUE ruby_whisper_params_get_single_segment(VALUE self) { static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
BOOL_PARAMS_GETTER(self, single_segment) BOOL_PARAMS_GETTER(self, single_segment)
} }
/*
* call-seq:
* print_special = force_print -> force_print
*/
static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_special, value) BOOL_PARAMS_SETTER(self, print_special, value)
} }
/*
* If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
*
* call-seq:
* print_special -> bool
*/
static VALUE ruby_whisper_params_get_print_special(VALUE self) { static VALUE ruby_whisper_params_get_print_special(VALUE self) {
BOOL_PARAMS_GETTER(self, print_special) BOOL_PARAMS_GETTER(self, print_special)
} }
/*
* call-seq:
* print_progress = force_print -> force_print
*/
static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_progress, value) BOOL_PARAMS_SETTER(self, print_progress, value)
} }
/*
* If true, prints progress information.
*
* call-seq:
* print_progress -> bool
*/
static VALUE ruby_whisper_params_get_print_progress(VALUE self) { static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
BOOL_PARAMS_GETTER(self, print_progress) BOOL_PARAMS_GETTER(self, print_progress)
} }
/*
* call-seq:
* print_realtime = force_print -> force_print
*/
static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_realtime, value) BOOL_PARAMS_SETTER(self, print_realtime, value)
} }
/*
* If true, prints results from within whisper.cpp. (avoid it, use callback instead)
* call-seq:
* print_realtime -> bool
*/
static VALUE ruby_whisper_params_get_print_realtime(VALUE self) { static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
BOOL_PARAMS_GETTER(self, print_realtime) BOOL_PARAMS_GETTER(self, print_realtime)
} }
/*
* call-seq:
* print_timestamps = force_print -> force_print
*/
static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, print_timestamps, value) BOOL_PARAMS_SETTER(self, print_timestamps, value)
} }
/*
* If true, prints timestamps for each text segment when printing realtime.
*
* call-seq:
* print_timestamps -> bool
*/
static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) { static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, print_timestamps) BOOL_PARAMS_GETTER(self, print_timestamps)
} }
/*
* call-seq:
* suppress_blank = force_suppress -> force_suppress
*/
static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_blank, value) BOOL_PARAMS_SETTER(self, suppress_blank, value)
} }
/*
* If true, suppresses blank outputs.
*
* call-seq:
* suppress_blank -> bool
*/
static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_blank) BOOL_PARAMS_GETTER(self, suppress_blank)
} }
/*
* call-seq:
* suppress_non_speech_tokens = force_suppress -> force_suppress
*/
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
} }
/*
* If true, suppresses non-speech-tokens.
*
* call-seq:
* suppress_non_speech_tokens -> bool
*/
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
} }
/*
* If true, enables token-level timestamps.
*
* call-seq:
* token_timestamps -> bool
*/
static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
BOOL_PARAMS_GETTER(self, token_timestamps) BOOL_PARAMS_GETTER(self, token_timestamps)
} }
/*
* call-seq:
* token_timestamps = force_timestamps -> force_timestamps
*/
static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, token_timestamps, value) BOOL_PARAMS_SETTER(self, token_timestamps, value)
} }
/*
* If true, split on word rather than on token (when used with max_len).
*
* call-seq:
* translate -> bool
*/
static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
BOOL_PARAMS_GETTER(self, split_on_word) BOOL_PARAMS_GETTER(self, split_on_word)
} }
/*
* call-seq:
* split_on_word = force_split -> force_split
*/
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, split_on_word, value) BOOL_PARAMS_SETTER(self, split_on_word, value)
} }
/*
* If true, enables diarization.
*
* call-seq:
* diarize -> bool
*/
static VALUE ruby_whisper_params_get_diarize(VALUE self) { static VALUE ruby_whisper_params_get_diarize(VALUE self) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
@ -320,6 +646,10 @@ static VALUE ruby_whisper_params_get_diarize(VALUE self) {
return Qfalse; return Qfalse;
} }
} }
/*
* call-seq:
* diarize = force_diarize -> force_diarize
*/
static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
@ -331,22 +661,42 @@ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
return value; return value;
} }
/*
* Start offset in ms.
*
* call-seq:
* offset -> Integer
*/
static VALUE ruby_whisper_params_get_offset(VALUE self) { static VALUE ruby_whisper_params_get_offset(VALUE self) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.offset_ms); return INT2NUM(rwp->params.offset_ms);
} }
/*
* call-seq:
* offset = offset_ms -> offset_ms
*/
static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.offset_ms = NUM2INT(value); rwp->params.offset_ms = NUM2INT(value);
return value; return value;
} }
/*
* Audio duration to process in ms.
*
* call-seq:
* duration -> Integer
*/
static VALUE ruby_whisper_params_get_duration(VALUE self) { static VALUE ruby_whisper_params_get_duration(VALUE self) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.duration_ms); return INT2NUM(rwp->params.duration_ms);
} }
/*
* call-seq:
* duration = duration_ms -> duration_ms
*/
static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
@ -354,27 +704,221 @@ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
return value; return value;
} }
/*
* Max tokens to use from past text as prompt for the decoder.
*
* call-seq:
* max_text_tokens -> Integer
*/
static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.n_max_text_ctx); return INT2NUM(rwp->params.n_max_text_ctx);
} }
/*
* call-seq:
* max_text_tokens = n_tokens -> n_tokens
*/
static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.n_max_text_ctx = NUM2INT(value); rwp->params.n_max_text_ctx = NUM2INT(value);
return value; return value;
} }
/*
* Sets new segment callback, called for every newly generated text segment.
*
* params.new_segment_callback = ->(context, _, n_new, user_data) {
* # ...
* }
*
* call-seq:
* new_segment_callback = callback -> callback
*/
static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->callback = value;
return value;
}
/*
* Sets user data passed to the last argument of new segment callback.
*
* call-seq:
* new_segment_callback_user_data = user_data -> use_data
*/
static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->user_data = value;
return value;
}
// High level API
typedef struct {
VALUE context;
int index;
} ruby_whisper_segment;
VALUE cSegment;
static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
rb_gc_mark(rws->context);
}
static VALUE ruby_whisper_segment_allocate(VALUE klass) {
ruby_whisper_segment *rws;
rws = ALLOC(ruby_whisper_segment);
return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
}
static VALUE rb_whisper_segment_initialize(VALUE context, int index) {
ruby_whisper_segment *rws;
const VALUE segment = ruby_whisper_segment_allocate(cSegment);
Data_Get_Struct(segment, ruby_whisper_segment, rws);
rws->context = context;
rws->index = index;
return segment;
};
/*
* Yields each Whisper::Segment:
*
* whisper.transcribe("path/to/audio.wav", params)
* whisper.each_segment do |segment|
* puts segment.text
* end
*
* Returns an Enumerator if no block given:
*
* whisper.transcribe("path/to/audio.wav", params)
* enum = whisper.each_segment
* enum.to_a # => [#<Whisper::Segment>, ...]
*
* call-seq:
* each_segment {|segment| ... }
* each_segment -> Enumerator
*/
static VALUE ruby_whisper_each_segment(VALUE self) {
if (!rb_block_given_p()) {
const VALUE method_name = rb_funcall(self, id___method__, 0);
return rb_funcall(self, id_to_enum, 1, method_name);
}
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int n_segments = whisper_full_n_segments(rw->context);
for (int i = 0; i < n_segments; ++i) {
rb_yield(rb_whisper_segment_initialize(self, i));
}
return self;
}
/*
* Hook called on new segment. Yields each Whisper::Segment.
*
* whisper.on_new_segment do |segment|
* # ...
* end
*
* call-seq:
* on_new_segment {|segment| ... }
*/
static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
return Qnil;
}
/*
* Start time in milliseconds.
*
* call-seq:
* start_time -> Integer
*/
static VALUE ruby_whisper_segment_get_start_time(VALUE self) {
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return INT2NUM(t0 * 10);
}
/*
* End time in milliseconds.
*
* call-seq:
* end_time -> Integer
*/
static VALUE ruby_whisper_segment_get_end_time(VALUE self) {
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return INT2NUM(t1 * 10);
}
/*
* Whether the next segment is predicted as a speaker turn.
*
* call-seq:
* speaker_turn_next? -> bool
*/
static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) {
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
}
/*
* call-seq:
* text -> String
*/
static VALUE ruby_whisper_segment_get_text(VALUE self) {
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
const char * text = whisper_full_get_segment_text(rw->context, rws->index);
return rb_str_new2(text);
}
void Init_whisper() { void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
id___method__ = rb_intern("__method__");
id_to_enum = rb_intern("to_enum");
mWhisper = rb_define_module("Whisper"); mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate); rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
@ -412,6 +956,20 @@ void Init_whisper() {
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
// High leve
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
} }
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -3,6 +3,13 @@
#include "whisper.h" #include "whisper.h"
typedef struct {
VALUE *context;
VALUE user_data;
VALUE callback;
VALUE callbacks;
} ruby_whisper_callback_container;
typedef struct { typedef struct {
struct whisper_context *context; struct whisper_context *context;
} ruby_whisper; } ruby_whisper;
@ -10,6 +17,7 @@ typedef struct {
typedef struct { typedef struct {
struct whisper_full_params params; struct whisper_full_params params;
bool diarize; bool diarize;
ruby_whisper_callback_container *new_segment_callback_container;
} ruby_whisper_params; } ruby_whisper_params;
#endif #endif

View File

@ -1,37 +1,29 @@
--- ---
../../src: - ../../src/whisper.cpp
- ext/whisper.cpp - ../../include/whisper.h
../../include: - ../../ggml/src/ggml.c
- ext/whisper.h - ../../ggml/src/ggml-impl.h
../../ggml/src: - ../../ggml/src/ggml-aarch64.h
- ext/ggml.c - ../../ggml/src/ggml-aarch64.c
- ext/ggml-impl.h - ../../ggml/src/ggml-alloc.c
- ext/ggml-aarch64.h - ../../ggml/src/ggml-backend-impl.h
- ext/ggml-aarch64.c - ../../ggml/src/ggml-backend.cpp
- ext/ggml-alloc.c - ../../ggml/src/ggml-common.h
- ext/ggml-backend-impl.h - ../../ggml/src/ggml-quants.h
- ext/ggml-backend.cpp - ../../ggml/src/ggml-quants.c
- ext/ggml-common.h - ../../ggml/src/ggml-cpu-impl.h
- ext/ggml-quants.h - ../../ggml/src/ggml-metal.m
- ext/ggml-quants.c - ../../ggml/src/ggml-metal.metal
- ext/ggml-cpu-impl.h - ../../ggml/src/ggml-blas.cpp
- ext/ggml-metal.m - ../../ggml/include/ggml.h
- ext/ggml-metal.metal - ../../ggml/include/ggml-alloc.h
- ext/ggml-blas.cpp - ../../ggml/include/ggml-backend.h
../../ggml/include: - ../../ggml/include/ggml-cuda.h
- ext/ggml.h - ../../ggml/include/ggml-kompute.h
- ext/ggml-alloc.h - ../../ggml/include/ggml-metal.h
- ext/ggml-backend.h - ../../ggml/include/ggml-sycl.h
- ext/ggml-cuda.h - ../../ggml/include/ggml-vulkan.h
- ext/ggml-kompute.h - ../../ggml/include/ggml-blas.h
- ext/ggml-metal.h - ../../scripts/get-flags.mk
- ext/ggml-sycl.h - ../../examples/dr_wav.h
- ext/ggml-vulkan.h - ../../LICENSE
- ext/ggml-blas.h
../../scripts:
- ext/get-flags.mk
../../examples:
- ext/dr_wav.h
../..:
- README.md
- LICENSE

View File

@ -0,0 +1,76 @@
require "test/unit"
require "whisper"
class TestCallback < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
def setup
@params = Whisper::Params.new
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
end
def test_new_segment_callback
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_kind_of Integer, n_new
assert n_new > 0
assert_same @whisper, context
n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
start_time = context.full_get_segment_t0(i_segment) * 10
end_time = context.full_get_segment_t1(i_segment) * 10
text = context.full_get_segment_text(i_segment)
assert_kind_of Integer, start_time
assert start_time >= 0
assert_kind_of Integer, end_time
assert end_time > 0
assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
end
}
@whisper.transcribe(@audio, @params)
end
def test_new_segment_callback_closure
search_word = "what"
@params.new_segment_callback = ->(context, state, n_new, user_data) {
n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
text = context.full_get_segment_text(i_segment)
if text.include?(search_word)
t0 = context.full_get_segment_t0(i_segment)
t1 = context.full_get_segment_t1(i_segment)
raise "search word '#{search_word}' found at between #{t0} and #{t1}"
end
end
}
assert_raise RuntimeError do
@whisper.transcribe(@audio, @params)
end
end
def test_new_segment_callback_user_data
udata = Object.new
@params.new_segment_callback_user_data = udata
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_same udata, user_data
}
@whisper.transcribe(@audio, @params)
end
def test_new_segment_callback_user_data_gc
@params.new_segment_callback_user_data = "My user data"
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_equal "My user data", user_data
}
GC.start
assert_same @whisper, @whisper.transcribe(@audio, @params)
end
end

View File

@ -0,0 +1,28 @@
require 'test/unit'
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestPackage < Test::Unit::TestCase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert file.size > 0
end
end
sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end
def test_install
filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename)
end
end
end
end

View File

@ -0,0 +1,112 @@
require 'whisper'
class TestParams < Test::Unit::TestCase
def setup
@params = Whisper::Params.new
end
def test_language
@params.language = "en"
assert_equal @params.language, "en"
@params.language = "auto"
assert_equal @params.language, "auto"
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@params.token_timestamps = true
assert @params.token_timestamps
@params.token_timestamps = false
assert !@params.token_timestamps
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@params.split_on_word = false
assert !@params.split_on_word
end
end

View File

@ -0,0 +1,87 @@
require "test/unit"
require "whisper"
class TestSegment < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params)
end
end
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
end
end
def test_enumerator
enum = whisper.each_segment
assert_instance_of Enumerator, enum
enum.to_a.each_with_index do |segment, index|
assert_instance_of Whisper::Segment, segment
assert_kind_of Integer, index
end
end
def test_start_time
i = 0
whisper.each_segment do |segment|
assert_equal 0, segment.start_time if i == 0
i += 1
end
end
def test_end_time
i = 0
whisper.each_segment do |segment|
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
i += 1
end
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
index = 0
params.on_new_segment do |segment|
assert_instance_of Whisper::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
end
index += 1
end
whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
assert_equal 0, seg.start_time
assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text
end
def test_on_new_segment_twice
params = Whisper::Params.new
seg = nil
params.on_new_segment do |segment|
seg = segment
return
end
params.on_new_segment do |segment|
assert_same seg, segment
return
end
whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
end
private
def whisper
self.class.whisper
end
end

View File

@ -1,121 +1,13 @@
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
require 'whisper' require 'whisper'
require 'test/unit' require 'test/unit'
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestWhisper < Test::Unit::TestCase class TestWhisper < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
def setup def setup
@params = Whisper::Params.new @params = Whisper::Params.new
end end
def test_language
@params.language = "en"
assert_equal @params.language, "en"
@params.language = "auto"
assert_equal @params.language, "auto"
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@params.token_timestamps = true
assert @params.token_timestamps
@params.token_timestamps = false
assert !@params.token_timestamps
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@params.split_on_word = false
assert !@params.split_on_word
end
def test_whisper def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new params = Whisper::Params.new
@ -127,25 +19,81 @@ class TestWhisper < Test::Unit::TestCase
} }
end end
def test_build sub_test_case "After transcription" do
Tempfile.create do |file| class << self
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) attr_reader :whisper
assert_path_exist file.to_path
def startup
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params)
end
end
def whisper
self.class.whisper
end
def test_full_n_segments
assert_equal 1, whisper.full_n_segments
end
def test_full_lang_id
assert_equal 0, whisper.full_lang_id
end
def test_full_get_segment_t0
assert_equal 0, whisper.full_get_segment_t0(0)
assert_raise IndexError do
whisper.full_get_segment_t0(whisper.full_n_segments)
end
assert_raise IndexError do
whisper.full_get_segment_t0(-1)
end
end
def test_full_get_segment_t1
t1 = whisper.full_get_segment_t1(0)
assert_kind_of Integer, t1
assert t1 > 0
assert_raise IndexError do
whisper.full_get_segment_t1(whisper.full_n_segments)
end
end
def test_full_get_segment_speaker_turn_next
assert_false whisper.full_get_segment_speaker_turn_next(0)
end
def test_full_get_segment_text
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
end end
end end
sub_test_case "Building binary on installation" do def test_lang_max_id
def setup assert_kind_of Integer, Whisper.lang_max_id
system "rake", "build", exception: true end
end
def test_install def test_lang_id
filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1] assert_equal 0, Whisper.lang_id("en")
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" assert_raise ArgumentError do
Dir.mktmpdir do |dir| Whisper.lang_id("non existing language")
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true end
assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename) end
end
def test_lang_str
assert_equal "en", Whisper.lang_str(0)
assert_raise IndexError do
Whisper.lang_str(Whisper.lang_max_id + 1)
end
end
def test_lang_str_full
assert_equal "english", Whisper.lang_str_full(0)
assert_raise IndexError do
Whisper.lang_str_full(Whisper.lang_max_id + 1)
end end
end end
end end

View File

@ -9,7 +9,15 @@ Gem::Specification.new do |s|
s.email = 'todd.fisher@gmail.com' s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md'] s.extra_rdoc_files = ['LICENSE', 'README.md']
s.files = `git ls-files . -z`.split("\x0") + YAML.load_file("extsources.yaml").values.flatten s.files = `git ls-files . -z`.split("\x0") +
YAML.load_file("extsources.yaml").collect {|file|
basename = File.basename(file)
if s.extra_rdoc_files.include?(basename)
basename
else
File.join("ext", basename)
end
}
s.summary = %q{Ruby whisper.cpp bindings} s.summary = %q{Ruby whisper.cpp bindings}
s.test_files = ["tests/test_whisper.rb"] s.test_files = ["tests/test_whisper.rb"]