diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 7b1a7f29..208a89f3 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -111,6 +111,41 @@ See [models][] page for details. Currently, whisper.cpp accepts only 16-bit WAV files. +### Voice Activity Detection (VAD) ### + +Support for Voice Activity Detection (VAD) can be enabled by setting `Whisper::Params`'s `vad` argument to `true` and specifying VAD model: + +```ruby +Whisper::Params.new( + vad: true, + vad_model_path: "silero-v5.1.2", + # other arguments... +) +``` + +When you pass the model name (`"silero-v5.1.2"`) or URI (`https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin`), it will be downloaded automatically. +Currently, "silero-v5.1.2" is registered as pre-converted model like ASR models. You also specify file path or URI of model. + +If you need configure VAD behavior, pass params for that: + +```ruby +Whisper::Params.new( + vad: true, + vad_model_path: "silero-v5.1.2", + vad_params: Whisper::VAD::Params.new( + threshold: 1.0, # defaults to 0.5 + min_speech_duration_ms: 500, # defaults to 250 + min_silence_duration_ms: 200, # defaults to 100 + max_speech_duration_s: 30000, # default is FLT_MAX, + speech_pad_ms: 50, # defaults to 30 + samples_overlap: 0.5 # defaults to 0.1 + ), + # other arguments... +) +``` + +For details on VAD, see [whisper.cpp's README](https://github.com/ggml-org/whisper.cpp?tab=readme-ov-file#voice-activity-detection-vad). + API --- diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 43227786..4a83aac9 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -3,8 +3,10 @@ #include "ruby_whisper.h" VALUE mWhisper; +VALUE mVAD; VALUE cContext; VALUE cParams; +VALUE cVADParams; VALUE eError; VALUE cSegment; @@ -31,6 +33,7 @@ extern void init_ruby_whisper_params(VALUE *mWhisper); extern void init_ruby_whisper_error(VALUE *mWhisper); extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment); extern void init_ruby_whisper_model(VALUE *mWhisper); +extern void init_ruby_whisper_vad_params(VALUE *mVAD); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); /* @@ -116,16 +119,6 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d return Qnil; } -static void rb_whisper_model_mark(ruby_whisper_model *rwm) { - rb_gc_mark(rwm->context); -} - -static VALUE ruby_whisper_model_allocate(VALUE klass) { - ruby_whisper_model *rwm; - rwm = ALLOC(ruby_whisper_model); - return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm); -} - void Init_whisper() { id_to_s = rb_intern("to_s"); id_call = rb_intern("call"); @@ -139,6 +132,7 @@ void Init_whisper() { id_pre_converted_models = rb_intern("pre_converted_models"); mWhisper = rb_define_module("Whisper"); + mVAD = rb_define_module_under(mWhisper, "VAD"); rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO)); @@ -159,6 +153,7 @@ void Init_whisper() { init_ruby_whisper_error(&mWhisper); init_ruby_whisper_segment(&mWhisper, &cContext); init_ruby_whisper_model(&mWhisper); + init_ruby_whisper_vad_params(&mVAD); rb_require("whisper/model/uri"); } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 6111a151..65b88122 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -21,8 +21,13 @@ typedef struct { ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; ruby_whisper_callback_container *abort_callback_container; + VALUE vad_params; } ruby_whisper_params; +typedef struct { + struct whisper_vad_params params; +} ruby_whisper_vad_params; + typedef struct { VALUE context; int index; diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index df375218..c498184e 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -16,10 +16,11 @@ extern VALUE cContext; extern VALUE eError; extern VALUE cModel; +extern const rb_data_type_t ruby_whisper_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_initialize(VALUE context); extern VALUE rb_whisper_segment_initialize(VALUE context, int index); -extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); static void ruby_whisper_free(ruby_whisper *rw) @@ -37,19 +38,64 @@ rb_whisper_mark(ruby_whisper *rw) } void -rb_whisper_free(ruby_whisper *rw) +rb_whisper_free(void *p) { + ruby_whisper *rw = (ruby_whisper *)p; ruby_whisper_free(rw); free(rw); } +static size_t +ruby_whisper_memsize(const void *p) +{ + const ruby_whisper *rw = (const ruby_whisper *)p; + size_t size = sizeof(rw); + if (!rw) { + return 0; + } + return size; +} + +const rb_data_type_t ruby_whisper_type = { + "ruby_whisper", + {0, rb_whisper_free, ruby_whisper_memsize,}, + 0, 0, + 0 +}; + static VALUE ruby_whisper_allocate(VALUE klass) { ruby_whisper *rw; - rw = ALLOC(ruby_whisper); + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper, &ruby_whisper_type, rw); rw->context = NULL; - return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); + return obj; +} + +VALUE +ruby_whisper_normalize_model_path(VALUE model_path) +{ + VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0); + VALUE pre_converted_model = rb_hash_aref(pre_converted_models, model_path); + if (!NIL_P(pre_converted_model)) { + model_path = pre_converted_model; + } + else if (TYPE(model_path) == T_STRING) { + const char * model_path_str = StringValueCStr(model_path); + if (strncmp("http://", model_path_str, 7) == 0 || strncmp("https://", model_path_str, 8) == 0) { + VALUE uri_class = rb_const_get(cModel, id_URI); + model_path = rb_class_new_instance(1, &model_path, uri_class); + } + } + else if (rb_obj_is_kind_of(model_path, rb_path2class("URI::HTTP"))) { + VALUE uri_class = rb_const_get(cModel, id_URI); + model_path = rb_class_new_instance(1, &model_path, uri_class); + } + if (rb_respond_to(model_path, id_to_path)) { + model_path = rb_funcall(model_path, id_to_path, 0); + } + + return model_path; } /* @@ -66,27 +112,9 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) // TODO: we can support init from buffer here too maybe another ruby object to expose rb_scan_args(argc, argv, "01", &whisper_model_file_path); - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); - VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0); - VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path); - if (!NIL_P(pre_converted_model)) { - whisper_model_file_path = pre_converted_model; - } - if (TYPE(whisper_model_file_path) == T_STRING) { - const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path); - if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) { - VALUE uri_class = rb_const_get(cModel, id_URI); - whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); - } - } - if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) { - VALUE uri_class = rb_const_get(cModel, id_URI); - whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); - } - if (rb_respond_to(whisper_model_file_path, id_to_path)) { - whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0); - } + whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path); if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } @@ -104,7 +132,7 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) VALUE ruby_whisper_model_n_vocab(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_vocab(rw->context)); } @@ -115,7 +143,7 @@ VALUE ruby_whisper_model_n_vocab(VALUE self) VALUE ruby_whisper_model_n_audio_ctx(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_ctx(rw->context)); } @@ -126,7 +154,7 @@ VALUE ruby_whisper_model_n_audio_ctx(VALUE self) VALUE ruby_whisper_model_n_audio_state(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_state(rw->context)); } @@ -137,7 +165,7 @@ VALUE ruby_whisper_model_n_audio_state(VALUE self) VALUE ruby_whisper_model_n_audio_head(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_head(rw->context)); } @@ -148,7 +176,7 @@ VALUE ruby_whisper_model_n_audio_head(VALUE self) VALUE ruby_whisper_model_n_audio_layer(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_layer(rw->context)); } @@ -159,7 +187,7 @@ VALUE ruby_whisper_model_n_audio_layer(VALUE self) VALUE ruby_whisper_model_n_text_ctx(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_ctx(rw->context)); } @@ -170,7 +198,7 @@ VALUE ruby_whisper_model_n_text_ctx(VALUE self) VALUE ruby_whisper_model_n_text_state(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_state(rw->context)); } @@ -181,7 +209,7 @@ VALUE ruby_whisper_model_n_text_state(VALUE self) VALUE ruby_whisper_model_n_text_head(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_head(rw->context)); } @@ -192,7 +220,7 @@ VALUE ruby_whisper_model_n_text_head(VALUE self) VALUE ruby_whisper_model_n_text_layer(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_layer(rw->context)); } @@ -203,7 +231,7 @@ VALUE ruby_whisper_model_n_text_layer(VALUE self) VALUE ruby_whisper_model_n_mels(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_mels(rw->context)); } @@ -214,7 +242,7 @@ VALUE ruby_whisper_model_n_mels(VALUE self) VALUE ruby_whisper_model_ftype(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_ftype(rw->context)); } @@ -225,7 +253,7 @@ VALUE ruby_whisper_model_ftype(VALUE self) VALUE ruby_whisper_model_type(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return rb_str_new2(whisper_model_type_readable(rw->context)); } @@ -248,9 +276,9 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) ruby_whisper *rw; ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); VALUE params = argv[0]; - Data_Get_Struct(params, ruby_whisper_params, rwp); + TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); VALUE samples = argv[1]; int n_samples; rb_memory_view_t view; @@ -296,7 +324,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } } - register_callbacks(rwp, &self); + prepare_transcription(rwp, &self); const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); if (0 == result) { return self; @@ -327,9 +355,9 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) ruby_whisper *rw; ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); VALUE params = argv[0]; - Data_Get_Struct(params, ruby_whisper_params, rwp); + TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); VALUE samples = argv[1]; int n_samples; int n_processors; @@ -387,7 +415,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) } } } - register_callbacks(rwp, &self); + prepare_transcription(rwp, &self); const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); if (0 == result) { return self; @@ -406,7 +434,7 @@ static VALUE ruby_whisper_full_n_segments(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_full_n_segments(rw->context)); } @@ -420,7 +448,7 @@ static VALUE ruby_whisper_full_lang_id(VALUE self) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_full_lang_id(rw->context)); } @@ -445,7 +473,7 @@ static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment); return INT2NUM(t0); @@ -463,7 +491,7 @@ static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment); return INT2NUM(t1); @@ -481,7 +509,7 @@ static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment); return speaker_turn_next ? Qtrue : Qfalse; @@ -499,7 +527,7 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const char * text = whisper_full_get_segment_text(rw->context, c_i_segment); return rb_str_new2(text); @@ -513,7 +541,7 @@ static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) { ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment); return DBL2NUM(no_speech_prob); @@ -554,7 +582,7 @@ ruby_whisper_each_segment(VALUE self) } ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); const int n_segments = whisper_full_n_segments(rw->context); for (int i = 0; i < n_segments; ++i) { diff --git a/bindings/ruby/ext/ruby_whisper_model.c b/bindings/ruby/ext/ruby_whisper_model.c index 1e0648fd..54763c92 100644 --- a/bindings/ruby/ext/ruby_whisper_model.c +++ b/bindings/ruby/ext/ruby_whisper_model.c @@ -1,22 +1,44 @@ #include #include "ruby_whisper.h" +extern const rb_data_type_t ruby_whisper_type; + extern VALUE cModel; -static void rb_whisper_model_mark(ruby_whisper_model *rwm) { - rb_gc_mark(rwm->context); +static void rb_whisper_model_mark(void *p) { + ruby_whisper_model *rwm = (ruby_whisper_model *)p; + if (rwm->context) { + rb_gc_mark(rwm->context); + } } +static size_t +ruby_whisper_model_memsize(const void *p) +{ + const ruby_whisper_model *rwm = (const ruby_whisper_model *)p; + size_t size = sizeof(rwm); + if (!rwm) { + return 0; + } + return size; +} + +static const rb_data_type_t rb_whisper_model_type = { + "ruby_whisper_model", + {rb_whisper_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_model_memsize,}, + 0, 0, + 0 +}; + static VALUE ruby_whisper_model_allocate(VALUE klass) { ruby_whisper_model *rwm; - rwm = ALLOC(ruby_whisper_model); - return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm); + return TypedData_Make_Struct(klass, ruby_whisper_model, &rb_whisper_model_type, rwm); } VALUE rb_whisper_model_initialize(VALUE context) { ruby_whisper_model *rwm; const VALUE model = ruby_whisper_model_allocate(cModel); - Data_Get_Struct(model, ruby_whisper_model, rwm); + TypedData_Get_Struct(model, ruby_whisper_model, &rb_whisper_model_type, rwm); rwm->context = context; return model; }; @@ -29,9 +51,9 @@ static VALUE ruby_whisper_model_n_vocab(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_vocab(rw->context)); } @@ -43,9 +65,9 @@ static VALUE ruby_whisper_model_n_audio_ctx(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_ctx(rw->context)); } @@ -57,9 +79,9 @@ static VALUE ruby_whisper_model_n_audio_state(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_state(rw->context)); } @@ -71,9 +93,9 @@ static VALUE ruby_whisper_model_n_audio_head(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_head(rw->context)); } @@ -85,9 +107,9 @@ static VALUE ruby_whisper_model_n_audio_layer(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_audio_layer(rw->context)); } @@ -99,9 +121,9 @@ static VALUE ruby_whisper_model_n_text_ctx(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_ctx(rw->context)); } @@ -113,9 +135,9 @@ static VALUE ruby_whisper_model_n_text_state(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_state(rw->context)); } @@ -127,9 +149,9 @@ static VALUE ruby_whisper_model_n_text_head(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_head(rw->context)); } @@ -141,9 +163,9 @@ static VALUE ruby_whisper_model_n_text_layer(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_text_layer(rw->context)); } @@ -155,9 +177,9 @@ static VALUE ruby_whisper_model_n_mels(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_n_mels(rw->context)); } @@ -169,9 +191,9 @@ static VALUE ruby_whisper_model_ftype(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return INT2NUM(whisper_model_ftype(rw->context)); } @@ -183,9 +205,9 @@ static VALUE ruby_whisper_model_type(VALUE self) { ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); + TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); + TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); return rb_str_new2(whisper_model_type_readable(rw->context)); } diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index c07f2372..4a65c92a 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -3,7 +3,7 @@ #define BOOL_PARAMS_SETTER(self, prop, value) \ ruby_whisper_params *rwp; \ - Data_Get_Struct(self, ruby_whisper_params, rwp); \ + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \ if (value == Qfalse || value == Qnil) { \ rwp->params.prop = false; \ } else { \ @@ -13,7 +13,7 @@ #define BOOL_PARAMS_GETTER(self, prop) \ ruby_whisper_params *rwp; \ - Data_Get_Struct(self, ruby_whisper_params, rwp); \ + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \ if (rwp->params.prop) { \ return Qtrue; \ } else { \ @@ -26,13 +26,16 @@ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); -#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32 +#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 35 extern VALUE cParams; +extern VALUE cVADParams; extern ID id_call; +extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); extern VALUE rb_whisper_segment_initialize(VALUE context, int index); +extern const rb_data_type_t ruby_whisper_vad_params_type; static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT]; static ID id_language; @@ -67,6 +70,9 @@ static ID id_encoder_begin_callback; static ID id_encoder_begin_callback_user_data; static ID id_abort_callback; static ID id_abort_callback_user_data; +static ID id_vad; +static ID id_vad_model_path; +static ID id_vad_params; static void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) @@ -177,7 +183,7 @@ static bool abort_callback(void * user_data) { return false; } -void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { +static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; @@ -203,13 +209,29 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { } } -void -rb_whisper_params_mark(ruby_whisper_params *rwp) +static void set_vad_params(ruby_whisper_params *rwp) { + ruby_whisper_vad_params * rwvp; + TypedData_Get_Struct(rwp->vad_params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwp->params.vad_params = rwvp->params; +} + +void +prepare_transcription(ruby_whisper_params *rwp, VALUE *context) +{ + register_callbacks(rwp, context); + set_vad_params(rwp); +} + +void +rb_whisper_params_mark(void *p) +{ + ruby_whisper_params *rwp = (ruby_whisper_params *)p; rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container); rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); rb_whisper_callbcack_container_mark(rwp->abort_callback_container); + rb_gc_mark(rwp->vad_params); } void @@ -218,25 +240,46 @@ ruby_whisper_params_free(ruby_whisper_params *rwp) } void -rb_whisper_params_free(ruby_whisper_params *rwp) +rb_whisper_params_free(void *p) { + ruby_whisper_params *rwp = (ruby_whisper_params *)p; // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); free(rwp); } +static size_t +ruby_whisper_params_memsize(const void *p) +{ + const ruby_whisper_params *rwp = (const ruby_whisper_params *)p; + + return sizeof(ruby_whisper_params) + sizeof(rwp->params) + sizeof(rwp->vad_params); +} + +const rb_data_type_t ruby_whisper_params_type = { + "ruby_whisper_params", + { + rb_whisper_params_mark, + rb_whisper_params_free, + ruby_whisper_params_memsize, + }, + 0, 0, + 0 +}; + static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; - rwp = ALLOC(ruby_whisper_params); + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); rwp->diarize = false; + rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); rwp->abort_callback_container = rb_whisper_callback_container_allocate(); - return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); + return obj; } /* @@ -249,7 +292,7 @@ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); if (value == Qfalse || value == Qnil) { rwp->params.language = "auto"; } else { @@ -265,7 +308,7 @@ static VALUE ruby_whisper_params_get_language(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); if (rwp->params.language) { return rb_str_new2(rwp->params.language); } else { @@ -502,7 +545,7 @@ static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt); } /* @@ -513,7 +556,7 @@ static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.initial_prompt = StringValueCStr(value); return value; } @@ -527,7 +570,7 @@ static VALUE ruby_whisper_params_get_diarize(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); if (rwp->diarize) { return Qtrue; } else { @@ -542,7 +585,7 @@ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); if (value == Qfalse || value == Qnil) { rwp->diarize = false; } else { @@ -561,7 +604,7 @@ static VALUE ruby_whisper_params_get_offset(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return INT2NUM(rwp->params.offset_ms); } /* @@ -572,7 +615,7 @@ static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.offset_ms = NUM2INT(value); return value; } @@ -586,7 +629,7 @@ static VALUE ruby_whisper_params_get_duration(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return INT2NUM(rwp->params.duration_ms); } /* @@ -597,7 +640,7 @@ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.duration_ms = NUM2INT(value); return value; } @@ -612,7 +655,7 @@ static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return INT2NUM(rwp->params.n_max_text_ctx); } /* @@ -623,7 +666,7 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.n_max_text_ctx = NUM2INT(value); return value; } @@ -635,7 +678,7 @@ static VALUE ruby_whisper_params_get_temperature(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return DBL2NUM(rwp->params.temperature); } /* @@ -646,7 +689,7 @@ static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.temperature = RFLOAT_VALUE(value); return value; } @@ -660,7 +703,7 @@ static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return DBL2NUM(rwp->params.max_initial_ts); } /* @@ -671,7 +714,7 @@ static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.max_initial_ts = RFLOAT_VALUE(value); return value; } @@ -683,7 +726,7 @@ static VALUE ruby_whisper_params_get_length_penalty(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return DBL2NUM(rwp->params.length_penalty); } /* @@ -694,7 +737,7 @@ static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.length_penalty = RFLOAT_VALUE(value); return value; } @@ -706,7 +749,7 @@ static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return DBL2NUM(rwp->params.temperature_inc); } /* @@ -717,7 +760,7 @@ static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.temperature_inc = RFLOAT_VALUE(value); return value; } @@ -731,7 +774,7 @@ static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return DBL2NUM(rwp->params.entropy_thold); } /* @@ -742,7 +785,7 @@ static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.entropy_thold = RFLOAT_VALUE(value); return value; } @@ -754,7 +797,7 @@ static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return DBL2NUM(rwp->params.logprob_thold); } /* @@ -765,7 +808,7 @@ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.logprob_thold = RFLOAT_VALUE(value); return value; } @@ -777,7 +820,7 @@ static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return DBL2NUM(rwp->params.no_speech_thold); } /* @@ -788,7 +831,7 @@ static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params.no_speech_thold = RFLOAT_VALUE(value); return value; } @@ -796,7 +839,7 @@ static VALUE ruby_whisper_params_get_new_segment_callback(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->new_segment_callback_container->callback; } /* @@ -813,7 +856,7 @@ static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->new_segment_callback_container->callback = value; return value; } @@ -821,7 +864,7 @@ static VALUE ruby_whisper_params_get_new_segment_callback_user_data(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->new_segment_callback_container->user_data; } /* @@ -834,7 +877,7 @@ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->new_segment_callback_container->user_data = value; return value; } @@ -842,7 +885,7 @@ static VALUE ruby_whisper_params_get_progress_callback(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->progress_callback_container->callback; } /* @@ -861,7 +904,7 @@ static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->progress_callback_container->callback = value; return value; } @@ -869,7 +912,7 @@ static VALUE ruby_whisper_params_get_progress_callback_user_data(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->progress_callback_container->user_data; } /* @@ -882,7 +925,7 @@ static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->progress_callback_container->user_data = value; return value; } @@ -891,7 +934,7 @@ static VALUE ruby_whisper_params_get_encoder_begin_callback(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->encoder_begin_callback_container->callback; } @@ -909,7 +952,7 @@ static VALUE ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->encoder_begin_callback_container->callback = value; return value; } @@ -918,7 +961,7 @@ static VALUE ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->encoder_begin_callback_container->user_data; } @@ -932,7 +975,7 @@ static VALUE ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->encoder_begin_callback_container->user_data = value; return value; } @@ -941,7 +984,7 @@ static VALUE ruby_whisper_params_get_abort_callback(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->abort_callback_container->callback; } /* @@ -958,7 +1001,7 @@ static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->abort_callback_container->callback = value; return value; } @@ -966,7 +1009,7 @@ static VALUE ruby_whisper_params_get_abort_callback_user_data(VALUE self) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); return rwp->abort_callback_container->user_data; } /* @@ -979,11 +1022,74 @@ static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) { ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->abort_callback_container->user_data = value; return value; } +/* + * call-seq: + * vad = use_vad -> use_vad + */ +static VALUE +ruby_whisper_params_get_vad(VALUE self) +{ + BOOL_PARAMS_GETTER(self, vad) +} + +static VALUE +ruby_whisper_params_set_vad(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, vad, value) +} + +/* + * call-seq: + * vad_model_path = model_path -> model_path + */ +static VALUE +ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + if (NIL_P(value)) { + rwp->params.vad_model_path = NULL; + return value; + } + VALUE path = ruby_whisper_normalize_model_path(value); + rwp->params.vad_model_path = StringValueCStr(path); + return value; +} + +static VALUE +ruby_whisper_params_get_vad_model_path(VALUE self) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + return rwp->params.vad_model_path == NULL ? Qnil : rb_str_new2(rwp->params.vad_model_path); +} + +/* + * call-seq: + * vad_params = params -> params + */ +static VALUE +ruby_whisper_params_set_vad_params(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + rwp->vad_params = value; + return value; +} + +static VALUE +ruby_whisper_params_get_vad_params(VALUE self) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + return rwp->vad_params; +} + #define SET_PARAM_IF_SAME(param_name) \ if (id == id_ ## param_name) { \ ruby_whisper_params_set_ ## param_name(self, value); \ @@ -993,7 +1099,6 @@ ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) static VALUE ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) { - VALUE kw_hash; VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef}; VALUE value; @@ -1007,7 +1112,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) } rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values); - Data_Get_Struct(self, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) { id = param_names[i]; @@ -1050,6 +1155,9 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) SET_PARAM_IF_SAME(encoder_begin_callback_user_data) SET_PARAM_IF_SAME(abort_callback) SET_PARAM_IF_SAME(abort_callback_user_data) + SET_PARAM_IF_SAME(vad) + SET_PARAM_IF_SAME(vad_model_path) + SET_PARAM_IF_SAME(vad_params) } } @@ -1071,10 +1179,10 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) static VALUE ruby_whisper_params_on_new_segment(VALUE self) { - ruby_whisper_params *rws; - Data_Get_Struct(self, ruby_whisper_params, rws); + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); - rb_ary_push(rws->new_segment_callback_container->callbacks, blk); + rb_ary_push(rwp->new_segment_callback_container->callbacks, blk); return Qnil; } @@ -1091,10 +1199,10 @@ ruby_whisper_params_on_new_segment(VALUE self) static VALUE ruby_whisper_params_on_progress(VALUE self) { - ruby_whisper_params *rws; - Data_Get_Struct(self, ruby_whisper_params, rws); + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); - rb_ary_push(rws->progress_callback_container->callbacks, blk); + rb_ary_push(rwp->progress_callback_container->callbacks, blk); return Qnil; } @@ -1111,10 +1219,10 @@ ruby_whisper_params_on_progress(VALUE self) static VALUE ruby_whisper_params_on_encoder_begin(VALUE self) { - ruby_whisper_params *rws; - Data_Get_Struct(self, ruby_whisper_params, rws); + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); - rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk); + rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk); return Qnil; } @@ -1135,10 +1243,10 @@ ruby_whisper_params_on_encoder_begin(VALUE self) static VALUE ruby_whisper_params_abort_on(VALUE self) { - ruby_whisper_params *rws; - Data_Get_Struct(self, ruby_whisper_params, rws); + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); - rb_ary_push(rws->abort_callback_container->callbacks, blk); + rb_ary_push(rwp->abort_callback_container->callbacks, blk); return Qnil; } @@ -1182,6 +1290,9 @@ init_ruby_whisper_params(VALUE *mWhisper) DEFINE_PARAM(encoder_begin_callback_user_data, 29) DEFINE_PARAM(abort_callback, 30) DEFINE_PARAM(abort_callback_user_data, 31) + DEFINE_PARAM(vad, 32) + DEFINE_PARAM(vad_model_path, 33) + DEFINE_PARAM(vad_params, 34) rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index 3440ff95..9399f286 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -1,20 +1,40 @@ #include #include "ruby_whisper.h" +extern const rb_data_type_t ruby_whisper_type; + extern VALUE cSegment; static void -rb_whisper_segment_mark(ruby_whisper_segment *rws) +rb_whisper_segment_mark(void *p) { + ruby_whisper_segment *rws = (ruby_whisper_segment *)p; rb_gc_mark(rws->context); } +static size_t +ruby_whisper_segment_memsize(const void *p) +{ + const ruby_whisper_segment *rws = (const ruby_whisper_segment *)p; + size_t size = sizeof(rws); + if (!rws) { + return 0; + } + return size; +} + +static const rb_data_type_t ruby_whisper_segment_type = { + "ruby_whisper_segment", + {rb_whisper_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_segment_memsize,}, + 0, 0, + 0 +}; + VALUE ruby_whisper_segment_allocate(VALUE klass) { ruby_whisper_segment *rws; - rws = ALLOC(ruby_whisper_segment); - return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws); + return TypedData_Make_Struct(klass, ruby_whisper_segment, &ruby_whisper_segment_type, rws); } VALUE @@ -22,7 +42,7 @@ rb_whisper_segment_initialize(VALUE context, int index) { ruby_whisper_segment *rws; const VALUE segment = ruby_whisper_segment_allocate(cSegment); - Data_Get_Struct(segment, ruby_whisper_segment, rws); + TypedData_Get_Struct(segment, ruby_whisper_segment, &ruby_whisper_segment_type, rws); rws->context = context; rws->index = index; return segment; @@ -38,9 +58,9 @@ static VALUE ruby_whisper_segment_get_start_time(VALUE self) { ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); + TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); + TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index); // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it return INT2NUM(t0 * 10); @@ -56,9 +76,9 @@ static VALUE ruby_whisper_segment_get_end_time(VALUE self) { ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); + TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); + TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index); // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it return INT2NUM(t1 * 10); @@ -74,9 +94,9 @@ static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) { ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); + TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); + TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse; } @@ -88,9 +108,9 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) { ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); + TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); + TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); const char * text = whisper_full_get_segment_text(rw->context, rws->index); return rb_str_new2(text); } @@ -103,9 +123,9 @@ static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) { ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); + TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); + TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index)); } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index ef3c0780..d12d2de9 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -8,11 +8,14 @@ extern "C" { #endif +extern const rb_data_type_t ruby_whisper_type; +extern const rb_data_type_t ruby_whisper_params_type; + extern ID id_to_s; extern ID id_call; extern void -register_callbacks(ruby_whisper_params * rwp, VALUE * self); +prepare_transcription(ruby_whisper_params * rwp, VALUE * self); /* * transcribe a single file @@ -34,8 +37,8 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { VALUE wave_file_path, blk, params; rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk); - Data_Get_Struct(self, ruby_whisper, rw); - Data_Get_Struct(params, ruby_whisper_params, rwp); + TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); + TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); if (!rb_respond_to(wave_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to wave file"); @@ -61,7 +64,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { // rwp->params.encoder_begin_callback_user_data = &is_aborted; // } - register_callbacks(rwp, &self); + prepare_transcription(rwp, &self); if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); diff --git a/bindings/ruby/ext/ruby_whisper_vad_params.c b/bindings/ruby/ext/ruby_whisper_vad_params.c new file mode 100644 index 00000000..be7bc465 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_vad_params.c @@ -0,0 +1,288 @@ +#include +#include "ruby_whisper.h" + +#define DEFINE_PARAM(param_name, nth) \ + id_ ## param_name = rb_intern(#param_name); \ + param_names[nth] = id_ ## param_name; \ + rb_define_method(cVADParams, #param_name, ruby_whisper_vad_params_get_ ## param_name, 0); \ + rb_define_method(cVADParams, #param_name "=", ruby_whisper_vad_params_set_ ## param_name, 1); + +#define NUM_PARAMS 6 + +extern VALUE cVADParams; + +static size_t +ruby_whisper_vad_params_memsize(const void *p) +{ + const struct ruby_whisper_vad_params *params = p; + size_t size = sizeof(params); + if (!params) { + return 0; + } + return size; +} + +static ID param_names[NUM_PARAMS]; +static ID id_threshold; +static ID id_min_speech_duration_ms; +static ID id_min_silence_duration_ms; +static ID id_max_speech_duration_s; +static ID id_speech_pad_ms; +static ID id_samples_overlap; + +const rb_data_type_t ruby_whisper_vad_params_type = { + "ruby_whisper_vad_params", + {0, 0, ruby_whisper_vad_params_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_vad_params_s_allocate(VALUE klass) +{ + ruby_whisper_vad_params *rwvp; + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwvp->params = whisper_vad_default_params(); + return obj; +} + +/* + * Probability threshold to consider as speech. + * + * call-seq: + * threshold = th -> th + */ +static VALUE +ruby_whisper_vad_params_set_threshold(VALUE self, VALUE value) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwvp->params.threshold = RFLOAT_VALUE(value); + return value; +} + +static VALUE +ruby_whisper_vad_params_get_threshold(VALUE self) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + return DBL2NUM(rwvp->params.threshold); +} + +/* + * Min duration for a valid speech segment. + * + * call-seq: + * min_speech_duration_ms = duration_ms -> duration_ms + */ +static VALUE +ruby_whisper_vad_params_set_min_speech_duration_ms(VALUE self, VALUE value) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwvp->params.min_speech_duration_ms = NUM2INT(value); + return value; +} + +static VALUE +ruby_whisper_vad_params_get_min_speech_duration_ms(VALUE self) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + return INT2NUM(rwvp->params.min_speech_duration_ms); +} + +/* + * Min silence duration to consider speech as ended. + * + * call-seq: + * min_silence_duration_ms = duration_ms -> duration_ms + */ +static VALUE +ruby_whisper_vad_params_set_min_silence_duration_ms(VALUE self, VALUE value) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwvp->params.min_silence_duration_ms = NUM2INT(value); + return value; +} + +static VALUE +ruby_whisper_vad_params_get_min_silence_duration_ms(VALUE self) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + return INT2NUM(rwvp->params.min_silence_duration_ms); +} + +/* + * Max duration of a speech segment before forcing a new segment. + * + * call-seq: + * max_speech_duration_s = duration_s -> duration_s + */ +static VALUE +ruby_whisper_vad_params_set_max_speech_duration_s(VALUE self, VALUE value) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwvp->params.max_speech_duration_s = RFLOAT_VALUE(value); + return value; +} + +static VALUE +ruby_whisper_vad_params_get_max_speech_duration_s(VALUE self) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + return DBL2NUM(rwvp->params.max_speech_duration_s); +} + +/* + * Padding added before and after speech segments. + * + * call-seq: + * speech_pad_ms = pad_ms -> pad_ms + */ +static VALUE +ruby_whisper_vad_params_set_speech_pad_ms(VALUE self, VALUE value) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwvp->params.speech_pad_ms = NUM2INT(value); + return value; +} + +static VALUE +ruby_whisper_vad_params_get_speech_pad_ms(VALUE self) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + return INT2NUM(rwvp->params.speech_pad_ms); +} + +/* + * Overlap in seconds when copying audio samples from speech segment. + * + * call-seq: + * samples_overlap = overlap -> overlap + */ +static VALUE +ruby_whisper_vad_params_set_samples_overlap(VALUE self, VALUE value) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + rwvp->params.samples_overlap = RFLOAT_VALUE(value); + return value; +} + +static VALUE +ruby_whisper_vad_params_get_samples_overlap(VALUE self) +{ + ruby_whisper_vad_params *rwvp; + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + return DBL2NUM(rwvp->params.samples_overlap); +} + +static VALUE +ruby_whisper_vad_params_equal(VALUE self, VALUE other) +{ + ruby_whisper_vad_params *rwvp1; + ruby_whisper_vad_params *rwvp2; + + if (self == other) { + return Qtrue; + } + + if (!rb_obj_is_kind_of(other, cVADParams)) { + return Qfalse; + } + + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp1); + TypedData_Get_Struct(other, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp2); + + if (rwvp1->params.threshold != rwvp2->params.threshold) { + return Qfalse; + } + if (rwvp1->params.min_speech_duration_ms != rwvp2->params.min_speech_duration_ms) { + return Qfalse; + } + if (rwvp1->params.min_silence_duration_ms != rwvp2->params.min_silence_duration_ms) { + return Qfalse; + } + if (rwvp1->params.max_speech_duration_s != rwvp2->params.max_speech_duration_s) { + return Qfalse; + } + if (rwvp1->params.speech_pad_ms != rwvp2->params.speech_pad_ms) { + return Qfalse; + } + if (rwvp1->params.samples_overlap != rwvp2->params.samples_overlap) { + return Qfalse; + } + + return Qtrue; +} + +#define SET_PARAM_IF_SAME(param_name) \ + if (id == id_ ## param_name) { \ + ruby_whisper_vad_params_set_ ## param_name(self, value); \ + continue; \ + } + +VALUE +ruby_whisper_vad_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[NUM_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_vad_params *rwvp; + ID id; + int i; + + TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return self; + } + + rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values); + + for (i = 0; i < NUM_PARAMS; i++) { + id= param_names[i]; + value = values[i]; + if (value == Qundef) { + continue; + } + SET_PARAM_IF_SAME(threshold) + SET_PARAM_IF_SAME(min_speech_duration_ms) + SET_PARAM_IF_SAME(min_silence_duration_ms) + SET_PARAM_IF_SAME(max_speech_duration_s) + SET_PARAM_IF_SAME(speech_pad_ms) + SET_PARAM_IF_SAME(samples_overlap) + } + + return self; +} + +#undef SET_PARAM_IF_SAME + +void +init_ruby_whisper_vad_params(VALUE *mVAD) +{ + cVADParams = rb_define_class_under(*mVAD, "Params", rb_cObject); + rb_define_alloc_func(cVADParams, ruby_whisper_vad_params_s_allocate); + rb_define_method(cVADParams, "initialize", ruby_whisper_vad_params_initialize, -1); + + DEFINE_PARAM(threshold, 0) + DEFINE_PARAM(min_speech_duration_ms, 1) + DEFINE_PARAM(min_silence_duration_ms, 2) + DEFINE_PARAM(max_speech_duration_s, 3) + DEFINE_PARAM(speech_pad_ms, 4) + DEFINE_PARAM(samples_overlap, 5) + + rb_define_method(cVADParams, "==", ruby_whisper_vad_params_equal, 1); +} + +#undef DEFINE_PARAM +#undef NUM_PARAMS diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 06e7a263..fb3ee5db 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -165,6 +165,12 @@ module Whisper models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin") } + %w[ + silero-v5.1.2 + ].each do |name| + @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin") + end + class << self attr_reader :pre_converted_models end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index a3ce94b8..c1373c87 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -150,7 +150,10 @@ module Whisper ?encoder_begin_callback: encoder_begin_callback, ?encoder_begin_callback_user_data: Object, ?abort_callback: abort_callback, - ?abort_callback_user_data: Object + ?abort_callback_user_data: Object, + ?vad: boolish, + ?vad_model_path: path | URI, + ?vad_params: Whisper::VAD::Params ) -> instance # params.language = "auto" | "en", etc... @@ -338,6 +341,20 @@ module Whisper def abort_callback_user_data: () -> Object + # Enable VAD + # + def vad=: (boolish) -> boolish + + def vad: () -> (true | false) + + # Path to the VAD model + def vad_model_path=: (path | URI | nil) -> (path | URI | nil) + + def vad_model_path: () -> (String | nil) + + def vad_params=: (Whisper::VAD::Params) -> Whisper::VAD::Params + def vad_params: () -> (Whisper::VAD::Params) + # Hook called on new segment. Yields each Whisper::Segment. # # whisper.on_new_segment do |segment| @@ -406,6 +423,55 @@ module Whisper def no_speech_prob: () -> Float end + module VAD + class Params + def self.new: ( + ?threshold: Float, + ?min_speech_duration_ms: Integer, + ?min_silence_duration_ms: Integer, + ?max_speech_duration_s: Float, + ?speech_pad_ms: Integer, + ?samples_overlap: Float + ) -> instance + + # Probability threshold to consider as speech. + # + def threshold=: (Float) -> Float + + def threshold: () -> Float + + # Min duration for a valid speech segment. + # + def min_speech_duration_ms=: (Integer) -> Integer + + def min_speech_duration_ms: () -> Integer + + # Min silence duration to consider speech as ended. + # + def min_silence_duration_ms=: (Integer) -> Integer + + def min_silence_duration_ms: () -> Integer + + # Max duration of a speech segment before forcing a new segment. + def max_speech_duration_s=: (Float) -> Float + + def max_speech_duration_s: () -> Float + + # Padding added before and after speech segments. + # + def speech_pad_ms=: (Integer) -> Integer + + def speech_pad_ms: () -> Integer + + # Overlap in seconds when copying audio samples from speech segment. + # + def samples_overlap=: (Float) -> Float + + def samples_overlap: () -> Float + def ==: (Params) -> (true | false) + end + end + class Error < StandardError attr_reader code: Integer diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb index 5f7fc387..9a953579 100644 --- a/bindings/ruby/tests/test_params.rb +++ b/bindings/ruby/tests/test_params.rb @@ -32,6 +32,9 @@ class TestParams < TestBase :progress_callback_user_data, :abort_callback, :abort_callback_user_data, + :vad, + :vad_model_path, + :vad_params, ] def setup @@ -191,6 +194,50 @@ class TestParams < TestBase assert_in_delta 0.2, @params.no_speech_thold end + def test_vad + assert_false @params.vad + @params.vad = true + assert_true @params.vad + end + + def test_vad_model_path + assert_nil @params.vad_model_path + @params.vad_model_path = "silero-v5.1.2" + assert_equal Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path, @params.vad_model_path + end + + def test_vad_model_path_with_nil + @params.vad_model_path = "silero-v5.1.2" + @params.vad_model_path = nil + assert_nil @params.vad_model_path + end + + def test_vad_model_path_with_invalid + assert_raise TypeError do + @params.vad_model_path = Object.new + end + end + + def test_vad_model_path_with_URI_string + @params.vad_model_path = "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin" + assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path + end + + def test_vad_model_path_with_URI + @params.vad_model_path = URI("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin") + assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path + end + + def test_vad_params + assert_kind_of Whisper::VAD::Params, @params.vad_params + default_params = @params.vad_params + assert_same default_params, @params.vad_params + assert_equal 0.5, default_params.threshold + new_params = Whisper::VAD::Params.new + @params.vad_params = new_params + assert_same new_params, @params.vad_params + end + def test_new_with_kw_args params = Whisper::Params.new(language: "es") assert_equal "es", params.language @@ -225,6 +272,10 @@ class TestParams < TestBase proc {} in [/_user_data\Z/, *] Object.new + in [:vad_model_path, *] + Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path + in [:vad_params, *] + Whisper::VAD::Params.new end params = Whisper::Params.new(param => value) if Float === value diff --git a/bindings/ruby/tests/test_vad.rb b/bindings/ruby/tests/test_vad.rb new file mode 100644 index 00000000..cb5e3c79 --- /dev/null +++ b/bindings/ruby/tests/test_vad.rb @@ -0,0 +1,19 @@ +require_relative "helper" + +class TestVAD < TestBase + def setup + @whisper = Whisper::Context.new("base.en") + vad_params = Whisper::VAD::Params.new + @params = Whisper::Params.new( + vad: true, + vad_model_path: "silero-v5.1.2", + vad_params: + ) + end + + def test_transcribe + @whisper.transcribe(TestBase::AUDIO, @params) do |text| + assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text) + end + end +end diff --git a/bindings/ruby/tests/test_vad_params.rb b/bindings/ruby/tests/test_vad_params.rb new file mode 100644 index 00000000..add4899e --- /dev/null +++ b/bindings/ruby/tests/test_vad_params.rb @@ -0,0 +1,103 @@ +require_relative "helper" + +class TestVADParams < TestBase + PARAM_NAMES = [ + :threshold, + :min_speech_duration_ms, + :min_silence_duration_ms, + :max_speech_duration_s, + :speech_pad_ms, + :samples_overlap + ] + + def setup + @params = Whisper::VAD::Params.new + end + + def test_new + params = Whisper::VAD::Params.new + assert_kind_of Whisper::VAD::Params, params + end + + def test_threshold + assert_in_delta @params.threshold, 0.5 + @params.threshold = 0.7 + assert_in_delta @params.threshold, 0.7 + end + + def test_min_speech_duration + pend + end + + def test_min_speech_duration_ms + assert_equal 250, @params.min_speech_duration_ms + @params.min_speech_duration_ms = 500 + assert_equal 500, @params.min_speech_duration_ms + end + + def test_min_silence_duration_ms + assert_equal 100, @params.min_silence_duration_ms + @params.min_silence_duration_ms = 200 + assert_equal 200, @params.min_silence_duration_ms + end + + def test_max_speech_duration + pend + end + + def test_max_speech_duration_s + assert @params.max_speech_duration_s >= 10e37 # Defaults to FLT_MAX + @params.max_speech_duration_s = 60.0 + assert_equal 60.0, @params.max_speech_duration_s + end + + def test_speech_pad_ms + assert_equal 30, @params.speech_pad_ms + @params.speech_pad_ms = 50 + assert_equal 50, @params.speech_pad_ms + end + + def test_samples_overlap + assert_in_delta @params.samples_overlap, 0.1 + @params.samples_overlap = 0.5 + assert_in_delta @params.samples_overlap, 0.5 + end + + def test_equal + assert_equal @params, Whisper::VAD::Params.new + end + + def test_new_with_kw_args + params = Whisper::VAD::Params.new(threshold: 0.7) + assert_in_delta params.threshold, 0.7 + assert_equal 250, params.min_speech_duration_ms + end + + def test_new_with_kw_args_non_existent + assert_raise ArgumentError do + Whisper::VAD::Params.new(non_existent: "value") + end + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_value = @params.send(param) + value = default_value + 1 + params = Whisper::VAD::Params.new(param => value) + if Float === value + assert_in_delta value, params.send(param) + else + assert_equal value, params.send(param) + end + + PARAM_NAMES.reject {|name| name == param}.each do |name| + expected = @params.send(name) + actual = params.send(name) + if Float === expected + assert_in_delta expected, actual + else + assert_equal expected, actual + end + end + end +end