mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-27 04:39:44 +02:00
go : add Encoder Begin Callback (#2900)
Adding in EncoderBeginCallback to the Context's Process callback. This optional callback function returns false if computation should be aborted. Co-authored-by: Amanda Der Bedrosian <aderbedr@gmail.com>
This commit is contained in:
parent
d2aaffd5d9
commit
96db0c5a9c
@ -31,7 +31,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
if err := context.Process(samples, nil, nil); err != nil {
|
if err := context.Process(samples, nil, nil, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
|
|||||||
// Process the data
|
// Process the data
|
||||||
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
||||||
context.ResetTimings()
|
context.ResetTimings()
|
||||||
if err := context.Process(data, cb, nil); err != nil {
|
if err := context.Process(data, nil, cb, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
|
|||||||
// Process new sample data and return any errors
|
// Process new sample data and return any errors
|
||||||
func (context *context) Process(
|
func (context *context) Process(
|
||||||
data []float32,
|
data []float32,
|
||||||
|
callEncoderBegin EncoderBeginCallback,
|
||||||
callNewSegment SegmentCallback,
|
callNewSegment SegmentCallback,
|
||||||
callProgress ProgressCallback,
|
callProgress ProgressCallback,
|
||||||
) error {
|
) error {
|
||||||
@ -203,7 +204,20 @@ func (context *context) Process(
|
|||||||
// We don't do parallel processing at the moment
|
// We don't do parallel processing at the moment
|
||||||
processors := 0
|
processors := 0
|
||||||
if processors > 1 {
|
if processors > 1 {
|
||||||
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
|
||||||
|
func(new int) {
|
||||||
|
if callNewSegment != nil {
|
||||||
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||||
|
s0 := num_segments - new
|
||||||
|
for i := s0; i < num_segments; i++ {
|
||||||
|
callNewSegment(toSegment(context.model.ctx, i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
|
||||||
|
func(new int) {
|
||||||
if callNewSegment != nil {
|
if callNewSegment != nil {
|
||||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||||
s0 := num_segments - new
|
s0 := num_segments - new
|
||||||
@ -211,22 +225,11 @@ func (context *context) Process(
|
|||||||
callNewSegment(toSegment(context.model.ctx, i))
|
callNewSegment(toSegment(context.model.ctx, i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}); err != nil {
|
}, func(progress int) {
|
||||||
return err
|
if callProgress != nil {
|
||||||
}
|
callProgress(progress)
|
||||||
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
|
||||||
if callNewSegment != nil {
|
|
||||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
|
||||||
s0 := num_segments - new
|
|
||||||
for i := s0; i < num_segments; i++ {
|
|
||||||
callNewSegment(toSegment(context.model.ctx, i))
|
|
||||||
}
|
}
|
||||||
}
|
}); err != nil {
|
||||||
}, func(progress int) {
|
|
||||||
if callProgress != nil {
|
|
||||||
callProgress(progress)
|
|
||||||
}
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,6 +88,6 @@ func TestProcess(t *testing.T) {
|
|||||||
context, err := model.NewContext()
|
context, err := model.NewContext()
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
|
|
||||||
err = context.Process(data, nil, nil)
|
err = context.Process(data, nil, nil, nil)
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
|
|||||||
// processing. It is called during the Process function
|
// processing. It is called during the Process function
|
||||||
type ProgressCallback func(int)
|
type ProgressCallback func(int)
|
||||||
|
|
||||||
|
// EncoderBeginCallback is the callback function for checking if we want to
|
||||||
|
// continue processing. It is called during the Process function
|
||||||
|
type EncoderBeginCallback func() bool
|
||||||
|
|
||||||
// Model is the interface to a whisper model. Create a new model with the
|
// Model is the interface to a whisper model. Create a new model with the
|
||||||
// function whisper.New(string)
|
// function whisper.New(string)
|
||||||
type Model interface {
|
type Model interface {
|
||||||
@ -31,7 +35,7 @@ type Model interface {
|
|||||||
Languages() []string
|
Languages() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context is the speach recognition context.
|
// Context is the speech recognition context.
|
||||||
type Context interface {
|
type Context interface {
|
||||||
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
|
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
|
||||||
SetTranslate(bool) // Set translate flag
|
SetTranslate(bool) // Set translate flag
|
||||||
@ -58,7 +62,7 @@ type Context interface {
|
|||||||
// Process mono audio data and return any errors.
|
// Process mono audio data and return any errors.
|
||||||
// If defined, newly generated segments are passed to the
|
// If defined, newly generated segments are passed to the
|
||||||
// callback function during processing.
|
// callback function during processing.
|
||||||
Process([]float32, SegmentCallback, ProgressCallback) error
|
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error
|
||||||
|
|
||||||
// After process is called, return segments until the end of the stream
|
// After process is called, return segments until the end of the stream
|
||||||
// is reached, when io.EOF is returned.
|
// is reached, when io.EOF is returned.
|
||||||
|
Loading…
Reference in New Issue
Block a user