mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-10 03:49:32 +02:00
Compare commits
6 Commits
v1.7.4-pre
...
v1.7.4-pre
Author | SHA1 | Date | |
---|---|---|---|
53c9a3a984 | |||
ed09075ca0 | |||
f07a81aa9f | |||
4183517076 | |||
f4668169a0 | |||
944ce49439 |
@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
|
||||
}
|
||||
|
||||
/** Flag to suppress non-speech tokens. */
|
||||
public CBool suppress_non_speech_tokens;
|
||||
public CBool suppress_nst;
|
||||
|
||||
/** Flag to suppress non-speech tokens. */
|
||||
public void suppressNonSpeechTokens(boolean enable) {
|
||||
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
|
||||
suppress_nst = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Initial decoding temperature. */
|
||||
@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
|
||||
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
||||
"suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
|
||||
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
||||
"new_segment_callback", "new_segment_callback_user_data",
|
||||
"progress_callback", "progress_callback_user_data",
|
||||
|
@ -63,7 +63,7 @@ whisper = Whisper::Context.new("base.en")
|
||||
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
|
||||
|
||||
```ruby
|
||||
puts Whisper::Model.preconverted_model_names
|
||||
puts Whisper::Model.preconverted_models.keys
|
||||
# tiny
|
||||
# tiny.en
|
||||
# tiny-q5_1
|
||||
@ -220,7 +220,7 @@ whisper.each_segment do |segment|
|
||||
end
|
||||
```
|
||||
|
||||
The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
|
||||
The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
|
||||
|
||||
License
|
||||
-------
|
||||
|
@ -53,6 +53,9 @@ static ID id_pre_converted_models;
|
||||
|
||||
static bool is_log_callback_finalized = false;
|
||||
|
||||
// High level API
|
||||
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* lang_max_id -> Integer
|
||||
@ -187,6 +190,69 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate(
|
||||
return container;
|
||||
}
|
||||
|
||||
static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
|
||||
// Currently, doesn't support state because
|
||||
// those require to resolve GC-related problems.
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
return;
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments_from_state(state);
|
||||
for (int i = n_new; i > 0; i--) {
|
||||
int i_segment = n_segments - i;
|
||||
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
rb_funcall(cb, id_call, 1, segment);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
const VALUE progress = INT2NUM(progress_cur);
|
||||
// Currently, doesn't support state because
|
||||
// those require to resolve GC-related problems.
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
return;
|
||||
}
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
rb_funcall(cb, id_call, 1, progress);
|
||||
}
|
||||
}
|
||||
|
||||
static bool abort_callback(void * user_data) {
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!NIL_P(container->callback)) {
|
||||
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
return false;
|
||||
}
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_params_allocate(VALUE klass) {
|
||||
ruby_whisper_params *rwp;
|
||||
rwp = ALLOC(ruby_whisper_params);
|
||||
@ -230,8 +296,25 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
// High level API
|
||||
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
|
||||
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
|
||||
rwp->new_segment_callback_container->context = self;
|
||||
rwp->params.new_segment_callback = new_segment_callback;
|
||||
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
|
||||
rwp->progress_callback_container->context = self;
|
||||
rwp->params.progress_callback = progress_callback;
|
||||
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
|
||||
rwp->abort_callback_container->context = self;
|
||||
rwp->params.abort_callback = abort_callback;
|
||||
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
@ -353,80 +436,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
||||
rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
|
||||
rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
|
||||
// Currently, doesn't support state because
|
||||
// those require to resolve GC-related problems.
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
return;
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments_from_state(state);
|
||||
for (int i = n_new; i > 0; i--) {
|
||||
int i_segment = n_segments - i;
|
||||
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
rb_funcall(cb, id_call, 1, segment);
|
||||
}
|
||||
}
|
||||
};
|
||||
rwp->new_segment_callback_container->context = &self;
|
||||
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
|
||||
rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
const VALUE progress = INT2NUM(progress_cur);
|
||||
// Currently, doesn't support state because
|
||||
// those require to resolve GC-related problems.
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
return;
|
||||
}
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
rb_funcall(cb, id_call, 1, progress);
|
||||
}
|
||||
};
|
||||
rwp->progress_callback_container->context = &self;
|
||||
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
|
||||
rwp->params.abort_callback = [](void * user_data) {
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!NIL_P(container->callback)) {
|
||||
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
return false;
|
||||
}
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
rwp->abort_callback_container->context = &self;
|
||||
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
|
||||
}
|
||||
register_callbacks(rwp, &self);
|
||||
|
||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
@ -631,6 +641,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
|
||||
}
|
||||
}
|
||||
}
|
||||
register_callbacks(rwp, &self);
|
||||
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
|
||||
if (0 == result) {
|
||||
return Qnil;
|
||||
@ -719,6 +730,7 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
|
||||
}
|
||||
}
|
||||
}
|
||||
register_callbacks(rwp, &self);
|
||||
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
|
||||
if (0 == result) {
|
||||
return Qnil;
|
||||
@ -823,6 +835,18 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
|
||||
return rb_str_new2(text);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* full_get_segment_no_speech_prob -> Float
|
||||
*/
|
||||
static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
|
||||
return DBL2NUM(no_speech_prob);
|
||||
}
|
||||
|
||||
/*
|
||||
* params.language = "auto" | "en", etc...
|
||||
*
|
||||
@ -979,19 +1003,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
|
||||
}
|
||||
/*
|
||||
* call-seq:
|
||||
* suppress_non_speech_tokens = force_suppress -> force_suppress
|
||||
* suppress_nst = force_suppress -> force_suppress
|
||||
*/
|
||||
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
|
||||
static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, suppress_nst, value)
|
||||
}
|
||||
/*
|
||||
* If true, suppresses non-speech-tokens.
|
||||
*
|
||||
* call-seq:
|
||||
* suppress_non_speech_tokens -> bool
|
||||
* suppress_nst -> bool
|
||||
*/
|
||||
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
|
||||
static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, suppress_nst)
|
||||
}
|
||||
/*
|
||||
* If true, enables token-level timestamps.
|
||||
@ -1547,6 +1571,18 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) {
|
||||
return rb_str_new2(text);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* no_speech_prob -> Float
|
||||
*/
|
||||
static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
|
||||
ruby_whisper_segment *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
|
||||
}
|
||||
|
||||
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
|
||||
rb_gc_mark(rwm->context);
|
||||
}
|
||||
@ -1809,6 +1845,7 @@ void Init_whisper() {
|
||||
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
|
||||
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
|
||||
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
|
||||
rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
|
||||
rb_define_method(cContext, "full", ruby_whisper_full, -1);
|
||||
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
|
||||
|
||||
@ -1832,8 +1869,8 @@ void Init_whisper() {
|
||||
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
|
||||
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
|
||||
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
|
||||
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
|
||||
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
|
||||
rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
|
||||
rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
|
||||
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
|
||||
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
||||
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
||||
@ -1887,6 +1924,7 @@ void Init_whisper() {
|
||||
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
|
||||
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
|
||||
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
|
||||
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
|
||||
|
||||
cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
|
||||
rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
|
||||
|
@ -79,30 +79,36 @@ class Whisper::Model
|
||||
downloaded += chunk.bytesize
|
||||
show_progress downloaded, size
|
||||
end
|
||||
$stderr.puts
|
||||
end
|
||||
downloading_path.rename path
|
||||
end
|
||||
|
||||
def show_progress(current, size)
|
||||
return unless $stderr.tty?
|
||||
return unless size
|
||||
progress_rate_available = size && $stderr.tty?
|
||||
|
||||
unless @prev
|
||||
@prev = Time.now
|
||||
$stderr.puts "Downloading #{@uri}"
|
||||
$stderr.puts "Downloading #{@uri} to #{cache_path}"
|
||||
end
|
||||
|
||||
now = Time.now
|
||||
return if now - @prev < 1 && current < size
|
||||
|
||||
progress_width = 20
|
||||
progress = current.to_f / size
|
||||
arrow_length = progress * progress_width
|
||||
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
|
||||
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
|
||||
padding = ' ' * ($stderr.winsize[1] - line.size)
|
||||
$stderr.print "\r#{line}#{padding}"
|
||||
$stderr.puts if current >= size
|
||||
if progress_rate_available
|
||||
return if now - @prev < 1 && current < size
|
||||
|
||||
progress_width = 20
|
||||
progress = current.to_f / size
|
||||
arrow_length = progress * progress_width
|
||||
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
|
||||
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
|
||||
padding = ' ' * ($stderr.winsize[1] - line.size)
|
||||
$stderr.print "\r#{line}#{padding}"
|
||||
else
|
||||
return if now - @prev < 1
|
||||
|
||||
$stderr.print "."
|
||||
end
|
||||
@prev = now
|
||||
end
|
||||
|
||||
|
@ -4,4 +4,21 @@ require_relative "jfk_reader/jfk_reader"
|
||||
|
||||
class TestBase < Test::Unit::TestCase
|
||||
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
|
||||
|
||||
class << self
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def whisper
|
||||
self.class.whisper
|
||||
end
|
||||
end
|
||||
|
@ -23,7 +23,7 @@ class TestPackage < TestBase
|
||||
version = match_data[2]
|
||||
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
|
||||
Dir.mktmpdir do |dir|
|
||||
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
|
||||
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
|
||||
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
|
||||
end
|
||||
end
|
||||
|
@ -89,11 +89,11 @@ class TestParams < TestBase
|
||||
assert !@params.suppress_blank
|
||||
end
|
||||
|
||||
def test_suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = true
|
||||
assert @params.suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = false
|
||||
assert !@params.suppress_non_speech_tokens
|
||||
def test_suppress_nst
|
||||
@params.suppress_nst = true
|
||||
assert @params.suppress_nst
|
||||
@params.suppress_nst = false
|
||||
assert !@params.suppress_nst
|
||||
end
|
||||
|
||||
def test_token_timestamps
|
||||
|
@ -1,17 +1,6 @@
|
||||
require_relative "helper"
|
||||
|
||||
class TestSegment < TestBase
|
||||
class << self
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
end
|
||||
end
|
||||
|
||||
def test_iteration
|
||||
whisper.each_segment do |segment|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
@ -43,6 +32,14 @@ class TestSegment < TestBase
|
||||
end
|
||||
end
|
||||
|
||||
def test_no_speech_prob
|
||||
no_speech_prob = nil
|
||||
whisper.each_segment do |segment|
|
||||
no_speech_prob = segment.no_speech_prob
|
||||
end
|
||||
assert no_speech_prob > 0.0
|
||||
end
|
||||
|
||||
def test_on_new_segment
|
||||
params = Whisper::Params.new
|
||||
seg = nil
|
||||
@ -74,10 +71,4 @@ class TestSegment < TestBase
|
||||
end
|
||||
whisper.transcribe(AUDIO, params)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def whisper
|
||||
self.class.whisper
|
||||
end
|
||||
end
|
||||
|
@ -21,21 +21,6 @@ class TestWhisper < TestBase
|
||||
end
|
||||
|
||||
sub_test_case "After transcription" do
|
||||
class << self
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
end
|
||||
end
|
||||
|
||||
def whisper
|
||||
self.class.whisper
|
||||
end
|
||||
|
||||
def test_full_n_segments
|
||||
assert_equal 1, whisper.full_n_segments
|
||||
end
|
||||
@ -70,6 +55,12 @@ class TestWhisper < TestBase
|
||||
def test_full_get_segment_text
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
|
||||
end
|
||||
|
||||
def test_full_get_segment_no_speech_prob
|
||||
prob = whisper.full_get_segment_no_speech_prob(0)
|
||||
assert prob > 0.0
|
||||
assert prob < 1.0
|
||||
end
|
||||
end
|
||||
|
||||
def test_lang_max_id
|
||||
|
@ -13,5 +13,4 @@ set_target_properties(${TARGET}
|
||||
PROPERTIES
|
||||
EXPORT_COMPILE_COMMANDS ON
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
||||
)
|
||||
|
@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.suppress_non_speech_tokens = true;
|
||||
wparams.suppress_nst = true;
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
|
||||
wparams.prompt_tokens = cs.prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = cs.prompt_tokens.size();
|
||||
// TODO: properly expose as option
|
||||
wparams.suppress_non_speech_tokens = true;
|
||||
wparams.suppress_nst = true;
|
||||
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
|
@ -61,6 +61,7 @@ struct whisper_params {
|
||||
float logprob_thold = -1.00f;
|
||||
float temperature = 0.00f;
|
||||
float temperature_inc = 0.20f;
|
||||
float no_speech_thold = 0.6f;
|
||||
|
||||
bool debug_mode = false;
|
||||
bool translate = false;
|
||||
@ -76,6 +77,7 @@ struct whisper_params {
|
||||
bool no_timestamps = false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
bool suppress_nst = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt = "";
|
||||
@ -134,7 +136,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
|
||||
fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
|
||||
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
|
||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
|
||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
|
||||
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
|
||||
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -179,6 +183,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
|
||||
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
|
||||
|
||||
// server params
|
||||
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
||||
@ -472,6 +479,14 @@ void get_req_parameters(const Request & req, whisper_params & params)
|
||||
{
|
||||
params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content);
|
||||
}
|
||||
if (req.has_file("suppress_non_speech"))
|
||||
{
|
||||
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
|
||||
}
|
||||
if (req.has_file("suppress_nst"))
|
||||
{
|
||||
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -779,6 +794,7 @@ int main(int argc, char ** argv) {
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.temperature = params.temperature;
|
||||
wparams.no_speech_thold = params.no_speech_thold;
|
||||
wparams.temperature_inc = params.temperature_inc;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
@ -786,6 +802,8 @@ int main(int argc, char ** argv) {
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
|
||||
|
||||
wparams.suppress_nst = params.suppress_nst;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
// this callback is called on each new segment
|
||||
@ -929,7 +947,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// TODO compression_ratio and no_speech_prob are not implemented yet
|
||||
// segment["compression_ratio"] = 0;
|
||||
// segment["no_speech_prob"] = 0;
|
||||
segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i);
|
||||
|
||||
jres["segments"].push_back(segment);
|
||||
}
|
||||
|
@ -522,8 +522,8 @@ extern "C" {
|
||||
bool detect_language;
|
||||
|
||||
// common decoding parameters:
|
||||
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||
bool suppress_nst; // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||
|
||||
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
|
||||
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
||||
@ -665,6 +665,8 @@ extern "C" {
|
||||
|
||||
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
// Get the no_speech probability for the specified segment
|
||||
WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -428,6 +428,7 @@ struct whisper_segment {
|
||||
int64_t t1;
|
||||
|
||||
std::string text;
|
||||
float no_speech_prob;
|
||||
|
||||
std::vector<whisper_token_data> tokens;
|
||||
|
||||
@ -4676,7 +4677,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.detect_language =*/ false,
|
||||
|
||||
/*.suppress_blank =*/ true,
|
||||
/*.suppress_non_speech_tokens =*/ false,
|
||||
/*.suppress_nst =*/ false,
|
||||
|
||||
/*.temperature =*/ 0.0f,
|
||||
/*.max_initial_ts =*/ 1.0f,
|
||||
@ -4960,7 +4961,7 @@ static void whisper_process_logits(
|
||||
|
||||
// suppress non-speech tokens
|
||||
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||
if (params.suppress_non_speech_tokens) {
|
||||
if (params.suppress_nst) {
|
||||
for (const std::string & token : non_speech_tokens) {
|
||||
const std::string suppress_tokens[] = {token, " " + token};
|
||||
for (const std::string & suppress_token : suppress_tokens) {
|
||||
@ -6147,7 +6148,7 @@ int whisper_full_with_state(
|
||||
|
||||
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
|
||||
|
||||
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
|
||||
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
||||
for (int j = i0; j <= i; j++) {
|
||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||
}
|
||||
@ -6192,7 +6193,7 @@ int whisper_full_with_state(
|
||||
}
|
||||
}
|
||||
|
||||
result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
|
||||
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
||||
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||
}
|
||||
@ -6459,6 +6460,10 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
|
||||
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
||||
}
|
||||
|
||||
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
|
||||
return ctx->state->result_all[i_segment].no_speech_prob;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user