From d4bc413505b2fba98dffbb9a176ddd1b165941d0 Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Tue, 29 Oct 2024 02:23:23 +0900 Subject: [PATCH] ruby : add more APIs (#2518) * Add test for built package existence * Add more tests for Whisper::Params * Add more Whisper::Params attributes * Add tests for callbacks * Add progress and abort callback features * [skip ci] Add prompt usage in README * Change prompt text in example --- bindings/ruby/README.md | 1 + bindings/ruby/ext/ruby_whisper.cpp | 320 ++++++++++++++++++++++++++- bindings/ruby/ext/ruby_whisper.h | 2 + bindings/ruby/tests/test_callback.rb | 87 ++++++++ bindings/ruby/tests/test_package.rb | 1 + bindings/ruby/tests/test_params.rb | 43 ++++ 6 files changed, 451 insertions(+), 3 deletions(-) diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 29dba120..928fe662 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -31,6 +31,7 @@ params.duration = 60_000 params.max_text_tokens = 300 params.translate = true params.print_timestamps = false +params.prompt = "Initial prompt here." whisper.transcribe("path/to/audio.wav", params) do |whole_text| puts whole_text diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index b17a6bca..2c720e98 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -107,10 +107,16 @@ void rb_whisper_free(ruby_whisper *rw) { free(rw); } +void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) { + rb_gc_mark(rwc->user_data); + rb_gc_mark(rwc->callback); + rb_gc_mark(rwc->callbacks); +} + void rb_whisper_params_mark(ruby_whisper_params *rwp) { - rb_gc_mark(rwp->new_segment_callback_container->user_data); - rb_gc_mark(rwp->new_segment_callback_container->callback); - rb_gc_mark(rwp->new_segment_callback_container->callbacks); + rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); + rb_whisper_callbcack_container_mark(rwp->progress_callback_container); + rb_whisper_callbcack_container_mark(rwp->abort_callback_container); } void rb_whisper_params_free(ruby_whisper_params *rwp) { @@ -141,6 +147,8 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); + rwp->progress_callback_container = rb_whisper_callback_container_allocate(); + rwp->abort_callback_container = rb_whisper_callback_container_allocate(); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -316,6 +324,54 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; } + if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + const VALUE progress = INT2NUM(progress_cur); + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, progress); + } + }; + rwp->progress_callback_container->context = &self; + rwp->params.progress_callback_user_data = rwp->progress_callback_container; + } + + if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + rwp->params.abort_callback = [](void * user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!NIL_P(container->callback)) { + VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (!NIL_P(result) && Qfalse != result) { + return true; + } + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return false; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + VALUE result = rb_funcall(cb, id_call, 1, container->user_data); + if (!NIL_P(result) && Qfalse != result) { + return true; + } + } + return false; + }; + rwp->abort_callback_container->context = &self; + rwp->params.abort_callback_user_data = rwp->abort_callback_container; + } + if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); return self; @@ -631,6 +687,30 @@ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, split_on_word, value) } +/* + * Tokens to provide to the whisper decoder as initial prompt + * these are prepended to any existing text context from a previous call + * use whisper_tokenize() to convert text to tokens. + * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). + * + * call-seq: + * initial_prompt -> String + */ +static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->params.initial_prompt == nullptr ? Qnil : rb_str_new2(rwp->params.initial_prompt); +} +/* + * call-seq: + * initial_prompt = prompt -> prompt + */ +static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.initial_prompt = StringValueCStr(value); + return value; +} /* * If true, enables diarization. * @@ -725,6 +805,124 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { rwp->params.n_max_text_ctx = NUM2INT(value); return value; } +/* + * call-seq: + * temperature -> Float + */ +static VALUE ruby_whisper_params_get_temperature(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.temperature); +} +/* + * call-seq: + * temperature = temp -> temp + */ +static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.temperature = RFLOAT_VALUE(value); + return value; +} +/* + * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + * + * call-seq: + * max_initial_ts -> Flaot + */ +static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.max_initial_ts); +} +/* + * call-seq: + * max_initial_ts = timestamp -> timestamp + */ +static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.max_initial_ts = RFLOAT_VALUE(value); + return value; +} +/* + * call-seq: + * length_penalty -> Float + */ +static VALUE ruby_whisper_params_get_length_penalty(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.length_penalty); +} +/* + * call-seq: + * length_penalty = penalty -> penalty + */ +static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.length_penalty = RFLOAT_VALUE(value); + return value; +} +/* + * call-seq: + * temperature_inc -> Float + */ +static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.temperature_inc); +} +/* + * call-seq: + * temperature_inc = inc -> inc + */ +static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.temperature_inc = RFLOAT_VALUE(value); + return value; +} +/* + * Similar to OpenAI's "compression_ratio_threshold" + * + * call-seq: + * entropy_thold -> Float + */ +static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.entropy_thold); +} +/* + * call-seq: + * entropy_thold = threshold -> threshold + */ +static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.entropy_thold = RFLOAT_VALUE(value); + return value; +} +/* + * call-seq: + * logprob_thold -> Float + */ +static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.logprob_thold); +} +/* + * call-seq: + * logprob_thold = threshold -> threshold + */ +static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.logprob_thold = RFLOAT_VALUE(value); + return value; +} /* * Sets new segment callback, called for every newly generated text segment. * @@ -753,6 +951,62 @@ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, rwp->new_segment_callback_container->user_data = value; return value; } +/* + * Sets progress callback, called on each progress update. + * + * params.new_segment_callback = ->(context, _, n_new, user_data) { + * # ... + * } + * + * call-seq: + * progress_callback = callback -> callback + */ +static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->progress_callback_container->callback = value; + return value; +} +/* + * Sets user data passed to the last argument of progress callback. + * + * call-seq: + * progress_callback_user_data = user_data -> use_data + */ +static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->progress_callback_container->user_data = value; + return value; +} +/* + * Sets abort callback, called to check if the process should be aborted. + * + * params.abort_callback = ->(user_data) { + * # ... + * } + * + * call-seq: + * abort_callback = callback -> callback + */ +static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->abort_callback_container->callback = value; + return value; +} +/* + * Sets user data passed to the last argument of abort callback. + * + * call-seq: + * abort_callback_user_data = user_data -> use_data + */ +static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->abort_callback_container->user_data = value; + return value; +} // High level API @@ -835,6 +1089,46 @@ static VALUE ruby_whisper_params_on_new_segment(VALUE self) { return Qnil; } +/* + * Hook called on progress update. Yields each progress Integer between 0 and 100. + * + * whisper.on_progress do |progress| + * # ... + * end + * + * call-seq: + * on_progress {|progress| ... } + */ +static VALUE ruby_whisper_params_on_progress(VALUE self) { + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->progress_callback_container->callbacks, blk); + return Qnil; +} + +/* + * Call block to determine whether abort or not. Return +true+ when you want to abort. + * + * params.abort_on do + * if some_condition + * true # abort + * else + * false # continue + * end + * end + * + * call-seq: + * abort_on { ... } + */ +static VALUE ruby_whisper_params_abort_on(VALUE self) { + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->abort_callback_container->callbacks, blk); + return Qnil; +} + /* * Start time in milliseconds. * @@ -946,6 +1240,8 @@ void Init_whisper() { rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); + rb_define_method(cParams, "initial_prompt", ruby_whisper_params_get_initial_prompt, 0); + rb_define_method(cParams, "initial_prompt=", ruby_whisper_params_set_initial_prompt, 1); rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); @@ -956,9 +1252,25 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); + rb_define_method(cParams, "temperature", ruby_whisper_params_get_temperature, 0); + rb_define_method(cParams, "temperature=", ruby_whisper_params_set_temperature, 1); + rb_define_method(cParams, "max_initial_ts", ruby_whisper_params_get_max_initial_ts, 0); + rb_define_method(cParams, "max_initial_ts=", ruby_whisper_params_set_max_initial_ts, 1); + rb_define_method(cParams, "length_penalty", ruby_whisper_params_get_length_penalty, 0); + rb_define_method(cParams, "length_penalty=", ruby_whisper_params_set_length_penalty, 1); + rb_define_method(cParams, "temperature_inc", ruby_whisper_params_get_temperature_inc, 0); + rb_define_method(cParams, "temperature_inc=", ruby_whisper_params_set_temperature_inc, 1); + rb_define_method(cParams, "entropy_thold", ruby_whisper_params_get_entropy_thold, 0); + rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1); + rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0); + rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1); rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1); + rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1); + rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1); + rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1); + rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1); // High leve cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject); @@ -966,6 +1278,8 @@ void Init_whisper() { rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate); rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); + rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); + rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0); rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0); rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index f3210a4b..a6771038 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -18,6 +18,8 @@ typedef struct { struct whisper_full_params params; bool diarize; ruby_whisper_callback_container *new_segment_callback_container; + ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *abort_callback_container; } ruby_whisper_params; #endif diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb index 80a5f4df..1234d31d 100644 --- a/bindings/ruby/tests/test_callback.rb +++ b/bindings/ruby/tests/test_callback.rb @@ -5,6 +5,7 @@ class TestCallback < Test::Unit::TestCase TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) def setup + GC.start @params = Whisper::Params.new @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') @@ -73,4 +74,90 @@ class TestCallback < Test::Unit::TestCase assert_same @whisper, @whisper.transcribe(@audio, @params) end + + def test_progress_callback + first = nil + last = nil + @params.progress_callback = ->(context, state, progress, user_data) { + assert_kind_of Integer, progress + assert 0 <= progress && progress <= 100 + assert_same @whisper, context + first = progress if first.nil? + last = progress + } + @whisper.transcribe(@audio, @params) + assert_equal 0, first + assert_equal 100, last + end + + def test_progress_callback_user_data + udata = Object.new + @params.progress_callback_user_data = udata + @params.progress_callback = ->(context, state, n_new, user_data) { + assert_same udata, user_data + } + + @whisper.transcribe(@audio, @params) + end + + def test_on_progress + first = nil + last = nil + @params.on_progress do |progress| + assert_kind_of Integer, progress + assert 0 <= progress && progress <= 100 + first = progress if first.nil? + last = progress + end + @whisper.transcribe(@audio, @params) + assert_equal 0, first + assert_equal 100, last + end + + def test_abort_callback + i = 0 + @params.abort_callback = ->(user_data) { + assert_nil user_data + i += 1 + return false + } + @whisper.transcribe(@audio, @params) + assert i > 0 + end + + def test_abort_callback_abort + i = 0 + @params.abort_callback = ->(user_data) { + i += 1 + return i == 3 + } + @whisper.transcribe(@audio, @params) + assert_equal 3, i + end + + def test_abort_callback_user_data + udata = Object.new + @params.abort_callback_user_data = udata + yielded = nil + @params.abort_callback = ->(user_data) { + yielded = user_data + } + @whisper.transcribe(@audio, @params) + assert_same udata, yielded + end + + def test_abort_on + do_abort = false + aborted_from_callback = false + @params.on_new_segment do |segment| + do_abort = true if segment.text.match? /ask/ + end + i = 0 + @params.abort_on do + i += 1 + do_abort + end + @whisper.transcribe(@audio, @params) + assert i > 0 + end end diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb index adaeedfb..3183c295 100644 --- a/bindings/ruby/tests/test_package.rb +++ b/bindings/ruby/tests/test_package.rb @@ -8,6 +8,7 @@ class TestPackage < Test::Unit::TestCase Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) assert file.size > 0 + assert_path_exist file.to_path end end diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb index 4484feee..63860496 100644 --- a/bindings/ruby/tests/test_params.rb +++ b/bindings/ruby/tests/test_params.rb @@ -1,3 +1,4 @@ +require 'test/unit' require 'whisper' class TestParams < Test::Unit::TestCase @@ -109,4 +110,46 @@ class TestParams < Test::Unit::TestCase @params.split_on_word = false assert !@params.split_on_word end + + def test_initial_prompt + assert_nil @params.initial_prompt + @params.initial_prompt = "You are a polite person." + assert_equal "You are a polite person.", @params.initial_prompt + end + + def test_temperature + assert_equal 0.0, @params.temperature + @params.temperature = 0.5 + assert_equal 0.5, @params.temperature + end + + def test_max_initial_ts + assert_equal 1.0, @params.max_initial_ts + @params.max_initial_ts = 600.0 + assert_equal 600.0, @params.max_initial_ts + end + + def test_length_penalty + assert_equal -1.0, @params.length_penalty + @params.length_penalty = 0.5 + assert_equal 0.5, @params.length_penalty + end + + def test_temperature_inc + assert_in_delta 0.2, @params.temperature_inc + @params.temperature_inc = 0.5 + assert_in_delta 0.5, @params.temperature_inc + end + + def test_entropy_thold + assert_in_delta 2.4, @params.entropy_thold + @params.entropy_thold = 3.0 + assert_in_delta 3.0, @params.entropy_thold + end + + def test_logprob_thold + assert_in_delta -1.0, @params.logprob_thold + @params.logprob_thold = -0.5 + assert_in_delta -0.5, @params.logprob_thold + end end