diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 2b6a9c8e..579c8ecd 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -27,6 +27,7 @@ type Model interface { // Return a new speech-to-text context. NewContext() (Context, error) + NewContextWithStrategy(SamplingStrategy) (Context, error) // Return true if the model is multilingual. IsMultilingual() bool diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 68a15022..3e9a6c18 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -20,6 +20,13 @@ type model struct { // Make sure model adheres to the interface var _ Model = (*model)(nil) +type SamplingStrategy whisper.SamplingStrategy + +const ( + SAMPLING_GREEDY SamplingStrategy = (SamplingStrategy)(whisper.SAMPLING_GREEDY) + SAMPLING_BEAM_SEARCH SamplingStrategy = (SamplingStrategy)(whisper.SAMPLING_BEAM_SEARCH) +) + /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE @@ -82,12 +89,17 @@ func (model *model) Languages() []string { } func (model *model) NewContext() (Context, error) { + // By default, specify the greedy strategy + return model.NewContextWithStrategy(SAMPLING_GREEDY) +} + +func (model *model) NewContextWithStrategy(strategy SamplingStrategy) (Context, error) { if model.ctx == nil { return nil, ErrInternalAppError } // Create new context - params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) + params := model.ctx.Whisper_full_default_params((whisper.SamplingStrategy)(strategy)) params.SetTranslate(false) params.SetPrintSpecial(false) params.SetPrintProgress(false)