Spaces:
Running
Running
bindings : initial import of golang bindings (#287)
Browse files* Initial import of golang bindings
* Updated makefile rules
* Updated bindings
* Makefile update to add in more tests
- bindings/go/.gitignore +3 -0
- bindings/go/LICENSE +21 -0
- bindings/go/Makefile +38 -0
- bindings/go/README.md +77 -0
- bindings/go/doc.go +5 -0
- bindings/go/examples/go-model-download/context.go +30 -0
- bindings/go/examples/go-model-download/main.go +206 -0
- bindings/go/examples/go-whisper/flags.go +61 -0
- bindings/go/examples/go-whisper/main.go +44 -0
- bindings/go/examples/go-whisper/process.go +80 -0
- bindings/go/go.mod +16 -0
- bindings/go/params.go +134 -0
- bindings/go/pkg/whisper/consts.go +27 -0
- bindings/go/pkg/whisper/context.go +145 -0
- bindings/go/pkg/whisper/context_test.go +55 -0
- bindings/go/pkg/whisper/doc.go +4 -0
- bindings/go/pkg/whisper/interface.go +63 -0
- bindings/go/pkg/whisper/model.go +95 -0
- bindings/go/samples/jfk.wav +3 -0
- bindings/go/whisper.go +412 -0
- bindings/go/whisper_test.go +110 -0
bindings/go/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build
|
| 2 |
+
models
|
| 3 |
+
go.sum
|
bindings/go/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 David Thorpe
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
bindings/go/Makefile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CMAKE := $(shell which cmake)
|
| 2 |
+
BUILD_DIR := "build"
|
| 3 |
+
MODELS_DIR := "models"
|
| 4 |
+
EXAMPLES_DIR := $(wildcard examples/*)
|
| 5 |
+
C_INCLUDE_PATH := "../.."
|
| 6 |
+
|
| 7 |
+
all: clean whisper examples
|
| 8 |
+
|
| 9 |
+
whisper: mkdir
|
| 10 |
+
@echo Build whisper
|
| 11 |
+
@${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
|
| 12 |
+
@${CMAKE} --build ${BUILD_DIR} --target whisper
|
| 13 |
+
|
| 14 |
+
test: model-small whisper
|
| 15 |
+
@go mod tidy
|
| 16 |
+
@go test -v .
|
| 17 |
+
@go test -v ./pkg/whisper/...
|
| 18 |
+
|
| 19 |
+
examples: $(EXAMPLES_DIR)
|
| 20 |
+
|
| 21 |
+
model-small: mkdir examples/go-model-download
|
| 22 |
+
@${BUILD_DIR}/go-model-download -out models small.en
|
| 23 |
+
|
| 24 |
+
$(EXAMPLES_DIR): mkdir whisper
|
| 25 |
+
@echo Build example $(notdir $@)
|
| 26 |
+
@go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
| 27 |
+
|
| 28 |
+
mkdir:
|
| 29 |
+
@echo Mkdir ${BUILD_DIR}
|
| 30 |
+
@install -d ${BUILD_DIR}
|
| 31 |
+
@echo Mkdir ${MODELS_DIR}
|
| 32 |
+
@install -d ${MODELS_DIR}
|
| 33 |
+
|
| 34 |
+
clean:
|
| 35 |
+
@echo Clean
|
| 36 |
+
@rm -fr $(BUILD_DIR)
|
| 37 |
+
@go mod tidy
|
| 38 |
+
@go clean
|
bindings/go/README.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Go bindings for Whisper
|
| 2 |
+
|
| 3 |
+
This package provides Go bindings for whisper.cpp. They have been tested on:
|
| 4 |
+
|
| 5 |
+
* Darwin (OS X) 12.6 on x64_64
|
| 6 |
+
* Debian Linux on arm64
|
| 7 |
+
* Fedora Linux on x86_64
|
| 8 |
+
|
| 9 |
+
The "low level" bindings are in the `bindings/go` directory and there is a more
|
| 10 |
+
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
|
| 11 |
+
is as follows:
|
| 12 |
+
|
| 13 |
+
```go
|
| 14 |
+
import (
|
| 15 |
+
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
func main() {
|
| 19 |
+
var modelpath string // Path to the model
|
| 20 |
+
var samples []float32 // Samples to process
|
| 21 |
+
|
| 22 |
+
// Load the model
|
| 23 |
+
model, err := whisper.New(modelpath)
|
| 24 |
+
if err != nil {
|
| 25 |
+
panic(err)
|
| 26 |
+
}
|
| 27 |
+
defer model.Close()
|
| 28 |
+
|
| 29 |
+
// Process samples
|
| 30 |
+
context, err := model.NewContext()
|
| 31 |
+
if err != nil {
|
| 32 |
+
panic(err)
|
| 33 |
+
}
|
| 34 |
+
if err := context.Process(samples, nil); err != nil {
|
| 35 |
+
return err
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// Print out the results
|
| 39 |
+
for {
|
| 40 |
+
segment, err := context.NextSegment()
|
| 41 |
+
if err != nil {
|
| 42 |
+
break
|
| 43 |
+
}
|
| 44 |
+
fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Building & Testing
|
| 50 |
+
|
| 51 |
+
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
git clone https://github.com/ggerganov/whisper.cpp.git
|
| 55 |
+
cd whisper.cpp/bindings/go
|
| 56 |
+
make test
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
make examples
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
./build/go-model-download -out models
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
And you can then test a model against samples with the following command:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
|
bindings/go/doc.go
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
github.com/ggerganov/whisper.cpp/bindings/go
|
| 3 |
+
provides a speech-to-text service bindings for the Go programming language.
|
| 4 |
+
*/
|
| 5 |
+
package whisper
|
bindings/go/examples/go-model-download/context.go
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"os"
|
| 6 |
+
"os/signal"
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
// ContextForSignal returns a context object which is cancelled when a signal
|
| 10 |
+
// is received. It returns nil if no signal parameter is provided
|
| 11 |
+
func ContextForSignal(signals ...os.Signal) context.Context {
|
| 12 |
+
if len(signals) == 0 {
|
| 13 |
+
return nil
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
ch := make(chan os.Signal)
|
| 17 |
+
ctx, cancel := context.WithCancel(context.Background())
|
| 18 |
+
|
| 19 |
+
// Send message on channel when signal received
|
| 20 |
+
signal.Notify(ch, signals...)
|
| 21 |
+
|
| 22 |
+
// When any signal received, call cancel
|
| 23 |
+
go func() {
|
| 24 |
+
<-ch
|
| 25 |
+
cancel()
|
| 26 |
+
}()
|
| 27 |
+
|
| 28 |
+
// Return success
|
| 29 |
+
return ctx
|
| 30 |
+
}
|
bindings/go/examples/go-model-download/main.go
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"flag"
|
| 6 |
+
"fmt"
|
| 7 |
+
"io"
|
| 8 |
+
"net/http"
|
| 9 |
+
"net/url"
|
| 10 |
+
"os"
|
| 11 |
+
"path/filepath"
|
| 12 |
+
"syscall"
|
| 13 |
+
"time"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 17 |
+
// CONSTANTS
|
| 18 |
+
|
| 19 |
+
const (
|
| 20 |
+
srcUrl = "https://huggingface.co/" // The location of the models
|
| 21 |
+
srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
|
| 22 |
+
srcExt = ".bin" // Filename extension
|
| 23 |
+
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
var (
|
| 27 |
+
// The models which will be downloaded, if no model is specified as an argument
|
| 28 |
+
modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
var (
|
| 32 |
+
// The output folder. When not set, use current working directory.
|
| 33 |
+
flagOut = flag.String("out", "", "Output folder")
|
| 34 |
+
|
| 35 |
+
// HTTP timeout parameter - will timeout if takes longer than this to download a model
|
| 36 |
+
flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
|
| 37 |
+
|
| 38 |
+
// Quiet parameter - will not print progress if set
|
| 39 |
+
flagQuiet = flag.Bool("quiet", false, "Quiet mode")
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
// MAIN
|
| 44 |
+
|
| 45 |
+
func main() {
|
| 46 |
+
flag.Usage = func() {
|
| 47 |
+
name := filepath.Base(flag.CommandLine.Name())
|
| 48 |
+
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
|
| 49 |
+
flag.PrintDefaults()
|
| 50 |
+
}
|
| 51 |
+
flag.Parse()
|
| 52 |
+
|
| 53 |
+
// Get output path
|
| 54 |
+
out, err := GetOut()
|
| 55 |
+
if err != nil {
|
| 56 |
+
fmt.Fprintln(os.Stderr, "Error:", err)
|
| 57 |
+
os.Exit(-1)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// Create context which quits on SIGINT or SIGQUIT
|
| 61 |
+
ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
|
| 62 |
+
|
| 63 |
+
// Progress filehandle
|
| 64 |
+
progress := os.Stdout
|
| 65 |
+
if *flagQuiet {
|
| 66 |
+
progress, err = os.Open(os.DevNull)
|
| 67 |
+
if err != nil {
|
| 68 |
+
fmt.Fprintln(os.Stderr, "Error:", err)
|
| 69 |
+
os.Exit(-1)
|
| 70 |
+
}
|
| 71 |
+
defer progress.Close()
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// Download models - exit on error or interrupt
|
| 75 |
+
for _, model := range GetModels() {
|
| 76 |
+
url, err := URLForModel(model)
|
| 77 |
+
if err != nil {
|
| 78 |
+
fmt.Fprintln(os.Stderr, "Error:", err)
|
| 79 |
+
continue
|
| 80 |
+
} else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
|
| 81 |
+
continue
|
| 82 |
+
} else if err == context.Canceled {
|
| 83 |
+
os.Remove(path)
|
| 84 |
+
fmt.Fprintln(progress, "\nInterrupted")
|
| 85 |
+
break
|
| 86 |
+
} else if err == context.DeadlineExceeded {
|
| 87 |
+
os.Remove(path)
|
| 88 |
+
fmt.Fprintln(progress, "Timeout downloading model")
|
| 89 |
+
continue
|
| 90 |
+
} else {
|
| 91 |
+
os.Remove(path)
|
| 92 |
+
fmt.Fprintln(os.Stderr, "Error:", err)
|
| 93 |
+
break
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 99 |
+
// PUBLIC METHODS
|
| 100 |
+
|
| 101 |
+
// GetOut returns the path to the output directory
|
| 102 |
+
func GetOut() (string, error) {
|
| 103 |
+
if *flagOut == "" {
|
| 104 |
+
return os.Getwd()
|
| 105 |
+
}
|
| 106 |
+
if info, err := os.Stat(*flagOut); err != nil {
|
| 107 |
+
return "", err
|
| 108 |
+
} else if !info.IsDir() {
|
| 109 |
+
return "", fmt.Errorf("not a directory: %s", info.Name())
|
| 110 |
+
} else {
|
| 111 |
+
return *flagOut, nil
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// GetModels returns the list of models to download
|
| 116 |
+
func GetModels() []string {
|
| 117 |
+
if flag.NArg() == 0 {
|
| 118 |
+
return modelNames
|
| 119 |
+
} else {
|
| 120 |
+
return flag.Args()
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// URLForModel returns the URL for the given model on huggingface.co
|
| 125 |
+
func URLForModel(model string) (string, error) {
|
| 126 |
+
url, err := url.Parse(srcUrl)
|
| 127 |
+
if err != nil {
|
| 128 |
+
return "", err
|
| 129 |
+
} else {
|
| 130 |
+
url.Path = srcPathPrefix + "-" + model + srcExt
|
| 131 |
+
}
|
| 132 |
+
return url.String(), nil
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// Download downloads the model from the given URL to the given output directory
|
| 136 |
+
func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
|
| 137 |
+
// Create HTTP client
|
| 138 |
+
client := http.Client{
|
| 139 |
+
Timeout: *flagTimeout,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
// Initiate the download
|
| 143 |
+
req, err := http.NewRequest("GET", model, nil)
|
| 144 |
+
if err != nil {
|
| 145 |
+
return "", err
|
| 146 |
+
}
|
| 147 |
+
resp, err := client.Do(req)
|
| 148 |
+
if err != nil {
|
| 149 |
+
return "", err
|
| 150 |
+
}
|
| 151 |
+
defer resp.Body.Close()
|
| 152 |
+
if resp.StatusCode != http.StatusOK {
|
| 153 |
+
return "", fmt.Errorf("%s: %s", model, resp.Status)
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// If output file exists and is the same size as the model, skip
|
| 157 |
+
path := filepath.Join(out, filepath.Base(model))
|
| 158 |
+
if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
|
| 159 |
+
fmt.Fprintln(p, "Skipping", model, "as it already exists")
|
| 160 |
+
return "", nil
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
// Create file
|
| 164 |
+
w, err := os.Create(path)
|
| 165 |
+
if err != nil {
|
| 166 |
+
return "", err
|
| 167 |
+
}
|
| 168 |
+
defer w.Close()
|
| 169 |
+
|
| 170 |
+
// Report
|
| 171 |
+
fmt.Fprintln(p, "Downloading", model, "to", out)
|
| 172 |
+
|
| 173 |
+
// Progressively download the model
|
| 174 |
+
data := make([]byte, bufSize)
|
| 175 |
+
count, pct := int64(0), int64(0)
|
| 176 |
+
ticker := time.NewTicker(5 * time.Second)
|
| 177 |
+
for {
|
| 178 |
+
select {
|
| 179 |
+
case <-ctx.Done():
|
| 180 |
+
// Cancelled, return error
|
| 181 |
+
return path, ctx.Err()
|
| 182 |
+
case <-ticker.C:
|
| 183 |
+
pct = DownloadReport(p, pct, count, resp.ContentLength)
|
| 184 |
+
default:
|
| 185 |
+
// Read body
|
| 186 |
+
n, err := resp.Body.Read(data)
|
| 187 |
+
if err != nil {
|
| 188 |
+
DownloadReport(p, pct, count, resp.ContentLength)
|
| 189 |
+
return path, err
|
| 190 |
+
} else if m, err := w.Write(data[:n]); err != nil {
|
| 191 |
+
return path, err
|
| 192 |
+
} else {
|
| 193 |
+
count += int64(m)
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
// Report periodically reports the download progress when percentage changes
|
| 200 |
+
func DownloadReport(w io.Writer, pct, count, total int64) int64 {
|
| 201 |
+
pct_ := count * 100 / total
|
| 202 |
+
if pct_ > pct {
|
| 203 |
+
fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
|
| 204 |
+
}
|
| 205 |
+
return pct_
|
| 206 |
+
}
|
bindings/go/examples/go-whisper/flags.go
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"flag"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
// TYPES
|
| 9 |
+
|
| 10 |
+
type Flags struct {
|
| 11 |
+
*flag.FlagSet
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 15 |
+
// LIFECYCLE
|
| 16 |
+
|
| 17 |
+
func NewFlags(name string, args []string) (*Flags, error) {
|
| 18 |
+
flags := &Flags{
|
| 19 |
+
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
// Register the command line arguments
|
| 23 |
+
registerFlags(flags)
|
| 24 |
+
|
| 25 |
+
// Parse command line
|
| 26 |
+
if err := flags.Parse(args); err != nil {
|
| 27 |
+
return nil, err
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Return success
|
| 31 |
+
return flags, nil
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 35 |
+
// PUBLIC METHODS
|
| 36 |
+
|
| 37 |
+
func (flags *Flags) GetModel() string {
|
| 38 |
+
return flags.Lookup("model").Value.String()
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
func (flags *Flags) GetLanguage() string {
|
| 42 |
+
return flags.Lookup("language").Value.String()
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func (flags *Flags) IsSpeedup() bool {
|
| 46 |
+
return flags.Lookup("speedup").Value.String() == "true"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
func (flags *Flags) IsTokens() bool {
|
| 50 |
+
return flags.Lookup("tokens").Value.String() == "true"
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
// PRIVATE METHODS
|
| 55 |
+
|
| 56 |
+
func registerFlags(flag *Flags) {
|
| 57 |
+
flag.String("model", "", "Path to the model file")
|
| 58 |
+
flag.String("language", "", "Language")
|
| 59 |
+
flag.Bool("speedup", false, "Enable speedup")
|
| 60 |
+
flag.Bool("tokens", false, "Display tokens")
|
| 61 |
+
}
|
bindings/go/examples/go-whisper/main.go
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"flag"
|
| 5 |
+
"fmt"
|
| 6 |
+
"os"
|
| 7 |
+
"path/filepath"
|
| 8 |
+
|
| 9 |
+
// Packages
|
| 10 |
+
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
func main() {
|
| 14 |
+
flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
|
| 15 |
+
if err == flag.ErrHelp {
|
| 16 |
+
os.Exit(0)
|
| 17 |
+
} else if err != nil {
|
| 18 |
+
fmt.Fprintln(os.Stderr, err)
|
| 19 |
+
os.Exit(1)
|
| 20 |
+
} else if flags.GetModel() == "" {
|
| 21 |
+
fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
|
| 22 |
+
os.Exit(1)
|
| 23 |
+
} else if flags.NArg() == 0 {
|
| 24 |
+
fmt.Fprintln(os.Stderr, "No input files specified")
|
| 25 |
+
os.Exit(1)
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// Load model
|
| 29 |
+
model, err := whisper.New(flags.GetModel())
|
| 30 |
+
if err != nil {
|
| 31 |
+
fmt.Fprintln(os.Stderr, err)
|
| 32 |
+
os.Exit(1)
|
| 33 |
+
}
|
| 34 |
+
defer model.Close()
|
| 35 |
+
|
| 36 |
+
// Process files
|
| 37 |
+
for _, filename := range flags.Args() {
|
| 38 |
+
fmt.Println("Processing", filename)
|
| 39 |
+
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
|
| 40 |
+
fmt.Fprintln(os.Stderr, err)
|
| 41 |
+
continue
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
}
|
bindings/go/examples/go-whisper/process.go
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"io"
|
| 6 |
+
"os"
|
| 7 |
+
"time"
|
| 8 |
+
|
| 9 |
+
// Package imports
|
| 10 |
+
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
| 11 |
+
wav "github.com/go-audio/wav"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
|
| 15 |
+
var data []float32
|
| 16 |
+
|
| 17 |
+
// Create processing context
|
| 18 |
+
context, err := model.NewContext()
|
| 19 |
+
if err != nil {
|
| 20 |
+
return err
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// Open the file
|
| 24 |
+
fh, err := os.Open(path)
|
| 25 |
+
if err != nil {
|
| 26 |
+
return err
|
| 27 |
+
}
|
| 28 |
+
defer fh.Close()
|
| 29 |
+
|
| 30 |
+
// Decode the WAV file
|
| 31 |
+
dec := wav.NewDecoder(fh)
|
| 32 |
+
if buf, err := dec.FullPCMBuffer(); err != nil {
|
| 33 |
+
return err
|
| 34 |
+
} else if dec.SampleRate != whisper.SampleRate {
|
| 35 |
+
return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
| 36 |
+
} else if dec.NumChans != 1 {
|
| 37 |
+
return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
| 38 |
+
} else {
|
| 39 |
+
data = buf.AsFloat32Buffer().Data
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// Set the parameters
|
| 43 |
+
var cb whisper.SegmentCallback
|
| 44 |
+
if lang != "" {
|
| 45 |
+
if err := context.SetLanguage(lang); err != nil {
|
| 46 |
+
return err
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
if speedup {
|
| 50 |
+
context.SetSpeedup(true)
|
| 51 |
+
}
|
| 52 |
+
if tokens {
|
| 53 |
+
cb = func(segment whisper.Segment) {
|
| 54 |
+
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
| 55 |
+
for _, token := range segment.Tokens {
|
| 56 |
+
fmt.Printf("%q ", token.Text)
|
| 57 |
+
}
|
| 58 |
+
fmt.Println("")
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// Process the data
|
| 63 |
+
if err := context.Process(data, cb); err != nil {
|
| 64 |
+
return err
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Print out the results
|
| 68 |
+
for {
|
| 69 |
+
segment, err := context.NextSegment()
|
| 70 |
+
if err == io.EOF {
|
| 71 |
+
break
|
| 72 |
+
} else if err != nil {
|
| 73 |
+
return err
|
| 74 |
+
}
|
| 75 |
+
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
// Return success
|
| 79 |
+
return nil
|
| 80 |
+
}
|
bindings/go/go.mod
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module github.com/ggerganov/whisper.cpp/bindings/go
|
| 2 |
+
|
| 3 |
+
go 1.19
|
| 4 |
+
|
| 5 |
+
require (
|
| 6 |
+
github.com/go-audio/wav v1.1.0
|
| 7 |
+
github.com/stretchr/testify v1.8.1
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
require (
|
| 11 |
+
github.com/davecgh/go-spew v1.1.1 // indirect
|
| 12 |
+
github.com/go-audio/audio v1.0.0 // indirect
|
| 13 |
+
github.com/go-audio/riff v1.0.0 // indirect
|
| 14 |
+
github.com/pmezard/go-difflib v1.0.0 // indirect
|
| 15 |
+
gopkg.in/yaml.v3 v3.0.1 // indirect
|
| 16 |
+
)
|
bindings/go/params.go
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper
|
| 2 |
+
|
| 3 |
+
// This file defines the whisper_token, whisper_token_data and whisper_full_params
|
| 4 |
+
// structures, which are used by the whisper_full() function.
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"fmt"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 11 |
+
// CGO
|
| 12 |
+
|
| 13 |
+
/*
|
| 14 |
+
#include <whisper.h>
|
| 15 |
+
*/
|
| 16 |
+
import "C"
|
| 17 |
+
|
| 18 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 19 |
+
// PUBLIC METHODS
|
| 20 |
+
|
| 21 |
+
func (p *Params) SetTranslate(v bool) {
|
| 22 |
+
p.translate = toBool(v)
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
func (p *Params) SetNoContext(v bool) {
|
| 26 |
+
p.no_context = toBool(v)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
func (p *Params) SetSingleSegment(v bool) {
|
| 30 |
+
p.single_segment = toBool(v)
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
func (p *Params) SetPrintSpecial(v bool) {
|
| 34 |
+
p.print_special = toBool(v)
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
func (p *Params) SetPrintProgress(v bool) {
|
| 38 |
+
p.print_progress = toBool(v)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
func (p *Params) SetPrintRealtime(v bool) {
|
| 42 |
+
p.print_realtime = toBool(v)
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func (p *Params) SetPrintTimestamps(v bool) {
|
| 46 |
+
p.print_timestamps = toBool(v)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
func (p *Params) SetSpeedup(v bool) {
|
| 50 |
+
p.speed_up = toBool(v)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
func (p *Params) SetLanguage(lang int) error {
|
| 54 |
+
str := C.whisper_lang_str(C.int(lang))
|
| 55 |
+
if str == nil {
|
| 56 |
+
return ErrInvalidLanguage
|
| 57 |
+
} else {
|
| 58 |
+
p.language = str
|
| 59 |
+
}
|
| 60 |
+
return nil
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
func (p *Params) Language() int {
|
| 64 |
+
if p.language == nil {
|
| 65 |
+
return -1
|
| 66 |
+
}
|
| 67 |
+
return int(C.whisper_lang_id(p.language))
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
func (p *Params) SetThreads(threads int) {
|
| 71 |
+
p.n_threads = C.int(threads)
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
func (p *Params) SetOffset(offset_ms int) {
|
| 75 |
+
p.offset_ms = C.int(offset_ms)
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
func (p *Params) SetDuration(duration_ms int) {
|
| 79 |
+
p.duration_ms = C.int(duration_ms)
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 83 |
+
// PRIVATE METHODS
|
| 84 |
+
|
| 85 |
+
func toBool(v bool) C.bool {
|
| 86 |
+
if v {
|
| 87 |
+
return C.bool(true)
|
| 88 |
+
}
|
| 89 |
+
return C.bool(false)
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 93 |
+
// STRINGIFY
|
| 94 |
+
|
| 95 |
+
func (p *Params) String() string {
|
| 96 |
+
str := "<whisper.params"
|
| 97 |
+
str += fmt.Sprintf(" strategy=%v", p.strategy)
|
| 98 |
+
str += fmt.Sprintf(" n_threads=%d", p.n_threads)
|
| 99 |
+
if p.language != nil {
|
| 100 |
+
str += fmt.Sprintf(" language=%s", C.GoString(p.language))
|
| 101 |
+
}
|
| 102 |
+
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
| 103 |
+
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
| 104 |
+
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
| 105 |
+
if p.translate {
|
| 106 |
+
str += " translate"
|
| 107 |
+
}
|
| 108 |
+
if p.no_context {
|
| 109 |
+
str += " no_context"
|
| 110 |
+
}
|
| 111 |
+
if p.single_segment {
|
| 112 |
+
str += " single_segment"
|
| 113 |
+
}
|
| 114 |
+
if p.print_special {
|
| 115 |
+
str += " print_special"
|
| 116 |
+
}
|
| 117 |
+
if p.print_progress {
|
| 118 |
+
str += " print_progress"
|
| 119 |
+
}
|
| 120 |
+
if p.print_realtime {
|
| 121 |
+
str += " print_realtime"
|
| 122 |
+
}
|
| 123 |
+
if p.print_timestamps {
|
| 124 |
+
str += " print_timestamps"
|
| 125 |
+
}
|
| 126 |
+
if p.token_timestamps {
|
| 127 |
+
str += " token_timestamps"
|
| 128 |
+
}
|
| 129 |
+
if p.speed_up {
|
| 130 |
+
str += " speed_up"
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
return str + ">"
|
| 134 |
+
}
|
bindings/go/pkg/whisper/consts.go
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"errors"
|
| 5 |
+
|
| 6 |
+
// Bindings
|
| 7 |
+
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 11 |
+
// ERRORS
|
| 12 |
+
|
| 13 |
+
var (
|
| 14 |
+
ErrUnableToLoadModel = errors.New("unable to load model")
|
| 15 |
+
ErrInternalAppError = errors.New("internal application error")
|
| 16 |
+
ErrProcessingFailed = errors.New("processing failed")
|
| 17 |
+
ErrUnsupportedLanguage = errors.New("unsupported language")
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 21 |
+
// CONSTANTS
|
| 22 |
+
|
| 23 |
+
// SampleRate is the sample rate of the audio data.
|
| 24 |
+
const SampleRate = whisper.SampleRate
|
| 25 |
+
|
| 26 |
+
// SampleBits is the number of bytes per sample.
|
| 27 |
+
const SampleBits = whisper.SampleBits
|
bindings/go/pkg/whisper/context.go
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"io"
|
| 5 |
+
"strings"
|
| 6 |
+
"time"
|
| 7 |
+
|
| 8 |
+
// Bindings
|
| 9 |
+
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 13 |
+
// TYPES
|
| 14 |
+
|
| 15 |
+
type context struct {
|
| 16 |
+
n int
|
| 17 |
+
model *model
|
| 18 |
+
params whisper.Params
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// Make sure context adheres to the interface
|
| 22 |
+
var _ Context = (*context)(nil)
|
| 23 |
+
|
| 24 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 25 |
+
// LIFECYCLE
|
| 26 |
+
|
| 27 |
+
func NewContext(model *model, params whisper.Params) (Context, error) {
|
| 28 |
+
context := new(context)
|
| 29 |
+
context.model = model
|
| 30 |
+
context.params = params
|
| 31 |
+
|
| 32 |
+
// Return success
|
| 33 |
+
return context, nil
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 37 |
+
// PUBLIC METHODS
|
| 38 |
+
|
| 39 |
+
// Set the language to use for speech recognition.
|
| 40 |
+
func (context *context) SetLanguage(lang string) error {
|
| 41 |
+
if context.model.ctx == nil {
|
| 42 |
+
return ErrInternalAppError
|
| 43 |
+
}
|
| 44 |
+
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
| 45 |
+
return ErrUnsupportedLanguage
|
| 46 |
+
} else if err := context.params.SetLanguage(id); err != nil {
|
| 47 |
+
return err
|
| 48 |
+
}
|
| 49 |
+
// Return success
|
| 50 |
+
return nil
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Get language
|
| 54 |
+
func (context *context) Language() string {
|
| 55 |
+
return whisper.Whisper_lang_str(context.params.Language())
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// Set speedup flag
|
| 59 |
+
func (context *context) SetSpeedup(v bool) {
|
| 60 |
+
context.params.SetSpeedup(v)
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// Process new sample data and return any errors
|
| 64 |
+
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
| 65 |
+
if context.model.ctx == nil {
|
| 66 |
+
return ErrInternalAppError
|
| 67 |
+
}
|
| 68 |
+
// If the callback is defined then we force on single_segment mode
|
| 69 |
+
if cb != nil {
|
| 70 |
+
context.params.SetSingleSegment(true)
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// We don't do parallel processing at the moment
|
| 74 |
+
processors := 0
|
| 75 |
+
if processors > 1 {
|
| 76 |
+
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
| 77 |
+
if cb != nil {
|
| 78 |
+
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 79 |
+
s0 := num_segments - new
|
| 80 |
+
for i := s0; i < num_segments; i++ {
|
| 81 |
+
cb(toSegment(context.model.ctx, i))
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
}); err != nil {
|
| 85 |
+
return err
|
| 86 |
+
}
|
| 87 |
+
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
| 88 |
+
if cb != nil {
|
| 89 |
+
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 90 |
+
s0 := num_segments - new
|
| 91 |
+
for i := s0; i < num_segments; i++ {
|
| 92 |
+
cb(toSegment(context.model.ctx, i))
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}); err != nil {
|
| 96 |
+
return err
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// Return success
|
| 100 |
+
return nil
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// Return the next segment of tokens
|
| 104 |
+
func (context *context) NextSegment() (Segment, error) {
|
| 105 |
+
if context.model.ctx == nil {
|
| 106 |
+
return Segment{}, ErrInternalAppError
|
| 107 |
+
}
|
| 108 |
+
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
| 109 |
+
return Segment{}, io.EOF
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// Populate result
|
| 113 |
+
result := toSegment(context.model.ctx, context.n)
|
| 114 |
+
|
| 115 |
+
// Increment the cursor
|
| 116 |
+
context.n++
|
| 117 |
+
|
| 118 |
+
// Return success
|
| 119 |
+
return result, nil
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 123 |
+
// PRIVATE METHODS
|
| 124 |
+
|
| 125 |
+
func toSegment(ctx *whisper.Context, n int) Segment {
|
| 126 |
+
return Segment{
|
| 127 |
+
Num: n,
|
| 128 |
+
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
| 129 |
+
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
| 130 |
+
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
| 131 |
+
Tokens: toTokens(ctx, n),
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
func toTokens(ctx *whisper.Context, n int) []Token {
|
| 136 |
+
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
| 137 |
+
for i := 0; i < len(result); i++ {
|
| 138 |
+
result[i] = Token{
|
| 139 |
+
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
| 140 |
+
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
|
| 141 |
+
P: ctx.Whisper_full_get_token_p(n, i),
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
return result
|
| 145 |
+
}
|
bindings/go/pkg/whisper/context_test.go
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper_test
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"os"
|
| 5 |
+
"testing"
|
| 6 |
+
|
| 7 |
+
// Packages
|
| 8 |
+
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
| 9 |
+
assert "github.com/stretchr/testify/assert"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
const (
|
| 13 |
+
ModelPath = "../../models/ggml-tiny.bin"
|
| 14 |
+
SamplePath = "../../samples/jfk.wav"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
func Test_Whisper_000(t *testing.T) {
|
| 18 |
+
assert := assert.New(t)
|
| 19 |
+
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
| 20 |
+
t.Skip("Skipping test, model not found:", ModelPath)
|
| 21 |
+
}
|
| 22 |
+
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
| 23 |
+
t.Skip("Skipping test, sample not found:", SamplePath)
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// Load model
|
| 27 |
+
model, err := whisper.New(ModelPath)
|
| 28 |
+
assert.NoError(err)
|
| 29 |
+
assert.NotNil(model)
|
| 30 |
+
assert.NoError(model.Close())
|
| 31 |
+
|
| 32 |
+
t.Log("languages=", model.Languages())
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
func Test_Whisper_001(t *testing.T) {
|
| 36 |
+
assert := assert.New(t)
|
| 37 |
+
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
| 38 |
+
t.Skip("Skipping test, model not found:", ModelPath)
|
| 39 |
+
}
|
| 40 |
+
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
| 41 |
+
t.Skip("Skipping test, sample not found:", SamplePath)
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// Load model
|
| 45 |
+
model, err := whisper.New(ModelPath)
|
| 46 |
+
assert.NoError(err)
|
| 47 |
+
assert.NotNil(model)
|
| 48 |
+
defer model.Close()
|
| 49 |
+
|
| 50 |
+
// Get context for decoding
|
| 51 |
+
ctx, err := model.NewContext()
|
| 52 |
+
assert.NoError(err)
|
| 53 |
+
assert.NotNil(ctx)
|
| 54 |
+
|
| 55 |
+
}
|
bindings/go/pkg/whisper/doc.go
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
This is the higher-level speech-to-text whisper.cpp API for go
|
| 3 |
+
*/
|
| 4 |
+
package whisper
|
bindings/go/pkg/whisper/interface.go
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"io"
|
| 5 |
+
"time"
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 9 |
+
// TYPES
|
| 10 |
+
|
| 11 |
+
// SegmentCallback is the callback function for processing segments in real
|
| 12 |
+
// time. It is called during the Process function
|
| 13 |
+
type SegmentCallback func(Segment)
|
| 14 |
+
|
| 15 |
+
// Model is the interface to a whisper model. Create a new model with the
|
| 16 |
+
// function whisper.New(string)
|
| 17 |
+
type Model interface {
|
| 18 |
+
io.Closer
|
| 19 |
+
|
| 20 |
+
// Return a new speech-to-text context.
|
| 21 |
+
NewContext() (Context, error)
|
| 22 |
+
|
| 23 |
+
// Return all languages supported.
|
| 24 |
+
Languages() []string
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// Context is the speach recognition context.
|
| 28 |
+
type Context interface {
|
| 29 |
+
SetLanguage(string) error // Set the language to use for speech recognition.
|
| 30 |
+
Language() string // Get language
|
| 31 |
+
SetSpeedup(bool) // Set speedup flag
|
| 32 |
+
|
| 33 |
+
// Process mono audio data and return any errors.
|
| 34 |
+
// If defined, newly generated segments are passed to the
|
| 35 |
+
// callback function during processing.
|
| 36 |
+
Process([]float32, SegmentCallback) error
|
| 37 |
+
|
| 38 |
+
// After process is called, return segments until the end of the stream
|
| 39 |
+
// is reached, when io.EOF is returned.
|
| 40 |
+
NextSegment() (Segment, error)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// Segment is the text result of a speech recognition.
|
| 44 |
+
type Segment struct {
|
| 45 |
+
// Segment Number
|
| 46 |
+
Num int
|
| 47 |
+
|
| 48 |
+
// Time beginning and end timestamps for the segment.
|
| 49 |
+
Start, End time.Duration
|
| 50 |
+
|
| 51 |
+
// The text of the segment.
|
| 52 |
+
Text string
|
| 53 |
+
|
| 54 |
+
// The tokens of the segment.
|
| 55 |
+
Tokens []Token
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// Token is a text or special token
|
| 59 |
+
type Token struct {
|
| 60 |
+
Id int
|
| 61 |
+
Text string
|
| 62 |
+
P float32
|
| 63 |
+
}
|
bindings/go/pkg/whisper/model.go
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"os"
|
| 6 |
+
"runtime"
|
| 7 |
+
|
| 8 |
+
// Bindings
|
| 9 |
+
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 13 |
+
// TYPES
|
| 14 |
+
|
| 15 |
+
type model struct {
|
| 16 |
+
path string
|
| 17 |
+
ctx *whisper.Context
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// Make sure model adheres to the interface
|
| 21 |
+
var _ Model = (*model)(nil)
|
| 22 |
+
|
| 23 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 24 |
+
// LIFECYCLE
|
| 25 |
+
|
| 26 |
+
func New(path string) (*model, error) {
|
| 27 |
+
model := new(model)
|
| 28 |
+
if _, err := os.Stat(path); err != nil {
|
| 29 |
+
return nil, err
|
| 30 |
+
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
| 31 |
+
return nil, ErrUnableToLoadModel
|
| 32 |
+
} else {
|
| 33 |
+
model.ctx = ctx
|
| 34 |
+
model.path = path
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Return success
|
| 38 |
+
return model, nil
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
func (model *model) Close() error {
|
| 42 |
+
if model.ctx != nil {
|
| 43 |
+
model.ctx.Whisper_free()
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// Release resources
|
| 47 |
+
model.ctx = nil
|
| 48 |
+
|
| 49 |
+
// Return success
|
| 50 |
+
return nil
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
// STRINGIFY
|
| 55 |
+
|
| 56 |
+
func (model *model) String() string {
|
| 57 |
+
str := "<whisper.model"
|
| 58 |
+
if model.ctx != nil {
|
| 59 |
+
str += fmt.Sprintf(" model=%q", model.path)
|
| 60 |
+
}
|
| 61 |
+
return str + ">"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
// PUBLIC METHODS
|
| 66 |
+
|
| 67 |
+
// Return all recognized languages. Initially it is set to auto-detect
|
| 68 |
+
func (model *model) Languages() []string {
|
| 69 |
+
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
| 70 |
+
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
| 71 |
+
str := whisper.Whisper_lang_str(i)
|
| 72 |
+
if model.ctx.Whisper_lang_id(str) >= 0 {
|
| 73 |
+
result = append(result, str)
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
return result
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
func (model *model) NewContext() (Context, error) {
|
| 80 |
+
if model.ctx == nil {
|
| 81 |
+
return nil, ErrInternalAppError
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Create new context
|
| 85 |
+
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
| 86 |
+
params.SetTranslate(false)
|
| 87 |
+
params.SetPrintSpecial(false)
|
| 88 |
+
params.SetPrintProgress(false)
|
| 89 |
+
params.SetPrintRealtime(false)
|
| 90 |
+
params.SetPrintTimestamps(false)
|
| 91 |
+
params.SetThreads(runtime.NumCPU())
|
| 92 |
+
|
| 93 |
+
// Return new context
|
| 94 |
+
return NewContext(model, params)
|
| 95 |
+
}
|
bindings/go/samples/jfk.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59dfb9a4acb36fe2a2affc14bacbee2920ff435cb13cc314a08c13f66ba7860e
|
| 3 |
+
size 352078
|
bindings/go/whisper.go
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"errors"
|
| 5 |
+
"unsafe"
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 9 |
+
// CGO
|
| 10 |
+
|
| 11 |
+
/*
|
| 12 |
+
#cgo CFLAGS: -I${SRCDIR}/../..
|
| 13 |
+
#cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
|
| 14 |
+
#cgo darwin LDFLAGS: -framework Accelerate
|
| 15 |
+
#include <whisper.h>
|
| 16 |
+
#include <stdlib.h>
|
| 17 |
+
|
| 18 |
+
extern void callNewSegment(void* user_data, int new);
|
| 19 |
+
extern bool callEncoderBegin(void* user_data);
|
| 20 |
+
|
| 21 |
+
// Text segment callback
|
| 22 |
+
// Called on every newly generated text segment
|
| 23 |
+
// Use the whisper_full_...() functions to obtain the text segments
|
| 24 |
+
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
|
| 25 |
+
if(user_data != NULL && ctx != NULL) {
|
| 26 |
+
callNewSegment(user_data, n_new);
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Encoder begin callback
|
| 31 |
+
// If not NULL, called before the encoder starts
|
| 32 |
+
// If it returns false, the computation is aborted
|
| 33 |
+
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
|
| 34 |
+
if(user_data != NULL && ctx != NULL) {
|
| 35 |
+
return callEncoderBegin(user_data);
|
| 36 |
+
}
|
| 37 |
+
return false;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Get default parameters and set callbacks
|
| 41 |
+
static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
|
| 42 |
+
struct whisper_full_params params = whisper_full_default_params(strategy);
|
| 43 |
+
params.new_segment_callback = whisper_new_segment_cb;
|
| 44 |
+
params.new_segment_callback_user_data = (void*)(ctx);
|
| 45 |
+
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
| 46 |
+
params.encoder_begin_callback_user_data = (void*)(ctx);
|
| 47 |
+
return params;
|
| 48 |
+
}
|
| 49 |
+
*/
|
| 50 |
+
import "C"
|
| 51 |
+
|
| 52 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
// TYPES
|
| 54 |
+
|
| 55 |
+
type (
|
| 56 |
+
Context C.struct_whisper_context
|
| 57 |
+
Token C.whisper_token
|
| 58 |
+
TokenData C.struct_whisper_token_data
|
| 59 |
+
SamplingStrategy C.enum_whisper_sampling_strategy
|
| 60 |
+
Params C.struct_whisper_full_params
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
// GLOBALS
|
| 65 |
+
|
| 66 |
+
const (
|
| 67 |
+
SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
|
| 68 |
+
SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
const (
|
| 72 |
+
SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
|
| 73 |
+
SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
|
| 74 |
+
NumFFT = C.WHISPER_N_FFT
|
| 75 |
+
NumMEL = C.WHISPER_N_MEL
|
| 76 |
+
HopLength = C.WHISPER_HOP_LENGTH
|
| 77 |
+
ChunkSize = C.WHISPER_CHUNK_SIZE
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
var (
|
| 81 |
+
ErrTokenizerFailed = errors.New("whisper_tokenize failed")
|
| 82 |
+
ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
|
| 83 |
+
ErrConversionFailed = errors.New("whisper_convert failed")
|
| 84 |
+
ErrInvalidLanguage = errors.New("invalid language")
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
// PUBLIC METHODS
|
| 89 |
+
|
| 90 |
+
// Allocates all memory needed for the model and loads the model from the given file.
|
| 91 |
+
// Returns NULL on failure.
|
| 92 |
+
func Whisper_init(path string) *Context {
|
| 93 |
+
cPath := C.CString(path)
|
| 94 |
+
defer C.free(unsafe.Pointer(cPath))
|
| 95 |
+
if ctx := C.whisper_init(cPath); ctx != nil {
|
| 96 |
+
return (*Context)(ctx)
|
| 97 |
+
} else {
|
| 98 |
+
return nil
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// Frees all memory allocated by the model.
|
| 103 |
+
func (ctx *Context) Whisper_free() {
|
| 104 |
+
C.whisper_free((*C.struct_whisper_context)(ctx))
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// Convert RAW PCM audio to log mel spectrogram.
|
| 108 |
+
// The resulting spectrogram is stored inside the provided whisper context.
|
| 109 |
+
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
| 110 |
+
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
|
| 111 |
+
return nil
|
| 112 |
+
} else {
|
| 113 |
+
return ErrConversionFailed
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
|
| 118 |
+
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
| 119 |
+
// n_mel must be 80
|
| 120 |
+
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
|
| 121 |
+
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
|
| 122 |
+
return nil
|
| 123 |
+
} else {
|
| 124 |
+
return ErrConversionFailed
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
| 129 |
+
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
| 130 |
+
// offset can be used to specify the offset of the first frame in the spectrogram.
|
| 131 |
+
func (ctx *Context) Whisper_encode(offset, threads int) error {
|
| 132 |
+
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
|
| 133 |
+
return nil
|
| 134 |
+
} else {
|
| 135 |
+
return ErrConversionFailed
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
| 140 |
+
// Make sure to call whisper_encode() first.
|
| 141 |
+
// tokens + n_tokens is the provided context for the decoder.
|
| 142 |
+
// n_past is the number of tokens to use from previous decoder calls.
|
| 143 |
+
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
|
| 144 |
+
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
|
| 145 |
+
return nil
|
| 146 |
+
} else {
|
| 147 |
+
return ErrConversionFailed
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
// whisper_sample_best() returns the token with the highest probability
|
| 152 |
+
func (ctx *Context) Whisper_sample_best() TokenData {
|
| 153 |
+
return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// whisper_sample_timestamp() returns the most probable timestamp token
|
| 157 |
+
func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
|
| 158 |
+
return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
|
| 162 |
+
// Returns the number of tokens on success
|
| 163 |
+
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
| 164 |
+
cText := C.CString(text)
|
| 165 |
+
defer C.free(unsafe.Pointer(cText))
|
| 166 |
+
if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
|
| 167 |
+
return int(n), nil
|
| 168 |
+
} else {
|
| 169 |
+
return 0, ErrTokenizerFailed
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// Return the id of the specified language, returns -1 if not found
|
| 174 |
+
func (ctx *Context) Whisper_lang_id(lang string) int {
|
| 175 |
+
return int(C.whisper_lang_id(C.CString(lang)))
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
// Largest language id (i.e. number of available languages - 1)
|
| 179 |
+
func Whisper_lang_max_id() int {
|
| 180 |
+
return int(C.whisper_lang_max_id())
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// Return the short string of the specified language id (e.g. 2 -> "de"),
|
| 184 |
+
// returns empty string if not found
|
| 185 |
+
func Whisper_lang_str(id int) string {
|
| 186 |
+
return C.GoString(C.whisper_lang_str(C.int(id)))
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// Use mel data at offset_ms to try and auto-detect the spoken language
|
| 190 |
+
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
| 191 |
+
// Returns the probabilities of all languages.
|
| 192 |
+
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
| 193 |
+
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
|
| 194 |
+
probs := make([]float32, Whisper_lang_max_id()+1)
|
| 195 |
+
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
|
| 196 |
+
return nil, ErrAutoDetectFailed
|
| 197 |
+
} else {
|
| 198 |
+
return probs, nil
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
func (ctx *Context) Whisper_n_len() int {
|
| 203 |
+
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
func (ctx *Context) Whisper_n_vocab() int {
|
| 207 |
+
return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
func (ctx *Context) Whisper_n_text_ctx() int {
|
| 211 |
+
return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
func (ctx *Context) Whisper_is_multilingual() int {
|
| 215 |
+
return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
// The probabilities for the next token
|
| 219 |
+
//func (ctx *Whisper_context) Whisper_get_probs() []float32 {
|
| 220 |
+
// return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
|
| 221 |
+
//}
|
| 222 |
+
|
| 223 |
+
// Token Id -> String. Uses the vocabulary in the provided context
|
| 224 |
+
func (ctx *Context) Whisper_token_to_str(token Token) string {
|
| 225 |
+
return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
// Special tokens
|
| 229 |
+
func (ctx *Context) Whisper_token_eot() Token {
|
| 230 |
+
return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
// Special tokens
|
| 234 |
+
func (ctx *Context) Whisper_token_sot() Token {
|
| 235 |
+
return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
// Special tokens
|
| 239 |
+
func (ctx *Context) Whisper_token_prev() Token {
|
| 240 |
+
return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
// Special tokens
|
| 244 |
+
func (ctx *Context) Whisper_token_solm() Token {
|
| 245 |
+
return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
// Special tokens
|
| 249 |
+
func (ctx *Context) Whisper_token_not() Token {
|
| 250 |
+
return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
// Special tokens
|
| 254 |
+
func (ctx *Context) Whisper_token_beg() Token {
|
| 255 |
+
return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// Special tokens
|
| 259 |
+
func (ctx *Context) Whisper_token_lang(lang_id int) Token {
|
| 260 |
+
return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
// Task tokens
|
| 264 |
+
func Whisper_token_translate() Token {
|
| 265 |
+
return Token(C.whisper_token_translate())
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// Task tokens
|
| 269 |
+
func Whisper_token_transcribe() Token {
|
| 270 |
+
return Token(C.whisper_token_transcribe())
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
// Performance information
|
| 274 |
+
func (ctx *Context) Whisper_print_timings() {
|
| 275 |
+
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
// Performance information
|
| 279 |
+
func (ctx *Context) Whisper_reset_timings() {
|
| 280 |
+
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// Print system information
|
| 284 |
+
func Whisper_print_system_info() string {
|
| 285 |
+
return C.GoString(C.whisper_print_system_info())
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// Return default parameters for a strategy
|
| 289 |
+
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
|
| 290 |
+
// Get default parameters
|
| 291 |
+
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
| 295 |
+
// Uses the specified decoding strategy to obtain the text.
|
| 296 |
+
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
| 297 |
+
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
| 298 |
+
registerNewSegmentCallback(ctx, newSegmentCallback)
|
| 299 |
+
defer registerEncoderBeginCallback(ctx, nil)
|
| 300 |
+
defer registerNewSegmentCallback(ctx, nil)
|
| 301 |
+
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
| 302 |
+
return nil
|
| 303 |
+
} else {
|
| 304 |
+
return ErrConversionFailed
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
// Split the input audio in chunks and process each chunk separately using whisper_full()
|
| 309 |
+
// It seems this approach can offer some speedup in some cases.
|
| 310 |
+
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
| 311 |
+
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
| 312 |
+
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
| 313 |
+
registerNewSegmentCallback(ctx, newSegmentCallback)
|
| 314 |
+
defer registerEncoderBeginCallback(ctx, nil)
|
| 315 |
+
defer registerNewSegmentCallback(ctx, nil)
|
| 316 |
+
|
| 317 |
+
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
|
| 318 |
+
return nil
|
| 319 |
+
} else {
|
| 320 |
+
return ErrConversionFailed
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// Number of generated text segments.
|
| 325 |
+
// A segment can be a few words, a sentence, or even a paragraph.
|
| 326 |
+
func (ctx *Context) Whisper_full_n_segments() int {
|
| 327 |
+
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// Get the start and end time of the specified segment.
|
| 331 |
+
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
|
| 332 |
+
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
// Get the start and end time of the specified segment.
|
| 336 |
+
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
|
| 337 |
+
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
// Get the text of the specified segment.
|
| 341 |
+
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
|
| 342 |
+
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
// Get number of tokens in the specified segment.
|
| 346 |
+
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
|
| 347 |
+
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
// Get the token text of the specified token index in the specified segment.
|
| 351 |
+
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
|
| 352 |
+
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
// Get the token of the specified token index in the specified segment.
|
| 356 |
+
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
| 357 |
+
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// Get token data for the specified token in the specified segment.
|
| 361 |
+
// This contains probabilities, timestamps, etc.
|
| 362 |
+
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
|
| 363 |
+
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
// Get the probability of the specified token in the specified segment.
|
| 367 |
+
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
| 368 |
+
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 372 |
+
// CALLBACKS
|
| 373 |
+
|
| 374 |
+
var (
|
| 375 |
+
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
| 376 |
+
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
| 380 |
+
if fn == nil {
|
| 381 |
+
delete(cbNewSegment, unsafe.Pointer(ctx))
|
| 382 |
+
} else {
|
| 383 |
+
cbNewSegment[unsafe.Pointer(ctx)] = fn
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
| 388 |
+
if fn == nil {
|
| 389 |
+
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
| 390 |
+
} else {
|
| 391 |
+
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
//export callNewSegment
|
| 396 |
+
func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
| 397 |
+
if fn, ok := cbNewSegment[user_data]; ok {
|
| 398 |
+
fn(int(new))
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
//export callEncoderBegin
|
| 403 |
+
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
| 404 |
+
if fn, ok := cbEncoderBegin[user_data]; ok {
|
| 405 |
+
if fn() {
|
| 406 |
+
return C.bool(true)
|
| 407 |
+
} else {
|
| 408 |
+
return C.bool(false)
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
return true
|
| 412 |
+
}
|
bindings/go/whisper_test.go
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package whisper_test
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"os"
|
| 5 |
+
"runtime"
|
| 6 |
+
"testing"
|
| 7 |
+
"time"
|
| 8 |
+
|
| 9 |
+
// Packages
|
| 10 |
+
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
| 11 |
+
wav "github.com/go-audio/wav"
|
| 12 |
+
assert "github.com/stretchr/testify/assert"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const (
|
| 16 |
+
ModelPath = "models/ggml-small.en.bin"
|
| 17 |
+
SamplePath = "samples/jfk.wav"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
func Test_Whisper_000(t *testing.T) {
|
| 21 |
+
assert := assert.New(t)
|
| 22 |
+
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
| 23 |
+
t.Skip("Skipping test, model not found:", ModelPath)
|
| 24 |
+
}
|
| 25 |
+
ctx := whisper.Whisper_init(ModelPath)
|
| 26 |
+
assert.NotNil(ctx)
|
| 27 |
+
ctx.Whisper_free()
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
func Test_Whisper_001(t *testing.T) {
|
| 31 |
+
assert := assert.New(t)
|
| 32 |
+
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
| 33 |
+
t.Skip("Skipping test, model not found:", ModelPath)
|
| 34 |
+
}
|
| 35 |
+
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
| 36 |
+
t.Skip("Skipping test, sample not found:", SamplePath)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Open samples
|
| 40 |
+
fh, err := os.Open(SamplePath)
|
| 41 |
+
assert.NoError(err)
|
| 42 |
+
defer fh.Close()
|
| 43 |
+
|
| 44 |
+
// Read samples
|
| 45 |
+
d := wav.NewDecoder(fh)
|
| 46 |
+
buf, err := d.FullPCMBuffer()
|
| 47 |
+
assert.NoError(err)
|
| 48 |
+
|
| 49 |
+
// Run whisper
|
| 50 |
+
ctx := whisper.Whisper_init(ModelPath)
|
| 51 |
+
assert.NotNil(ctx)
|
| 52 |
+
defer ctx.Whisper_free()
|
| 53 |
+
assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
|
| 54 |
+
|
| 55 |
+
// Print out tokens
|
| 56 |
+
num_segments := ctx.Whisper_full_n_segments()
|
| 57 |
+
assert.GreaterOrEqual(num_segments, 1)
|
| 58 |
+
for i := 0; i < num_segments; i++ {
|
| 59 |
+
str := ctx.Whisper_full_get_segment_text(i)
|
| 60 |
+
assert.NotEmpty(str)
|
| 61 |
+
t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
|
| 62 |
+
t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
|
| 63 |
+
t.Logf("[%6s->%-6s] %q", t0, t1, str)
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
func Test_Whisper_002(t *testing.T) {
|
| 68 |
+
assert := assert.New(t)
|
| 69 |
+
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
| 70 |
+
str := whisper.Whisper_lang_str(i)
|
| 71 |
+
assert.NotEmpty(str)
|
| 72 |
+
t.Log(str)
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
func Test_Whisper_003(t *testing.T) {
|
| 77 |
+
threads := runtime.NumCPU()
|
| 78 |
+
assert := assert.New(t)
|
| 79 |
+
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
| 80 |
+
t.Skip("Skipping test, model not found:", ModelPath)
|
| 81 |
+
}
|
| 82 |
+
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
| 83 |
+
t.Skip("Skipping test, sample not found:", SamplePath)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Open samples
|
| 87 |
+
fh, err := os.Open(SamplePath)
|
| 88 |
+
assert.NoError(err)
|
| 89 |
+
defer fh.Close()
|
| 90 |
+
|
| 91 |
+
// Read samples
|
| 92 |
+
d := wav.NewDecoder(fh)
|
| 93 |
+
buf, err := d.FullPCMBuffer()
|
| 94 |
+
assert.NoError(err)
|
| 95 |
+
|
| 96 |
+
// Make the model
|
| 97 |
+
ctx := whisper.Whisper_init(ModelPath)
|
| 98 |
+
assert.NotNil(ctx)
|
| 99 |
+
defer ctx.Whisper_free()
|
| 100 |
+
|
| 101 |
+
// Get MEL
|
| 102 |
+
assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
|
| 103 |
+
|
| 104 |
+
// Get Languages
|
| 105 |
+
languages, err := ctx.Whisper_lang_auto_detect(0, threads)
|
| 106 |
+
assert.NoError(err)
|
| 107 |
+
for i, p := range languages {
|
| 108 |
+
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
|
| 109 |
+
}
|
| 110 |
+
}
|