Amanda Der Bedrosian Amanda Der Bedrosian commited on
Commit
6eea7b7
·
unverified ·
1 Parent(s): ab959d1

go : add Encoder Begin Callback (#2900)

Browse files

Adding in EncoderBeginCallback to the Context's Process callback.
This optional callback function returns false if computation should
be aborted.

Co-authored-by: Amanda Der Bedrosian <[email protected]>

bindings/go/README.md CHANGED
@@ -31,7 +31,7 @@ func main() {
31
  if err != nil {
32
  panic(err)
33
  }
34
- if err := context.Process(samples, nil, nil); err != nil {
35
  return err
36
  }
37
 
 
31
  if err != nil {
32
  panic(err)
33
  }
34
+ if err := context.Process(samples, nil, nil, nil); err != nil {
35
  return err
36
  }
37
 
bindings/go/examples/go-whisper/process.go CHANGED
@@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
67
  // Process the data
68
  fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
69
  context.ResetTimings()
70
- if err := context.Process(data, cb, nil); err != nil {
71
  return err
72
  }
73
 
 
67
  // Process the data
68
  fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
69
  context.ResetTimings()
70
+ if err := context.Process(data, nil, cb, nil); err != nil {
71
  return err
72
  }
73
 
bindings/go/pkg/whisper/context.go CHANGED
@@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
189
  // Process new sample data and return any errors
190
  func (context *context) Process(
191
  data []float32,
 
192
  callNewSegment SegmentCallback,
193
  callProgress ProgressCallback,
194
  ) error {
@@ -203,7 +204,20 @@ func (context *context) Process(
203
  // We don't do parallel processing at the moment
204
  processors := 0
205
  if processors > 1 {
206
- if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if callNewSegment != nil {
208
  num_segments := context.model.ctx.Whisper_full_n_segments()
209
  s0 := num_segments - new
@@ -211,22 +225,11 @@ func (context *context) Process(
211
  callNewSegment(toSegment(context.model.ctx, i))
212
  }
213
  }
214
- }); err != nil {
215
- return err
216
- }
217
- } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
218
- if callNewSegment != nil {
219
- num_segments := context.model.ctx.Whisper_full_n_segments()
220
- s0 := num_segments - new
221
- for i := s0; i < num_segments; i++ {
222
- callNewSegment(toSegment(context.model.ctx, i))
223
  }
224
- }
225
- }, func(progress int) {
226
- if callProgress != nil {
227
- callProgress(progress)
228
- }
229
- }); err != nil {
230
  return err
231
  }
232
 
 
189
  // Process new sample data and return any errors
190
  func (context *context) Process(
191
  data []float32,
192
+ callEncoderBegin EncoderBeginCallback,
193
  callNewSegment SegmentCallback,
194
  callProgress ProgressCallback,
195
  ) error {
 
204
  // We don't do parallel processing at the moment
205
  processors := 0
206
  if processors > 1 {
207
+ if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
208
+ func(new int) {
209
+ if callNewSegment != nil {
210
+ num_segments := context.model.ctx.Whisper_full_n_segments()
211
+ s0 := num_segments - new
212
+ for i := s0; i < num_segments; i++ {
213
+ callNewSegment(toSegment(context.model.ctx, i))
214
+ }
215
+ }
216
+ }); err != nil {
217
+ return err
218
+ }
219
+ } else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
220
+ func(new int) {
221
  if callNewSegment != nil {
222
  num_segments := context.model.ctx.Whisper_full_n_segments()
223
  s0 := num_segments - new
 
225
  callNewSegment(toSegment(context.model.ctx, i))
226
  }
227
  }
228
+ }, func(progress int) {
229
+ if callProgress != nil {
230
+ callProgress(progress)
 
 
 
 
 
 
231
  }
232
+ }); err != nil {
 
 
 
 
 
233
  return err
234
  }
235
 
bindings/go/pkg/whisper/context_test.go CHANGED
@@ -88,6 +88,6 @@ func TestProcess(t *testing.T) {
88
  context, err := model.NewContext()
89
  assert.NoError(err)
90
 
91
- err = context.Process(data, nil, nil)
92
  assert.NoError(err)
93
  }
 
88
  context, err := model.NewContext()
89
  assert.NoError(err)
90
 
91
+ err = context.Process(data, nil, nil, nil)
92
  assert.NoError(err)
93
  }
bindings/go/pkg/whisper/interface.go CHANGED
@@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
16
  // processing. It is called during the Process function
17
  type ProgressCallback func(int)
18
 
 
 
 
 
19
  // Model is the interface to a whisper model. Create a new model with the
20
  // function whisper.New(string)
21
  type Model interface {
@@ -31,7 +35,7 @@ type Model interface {
31
  Languages() []string
32
  }
33
 
34
- // Context is the speach recognition context.
35
  type Context interface {
36
  SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
37
  SetTranslate(bool) // Set translate flag
@@ -58,7 +62,7 @@ type Context interface {
58
  // Process mono audio data and return any errors.
59
  // If defined, newly generated segments are passed to the
60
  // callback function during processing.
61
- Process([]float32, SegmentCallback, ProgressCallback) error
62
 
63
  // After process is called, return segments until the end of the stream
64
  // is reached, when io.EOF is returned.
 
16
  // processing. It is called during the Process function
17
  type ProgressCallback func(int)
18
 
19
+ // EncoderBeginCallback is the callback function for checking if we want to
20
+ // continue processing. It is called during the Process function
21
+ type EncoderBeginCallback func() bool
22
+
23
  // Model is the interface to a whisper model. Create a new model with the
24
  // function whisper.New(string)
25
  type Model interface {
 
35
  Languages() []string
36
  }
37
 
38
+ // Context is the speech recognition context.
39
  type Context interface {
40
  SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
41
  SetTranslate(bool) // Set translate flag
 
62
  // Process mono audio data and return any errors.
63
  // If defined, newly generated segments are passed to the
64
  // callback function during processing.
65
+ Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error
66
 
67
  // After process is called, return segments until the end of the stream
68
  // is reached, when io.EOF is returned.