diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 03a8b9e1..8492e4ed 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -63,7 +63,7 @@ whisper = Whisper::Context.new("base.en") You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`: ```ruby -puts Whisper::Model.preconverted_model_names +puts Whisper::Model.preconverted_models.keys # tiny # tiny.en # tiny-q5_1 @@ -220,7 +220,7 @@ whisper.each_segment do |segment| end ``` -The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. +The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. License ------- diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index aa526577..88a4fd2c 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -53,6 +53,9 @@ static ID id_pre_converted_models; static bool is_log_callback_finalized = false; +// High level API +static VALUE rb_whisper_segment_initialize(VALUE context, int index); + /* * call-seq: * lang_max_id -> Integer @@ -187,6 +190,69 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate( return container; } +static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return; + } + const int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment); + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } +} + +static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + const VALUE progress = INT2NUM(progress_cur); + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, progress); + } +} + +static bool abort_callback(void * user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!NIL_P(container->callback)) { + VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (!NIL_P(result) && Qfalse != result) { + return true; + } + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return false; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + VALUE result = rb_funcall(cb, id_call, 1, container->user_data); + if (!NIL_P(result) && Qfalse != result) { + return true; + } + } + return false; +} + static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); @@ -230,8 +296,25 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { return self; } -// High level API -static VALUE rb_whisper_segment_initialize(VALUE context, int index); +static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) { + if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + rwp->new_segment_callback_container->context = self; + rwp->params.new_segment_callback = new_segment_callback; + rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; + } + + if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + rwp->progress_callback_container->context = self; + rwp->params.progress_callback = progress_callback; + rwp->params.progress_callback_user_data = rwp->progress_callback_container; + } + + if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + rwp->abort_callback_container->context = self; + rwp->params.abort_callback = abort_callback; + rwp->params.abort_callback_user_data = rwp->abort_callback_container; + } +} /* * transcribe a single file @@ -353,80 +436,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { - rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - - // Currently, doesn't support state because - // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); - } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return; - } - const int n_segments = whisper_full_n_segments_from_state(state); - for (int i = n_new; i > 0; i--) { - int i_segment = n_segments - i; - VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment); - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, segment); - } - } - }; - rwp->new_segment_callback_container->context = &self; - rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; - } - - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { - rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - const VALUE progress = INT2NUM(progress_cur); - // Currently, doesn't support state because - // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); - } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, progress); - } - }; - rwp->progress_callback_container->context = &self; - rwp->params.progress_callback_user_data = rwp->progress_callback_container; - } - - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { - rwp->params.abort_callback = [](void * user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { - return true; - } - } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return false; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { - return true; - } - } - return false; - }; - rwp->abort_callback_container->context = &self; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; - } + register_callbacks(rwp, &self); if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); @@ -631,6 +641,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) { } } } + register_callbacks(rwp, &self); const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); if (0 == result) { return Qnil; @@ -719,6 +730,7 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) { } } } + register_callbacks(rwp, &self); const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); if (0 == result) { return Qnil; @@ -823,6 +835,18 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { return rb_str_new2(text); } +/* + * call-seq: + * full_get_segment_no_speech_prob -> Float + */ +static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment); + return DBL2NUM(no_speech_prob); +} + /* * params.language = "auto" | "en", etc... * @@ -1547,6 +1571,18 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) { return rb_str_new2(text); } +/* + * call-seq: + * no_speech_prob -> Float + */ +static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index)); +} + static void rb_whisper_model_mark(ruby_whisper_model *rwm) { rb_gc_mark(rwm->context); } @@ -1809,6 +1845,7 @@ void Init_whisper() { rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1); + rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1); rb_define_method(cContext, "full", ruby_whisper_full, -1); rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1); @@ -1887,6 +1924,7 @@ void Init_whisper() { rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0); + rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0); cModel = rb_define_class_under(mWhisper, "Model", rb_cObject); rb_define_alloc_func(cModel, ruby_whisper_model_allocate); diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 5ca77ed4..fe5ed56b 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -79,30 +79,36 @@ class Whisper::Model downloaded += chunk.bytesize show_progress downloaded, size end + $stderr.puts end downloading_path.rename path end def show_progress(current, size) - return unless $stderr.tty? - return unless size + progress_rate_available = size && $stderr.tty? unless @prev @prev = Time.now - $stderr.puts "Downloading #{@uri}" + $stderr.puts "Downloading #{@uri} to #{cache_path}" end now = Time.now - return if now - @prev < 1 && current < size - progress_width = 20 - progress = current.to_f / size - arrow_length = progress * progress_width - arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length) - line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})" - padding = ' ' * ($stderr.winsize[1] - line.size) - $stderr.print "\r#{line}#{padding}" - $stderr.puts if current >= size + if progress_rate_available + return if now - @prev < 1 && current < size + + progress_width = 20 + progress = current.to_f / size + arrow_length = progress * progress_width + arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length) + line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})" + padding = ' ' * ($stderr.winsize[1] - line.size) + $stderr.print "\r#{line}#{padding}" + else + return if now - @prev < 1 + + $stderr.print "." + end @prev = now end diff --git a/bindings/ruby/tests/helper.rb b/bindings/ruby/tests/helper.rb index da52f268..a182319d 100644 --- a/bindings/ruby/tests/helper.rb +++ b/bindings/ruby/tests/helper.rb @@ -4,4 +4,21 @@ require_relative "jfk_reader/jfk_reader" class TestBase < Test::Unit::TestCase AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav") + + class << self + attr_reader :whisper + + def startup + @whisper = Whisper::Context.new("base.en") + params = Whisper::Params.new + params.print_timestamps = false + @whisper.transcribe(TestBase::AUDIO, params) + end + end + + private + + def whisper + self.class.whisper + end end diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb index 9c47870e..33c2b37e 100644 --- a/bindings/ruby/tests/test_package.rb +++ b/bindings/ruby/tests/test_package.rb @@ -23,7 +23,7 @@ class TestPackage < TestBase version = match_data[2] basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true + system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename) end end diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb index 559bcea7..44ab0a6b 100644 --- a/bindings/ruby/tests/test_segment.rb +++ b/bindings/ruby/tests/test_segment.rb @@ -1,17 +1,6 @@ require_relative "helper" class TestSegment < TestBase - class << self - attr_reader :whisper - - def startup - @whisper = Whisper::Context.new("base.en") - params = Whisper::Params.new - params.print_timestamps = false - @whisper.transcribe(TestBase::AUDIO, params) - end - end - def test_iteration whisper.each_segment do |segment| assert_instance_of Whisper::Segment, segment @@ -43,6 +32,14 @@ class TestSegment < TestBase end end + def test_no_speech_prob + no_speech_prob = nil + whisper.each_segment do |segment| + no_speech_prob = segment.no_speech_prob + end + assert no_speech_prob > 0.0 + end + def test_on_new_segment params = Whisper::Params.new seg = nil @@ -74,10 +71,4 @@ class TestSegment < TestBase end whisper.transcribe(AUDIO, params) end - - private - - def whisper - self.class.whisper - end end diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 115569ed..5b0d189e 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -21,21 +21,6 @@ class TestWhisper < TestBase end sub_test_case "After transcription" do - class << self - attr_reader :whisper - - def startup - @whisper = Whisper::Context.new("base.en") - params = Whisper::Params.new - params.print_timestamps = false - @whisper.transcribe(TestBase::AUDIO, params) - end - end - - def whisper - self.class.whisper - end - def test_full_n_segments assert_equal 1, whisper.full_n_segments end @@ -70,6 +55,12 @@ class TestWhisper < TestBase def test_full_get_segment_text assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0) end + + def test_full_get_segment_no_speech_prob + prob = whisper.full_get_segment_no_speech_prob(0) + assert prob > 0.0 + assert prob < 1.0 + end end def test_lang_max_id