mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
ruby : add more APIs (#2518)
* Add test for built package existence * Add more tests for Whisper::Params * Add more Whisper::Params attributes * Add tests for callbacks * Add progress and abort callback features * [skip ci] Add prompt usage in README * Change prompt text in example
This commit is contained in:
parent
fc49ee4479
commit
d4bc413505
@ -31,6 +31,7 @@ params.duration = 60_000
|
|||||||
params.max_text_tokens = 300
|
params.max_text_tokens = 300
|
||||||
params.translate = true
|
params.translate = true
|
||||||
params.print_timestamps = false
|
params.print_timestamps = false
|
||||||
|
params.prompt = "Initial prompt here."
|
||||||
|
|
||||||
whisper.transcribe("path/to/audio.wav", params) do |whole_text|
|
whisper.transcribe("path/to/audio.wav", params) do |whole_text|
|
||||||
puts whole_text
|
puts whole_text
|
||||||
|
@ -107,10 +107,16 @@ void rb_whisper_free(ruby_whisper *rw) {
|
|||||||
free(rw);
|
free(rw);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) {
|
||||||
|
rb_gc_mark(rwc->user_data);
|
||||||
|
rb_gc_mark(rwc->callback);
|
||||||
|
rb_gc_mark(rwc->callbacks);
|
||||||
|
}
|
||||||
|
|
||||||
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
|
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
|
||||||
rb_gc_mark(rwp->new_segment_callback_container->user_data);
|
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
|
||||||
rb_gc_mark(rwp->new_segment_callback_container->callback);
|
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
|
||||||
rb_gc_mark(rwp->new_segment_callback_container->callbacks);
|
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
|
||||||
}
|
}
|
||||||
|
|
||||||
void rb_whisper_params_free(ruby_whisper_params *rwp) {
|
void rb_whisper_params_free(ruby_whisper_params *rwp) {
|
||||||
@ -141,6 +147,8 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
|
|||||||
rwp = ALLOC(ruby_whisper_params);
|
rwp = ALLOC(ruby_whisper_params);
|
||||||
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
||||||
|
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
|
||||||
|
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
|
||||||
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,6 +324,54 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
|||||||
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
|
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
|
||||||
|
rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
|
||||||
|
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||||
|
const VALUE progress = INT2NUM(progress_cur);
|
||||||
|
// Currently, doesn't support state because
|
||||||
|
// those require to resolve GC-related problems.
|
||||||
|
if (!NIL_P(container->callback)) {
|
||||||
|
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
|
||||||
|
}
|
||||||
|
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||||
|
if (0 == callbacks_len) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (int j = 0; j < callbacks_len; j++) {
|
||||||
|
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||||
|
rb_funcall(cb, id_call, 1, progress);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
rwp->progress_callback_container->context = &self;
|
||||||
|
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
|
||||||
|
rwp->params.abort_callback = [](void * user_data) {
|
||||||
|
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||||
|
if (!NIL_P(container->callback)) {
|
||||||
|
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||||
|
if (!NIL_P(result) && Qfalse != result) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||||
|
if (0 == callbacks_len) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int j = 0; j < callbacks_len; j++) {
|
||||||
|
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||||
|
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
|
||||||
|
if (!NIL_P(result) && Qfalse != result) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
rwp->abort_callback_container->context = &self;
|
||||||
|
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
|
||||||
|
}
|
||||||
|
|
||||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
||||||
fprintf(stderr, "failed to process audio\n");
|
fprintf(stderr, "failed to process audio\n");
|
||||||
return self;
|
return self;
|
||||||
@ -631,6 +687,30 @@ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
|
|||||||
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
|
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
|
||||||
BOOL_PARAMS_SETTER(self, split_on_word, value)
|
BOOL_PARAMS_SETTER(self, split_on_word, value)
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
* Tokens to provide to the whisper decoder as initial prompt
|
||||||
|
* these are prepended to any existing text context from a previous call
|
||||||
|
* use whisper_tokenize() to convert text to tokens.
|
||||||
|
* Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* initial_prompt -> String
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return rwp->params.initial_prompt == nullptr ? Qnil : rb_str_new2(rwp->params.initial_prompt);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* initial_prompt = prompt -> prompt
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.initial_prompt = StringValueCStr(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
/*
|
/*
|
||||||
* If true, enables diarization.
|
* If true, enables diarization.
|
||||||
*
|
*
|
||||||
@ -725,6 +805,124 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
|
|||||||
rwp->params.n_max_text_ctx = NUM2INT(value);
|
rwp->params.n_max_text_ctx = NUM2INT(value);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* temperature -> Float
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_get_temperature(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return DBL2NUM(rwp->params.temperature);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* temperature = temp -> temp
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.temperature = RFLOAT_VALUE(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* max_initial_ts -> Flaot
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return DBL2NUM(rwp->params.max_initial_ts);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* max_initial_ts = timestamp -> timestamp
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.max_initial_ts = RFLOAT_VALUE(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* length_penalty -> Float
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_get_length_penalty(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return DBL2NUM(rwp->params.length_penalty);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* length_penalty = penalty -> penalty
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.length_penalty = RFLOAT_VALUE(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* temperature_inc -> Float
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return DBL2NUM(rwp->params.temperature_inc);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* temperature_inc = inc -> inc
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.temperature_inc = RFLOAT_VALUE(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* Similar to OpenAI's "compression_ratio_threshold"
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* entropy_thold -> Float
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return DBL2NUM(rwp->params.entropy_thold);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* entropy_thold = threshold -> threshold
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.entropy_thold = RFLOAT_VALUE(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* logprob_thold -> Float
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return DBL2NUM(rwp->params.logprob_thold);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* call-seq:
|
||||||
|
* logprob_thold = threshold -> threshold
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->params.logprob_thold = RFLOAT_VALUE(value);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
/*
|
/*
|
||||||
* Sets new segment callback, called for every newly generated text segment.
|
* Sets new segment callback, called for every newly generated text segment.
|
||||||
*
|
*
|
||||||
@ -753,6 +951,62 @@ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self,
|
|||||||
rwp->new_segment_callback_container->user_data = value;
|
rwp->new_segment_callback_container->user_data = value;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
* Sets progress callback, called on each progress update.
|
||||||
|
*
|
||||||
|
* params.new_segment_callback = ->(context, _, n_new, user_data) {
|
||||||
|
* # ...
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* progress_callback = callback -> callback
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->progress_callback_container->callback = value;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* Sets user data passed to the last argument of progress callback.
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* progress_callback_user_data = user_data -> use_data
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->progress_callback_container->user_data = value;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* Sets abort callback, called to check if the process should be aborted.
|
||||||
|
*
|
||||||
|
* params.abort_callback = ->(user_data) {
|
||||||
|
* # ...
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* abort_callback = callback -> callback
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->abort_callback_container->callback = value;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* Sets user data passed to the last argument of abort callback.
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* abort_callback_user_data = user_data -> use_data
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) {
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->abort_callback_container->user_data = value;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
// High level API
|
// High level API
|
||||||
|
|
||||||
@ -835,6 +1089,46 @@ static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
|
|||||||
return Qnil;
|
return Qnil;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Hook called on progress update. Yields each progress Integer between 0 and 100.
|
||||||
|
*
|
||||||
|
* whisper.on_progress do |progress|
|
||||||
|
* # ...
|
||||||
|
* end
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* on_progress {|progress| ... }
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_on_progress(VALUE self) {
|
||||||
|
ruby_whisper_params *rws;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rws);
|
||||||
|
const VALUE blk = rb_block_proc();
|
||||||
|
rb_ary_push(rws->progress_callback_container->callbacks, blk);
|
||||||
|
return Qnil;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Call block to determine whether abort or not. Return +true+ when you want to abort.
|
||||||
|
*
|
||||||
|
* params.abort_on do
|
||||||
|
* if some_condition
|
||||||
|
* true # abort
|
||||||
|
* else
|
||||||
|
* false # continue
|
||||||
|
* end
|
||||||
|
* end
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* abort_on { ... }
|
||||||
|
*/
|
||||||
|
static VALUE ruby_whisper_params_abort_on(VALUE self) {
|
||||||
|
ruby_whisper_params *rws;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rws);
|
||||||
|
const VALUE blk = rb_block_proc();
|
||||||
|
rb_ary_push(rws->abort_callback_container->callbacks, blk);
|
||||||
|
return Qnil;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Start time in milliseconds.
|
* Start time in milliseconds.
|
||||||
*
|
*
|
||||||
@ -946,6 +1240,8 @@ void Init_whisper() {
|
|||||||
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
||||||
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
||||||
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
|
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
|
||||||
|
rb_define_method(cParams, "initial_prompt", ruby_whisper_params_get_initial_prompt, 0);
|
||||||
|
rb_define_method(cParams, "initial_prompt=", ruby_whisper_params_set_initial_prompt, 1);
|
||||||
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
|
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
|
||||||
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
|
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
|
||||||
|
|
||||||
@ -956,9 +1252,25 @@ void Init_whisper() {
|
|||||||
|
|
||||||
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
|
rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
|
||||||
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
|
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
|
||||||
|
rb_define_method(cParams, "temperature", ruby_whisper_params_get_temperature, 0);
|
||||||
|
rb_define_method(cParams, "temperature=", ruby_whisper_params_set_temperature, 1);
|
||||||
|
rb_define_method(cParams, "max_initial_ts", ruby_whisper_params_get_max_initial_ts, 0);
|
||||||
|
rb_define_method(cParams, "max_initial_ts=", ruby_whisper_params_set_max_initial_ts, 1);
|
||||||
|
rb_define_method(cParams, "length_penalty", ruby_whisper_params_get_length_penalty, 0);
|
||||||
|
rb_define_method(cParams, "length_penalty=", ruby_whisper_params_set_length_penalty, 1);
|
||||||
|
rb_define_method(cParams, "temperature_inc", ruby_whisper_params_get_temperature_inc, 0);
|
||||||
|
rb_define_method(cParams, "temperature_inc=", ruby_whisper_params_set_temperature_inc, 1);
|
||||||
|
rb_define_method(cParams, "entropy_thold", ruby_whisper_params_get_entropy_thold, 0);
|
||||||
|
rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
|
||||||
|
rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
|
||||||
|
rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
|
||||||
|
|
||||||
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
|
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
|
||||||
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
|
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
|
||||||
|
rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1);
|
||||||
|
rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1);
|
||||||
|
rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
|
||||||
|
rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
|
||||||
|
|
||||||
// High leve
|
// High leve
|
||||||
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
|
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
|
||||||
@ -966,6 +1278,8 @@ void Init_whisper() {
|
|||||||
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
|
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
|
||||||
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
|
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
|
||||||
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
|
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
|
||||||
|
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
|
||||||
|
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
|
||||||
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
|
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
|
||||||
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
|
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
|
||||||
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
|
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
|
||||||
|
@ -18,6 +18,8 @@ typedef struct {
|
|||||||
struct whisper_full_params params;
|
struct whisper_full_params params;
|
||||||
bool diarize;
|
bool diarize;
|
||||||
ruby_whisper_callback_container *new_segment_callback_container;
|
ruby_whisper_callback_container *new_segment_callback_container;
|
||||||
|
ruby_whisper_callback_container *progress_callback_container;
|
||||||
|
ruby_whisper_callback_container *abort_callback_container;
|
||||||
} ruby_whisper_params;
|
} ruby_whisper_params;
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -5,6 +5,7 @@ class TestCallback < Test::Unit::TestCase
|
|||||||
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
|
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
|
||||||
|
|
||||||
def setup
|
def setup
|
||||||
|
GC.start
|
||||||
@params = Whisper::Params.new
|
@params = Whisper::Params.new
|
||||||
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
|
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
|
||||||
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
|
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
|
||||||
@ -73,4 +74,90 @@ class TestCallback < Test::Unit::TestCase
|
|||||||
|
|
||||||
assert_same @whisper, @whisper.transcribe(@audio, @params)
|
assert_same @whisper, @whisper.transcribe(@audio, @params)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def test_progress_callback
|
||||||
|
first = nil
|
||||||
|
last = nil
|
||||||
|
@params.progress_callback = ->(context, state, progress, user_data) {
|
||||||
|
assert_kind_of Integer, progress
|
||||||
|
assert 0 <= progress && progress <= 100
|
||||||
|
assert_same @whisper, context
|
||||||
|
first = progress if first.nil?
|
||||||
|
last = progress
|
||||||
|
}
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert_equal 0, first
|
||||||
|
assert_equal 100, last
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_progress_callback_user_data
|
||||||
|
udata = Object.new
|
||||||
|
@params.progress_callback_user_data = udata
|
||||||
|
@params.progress_callback = ->(context, state, n_new, user_data) {
|
||||||
|
assert_same udata, user_data
|
||||||
|
}
|
||||||
|
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_on_progress
|
||||||
|
first = nil
|
||||||
|
last = nil
|
||||||
|
@params.on_progress do |progress|
|
||||||
|
assert_kind_of Integer, progress
|
||||||
|
assert 0 <= progress && progress <= 100
|
||||||
|
first = progress if first.nil?
|
||||||
|
last = progress
|
||||||
|
end
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert_equal 0, first
|
||||||
|
assert_equal 100, last
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_abort_callback
|
||||||
|
i = 0
|
||||||
|
@params.abort_callback = ->(user_data) {
|
||||||
|
assert_nil user_data
|
||||||
|
i += 1
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert i > 0
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_abort_callback_abort
|
||||||
|
i = 0
|
||||||
|
@params.abort_callback = ->(user_data) {
|
||||||
|
i += 1
|
||||||
|
return i == 3
|
||||||
|
}
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert_equal 3, i
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_abort_callback_user_data
|
||||||
|
udata = Object.new
|
||||||
|
@params.abort_callback_user_data = udata
|
||||||
|
yielded = nil
|
||||||
|
@params.abort_callback = ->(user_data) {
|
||||||
|
yielded = user_data
|
||||||
|
}
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert_same udata, yielded
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_abort_on
|
||||||
|
do_abort = false
|
||||||
|
aborted_from_callback = false
|
||||||
|
@params.on_new_segment do |segment|
|
||||||
|
do_abort = true if segment.text.match? /ask/
|
||||||
|
end
|
||||||
|
i = 0
|
||||||
|
@params.abort_on do
|
||||||
|
i += 1
|
||||||
|
do_abort
|
||||||
|
end
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert i > 0
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
@ -8,6 +8,7 @@ class TestPackage < Test::Unit::TestCase
|
|||||||
Tempfile.create do |file|
|
Tempfile.create do |file|
|
||||||
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
|
||||||
assert file.size > 0
|
assert file.size > 0
|
||||||
|
assert_path_exist file.to_path
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
require 'test/unit'
|
||||||
require 'whisper'
|
require 'whisper'
|
||||||
|
|
||||||
class TestParams < Test::Unit::TestCase
|
class TestParams < Test::Unit::TestCase
|
||||||
@ -109,4 +110,46 @@ class TestParams < Test::Unit::TestCase
|
|||||||
@params.split_on_word = false
|
@params.split_on_word = false
|
||||||
assert !@params.split_on_word
|
assert !@params.split_on_word
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def test_initial_prompt
|
||||||
|
assert_nil @params.initial_prompt
|
||||||
|
@params.initial_prompt = "You are a polite person."
|
||||||
|
assert_equal "You are a polite person.", @params.initial_prompt
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_temperature
|
||||||
|
assert_equal 0.0, @params.temperature
|
||||||
|
@params.temperature = 0.5
|
||||||
|
assert_equal 0.5, @params.temperature
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_max_initial_ts
|
||||||
|
assert_equal 1.0, @params.max_initial_ts
|
||||||
|
@params.max_initial_ts = 600.0
|
||||||
|
assert_equal 600.0, @params.max_initial_ts
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_length_penalty
|
||||||
|
assert_equal -1.0, @params.length_penalty
|
||||||
|
@params.length_penalty = 0.5
|
||||||
|
assert_equal 0.5, @params.length_penalty
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_temperature_inc
|
||||||
|
assert_in_delta 0.2, @params.temperature_inc
|
||||||
|
@params.temperature_inc = 0.5
|
||||||
|
assert_in_delta 0.5, @params.temperature_inc
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_entropy_thold
|
||||||
|
assert_in_delta 2.4, @params.entropy_thold
|
||||||
|
@params.entropy_thold = 3.0
|
||||||
|
assert_in_delta 3.0, @params.entropy_thold
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_logprob_thold
|
||||||
|
assert_in_delta -1.0, @params.logprob_thold
|
||||||
|
@params.logprob_thold = -0.5
|
||||||
|
assert_in_delta -0.5, @params.logprob_thold
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user