From 7dfc11843c8f9b1951360895a934c7188c15919a Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 25 Jun 2023 19:07:55 +0800 Subject: [PATCH] go : improve progress reporting and callback handling (#1024) - Rename `cb` to `callNewSegment` in the `Process` function - Add `callProgress` as a new parameter to the `Process` function - Introduce `ProgressCallback` type for reporting progress during processing - Update `Whisper_full` function to include `progressCallback` parameter - Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks Signed-off-by: appleboy --- bindings/go/Makefile | 2 +- bindings/go/pkg/whisper/context.go | 18 +++++++++---- bindings/go/pkg/whisper/interface.go | 6 ++++- bindings/go/whisper.go | 38 +++++++++++++++++++++++++++- bindings/go/whisper_test.go | 2 +- 5 files changed, 57 insertions(+), 9 deletions(-) diff --git a/bindings/go/Makefile b/bindings/go/Makefile index 6be29799..74118262 100644 --- a/bindings/go/Makefile +++ b/bindings/go/Makefile @@ -32,7 +32,7 @@ mkdir: modtidy: @go mod tidy -clean: +clean: @echo Clean @rm -fr $(BUILD_DIR) @go clean diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 593b32b3..e193832e 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -152,7 +152,11 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f } // Process new sample data and return any errors -func (context *context) Process(data []float32, cb SegmentCallback) error { +func (context *context) Process( + data []float32, + callNewSegment SegmentCallback, + callProgress ProgressCallback, +) error { if context.model.ctx == nil { return ErrInternalAppError } @@ -165,24 +169,28 @@ func (context *context) Process(data []float32, cb SegmentCallback) error { processors := 0 if processors > 1 { if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { - if cb != nil { + if callNewSegment != nil { num_segments := context.model.ctx.Whisper_full_n_segments() s0 := num_segments - new for i := s0; i < num_segments; i++ { - cb(toSegment(context.model.ctx, i)) + callNewSegment(toSegment(context.model.ctx, i)) } } }); err != nil { return err } } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { - if cb != nil { + if callNewSegment != nil { num_segments := context.model.ctx.Whisper_full_n_segments() s0 := num_segments - new for i := s0; i < num_segments; i++ { - cb(toSegment(context.model.ctx, i)) + callNewSegment(toSegment(context.model.ctx, i)) } } + }, func(progress int) { + if callProgress != nil { + callProgress(progress) + } }); err != nil { return err } diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index e65fed17..dc9c66df 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -12,6 +12,10 @@ import ( // time. It is called during the Process function type SegmentCallback func(Segment) +// ProgressCallback is the callback function for reporting progress during +// processing. It is called during the Process function +type ProgressCallback func(int) + // Model is the interface to a whisper model. Create a new model with the // function whisper.New(string) type Model interface { @@ -47,7 +51,7 @@ type Context interface { // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the // callback function during processing. - Process([]float32, SegmentCallback) error + Process([]float32, SegmentCallback, ProgressCallback) error // After process is called, return segments until the end of the stream // is reached, when io.EOF is returned. diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index babadf00..451f3f8d 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -15,6 +15,7 @@ import ( #include extern void callNewSegment(void* user_data, int new); +extern void callProgress(void* user_data, int progress); extern bool callEncoderBegin(void* user_data); // Text segment callback @@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s } } +// Progress callback +// Called on every newly generated text segment +// Use the whisper_full_...() functions to obtain the text segments +static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { + if(user_data != NULL && ctx != NULL) { + callProgress(user_data, progress); + } +} + // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted @@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_ params.new_segment_callback_user_data = (void*)(ctx); params.encoder_begin_callback = whisper_encoder_begin_cb; params.encoder_begin_callback_user_data = (void*)(ctx); + params.progress_callback = whisper_progress_cb; + params.progress_callback_user_data = (void*)(ctx); return params; } */ @@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Uses the specified decoding strategy to obtain the text. -func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error { +func (ctx *Context) Whisper_full( + params Params, + samples []float32, + encoderBeginCallback func() bool, + newSegmentCallback func(int), + progressCallback func(int), +) error { registerEncoderBeginCallback(ctx, encoderBeginCallback) registerNewSegmentCallback(ctx, newSegmentCallback) + registerProgressCallback(ctx, progressCallback) defer registerEncoderBeginCallback(ctx, nil) defer registerNewSegmentCallback(ctx, nil) + defer registerProgressCallback(ctx, nil) if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 { return nil } else { @@ -370,6 +390,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 { var ( cbNewSegment = make(map[unsafe.Pointer]func(int)) + cbProgress = make(map[unsafe.Pointer]func(int)) cbEncoderBegin = make(map[unsafe.Pointer]func() bool) ) @@ -381,6 +402,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) { } } +func registerProgressCallback(ctx *Context, fn func(int)) { + if fn == nil { + delete(cbProgress, unsafe.Pointer(ctx)) + } else { + cbProgress[unsafe.Pointer(ctx)] = fn + } +} + func registerEncoderBeginCallback(ctx *Context, fn func() bool) { if fn == nil { delete(cbEncoderBegin, unsafe.Pointer(ctx)) @@ -396,6 +425,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) { } } +//export callProgress +func callProgress(user_data unsafe.Pointer, progress C.int) { + if fn, ok := cbProgress[user_data]; ok { + fn(int(progress)) + } +} + //export callEncoderBegin func callEncoderBegin(user_data unsafe.Pointer) C.bool { if fn, ok := cbEncoderBegin[user_data]; ok { diff --git a/bindings/go/whisper_test.go b/bindings/go/whisper_test.go index 2c95c81f..40648ffa 100644 --- a/bindings/go/whisper_test.go +++ b/bindings/go/whisper_test.go @@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) { defer ctx.Whisper_free() params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) data := buf.AsFloat32Buffer().Data - err = ctx.Whisper_full(params, data, nil, nil) + err = ctx.Whisper_full(params, data, nil, nil, nil) assert.NoError(err) // Print out tokens