From 50fda73f4c46632722df0f102e294e91a4fa731a Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Sat, 26 Apr 2025 04:33:11 +0900 Subject: [PATCH] ruby : add encoder begin callback related methods (#3076) * Lazy run TestBase.whisper * Fix indentation * Remove disused GGML_HIP_UMA from Ruby * Add encoder_begin_callback * Comment out existing abort mechanism * Add test for encoder_begin_callback * Add signatures for encoder_begin_callback related methods * Update gem date --- bindings/ruby/ext/options.rb | 1 - bindings/ruby/ext/ruby_whisper.h | 1 + bindings/ruby/ext/ruby_whisper_params.c | 119 +++++++++++++++++- bindings/ruby/ext/ruby_whisper_transcribe.cpp | 17 +-- bindings/ruby/lib/whisper/model/uri.rb | 4 +- bindings/ruby/sig/whisper.rbs | 19 +++ bindings/ruby/tests/helper.rb | 4 +- bindings/ruby/tests/test_callback.rb | 42 +++++++ bindings/ruby/whispercpp.gemspec | 2 +- 9 files changed, 192 insertions(+), 17 deletions(-) diff --git a/bindings/ruby/ext/options.rb b/bindings/ruby/ext/options.rb index 679b74d1..6fed3184 100644 --- a/bindings/ruby/ext/options.rb +++ b/bindings/ruby/ext/options.rb @@ -114,7 +114,6 @@ class Options bool "GGML_HIP_GRAPHS" bool "GGML_HIP_NO_VMM" bool "GGML_HIP_ROCWMMA_FATTN" - bool "GGML_HIP_UMA" ignored "GGML_INCLUDE_INSTALL_DIR" bool "GGML_KOMPUTE" bool "GGML_LASX" diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index bbf3435e..6111a151 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -19,6 +19,7 @@ typedef struct { bool diarize; ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *encoder_begin_callback_container; ruby_whisper_callback_container *abort_callback_container; } ruby_whisper_params; diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index caeb34f2..c07f2372 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -26,7 +26,7 @@ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); -#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30 +#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32 extern VALUE cParams; @@ -63,6 +63,8 @@ static ID id_new_segment_callback; static ID id_new_segment_callback_user_data; static ID id_progress_callback; static ID id_progress_callback_user_data; +static ID id_encoder_begin_callback; +static ID id_encoder_begin_callback_user_data; static ID id_abort_callback; static ID id_abort_callback_user_data; @@ -126,6 +128,33 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state } } +static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + bool is_aborted = false; + VALUE result; + + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); + if (result == Qfalse) { + is_aborted = true; + } + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return !is_aborted; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + is_aborted = true; + } + } + return !is_aborted; +} + static bool abort_callback(void * user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; if (!NIL_P(container->callback)) { @@ -161,6 +190,12 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { rwp->params.progress_callback_user_data = rwp->progress_callback_container; } + if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + rwp->encoder_begin_callback_container->context = context; + rwp->params.encoder_begin_callback = encoder_begin_callback; + rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; + } + if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { rwp->abort_callback_container->context = context; rwp->params.abort_callback = abort_callback; @@ -173,6 +208,7 @@ rb_whisper_params_mark(ruby_whisper_params *rwp) { rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container); + rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); rb_whisper_callbcack_container_mark(rwp->abort_callback_container); } @@ -198,6 +234,7 @@ ruby_whisper_params_allocate(VALUE klass) rwp->diarize = false; rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate(); + rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); rwp->abort_callback_container = rb_whisper_callback_container_allocate(); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -849,6 +886,57 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) rwp->progress_callback_container->user_data = value; return value; } + +static VALUE +ruby_whisper_params_get_encoder_begin_callback(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->encoder_begin_callback_container->callback; +} + +/* + * Sets encoder begin callback, called when the encoder starts. + * + * params.encoder_begin_callback = ->(context, _, user_data) { + * # ... + * } + * + * call-seq: + * encoder_begin_callback = callback -> callback + */ +static VALUE +ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->encoder_begin_callback_container->callback = value; + return value; +} + +static VALUE +ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->encoder_begin_callback_container->user_data; +} + +/* + * Sets user data passed to the last argument of encoder begin callback. + * + * call-seq: + * encoder_begin_callback_user_data = user_data -> use_data + */ +static VALUE +ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->encoder_begin_callback_container->user_data = value; + return value; +} + static VALUE ruby_whisper_params_get_abort_callback(VALUE self) { @@ -958,6 +1046,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) SET_PARAM_IF_SAME(new_segment_callback_user_data) SET_PARAM_IF_SAME(progress_callback) SET_PARAM_IF_SAME(progress_callback_user_data) + SET_PARAM_IF_SAME(encoder_begin_callback) + SET_PARAM_IF_SAME(encoder_begin_callback_user_data) SET_PARAM_IF_SAME(abort_callback) SET_PARAM_IF_SAME(abort_callback_user_data) } @@ -1008,6 +1098,26 @@ ruby_whisper_params_on_progress(VALUE self) return Qnil; } +/* + * Hook called when the encoder starts. + * + * whisper.on_encoder_begin do + * # ... + * end + * + * call-seq: + * on_encoder_begin { ... } + */ +static VALUE +ruby_whisper_params_on_encoder_begin(VALUE self) +{ + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk); + return Qnil; +} + /* * Call block to determine whether abort or not. Return +true+ when you want to abort. * @@ -1068,10 +1178,13 @@ init_ruby_whisper_params(VALUE *mWhisper) DEFINE_PARAM(new_segment_callback_user_data, 25) DEFINE_PARAM(progress_callback, 26) DEFINE_PARAM(progress_callback_user_data, 27) - DEFINE_PARAM(abort_callback, 28) - DEFINE_PARAM(abort_callback_user_data, 29) + DEFINE_PARAM(encoder_begin_callback, 28) + DEFINE_PARAM(encoder_begin_callback_user_data, 29) + DEFINE_PARAM(abort_callback, 30) + DEFINE_PARAM(abort_callback_user_data, 31) rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); + rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0); rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0); } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 00b9d2e1..ef3c0780 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -50,15 +50,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return self; } - { - static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + // Commented out because it is work in progress + // { + // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; - }; - rwp->params.encoder_begin_callback_user_data = &is_aborted; - } + // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + // bool is_aborted = *(bool*)user_data; + // return !is_aborted; + // }; + // rwp->params.encoder_begin_callback_user_data = &is_aborted; + // } register_callbacks(rwp, &self); diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index b2bc9c4b..47c23c52 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -53,7 +53,7 @@ module Whisper http.request request do |response| case response when Net::HTTPNotModified - # noop + # noop when Net::HTTPOK download response when Net::HTTPRedirection @@ -68,7 +68,7 @@ module Whisper rescue => err if cache_path.exist? warn err - # Use cache file + # Use cache file else raise end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 0f3d74e0..a3ce94b8 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -7,6 +7,7 @@ module Whisper type log_callback = ^(Integer level, String message, Object user_data) -> void type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void + type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish LOG_LEVEL_NONE: Integer @@ -146,6 +147,8 @@ module Whisper ?new_segment_callback_user_data: Object, ?progress_callback: progress_callback, ?progress_callback_user_data: Object, + ?encoder_begin_callback: encoder_begin_callback, + ?encoder_begin_callback_user_data: Object, ?abort_callback: abort_callback, ?abort_callback_user_data: Object ) -> instance @@ -306,6 +309,18 @@ module Whisper def progress_callback_user_data: () -> Object + # Sets encoder begin callback, called when the encoder starts. + # + def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback + + def encoder_begin_callback: () -> (encoder_begin_callback | nil) + + # Sets user data passed to the last argument of encoder begin callback. + # + def encoder_begin_callback_user_data=: (Object) -> Object + + def encoder_begin_callback_user_data: () -> Object + # Sets abort callback, called to check if the process should be aborted. # # params.abort_callback = ->(user_data) { @@ -335,6 +350,10 @@ module Whisper # def on_progress: { (Integer progress) -> void } -> void + # Hook called on encoder starts. + # + def on_encoder_begin: { () -> void } -> void + # Call block to determine whether abort or not. Return +true+ when you want to abort. # # params.abort_on do diff --git a/bindings/ruby/tests/helper.rb b/bindings/ruby/tests/helper.rb index a69a2b7e..bc5e4724 100644 --- a/bindings/ruby/tests/helper.rb +++ b/bindings/ruby/tests/helper.rb @@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav") class << self - attr_reader :whisper + def whisper + return @whisper if @whisper - def startup @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new params.print_timestamps = false diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb index 61ef366c..a7f49245 100644 --- a/bindings/ruby/tests/test_callback.rb +++ b/bindings/ruby/tests/test_callback.rb @@ -111,6 +111,48 @@ class TestCallback < TestBase assert_equal 100, last end + def test_encoder_begin_callback + i = 0 + @params.encoder_begin_callback = ->(context, state, user_data) { + i += 1 + } + @whisper.transcribe(@audio, @params) + assert i > 0 + end + + def test_encoder_begin_callback_abort + logs = [] + Whisper.log_set -> (level, buffer, user_data) { + logs << buffer if level == Whisper::LOG_LEVEL_ERROR + }, logs + @params.encoder_begin_callback = ->(context, state, user_data) { + return false + } + @whisper.transcribe(@audio, @params) + assert_match(/encoder_begin_callback returned false - aborting/, logs.join) + Whisper.log_set ->(level, buffer, user_data) {}, nil + end + + def test_encoder_begin_callback_user_data + udata = Object.new + @params.encoder_begin_callback_user_data = udata + yielded = nil + @params.encoder_begin_callback = ->(context, state, user_data) { + yielded = user_data + } + @whisper.transcribe(@audio, @params) + assert_same udata, yielded + end + + def test_on_encoder_begin + i = 0 + @params.on_encoder_begin do + i += 1 + end + @whisper.transcribe(@audio, @params) + assert i > 0 + end + def test_abort_callback i = 0 @params.abort_callback = ->(user_data) { diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 329e670b..97cf4e27 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -4,7 +4,7 @@ Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] s.version = '1.3.2' - s.date = '2025-04-17' + s.date = '2025-04-25' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md']