Compare commits

...

6 Commits

Author SHA1 Message Date
53c9a3a984 cmake : remove hardcoded install rpath 2024-12-23 21:22:10 +02:00
ed09075ca0 server : fix help print 2024-12-22 15:32:05 +02:00
f07a81aa9f ruby : bug fix on callbacks and no_speech_prob (#2656)
* Don't generate documentation on test

* Move .startup to TestBase class

* Extract new_segment_callback as a function

* Extract progress_callback as a function

* Extract abort_callback as a function

* Extract register_callbacks as a function

* Call callbacks in Whiser::Context#full and #full_parallel

* Fix README

* Care about the cases content-size is nil and TTY is not available

* Add tests for no_speech_prob

* Add Whisper::Context#full_get_segment_no_speech_prob and Whisper::Segment#no_speech_prob
2024-12-21 21:52:06 +02:00
4183517076 server : add no-speech threshold parameter and functionality (#2654) 2024-12-21 17:00:08 +02:00
f4668169a0 whisper : rename suppress_non_speech_tokens to suppress_nst (#2653) 2024-12-21 12:54:35 +02:00
944ce49439 server : add option to suppress non-speech tokens (#2649)
* The parameter will suppress non-speech tokens like [LAUGH], [SIGH], etc. from the output when enabled.

* add to whisper_params_parse

* add missing param
2024-12-21 12:05:05 +02:00
14 changed files with 217 additions and 150 deletions

View File

@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
}
/** Flag to suppress non-speech tokens. */
public CBool suppress_non_speech_tokens;
public CBool suppress_nst;
/** Flag to suppress non-speech tokens. */
public void suppressNonSpeechTokens(boolean enable) {
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
suppress_nst = enable ? CBool.TRUE : CBool.FALSE;
}
/** Initial decoding temperature. */
@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",

View File

@ -63,7 +63,7 @@ whisper = Whisper::Context.new("base.en")
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
```ruby
puts Whisper::Model.preconverted_model_names
puts Whisper::Model.preconverted_models.keys
# tiny
# tiny.en
# tiny-q5_1
@ -220,7 +220,7 @@ whisper.each_segment do |segment|
end
```
The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
License
-------

View File

@ -53,6 +53,9 @@ static ID id_pre_converted_models;
static bool is_log_callback_finalized = false;
// High level API
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
/*
* call-seq:
* lang_max_id -> Integer
@ -187,6 +190,69 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate(
return container;
}
static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
}
static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
const VALUE progress = INT2NUM(progress_cur);
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, progress);
}
}
static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return false;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
return false;
}
static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
@ -230,8 +296,25 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
return self;
}
// High level API
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->new_segment_callback_container->context = self;
rwp->params.new_segment_callback = new_segment_callback;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
rwp->progress_callback_container->context = self;
rwp->params.progress_callback = progress_callback;
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = self;
rwp->params.abort_callback = abort_callback;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
}
}
/*
* transcribe a single file
@ -353,80 +436,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
};
rwp->new_segment_callback_container->context = &self;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
}
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
const VALUE progress = INT2NUM(progress_cur);
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, progress);
}
};
rwp->progress_callback_container->context = &self;
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->params.abort_callback = [](void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return false;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
return true;
}
}
return false;
};
rwp->abort_callback_container->context = &self;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
}
register_callbacks(rwp, &self);
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
@ -631,6 +641,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) {
return Qnil;
@ -719,6 +730,7 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) {
return Qnil;
@ -823,6 +835,18 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
return rb_str_new2(text);
}
/*
* call-seq:
* full_get_segment_no_speech_prob -> Float
*/
static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
return DBL2NUM(no_speech_prob);
}
/*
* params.language = "auto" | "en", etc...
*
@ -979,19 +1003,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
}
/*
* call-seq:
* suppress_non_speech_tokens = force_suppress -> force_suppress
* suppress_nst = force_suppress -> force_suppress
*/
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_nst, value)
}
/*
* If true, suppresses non-speech-tokens.
*
* call-seq:
* suppress_non_speech_tokens -> bool
* suppress_nst -> bool
*/
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_nst)
}
/*
* If true, enables token-level timestamps.
@ -1547,6 +1571,18 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) {
return rb_str_new2(text);
}
/*
* call-seq:
* no_speech_prob -> Float
*/
static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
}
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
@ -1809,6 +1845,7 @@ void Init_whisper() {
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
rb_define_method(cContext, "full", ruby_whisper_full, -1);
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
@ -1832,8 +1869,8 @@ void Init_whisper() {
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0);
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1);
rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
@ -1887,6 +1924,7 @@ void Init_whisper() {
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
rb_define_alloc_func(cModel, ruby_whisper_model_allocate);

View File

@ -79,30 +79,36 @@ class Whisper::Model
downloaded += chunk.bytesize
show_progress downloaded, size
end
$stderr.puts
end
downloading_path.rename path
end
def show_progress(current, size)
return unless $stderr.tty?
return unless size
progress_rate_available = size && $stderr.tty?
unless @prev
@prev = Time.now
$stderr.puts "Downloading #{@uri}"
$stderr.puts "Downloading #{@uri} to #{cache_path}"
end
now = Time.now
return if now - @prev < 1 && current < size
progress_width = 20
progress = current.to_f / size
arrow_length = progress * progress_width
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
padding = ' ' * ($stderr.winsize[1] - line.size)
$stderr.print "\r#{line}#{padding}"
$stderr.puts if current >= size
if progress_rate_available
return if now - @prev < 1 && current < size
progress_width = 20
progress = current.to_f / size
arrow_length = progress * progress_width
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
padding = ' ' * ($stderr.winsize[1] - line.size)
$stderr.print "\r#{line}#{padding}"
else
return if now - @prev < 1
$stderr.print "."
end
@prev = now
end

View File

@ -4,4 +4,21 @@ require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
end
end
private
def whisper
self.class.whisper
end
end

View File

@ -23,7 +23,7 @@ class TestPackage < TestBase
version = match_data[2]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
end
end

View File

@ -89,11 +89,11 @@ class TestParams < TestBase
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
def test_suppress_nst
@params.suppress_nst = true
assert @params.suppress_nst
@params.suppress_nst = false
assert !@params.suppress_nst
end
def test_token_timestamps

View File

@ -1,17 +1,6 @@
require_relative "helper"
class TestSegment < TestBase
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
end
end
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
@ -43,6 +32,14 @@ class TestSegment < TestBase
end
end
def test_no_speech_prob
no_speech_prob = nil
whisper.each_segment do |segment|
no_speech_prob = segment.no_speech_prob
end
assert no_speech_prob > 0.0
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
@ -74,10 +71,4 @@ class TestSegment < TestBase
end
whisper.transcribe(AUDIO, params)
end
private
def whisper
self.class.whisper
end
end

View File

@ -21,21 +21,6 @@ class TestWhisper < TestBase
end
sub_test_case "After transcription" do
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
end
end
def whisper
self.class.whisper
end
def test_full_n_segments
assert_equal 1, whisper.full_n_segments
end
@ -70,6 +55,12 @@ class TestWhisper < TestBase
def test_full_get_segment_text
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
end
def test_full_get_segment_no_speech_prob
prob = whisper.full_get_segment_no_speech_prob(0)
assert prob > 0.0
assert prob < 1.0
end
end
def test_lang_max_id

View File

@ -13,5 +13,4 @@ set_target_properties(${TARGET}
PROPERTIES
EXPORT_COMPILE_COMMANDS ON
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
)

View File

@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.suppress_non_speech_tokens = true;
wparams.suppress_nst = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
wparams.prompt_tokens = cs.prompt_tokens.data();
wparams.prompt_n_tokens = cs.prompt_tokens.size();
// TODO: properly expose as option
wparams.suppress_non_speech_tokens = true;
wparams.suppress_nst = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {

View File

@ -61,6 +61,7 @@ struct whisper_params {
float logprob_thold = -1.00f;
float temperature = 0.00f;
float temperature_inc = 0.20f;
float no_speech_thold = 0.6f;
bool debug_mode = false;
bool translate = false;
@ -76,6 +77,7 @@ struct whisper_params {
bool no_timestamps = false;
bool use_gpu = true;
bool flash_attn = false;
bool suppress_nst = false;
std::string language = "en";
std::string prompt = "";
@ -134,7 +136,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, "\n");
}
@ -179,6 +183,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@ -472,6 +479,14 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content);
}
if (req.has_file("suppress_non_speech"))
{
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
}
if (req.has_file("suppress_nst"))
{
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
}
}
} // namespace
@ -779,6 +794,7 @@ int main(int argc, char ** argv) {
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature = params.temperature;
wparams.no_speech_thold = params.no_speech_thold;
wparams.temperature_inc = params.temperature_inc;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
@ -786,6 +802,8 @@ int main(int argc, char ** argv) {
wparams.no_timestamps = params.no_timestamps;
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
wparams.suppress_nst = params.suppress_nst;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment
@ -929,7 +947,7 @@ int main(int argc, char ** argv) {
// TODO compression_ratio and no_speech_prob are not implemented yet
// segment["compression_ratio"] = 0;
// segment["no_speech_prob"] = 0;
segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i);
jres["segments"].push_back(segment);
}

View File

@ -522,8 +522,8 @@ extern "C" {
bool detect_language;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_nst; // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
@ -665,6 +665,8 @@ extern "C" {
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
// Get the no_speech probability for the specified segment
WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment);
#ifdef __cplusplus
}
#endif

View File

@ -428,6 +428,7 @@ struct whisper_segment {
int64_t t1;
std::string text;
float no_speech_prob;
std::vector<whisper_token_data> tokens;
@ -4676,7 +4677,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.detect_language =*/ false,
/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,
/*.suppress_nst =*/ false,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
@ -4960,7 +4961,7 @@ static void whisper_process_logits(
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) {
if (params.suppress_nst) {
for (const std::string & token : non_speech_tokens) {
const std::string suppress_tokens[] = {token, " " + token};
for (const std::string & suppress_token : suppress_tokens) {
@ -6147,7 +6148,7 @@ int whisper_full_with_state(
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
@ -6192,7 +6193,7 @@ int whisper_full_with_state(
}
}
result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
@ -6459,6 +6460,10 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
return ctx->state->result_all[i_segment].tokens[i_token].p;
}
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].no_speech_prob;
}
// =================================================================================================
//