Spaces:
Running
Running
Amanda Der Bedrosian
Amanda Der Bedrosian
commited on
go : add Encoder Begin Callback (#2900)
Browse filesAdding in EncoderBeginCallback to the Context's Process callback.
This optional callback function returns false if computation should
be aborted.
Co-authored-by: Amanda Der Bedrosian <[email protected]>
bindings/go/README.md
CHANGED
|
@@ -31,7 +31,7 @@ func main() {
|
|
| 31 |
if err != nil {
|
| 32 |
panic(err)
|
| 33 |
}
|
| 34 |
-
if err := context.Process(samples, nil, nil); err != nil {
|
| 35 |
return err
|
| 36 |
}
|
| 37 |
|
|
|
|
| 31 |
if err != nil {
|
| 32 |
panic(err)
|
| 33 |
}
|
| 34 |
+
if err := context.Process(samples, nil, nil, nil); err != nil {
|
| 35 |
return err
|
| 36 |
}
|
| 37 |
|
bindings/go/examples/go-whisper/process.go
CHANGED
|
@@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
|
|
| 67 |
// Process the data
|
| 68 |
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
| 69 |
context.ResetTimings()
|
| 70 |
-
if err := context.Process(data, cb, nil); err != nil {
|
| 71 |
return err
|
| 72 |
}
|
| 73 |
|
|
|
|
| 67 |
// Process the data
|
| 68 |
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
| 69 |
context.ResetTimings()
|
| 70 |
+
if err := context.Process(data, nil, cb, nil); err != nil {
|
| 71 |
return err
|
| 72 |
}
|
| 73 |
|
bindings/go/pkg/whisper/context.go
CHANGED
|
@@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
|
|
| 189 |
// Process new sample data and return any errors
|
| 190 |
func (context *context) Process(
|
| 191 |
data []float32,
|
|
|
|
| 192 |
callNewSegment SegmentCallback,
|
| 193 |
callProgress ProgressCallback,
|
| 194 |
) error {
|
|
@@ -203,7 +204,20 @@ func (context *context) Process(
|
|
| 203 |
// We don't do parallel processing at the moment
|
| 204 |
processors := 0
|
| 205 |
if processors > 1 {
|
| 206 |
-
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
if callNewSegment != nil {
|
| 208 |
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 209 |
s0 := num_segments - new
|
|
@@ -211,22 +225,11 @@ func (context *context) Process(
|
|
| 211 |
callNewSegment(toSegment(context.model.ctx, i))
|
| 212 |
}
|
| 213 |
}
|
| 214 |
-
}
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
| 218 |
-
if callNewSegment != nil {
|
| 219 |
-
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 220 |
-
s0 := num_segments - new
|
| 221 |
-
for i := s0; i < num_segments; i++ {
|
| 222 |
-
callNewSegment(toSegment(context.model.ctx, i))
|
| 223 |
}
|
| 224 |
-
}
|
| 225 |
-
}, func(progress int) {
|
| 226 |
-
if callProgress != nil {
|
| 227 |
-
callProgress(progress)
|
| 228 |
-
}
|
| 229 |
-
}); err != nil {
|
| 230 |
return err
|
| 231 |
}
|
| 232 |
|
|
|
|
| 189 |
// Process new sample data and return any errors
|
| 190 |
func (context *context) Process(
|
| 191 |
data []float32,
|
| 192 |
+
callEncoderBegin EncoderBeginCallback,
|
| 193 |
callNewSegment SegmentCallback,
|
| 194 |
callProgress ProgressCallback,
|
| 195 |
) error {
|
|
|
|
| 204 |
// We don't do parallel processing at the moment
|
| 205 |
processors := 0
|
| 206 |
if processors > 1 {
|
| 207 |
+
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
|
| 208 |
+
func(new int) {
|
| 209 |
+
if callNewSegment != nil {
|
| 210 |
+
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 211 |
+
s0 := num_segments - new
|
| 212 |
+
for i := s0; i < num_segments; i++ {
|
| 213 |
+
callNewSegment(toSegment(context.model.ctx, i))
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
}); err != nil {
|
| 217 |
+
return err
|
| 218 |
+
}
|
| 219 |
+
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
|
| 220 |
+
func(new int) {
|
| 221 |
if callNewSegment != nil {
|
| 222 |
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 223 |
s0 := num_segments - new
|
|
|
|
| 225 |
callNewSegment(toSegment(context.model.ctx, i))
|
| 226 |
}
|
| 227 |
}
|
| 228 |
+
}, func(progress int) {
|
| 229 |
+
if callProgress != nil {
|
| 230 |
+
callProgress(progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
}
|
| 232 |
+
}); err != nil {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
return err
|
| 234 |
}
|
| 235 |
|
bindings/go/pkg/whisper/context_test.go
CHANGED
|
@@ -88,6 +88,6 @@ func TestProcess(t *testing.T) {
|
|
| 88 |
context, err := model.NewContext()
|
| 89 |
assert.NoError(err)
|
| 90 |
|
| 91 |
-
err = context.Process(data, nil, nil)
|
| 92 |
assert.NoError(err)
|
| 93 |
}
|
|
|
|
| 88 |
context, err := model.NewContext()
|
| 89 |
assert.NoError(err)
|
| 90 |
|
| 91 |
+
err = context.Process(data, nil, nil, nil)
|
| 92 |
assert.NoError(err)
|
| 93 |
}
|
bindings/go/pkg/whisper/interface.go
CHANGED
|
@@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
|
|
| 16 |
// processing. It is called during the Process function
|
| 17 |
type ProgressCallback func(int)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
// Model is the interface to a whisper model. Create a new model with the
|
| 20 |
// function whisper.New(string)
|
| 21 |
type Model interface {
|
|
@@ -31,7 +35,7 @@ type Model interface {
|
|
| 31 |
Languages() []string
|
| 32 |
}
|
| 33 |
|
| 34 |
-
// Context is the
|
| 35 |
type Context interface {
|
| 36 |
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
|
| 37 |
SetTranslate(bool) // Set translate flag
|
|
@@ -58,7 +62,7 @@ type Context interface {
|
|
| 58 |
// Process mono audio data and return any errors.
|
| 59 |
// If defined, newly generated segments are passed to the
|
| 60 |
// callback function during processing.
|
| 61 |
-
Process([]float32, SegmentCallback, ProgressCallback) error
|
| 62 |
|
| 63 |
// After process is called, return segments until the end of the stream
|
| 64 |
// is reached, when io.EOF is returned.
|
|
|
|
| 16 |
// processing. It is called during the Process function
|
| 17 |
type ProgressCallback func(int)
|
| 18 |
|
| 19 |
+
// EncoderBeginCallback is the callback function for checking if we want to
|
| 20 |
+
// continue processing. It is called during the Process function
|
| 21 |
+
type EncoderBeginCallback func() bool
|
| 22 |
+
|
| 23 |
// Model is the interface to a whisper model. Create a new model with the
|
| 24 |
// function whisper.New(string)
|
| 25 |
type Model interface {
|
|
|
|
| 35 |
Languages() []string
|
| 36 |
}
|
| 37 |
|
| 38 |
+
// Context is the speech recognition context.
|
| 39 |
type Context interface {
|
| 40 |
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
|
| 41 |
SetTranslate(bool) // Set translate flag
|
|
|
|
| 62 |
// Process mono audio data and return any errors.
|
| 63 |
// If defined, newly generated segments are passed to the
|
| 64 |
// callback function during processing.
|
| 65 |
+
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error
|
| 66 |
|
| 67 |
// After process is called, return segments until the end of the stream
|
| 68 |
// is reached, when io.EOF is returned.
|