ruby : add encoder begin callback related methods (#3076)

* Lazy run TestBase.whisper

* Fix indentation

* Remove disused GGML_HIP_UMA from Ruby

* Add encoder_begin_callback

* Comment out existing abort mechanism

* Add test for encoder_begin_callback

* Add signatures for encoder_begin_callback related methods

* Update gem date
This commit is contained in:
KITAITI Makoto 2025-04-26 04:33:11 +09:00 committed by GitHub
parent 1c20f46887
commit 50fda73f4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 192 additions and 17 deletions

View File

@ -114,7 +114,6 @@ class Options
bool "GGML_HIP_GRAPHS" bool "GGML_HIP_GRAPHS"
bool "GGML_HIP_NO_VMM" bool "GGML_HIP_NO_VMM"
bool "GGML_HIP_ROCWMMA_FATTN" bool "GGML_HIP_ROCWMMA_FATTN"
bool "GGML_HIP_UMA"
ignored "GGML_INCLUDE_INSTALL_DIR" ignored "GGML_INCLUDE_INSTALL_DIR"
bool "GGML_KOMPUTE" bool "GGML_KOMPUTE"
bool "GGML_LASX" bool "GGML_LASX"

View File

@ -19,6 +19,7 @@ typedef struct {
bool diarize; bool diarize;
ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container; ruby_whisper_callback_container *abort_callback_container;
} ruby_whisper_params; } ruby_whisper_params;

View File

@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30 #define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
extern VALUE cParams; extern VALUE cParams;
@ -63,6 +63,8 @@ static ID id_new_segment_callback;
static ID id_new_segment_callback_user_data; static ID id_new_segment_callback_user_data;
static ID id_progress_callback; static ID id_progress_callback;
static ID id_progress_callback_user_data; static ID id_progress_callback_user_data;
static ID id_encoder_begin_callback;
static ID id_encoder_begin_callback_user_data;
static ID id_abort_callback; static ID id_abort_callback;
static ID id_abort_callback_user_data; static ID id_abort_callback_user_data;
@ -126,6 +128,33 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
} }
} }
static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
bool is_aborted = false;
VALUE result;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
is_aborted = true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return !is_aborted;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
is_aborted = true;
}
}
return !is_aborted;
}
static bool abort_callback(void * user_data) { static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) { if (!NIL_P(container->callback)) {
@ -161,6 +190,12 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.progress_callback_user_data = rwp->progress_callback_container; rwp->params.progress_callback_user_data = rwp->progress_callback_container;
} }
if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
rwp->encoder_begin_callback_container->context = context;
rwp->params.encoder_begin_callback = encoder_begin_callback;
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = context; rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback; rwp->params.abort_callback = abort_callback;
@ -173,6 +208,7 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
{ {
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_callbcack_container_mark(rwp->abort_callback_container); rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
} }
@ -198,6 +234,7 @@ ruby_whisper_params_allocate(VALUE klass)
rwp->diarize = false; rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate(); rwp->abort_callback_container = rb_whisper_callback_container_allocate();
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
} }
@ -849,6 +886,57 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
rwp->progress_callback_container->user_data = value; rwp->progress_callback_container->user_data = value;
return value; return value;
} }
static VALUE
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->callback;
}
/*
* Sets encoder begin callback, called when the encoder starts.
*
* params.encoder_begin_callback = ->(context, _, user_data) {
* # ...
* }
*
* call-seq:
* encoder_begin_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->callback = value;
return value;
}
static VALUE
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->user_data;
}
/*
* Sets user data passed to the last argument of encoder begin callback.
*
* call-seq:
* encoder_begin_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->user_data = value;
return value;
}
static VALUE static VALUE
ruby_whisper_params_get_abort_callback(VALUE self) ruby_whisper_params_get_abort_callback(VALUE self)
{ {
@ -958,6 +1046,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(new_segment_callback_user_data) SET_PARAM_IF_SAME(new_segment_callback_user_data)
SET_PARAM_IF_SAME(progress_callback) SET_PARAM_IF_SAME(progress_callback)
SET_PARAM_IF_SAME(progress_callback_user_data) SET_PARAM_IF_SAME(progress_callback_user_data)
SET_PARAM_IF_SAME(encoder_begin_callback)
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
SET_PARAM_IF_SAME(abort_callback) SET_PARAM_IF_SAME(abort_callback)
SET_PARAM_IF_SAME(abort_callback_user_data) SET_PARAM_IF_SAME(abort_callback_user_data)
} }
@ -1008,6 +1098,26 @@ ruby_whisper_params_on_progress(VALUE self)
return Qnil; return Qnil;
} }
/*
* Hook called when the encoder starts.
*
* whisper.on_encoder_begin do
* # ...
* end
*
* call-seq:
* on_encoder_begin { ... }
*/
static VALUE
ruby_whisper_params_on_encoder_begin(VALUE self)
{
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
return Qnil;
}
/* /*
* Call block to determine whether abort or not. Return +true+ when you want to abort. * Call block to determine whether abort or not. Return +true+ when you want to abort.
* *
@ -1068,10 +1178,13 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(new_segment_callback_user_data, 25) DEFINE_PARAM(new_segment_callback_user_data, 25)
DEFINE_PARAM(progress_callback, 26) DEFINE_PARAM(progress_callback, 26)
DEFINE_PARAM(progress_callback_user_data, 27) DEFINE_PARAM(progress_callback_user_data, 27)
DEFINE_PARAM(abort_callback, 28) DEFINE_PARAM(encoder_begin_callback, 28)
DEFINE_PARAM(abort_callback_user_data, 29) DEFINE_PARAM(encoder_begin_callback_user_data, 29)
DEFINE_PARAM(abort_callback, 30)
DEFINE_PARAM(abort_callback_user_data, 31)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0); rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
} }

View File

@ -50,15 +50,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self; return self;
} }
{ // Commented out because it is work in progress
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race // {
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data; // bool is_aborted = *(bool*)user_data;
return !is_aborted; // return !is_aborted;
}; // };
rwp->params.encoder_begin_callback_user_data = &is_aborted; // rwp->params.encoder_begin_callback_user_data = &is_aborted;
} // }
register_callbacks(rwp, &self); register_callbacks(rwp, &self);

View File

@ -7,6 +7,7 @@ module Whisper
type log_callback = ^(Integer level, String message, Object user_data) -> void type log_callback = ^(Integer level, String message, Object user_data) -> void
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
LOG_LEVEL_NONE: Integer LOG_LEVEL_NONE: Integer
@ -146,6 +147,8 @@ module Whisper
?new_segment_callback_user_data: Object, ?new_segment_callback_user_data: Object,
?progress_callback: progress_callback, ?progress_callback: progress_callback,
?progress_callback_user_data: Object, ?progress_callback_user_data: Object,
?encoder_begin_callback: encoder_begin_callback,
?encoder_begin_callback_user_data: Object,
?abort_callback: abort_callback, ?abort_callback: abort_callback,
?abort_callback_user_data: Object ?abort_callback_user_data: Object
) -> instance ) -> instance
@ -306,6 +309,18 @@ module Whisper
def progress_callback_user_data: () -> Object def progress_callback_user_data: () -> Object
# Sets encoder begin callback, called when the encoder starts.
#
def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
def encoder_begin_callback: () -> (encoder_begin_callback | nil)
# Sets user data passed to the last argument of encoder begin callback.
#
def encoder_begin_callback_user_data=: (Object) -> Object
def encoder_begin_callback_user_data: () -> Object
# Sets abort callback, called to check if the process should be aborted. # Sets abort callback, called to check if the process should be aborted.
# #
# params.abort_callback = ->(user_data) { # params.abort_callback = ->(user_data) {
@ -335,6 +350,10 @@ module Whisper
# #
def on_progress: { (Integer progress) -> void } -> void def on_progress: { (Integer progress) -> void } -> void
# Hook called on encoder starts.
#
def on_encoder_begin: { () -> void } -> void
# Call block to determine whether abort or not. Return +true+ when you want to abort. # Call block to determine whether abort or not. Return +true+ when you want to abort.
# #
# params.abort_on do # params.abort_on do

View File

@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav") AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self class << self
attr_reader :whisper def whisper
return @whisper if @whisper
def startup
@whisper = Whisper::Context.new("base.en") @whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new params = Whisper::Params.new
params.print_timestamps = false params.print_timestamps = false

View File

@ -111,6 +111,48 @@ class TestCallback < TestBase
assert_equal 100, last assert_equal 100, last
end end
def test_encoder_begin_callback
i = 0
@params.encoder_begin_callback = ->(context, state, user_data) {
i += 1
}
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_encoder_begin_callback_abort
logs = []
Whisper.log_set -> (level, buffer, user_data) {
logs << buffer if level == Whisper::LOG_LEVEL_ERROR
}, logs
@params.encoder_begin_callback = ->(context, state, user_data) {
return false
}
@whisper.transcribe(@audio, @params)
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
Whisper.log_set ->(level, buffer, user_data) {}, nil
end
def test_encoder_begin_callback_user_data
udata = Object.new
@params.encoder_begin_callback_user_data = udata
yielded = nil
@params.encoder_begin_callback = ->(context, state, user_data) {
yielded = user_data
}
@whisper.transcribe(@audio, @params)
assert_same udata, yielded
end
def test_on_encoder_begin
i = 0
@params.on_encoder_begin do
i += 1
end
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_abort_callback def test_abort_callback
i = 0 i = 0
@params.abort_callback = ->(user_data) { @params.abort_callback = ->(user_data) {

View File

@ -4,7 +4,7 @@ Gem::Specification.new do |s|
s.name = "whispercpp" s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"] s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.2' s.version = '1.3.2'
s.date = '2025-04-17' s.date = '2025-04-25'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com' s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md'] s.extra_rdoc_files = ['LICENSE', 'README.md']