ruby : bug fix on callbacks and no_speech_prob (#2656)

* Don't generate documentation on test

* Move .startup to TestBase class

* Extract new_segment_callback as a function

* Extract progress_callback as a function

* Extract abort_callback as a function

* Extract register_callbacks as a function

* Call callbacks in Whiser::Context#full and #full_parallel

* Fix README

* Care about the cases content-size is nil and TTY is not available

* Add tests for no_speech_prob

* Add Whisper::Context#full_get_segment_no_speech_prob and Whisper::Segment#no_speech_prob
This commit is contained in:
KITAITI Makoto 2024-12-22 04:52:06 +09:00 committed by GitHub
parent 4183517076
commit f07a81aa9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 166 additions and 123 deletions

View File

@ -63,7 +63,7 @@ whisper = Whisper::Context.new("base.en")
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
```ruby
puts Whisper::Model.preconverted_model_names
puts Whisper::Model.preconverted_models.keys
# tiny
# tiny.en
# tiny-q5_1
@ -220,7 +220,7 @@ whisper.each_segment do |segment|
end
```
The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
License
-------

View File

@ -53,6 +53,9 @@ static ID id_pre_converted_models;
static bool is_log_callback_finalized = false;
// High level API
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
/*
* call-seq:
* lang_max_id -> Integer
@ -187,6 +190,69 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate(
return container;
}
static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
}
static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
const VALUE progress = INT2NUM(progress_cur);
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, progress);
}
}
static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return false;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
return false;
}
static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
@ -230,8 +296,25 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
return self;
}
// High level API
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->new_segment_callback_container->context = self;
rwp->params.new_segment_callback = new_segment_callback;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
rwp->progress_callback_container->context = self;
rwp->params.progress_callback = progress_callback;
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = self;
rwp->params.abort_callback = abort_callback;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
}
}
/*
* transcribe a single file
@ -353,80 +436,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
};
rwp->new_segment_callback_container->context = &self;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
const VALUE progress = INT2NUM(progress_cur);
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, progress);
}
};
rwp->progress_callback_container->context = &self;
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->params.abort_callback = [](void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return false;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
return false;
};
rwp->abort_callback_container->context = &self;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
}
register_callbacks(rwp, &self);
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
@ -631,6 +641,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) {
return Qnil;
@ -719,6 +730,7 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) {
return Qnil;
@ -823,6 +835,18 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
return rb_str_new2(text);
}
/*
* call-seq:
* full_get_segment_no_speech_prob -> Float
*/
static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
return DBL2NUM(no_speech_prob);
}
/*
* params.language = "auto" | "en", etc...
*
@ -1547,6 +1571,18 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) {
return rb_str_new2(text);
}
/*
* call-seq:
* no_speech_prob -> Float
*/
static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
}
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
@ -1809,6 +1845,7 @@ void Init_whisper() {
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
rb_define_method(cContext, "full", ruby_whisper_full, -1);
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
@ -1887,6 +1924,7 @@ void Init_whisper() {
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
rb_define_alloc_func(cModel, ruby_whisper_model_allocate);

View File

@ -79,30 +79,36 @@ class Whisper::Model
downloaded += chunk.bytesize
show_progress downloaded, size
end
$stderr.puts
end
downloading_path.rename path
end
def show_progress(current, size)
return unless $stderr.tty?
return unless size
progress_rate_available = size && $stderr.tty?
unless @prev
@prev = Time.now
$stderr.puts "Downloading #{@uri}"
$stderr.puts "Downloading #{@uri} to #{cache_path}"
end
now = Time.now
return if now - @prev < 1 && current < size
progress_width = 20
progress = current.to_f / size
arrow_length = progress * progress_width
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
padding = ' ' * ($stderr.winsize[1] - line.size)
$stderr.print "\r#{line}#{padding}"
$stderr.puts if current >= size
if progress_rate_available
return if now - @prev < 1 && current < size
progress_width = 20
progress = current.to_f / size
arrow_length = progress * progress_width
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
padding = ' ' * ($stderr.winsize[1] - line.size)
$stderr.print "\r#{line}#{padding}"
else
return if now - @prev < 1
$stderr.print "."
end
@prev = now
end

View File

@ -4,4 +4,21 @@ require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
end
end
private
def whisper
self.class.whisper
end
end

View File

@ -23,7 +23,7 @@ class TestPackage < TestBase
version = match_data[2]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
end
end

View File

@ -1,17 +1,6 @@
require_relative "helper"
class TestSegment < TestBase
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
end
end
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
@ -43,6 +32,14 @@ class TestSegment < TestBase
end
end
def test_no_speech_prob
no_speech_prob = nil
whisper.each_segment do |segment|
no_speech_prob = segment.no_speech_prob
end
assert no_speech_prob > 0.0
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
@ -74,10 +71,4 @@ class TestSegment < TestBase
end
whisper.transcribe(AUDIO, params)
end
private
def whisper
self.class.whisper
end
end

View File

@ -21,21 +21,6 @@ class TestWhisper < TestBase
end
sub_test_case "After transcription" do
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
end
end
def whisper
self.class.whisper
end
def test_full_n_segments
assert_equal 1, whisper.full_n_segments
end
@ -70,6 +55,12 @@ class TestWhisper < TestBase
def test_full_get_segment_text
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
end
def test_full_get_segment_no_speech_prob
prob = whisper.full_get_segment_no_speech_prob(0)
assert prob > 0.0
assert prob < 1.0
end
end
def test_lang_max_id