From fc49ee4479c59372b34c40cdfb71ea2a96836c8c Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Mon, 28 Oct 2024 22:43:27 +0900 Subject: [PATCH] ruby : support new-segment callback (#2506) * Add Params#new_segment_callback= method * Add tests for Params#new_segment_callback= * Group tests for #transcribe * Don't use static for thread-safety * Set new_segment_callback only when necessary * Remove redundant check * [skip ci] Add Ruby version README * Revert "Group tests for #transcribe" This reverts commit 71b65b00ccf1816c9ea8a247fb30f71bc09707d3. * Revert "Add tests for Params#new_segment_callback=" This reverts commit 81e6df3bab7662da5379db51f28a989db7408c02. * Add test for Context#full_n_segments * Add Context#full_n_segments * Add tests for lang API * Add lang API * Add tests for Context#full_lang_id API * Add Context#full_lang_id * Add abnormal test cases for lang * Raise appropriate errors from lang APIs * Add tests for Context#full_get_segment_t{0,1} API * Add Context#full_get_segment_t{0,1} * Add tests for Context#full_get_segment_speaker_turn_next API * Add Context#full_get_segment_speaker_turn_next * Add tests for Context#full_get_segment_text * Add Context#full_get_setgment_text * Add tests for Params#new_segment_callback= * Run new segment callback * Split tests to multiple files * Use container struct for new segment callback * Add tests for Params#new_segment_callback_user_data= * Add Whisper::Params#new_user_callback_user_data= * Add GC-related test for new segment callback * Protect new segment callback related structs from GC * Add meaningful test for build * Rename: new_segment_callback_user_data -> new_segment_callback_container * Add tests for Whisper::Segment * Add Whisper::Segment and Whisper::Context#each_segment * Extract c_ruby_whisper_callback_container_allocate() * Add test for Whisper::Params#on_new_segment * Add Whisper::Params#on_new_egment * Assign symbol IDs to variables * Make extsources.yaml simpler * Update README * Add document comments * Add test for calling Whisper::Params#on_new_segment multiple times * Add file dependencies to GitHub actions config and .gitignore * Add more files to ext/.gitignore --- .github/workflows/bindings-ruby.yml | 10 + bindings/ruby/.gitignore | 1 - bindings/ruby/README.md | 110 ++++++ bindings/ruby/Rakefile | 17 +- bindings/ruby/ext/.gitignore | 7 + bindings/ruby/ext/ruby_whisper.cpp | 564 ++++++++++++++++++++++++++- bindings/ruby/ext/ruby_whisper.h | 8 + bindings/ruby/extsources.yaml | 64 ++- bindings/ruby/tests/test_callback.rb | 76 ++++ bindings/ruby/tests/test_package.rb | 28 ++ bindings/ruby/tests/test_params.rb | 112 ++++++ bindings/ruby/tests/test_segment.rb | 87 +++++ bindings/ruby/tests/test_whisper.rb | 198 ++++------ bindings/ruby/whispercpp.gemspec | 10 +- 14 files changed, 1117 insertions(+), 175 deletions(-) create mode 100644 bindings/ruby/README.md create mode 100644 bindings/ruby/tests/test_callback.rb create mode 100644 bindings/ruby/tests/test_package.rb create mode 100644 bindings/ruby/tests/test_params.rb create mode 100644 bindings/ruby/tests/test_segment.rb diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 2b9b57bf..d1d3c341 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -16,6 +16,9 @@ on: - ggml/src/ggml-quants.h - ggml/src/ggml-quants.c - ggml/src/ggml-cpu-impl.h + - ggml/src/ggml-metal.m + - ggml/src/ggml-metal.metal + - ggml/src/ggml-blas.cpp - ggml/include/ggml.h - ggml/include/ggml-alloc.h - ggml/include/ggml-backend.h @@ -24,6 +27,8 @@ on: - ggml/include/ggml-metal.h - ggml/include/ggml-sycl.h - ggml/include/ggml-vulkan.h + - ggml/include/ggml-blas.h + - scripts/get-flags.mk - examples/dr_wav.h pull_request: paths: @@ -41,6 +46,9 @@ on: - ggml/src/ggml-quants.h - ggml/src/ggml-quants.c - ggml/src/ggml-cpu-impl.h + - ggml/src/ggml-metal.m + - ggml/src/ggml-metal.metal + - ggml/src/ggml-blas.cpp - ggml/include/ggml.h - ggml/include/ggml-alloc.h - ggml/include/ggml-backend.h @@ -49,6 +57,8 @@ on: - ggml/include/ggml-metal.h - ggml/include/ggml-sycl.h - ggml/include/ggml-vulkan.h + - ggml/include/ggml-blas.h + - scripts/get-flags.mk - examples/dr_wav.h jobs: diff --git a/bindings/ruby/.gitignore b/bindings/ruby/.gitignore index 6ff6e5f2..e04a90a9 100644 --- a/bindings/ruby/.gitignore +++ b/bindings/ruby/.gitignore @@ -1,4 +1,3 @@ -README.md LICENSE pkg/ lib/whisper.* diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md new file mode 100644 index 00000000..29dba120 --- /dev/null +++ b/bindings/ruby/README.md @@ -0,0 +1,110 @@ +whispercpp +========== + +![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg) + +Ruby bindings for [whisper.cpp][], an interface of automatic speech recognition model. + +Installation +------------ + +Install the gem and add to the application's Gemfile by executing: + + $ bundle add whispercpp + +If bundler is not being used to manage dependencies, install the gem by executing: + + $ gem install whispercpp + +Usage +----- + +```ruby +require "whisper" + +whisper = Whisper::Context.new("path/to/model.bin") + +params = Whisper::Params.new +params.language = "en" +params.offset = 10_000 +params.duration = 60_000 +params.max_text_tokens = 300 +params.translate = true +params.print_timestamps = false + +whisper.transcribe("path/to/audio.wav", params) do |whole_text| + puts whole_text +end + +``` + +### Preparing model ### + +Use script to download model file(s): + +```bash +git clone https://github.com/ggerganov/whisper.cpp.git +cd whisper.cpp +sh ./models/download-ggml-model.sh base.en +``` + +There are some types of models. See [models][] page for details. + +### Preparing audio file ### + +Currently, whisper.cpp accepts only 16-bit WAV files. + +### API ### + +Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: + +```ruby +def format_time(time_ms) + sec, decimal_part = time_ms.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part] +end + +whisper.transcribe("path/to/audio.wav", params) + +whisper.each_segment.with_index do |segment, index| + line = "[%{nth}: %{st} --> %{ed}] %{text}" % { + nth: index + 1, + st: format_time(segment.start_time), + ed: format_time(segment.end_time), + text: segment.text + } + line << " (speaker turned)" if segment.speaker_next_turn? + puts line +end + +``` + +You can also add hook to params called on new segment: + +```ruby +def format_time(time_ms) + sec, decimal_part = time_ms.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part] +end + +# Add hook before calling #transcribe +params.on_new_segment do |segment| + line = "[%{st} --> %{ed}] %{text}" % { + st: format_time(segment.start_time), + ed: format_time(segment.end_time), + text: segment.text + } + line << " (speaker turned)" if segment.speaker_next_turn? + puts line +end + +whisper.transcribe("path/to/audio.wav", params) + +``` + +[whisper.cpp]: https://github.com/ggerganov/whisper.cpp +[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 9b2787e9..5a6a9167 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -5,17 +5,16 @@ require "yaml" require "rake/testtask" extsources = YAML.load_file("extsources.yaml") -extsources.each_pair do |src_dir, dests| - dests.each do |dest| - src = Pathname(src_dir)/File.basename(dest) - - file src - file dest => src do |t| - cp t.source, t.name - end +SOURCES = FileList[] +extsources.each do |src| + basename = src.pathmap("%f") + dest = basename == "LICENSE" ? basename : basename.pathmap("ext/%f") + file src + file dest => src do |t| + cp t.source, t.name end + SOURCES.include dest end -SOURCES = extsources.values.flatten CLEAN.include SOURCES CLEAN.include FileList[ "ext/*.o", diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore index 3e996866..c9f31967 100644 --- a/bindings/ruby/ext/.gitignore +++ b/bindings/ruby/ext/.gitignore @@ -11,6 +11,10 @@ ggml-backend.c ggml-backend.h ggml-common.h ggml-cpu-impl.h +ggml-metal.m +ggml-metal.metal +ggml-metal-embed.metal +ggml-blas.cpp ggml-cuda.h ggml-impl.h ggml-kompute.h @@ -20,9 +24,12 @@ ggml-quants.c ggml-quants.h ggml-sycl.h ggml-vulkan.h +ggml-blas.h +get-flags.mk whisper.cpp whisper.h dr_wav.h +depend whisper.bundle whisper.so whisper.dll diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9d933453..b17a6bca 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -36,12 +36,65 @@ VALUE mWhisper; VALUE cContext; VALUE cParams; +static ID id_to_s; +static ID id_call; +static ID id___method__; +static ID id_to_enum; + +/* + * call-seq: + * lang_max_id -> Integer + */ +static VALUE ruby_whisper_s_lang_max_id(VALUE self) { + return INT2NUM(whisper_lang_max_id()); +} + +/* + * call-seq: + * lang_id(lang_name) -> Integer + */ +static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { + const char * lang_str = StringValueCStr(lang); + const int id = whisper_lang_id(lang_str); + if (-1 == id) { + rb_raise(rb_eArgError, "language not found: %s", lang_str); + } + return INT2NUM(id); +} + +/* + * call-seq: + * lang_str(lang_id) -> String + */ +static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { + const int lang_id = NUM2INT(id); + const char * str = whisper_lang_str(lang_id); + if (nullptr == str) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str); +} + +/* + * call-seq: + * lang_str(lang_id) -> String + */ +static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { + const int lang_id = NUM2INT(id); + const char * str_full = whisper_lang_str_full(lang_id); + if (nullptr == str_full) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str_full); +} + static void ruby_whisper_free(ruby_whisper *rw) { if (rw->context) { whisper_free(rw->context); rw->context = NULL; } } + static void ruby_whisper_params_free(ruby_whisper_params *rwp) { } @@ -55,9 +108,13 @@ void rb_whisper_free(ruby_whisper *rw) { } void rb_whisper_params_mark(ruby_whisper_params *rwp) { + rb_gc_mark(rwp->new_segment_callback_container->user_data); + rb_gc_mark(rwp->new_segment_callback_container->callback); + rb_gc_mark(rwp->new_segment_callback_container->callbacks); } void rb_whisper_params_free(ruby_whisper_params *rwp) { + // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); free(rwp); } @@ -69,13 +126,28 @@ static VALUE ruby_whisper_allocate(VALUE klass) { return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); } +static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() { + ruby_whisper_callback_container *container; + container = ALLOC(ruby_whisper_callback_container); + container->context = nullptr; + container->user_data = Qnil; + container->callback = Qnil; + container->callbacks = rb_ary_new(); + return container; +} + static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } +/* + * call-seq: + * new("path/to/model.bin") -> Whisper::Context + */ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; VALUE whisper_model_file_path; @@ -84,7 +156,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { rb_scan_args(argc, argv, "01", &whisper_model_file_path); Data_Get_Struct(self, ruby_whisper, rw); - if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) { + if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); @@ -94,10 +166,21 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { return self; } +// High level API +static VALUE rb_whisper_segment_initialize(VALUE context, int index); + /* * transcribe a single file * can emit to a block results * + * params = Whisper::Params.new + * params.duration = 60_000 + * whisper.transcribe "path/to/audio.wav", params do |text| + * puts text + * end + * + * call-seq: + * transcribe(path_to_audio, params) {|text| ...} **/ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; @@ -108,7 +191,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { Data_Get_Struct(self, ruby_whisper, rw); Data_Get_Struct(params, ruby_whisper_params, rwp); - if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) { + if (!rb_respond_to(wave_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to wave file"); } @@ -206,6 +289,33 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } + if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return; + } + const int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment); + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } + }; + rwp->new_segment_callback_container->context = &self; + rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; + } + if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); return self; @@ -216,15 +326,114 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { const char * text = whisper_full_get_segment_text(rw->context, i); output = rb_str_concat(output, rb_str_new2(text)); } - VALUE idCall = rb_intern("call"); + VALUE idCall = id_call; if (blk != Qnil) { rb_funcall(blk, idCall, 1, output); } return self; } +/* + * Number of segments. + * + * call-seq: + * full_n_segments -> Integer + */ +static VALUE ruby_whisper_full_n_segments(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_n_segments(rw->context)); +} + +/* + * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full. + * + * call-seq: + * full_lang_id -> Integer + */ +static VALUE ruby_whisper_full_lang_id(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_lang_id(rw->context)); +} + +static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) { + const int c_i_segment = NUM2INT(i_segment); + if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) { + rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment); + } + return c_i_segment; +} + +/* + * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + * + * full_get_segment_t0(3) # => 1668 (16680 ms) + * + * call-seq: + * full_get_segment_t0(segment_index) -> Integer + */ +static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment); + return INT2NUM(t0); +} + +/* + * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + * + * full_get_segment_t1(3) # => 1668 (16680 ms) + * + * call-seq: + * full_get_segment_t1(segment_index) -> Integer + */ +static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment); + return INT2NUM(t1); +} + +/* + * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. + * + * full_get_segment_speacker_turn_next(3) # => true + * + * call-seq: + * full_get_segment_speacker_turn_next(segment_index) -> bool + */ +static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment); + return speaker_turn_next ? Qtrue : Qfalse; +} + +/* + * Text of a segment indexed by +segment_index+. + * + * full_get_segment_text(3) # => "ask not what your country can do for you, ..." + * + * call-seq: + * full_get_segment_text(segment_index) -> String + */ +static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const char * text = whisper_full_get_segment_text(rw->context, c_i_segment); + return rb_str_new2(text); +} + /* * params.language = "auto" | "en", etc... + * + * call-seq: + * language = lang_name -> lang_name */ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; @@ -236,6 +445,10 @@ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { } return value; } +/* + * call-seq: + * language -> String + */ static VALUE ruby_whisper_params_get_language(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -245,72 +458,185 @@ static VALUE ruby_whisper_params_get_language(VALUE self) { return rb_str_new2("auto"); } } +/* + * call-seq: + * translate = do_translate -> do_translate + */ static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, translate, value) } +/* + * call-seq: + * translate -> bool + */ static VALUE ruby_whisper_params_get_translate(VALUE self) { BOOL_PARAMS_GETTER(self, translate) } +/* + * call-seq: + * no_context = dont_use_context -> dont_use_context + */ static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, no_context, value) } +/* + * If true, does not use past transcription (if any) as initial prompt for the decoder. + * + * call-seq: + * no_context -> bool + */ static VALUE ruby_whisper_params_get_no_context(VALUE self) { BOOL_PARAMS_GETTER(self, no_context) } +/* + * call-seq: + * single_segment = force_single -> force_single + */ static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, single_segment, value) } +/* + * If true, forces single segment output (useful for streaming). + * + * call-seq: + * single_segment -> bool + */ static VALUE ruby_whisper_params_get_single_segment(VALUE self) { BOOL_PARAMS_GETTER(self, single_segment) } +/* + * call-seq: + * print_special = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_special, value) } +/* + * If true, prints special tokens (e.g. , , , etc.). + * + * call-seq: + * print_special -> bool + */ static VALUE ruby_whisper_params_get_print_special(VALUE self) { BOOL_PARAMS_GETTER(self, print_special) } +/* + * call-seq: + * print_progress = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_progress, value) } +/* + * If true, prints progress information. + * + * call-seq: + * print_progress -> bool + */ static VALUE ruby_whisper_params_get_print_progress(VALUE self) { BOOL_PARAMS_GETTER(self, print_progress) } +/* + * call-seq: + * print_realtime = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_realtime, value) } +/* + * If true, prints results from within whisper.cpp. (avoid it, use callback instead) + * call-seq: + * print_realtime -> bool + */ static VALUE ruby_whisper_params_get_print_realtime(VALUE self) { BOOL_PARAMS_GETTER(self, print_realtime) } +/* + * call-seq: + * print_timestamps = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_timestamps, value) } +/* + * If true, prints timestamps for each text segment when printing realtime. + * + * call-seq: + * print_timestamps -> bool + */ static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) { BOOL_PARAMS_GETTER(self, print_timestamps) } +/* + * call-seq: + * suppress_blank = force_suppress -> force_suppress + */ static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, suppress_blank, value) } +/* + * If true, suppresses blank outputs. + * + * call-seq: + * suppress_blank -> bool + */ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { BOOL_PARAMS_GETTER(self, suppress_blank) } +/* + * call-seq: + * suppress_non_speech_tokens = force_suppress -> force_suppress + */ static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) } +/* + * If true, suppresses non-speech-tokens. + * + * call-seq: + * suppress_non_speech_tokens -> bool + */ static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) } +/* + * If true, enables token-level timestamps. + * + * call-seq: + * token_timestamps -> bool + */ static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { BOOL_PARAMS_GETTER(self, token_timestamps) } +/* + * call-seq: + * token_timestamps = force_timestamps -> force_timestamps + */ static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, token_timestamps, value) } +/* + * If true, split on word rather than on token (when used with max_len). + * + * call-seq: + * translate -> bool + */ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { BOOL_PARAMS_GETTER(self, split_on_word) } +/* + * call-seq: + * split_on_word = force_split -> force_split + */ static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, split_on_word, value) } +/* + * If true, enables diarization. + * + * call-seq: + * diarize -> bool + */ static VALUE ruby_whisper_params_get_diarize(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -320,6 +646,10 @@ static VALUE ruby_whisper_params_get_diarize(VALUE self) { return Qfalse; } } +/* + * call-seq: + * diarize = force_diarize -> force_diarize + */ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -331,22 +661,42 @@ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { return value; } +/* + * Start offset in ms. + * + * call-seq: + * offset -> Integer + */ static VALUE ruby_whisper_params_get_offset(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); return INT2NUM(rwp->params.offset_ms); } +/* + * call-seq: + * offset = offset_ms -> offset_ms + */ static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); rwp->params.offset_ms = NUM2INT(value); return value; } +/* + * Audio duration to process in ms. + * + * call-seq: + * duration -> Integer + */ static VALUE ruby_whisper_params_get_duration(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); return INT2NUM(rwp->params.duration_ms); } +/* + * call-seq: + * duration = duration_ms -> duration_ms + */ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -354,27 +704,221 @@ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { return value; } +/* + * Max tokens to use from past text as prompt for the decoder. + * + * call-seq: + * max_text_tokens -> Integer + */ static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); return INT2NUM(rwp->params.n_max_text_ctx); } +/* + * call-seq: + * max_text_tokens = n_tokens -> n_tokens + */ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); rwp->params.n_max_text_ctx = NUM2INT(value); return value; } +/* + * Sets new segment callback, called for every newly generated text segment. + * + * params.new_segment_callback = ->(context, _, n_new, user_data) { + * # ... + * } + * + * call-seq: + * new_segment_callback = callback -> callback + */ +static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback_container->callback = value; + return value; +} +/* + * Sets user data passed to the last argument of new segment callback. + * + * call-seq: + * new_segment_callback_user_data = user_data -> use_data + */ +static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback_container->user_data = value; + return value; +} + +// High level API + +typedef struct { + VALUE context; + int index; +} ruby_whisper_segment; + +VALUE cSegment; + +static void rb_whisper_segment_mark(ruby_whisper_segment *rws) { + rb_gc_mark(rws->context); +} + +static VALUE ruby_whisper_segment_allocate(VALUE klass) { + ruby_whisper_segment *rws; + rws = ALLOC(ruby_whisper_segment); + return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws); +} + +static VALUE rb_whisper_segment_initialize(VALUE context, int index) { + ruby_whisper_segment *rws; + const VALUE segment = ruby_whisper_segment_allocate(cSegment); + Data_Get_Struct(segment, ruby_whisper_segment, rws); + rws->context = context; + rws->index = index; + return segment; +}; + +/* + * Yields each Whisper::Segment: + * + * whisper.transcribe("path/to/audio.wav", params) + * whisper.each_segment do |segment| + * puts segment.text + * end + * + * Returns an Enumerator if no block given: + * + * whisper.transcribe("path/to/audio.wav", params) + * enum = whisper.each_segment + * enum.to_a # => [#, ...] + * + * call-seq: + * each_segment {|segment| ... } + * each_segment -> Enumerator + */ +static VALUE ruby_whisper_each_segment(VALUE self) { + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + + const int n_segments = whisper_full_n_segments(rw->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(rb_whisper_segment_initialize(self, i)); + } + + return self; +} + +/* + * Hook called on new segment. Yields each Whisper::Segment. + * + * whisper.on_new_segment do |segment| + * # ... + * end + * + * call-seq: + * on_new_segment {|segment| ... } + */ +static VALUE ruby_whisper_params_on_new_segment(VALUE self) { + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->new_segment_callback_container->callbacks, blk); + return Qnil; +} + +/* + * Start time in milliseconds. + * + * call-seq: + * start_time -> Integer + */ +static VALUE ruby_whisper_segment_get_start_time(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index); + // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it + return INT2NUM(t0 * 10); +} + +/* + * End time in milliseconds. + * + * call-seq: + * end_time -> Integer + */ +static VALUE ruby_whisper_segment_get_end_time(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index); + // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it + return INT2NUM(t1 * 10); +} + +/* + * Whether the next segment is predicted as a speaker turn. + * + * call-seq: + * speaker_turn_next? -> bool + */ +static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse; +} + +/* + * call-seq: + * text -> String + */ +static VALUE ruby_whisper_segment_get_text(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const char * text = whisper_full_get_segment_text(rw->context, rws->index); + return rb_str_new2(text); +} void Init_whisper() { + id_to_s = rb_intern("to_s"); + id_call = rb_intern("call"); + id___method__ = rb_intern("__method__"); + id_to_enum = rb_intern("to_enum"); + mWhisper = rb_define_module("Whisper"); cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); + rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); + rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); + rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); + rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); + rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); + rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); + rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); + rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); + rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); @@ -412,6 +956,20 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); + + rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); + rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1); + + // High leve + cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject); + + rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate); + rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); + rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); + rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0); + rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); + rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); + rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0); } #ifdef __cplusplus } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 8c35b7cb..f3210a4b 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -3,6 +3,13 @@ #include "whisper.h" +typedef struct { + VALUE *context; + VALUE user_data; + VALUE callback; + VALUE callbacks; +} ruby_whisper_callback_container; + typedef struct { struct whisper_context *context; } ruby_whisper; @@ -10,6 +17,7 @@ typedef struct { typedef struct { struct whisper_full_params params; bool diarize; + ruby_whisper_callback_container *new_segment_callback_container; } ruby_whisper_params; #endif diff --git a/bindings/ruby/extsources.yaml b/bindings/ruby/extsources.yaml index e59f6ecf..85488864 100644 --- a/bindings/ruby/extsources.yaml +++ b/bindings/ruby/extsources.yaml @@ -1,37 +1,29 @@ --- -../../src: -- ext/whisper.cpp -../../include: -- ext/whisper.h -../../ggml/src: -- ext/ggml.c -- ext/ggml-impl.h -- ext/ggml-aarch64.h -- ext/ggml-aarch64.c -- ext/ggml-alloc.c -- ext/ggml-backend-impl.h -- ext/ggml-backend.cpp -- ext/ggml-common.h -- ext/ggml-quants.h -- ext/ggml-quants.c -- ext/ggml-cpu-impl.h -- ext/ggml-metal.m -- ext/ggml-metal.metal -- ext/ggml-blas.cpp -../../ggml/include: -- ext/ggml.h -- ext/ggml-alloc.h -- ext/ggml-backend.h -- ext/ggml-cuda.h -- ext/ggml-kompute.h -- ext/ggml-metal.h -- ext/ggml-sycl.h -- ext/ggml-vulkan.h -- ext/ggml-blas.h -../../scripts: -- ext/get-flags.mk -../../examples: -- ext/dr_wav.h -../..: -- README.md -- LICENSE +- ../../src/whisper.cpp +- ../../include/whisper.h +- ../../ggml/src/ggml.c +- ../../ggml/src/ggml-impl.h +- ../../ggml/src/ggml-aarch64.h +- ../../ggml/src/ggml-aarch64.c +- ../../ggml/src/ggml-alloc.c +- ../../ggml/src/ggml-backend-impl.h +- ../../ggml/src/ggml-backend.cpp +- ../../ggml/src/ggml-common.h +- ../../ggml/src/ggml-quants.h +- ../../ggml/src/ggml-quants.c +- ../../ggml/src/ggml-cpu-impl.h +- ../../ggml/src/ggml-metal.m +- ../../ggml/src/ggml-metal.metal +- ../../ggml/src/ggml-blas.cpp +- ../../ggml/include/ggml.h +- ../../ggml/include/ggml-alloc.h +- ../../ggml/include/ggml-backend.h +- ../../ggml/include/ggml-cuda.h +- ../../ggml/include/ggml-kompute.h +- ../../ggml/include/ggml-metal.h +- ../../ggml/include/ggml-sycl.h +- ../../ggml/include/ggml-vulkan.h +- ../../ggml/include/ggml-blas.h +- ../../scripts/get-flags.mk +- ../../examples/dr_wav.h +- ../../LICENSE diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb new file mode 100644 index 00000000..80a5f4df --- /dev/null +++ b/bindings/ruby/tests/test_callback.rb @@ -0,0 +1,76 @@ +require "test/unit" +require "whisper" + +class TestCallback < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + + def setup + @params = Whisper::Params.new + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @whisper, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0 + end + } + + @whisper.transcribe(@audio, @params) + end + + def test_new_segment_callback_closure + search_word = "what" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + text = context.full_get_segment_text(i_segment) + if text.include?(search_word) + t0 = context.full_get_segment_t0(i_segment) + t1 = context.full_get_segment_t1(i_segment) + raise "search word '#{search_word}' found at between #{t0} and #{t1}" + end + end + } + + assert_raise RuntimeError do + @whisper.transcribe(@audio, @params) + end + end + + def test_new_segment_callback_user_data + udata = Object.new + @params.new_segment_callback_user_data = udata + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_same udata, user_data + } + + @whisper.transcribe(@audio, @params) + end + + def test_new_segment_callback_user_data_gc + @params.new_segment_callback_user_data = "My user data" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_equal "My user data", user_data + } + GC.start + + assert_same @whisper, @whisper.transcribe(@audio, @params) + end +end diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb new file mode 100644 index 00000000..adaeedfb --- /dev/null +++ b/bindings/ruby/tests/test_package.rb @@ -0,0 +1,28 @@ +require 'test/unit' +require 'tempfile' +require 'tmpdir' +require 'shellwords' + +class TestPackage < Test::Unit::TestCase + def test_build + Tempfile.create do |file| + assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) + assert file.size > 0 + end + end + + sub_test_case "Building binary on installation" do + def setup + system "rake", "build", exception: true + end + + def test_install + filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1] + basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" + Dir.mktmpdir do |dir| + system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true + assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename) + end + end + end +end diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb new file mode 100644 index 00000000..4484feee --- /dev/null +++ b/bindings/ruby/tests/test_params.rb @@ -0,0 +1,112 @@ +require 'whisper' + +class TestParams < Test::Unit::TestCase + def setup + @params = Whisper::Params.new + end + + def test_language + @params.language = "en" + assert_equal @params.language, "en" + @params.language = "auto" + assert_equal @params.language, "auto" + end + + def test_offset + @params.offset = 10_000 + assert_equal @params.offset, 10_000 + @params.offset = 0 + assert_equal @params.offset, 0 + end + + def test_duration + @params.duration = 60_000 + assert_equal @params.duration, 60_000 + @params.duration = 0 + assert_equal @params.duration, 0 + end + + def test_max_text_tokens + @params.max_text_tokens = 300 + assert_equal @params.max_text_tokens, 300 + @params.max_text_tokens = 0 + assert_equal @params.max_text_tokens, 0 + end + + def test_translate + @params.translate = true + assert @params.translate + @params.translate = false + assert !@params.translate + end + + def test_no_context + @params.no_context = true + assert @params.no_context + @params.no_context = false + assert !@params.no_context + end + + def test_single_segment + @params.single_segment = true + assert @params.single_segment + @params.single_segment = false + assert !@params.single_segment + end + + def test_print_special + @params.print_special = true + assert @params.print_special + @params.print_special = false + assert !@params.print_special + end + + def test_print_progress + @params.print_progress = true + assert @params.print_progress + @params.print_progress = false + assert !@params.print_progress + end + + def test_print_realtime + @params.print_realtime = true + assert @params.print_realtime + @params.print_realtime = false + assert !@params.print_realtime + end + + def test_print_timestamps + @params.print_timestamps = true + assert @params.print_timestamps + @params.print_timestamps = false + assert !@params.print_timestamps + end + + def test_suppress_blank + @params.suppress_blank = true + assert @params.suppress_blank + @params.suppress_blank = false + assert !@params.suppress_blank + end + + def test_suppress_non_speech_tokens + @params.suppress_non_speech_tokens = true + assert @params.suppress_non_speech_tokens + @params.suppress_non_speech_tokens = false + assert !@params.suppress_non_speech_tokens + end + + def test_token_timestamps + @params.token_timestamps = true + assert @params.token_timestamps + @params.token_timestamps = false + assert !@params.token_timestamps + end + + def test_split_on_word + @params.split_on_word = true + assert @params.split_on_word + @params.split_on_word = false + assert !@params.split_on_word + end +end diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb new file mode 100644 index 00000000..f3ebc0e9 --- /dev/null +++ b/bindings/ruby/tests/test_segment.rb @@ -0,0 +1,87 @@ +require "test/unit" +require "whisper" + +class TestSegment < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + + class << self + attr_reader :whisper + + def startup + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) + end + end + + def test_iteration + whisper.each_segment do |segment| + assert_instance_of Whisper::Segment, segment + end + end + + def test_enumerator + enum = whisper.each_segment + assert_instance_of Enumerator, enum + enum.to_a.each_with_index do |segment, index| + assert_instance_of Whisper::Segment, segment + assert_kind_of Integer, index + end + end + + def test_start_time + i = 0 + whisper.each_segment do |segment| + assert_equal 0, segment.start_time if i == 0 + i += 1 + end + end + + def test_end_time + i = 0 + whisper.each_segment do |segment| + assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time + i += 1 + end + end + + def test_on_new_segment + params = Whisper::Params.new + seg = nil + index = 0 + params.on_new_segment do |segment| + assert_instance_of Whisper::Segment, segment + if index == 0 + seg = segment + assert_equal 0, segment.start_time + assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text + end + index += 1 + end + whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params) + assert_equal 0, seg.start_time + assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text + end + + def test_on_new_segment_twice + params = Whisper::Params.new + seg = nil + params.on_new_segment do |segment| + seg = segment + return + end + params.on_new_segment do |segment| + assert_same seg, segment + return + end + whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params) + end + + private + + def whisper + self.class.whisper + end +end diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 410b5248..5ebb8151 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -1,121 +1,13 @@ -TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) - require 'whisper' require 'test/unit' -require 'tempfile' -require 'tmpdir' -require 'shellwords' class TestWhisper < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + def setup @params = Whisper::Params.new end - def test_language - @params.language = "en" - assert_equal @params.language, "en" - @params.language = "auto" - assert_equal @params.language, "auto" - end - - def test_offset - @params.offset = 10_000 - assert_equal @params.offset, 10_000 - @params.offset = 0 - assert_equal @params.offset, 0 - end - - def test_duration - @params.duration = 60_000 - assert_equal @params.duration, 60_000 - @params.duration = 0 - assert_equal @params.duration, 0 - end - - def test_max_text_tokens - @params.max_text_tokens = 300 - assert_equal @params.max_text_tokens, 300 - @params.max_text_tokens = 0 - assert_equal @params.max_text_tokens, 0 - end - - def test_translate - @params.translate = true - assert @params.translate - @params.translate = false - assert !@params.translate - end - - def test_no_context - @params.no_context = true - assert @params.no_context - @params.no_context = false - assert !@params.no_context - end - - def test_single_segment - @params.single_segment = true - assert @params.single_segment - @params.single_segment = false - assert !@params.single_segment - end - - def test_print_special - @params.print_special = true - assert @params.print_special - @params.print_special = false - assert !@params.print_special - end - - def test_print_progress - @params.print_progress = true - assert @params.print_progress - @params.print_progress = false - assert !@params.print_progress - end - - def test_print_realtime - @params.print_realtime = true - assert @params.print_realtime - @params.print_realtime = false - assert !@params.print_realtime - end - - def test_print_timestamps - @params.print_timestamps = true - assert @params.print_timestamps - @params.print_timestamps = false - assert !@params.print_timestamps - end - - def test_suppress_blank - @params.suppress_blank = true - assert @params.suppress_blank - @params.suppress_blank = false - assert !@params.suppress_blank - end - - def test_suppress_non_speech_tokens - @params.suppress_non_speech_tokens = true - assert @params.suppress_non_speech_tokens - @params.suppress_non_speech_tokens = false - assert !@params.suppress_non_speech_tokens - end - - def test_token_timestamps - @params.token_timestamps = true - assert @params.token_timestamps - @params.token_timestamps = false - assert !@params.token_timestamps - end - - def test_split_on_word - @params.split_on_word = true - assert @params.split_on_word - @params.split_on_word = false - assert !@params.split_on_word - end - def test_whisper @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) params = Whisper::Params.new @@ -127,25 +19,81 @@ class TestWhisper < Test::Unit::TestCase } end - def test_build - Tempfile.create do |file| - assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) - assert_path_exist file.to_path + sub_test_case "After transcription" do + class << self + attr_reader :whisper + + def startup + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) + end + end + + def whisper + self.class.whisper + end + + def test_full_n_segments + assert_equal 1, whisper.full_n_segments + end + + def test_full_lang_id + assert_equal 0, whisper.full_lang_id + end + + def test_full_get_segment_t0 + assert_equal 0, whisper.full_get_segment_t0(0) + assert_raise IndexError do + whisper.full_get_segment_t0(whisper.full_n_segments) + end + assert_raise IndexError do + whisper.full_get_segment_t0(-1) + end + end + + def test_full_get_segment_t1 + t1 = whisper.full_get_segment_t1(0) + assert_kind_of Integer, t1 + assert t1 > 0 + assert_raise IndexError do + whisper.full_get_segment_t1(whisper.full_n_segments) + end + end + + def test_full_get_segment_speaker_turn_next + assert_false whisper.full_get_segment_speaker_turn_next(0) + end + + def test_full_get_segment_text + assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0) end end - sub_test_case "Building binary on installation" do - def setup - system "rake", "build", exception: true - end + def test_lang_max_id + assert_kind_of Integer, Whisper.lang_max_id + end - def test_install - filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1] - basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" - Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true - assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename) - end + def test_lang_id + assert_equal 0, Whisper.lang_id("en") + assert_raise ArgumentError do + Whisper.lang_id("non existing language") + end + end + + def test_lang_str + assert_equal "en", Whisper.lang_str(0) + assert_raise IndexError do + Whisper.lang_str(Whisper.lang_max_id + 1) + end + end + + def test_lang_str_full + assert_equal "english", Whisper.lang_str_full(0) + assert_raise IndexError do + Whisper.lang_str_full(Whisper.lang_max_id + 1) end end end diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 5b24d7e7..251d03fa 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -9,7 +9,15 @@ Gem::Specification.new do |s| s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] - s.files = `git ls-files . -z`.split("\x0") + YAML.load_file("extsources.yaml").values.flatten + s.files = `git ls-files . -z`.split("\x0") + + YAML.load_file("extsources.yaml").collect {|file| + basename = File.basename(file) + if s.extra_rdoc_files.include?(basename) + basename + else + File.join("ext", basename) + end + } s.summary = %q{Ruby whisper.cpp bindings} s.test_files = ["tests/test_whisper.rb"]