djthorpe commited on
Commit
5e456c0
·
unverified ·
1 Parent(s): 9fc9845

bindings : initial import of golang bindings (#287)

Browse files

* Initial import of golang bindings

* Updated makefile rules

* Updated bindings

* Makefile update to add in more tests

bindings/go/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ build
2
+ models
3
+ go.sum
bindings/go/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 David Thorpe
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
bindings/go/Makefile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CMAKE := $(shell which cmake)
2
+ BUILD_DIR := "build"
3
+ MODELS_DIR := "models"
4
+ EXAMPLES_DIR := $(wildcard examples/*)
5
+ C_INCLUDE_PATH := "../.."
6
+
7
+ all: clean whisper examples
8
+
9
+ whisper: mkdir
10
+ @echo Build whisper
11
+ @${CMAKE} -S ../.. -B ${BUILD_DIR} -D BUILD_SHARED_LIBS=off -D WHISPER_NO_AVX2=on
12
+ @${CMAKE} --build ${BUILD_DIR} --target whisper
13
+
14
+ test: model-small whisper
15
+ @go mod tidy
16
+ @go test -v .
17
+ @go test -v ./pkg/whisper/...
18
+
19
+ examples: $(EXAMPLES_DIR)
20
+
21
+ model-small: mkdir examples/go-model-download
22
+ @${BUILD_DIR}/go-model-download -out models small.en
23
+
24
+ $(EXAMPLES_DIR): mkdir whisper
25
+ @echo Build example $(notdir $@)
26
+ @go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
27
+
28
+ mkdir:
29
+ @echo Mkdir ${BUILD_DIR}
30
+ @install -d ${BUILD_DIR}
31
+ @echo Mkdir ${MODELS_DIR}
32
+ @install -d ${MODELS_DIR}
33
+
34
+ clean:
35
+ @echo Clean
36
+ @rm -fr $(BUILD_DIR)
37
+ @go mod tidy
38
+ @go clean
bindings/go/README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Go bindings for Whisper
2
+
3
+ This package provides Go bindings for whisper.cpp. They have been tested on:
4
+
5
+ * Darwin (OS X) 12.6 on x64_64
6
+ * Debian Linux on arm64
7
+ * Fedora Linux on x86_64
8
+
9
+ The "low level" bindings are in the `bindings/go` directory and there is a more
10
+ Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
11
+ is as follows:
12
+
13
+ ```go
14
+ import (
15
+ "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
16
+ )
17
+
18
+ func main() {
19
+ var modelpath string // Path to the model
20
+ var samples []float32 // Samples to process
21
+
22
+ // Load the model
23
+ model, err := whisper.New(modelpath)
24
+ if err != nil {
25
+ panic(err)
26
+ }
27
+ defer model.Close()
28
+
29
+ // Process samples
30
+ context, err := model.NewContext()
31
+ if err != nil {
32
+ panic(err)
33
+ }
34
+ if err := context.Process(samples, nil); err != nil {
35
+ return err
36
+ }
37
+
38
+ // Print out the results
39
+ for {
40
+ segment, err := context.NextSegment()
41
+ if err != nil {
42
+ break
43
+ }
44
+ fmt.Printf("[%6s->%6s] %s\n", segment.Start, segment.End, segment.Text)
45
+ }
46
+ }
47
+ ```
48
+
49
+ ## Building & Testing
50
+
51
+ In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
52
+
53
+ ```bash
54
+ git clone https://github.com/ggerganov/whisper.cpp.git
55
+ cd whisper.cpp/bindings/go
56
+ make test
57
+ ```
58
+
59
+ This will compile a static `libwhisper.a` in a `build` folder, download a model file, then run the tests. To build the examples:
60
+
61
+ ```bash
62
+ make examples
63
+ ```
64
+
65
+ The examples are placed in the `build` directory. Once built, you can download all the models with the following command:
66
+
67
+ ```bash
68
+ ./build/go-model-download -out models
69
+ ```
70
+
71
+ And you can then test a model against samples with the following command:
72
+
73
+ ```bash
74
+ ./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
75
+ ```
76
+
77
+
bindings/go/doc.go ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /*
2
+ github.com/ggerganov/whisper.cpp/bindings/go
3
+ provides a speech-to-text service bindings for the Go programming language.
4
+ */
5
+ package whisper
bindings/go/examples/go-model-download/context.go ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "context"
5
+ "os"
6
+ "os/signal"
7
+ )
8
+
9
+ // ContextForSignal returns a context object which is cancelled when a signal
10
+ // is received. It returns nil if no signal parameter is provided
11
+ func ContextForSignal(signals ...os.Signal) context.Context {
12
+ if len(signals) == 0 {
13
+ return nil
14
+ }
15
+
16
+ ch := make(chan os.Signal)
17
+ ctx, cancel := context.WithCancel(context.Background())
18
+
19
+ // Send message on channel when signal received
20
+ signal.Notify(ch, signals...)
21
+
22
+ // When any signal received, call cancel
23
+ go func() {
24
+ <-ch
25
+ cancel()
26
+ }()
27
+
28
+ // Return success
29
+ return ctx
30
+ }
bindings/go/examples/go-model-download/main.go ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "context"
5
+ "flag"
6
+ "fmt"
7
+ "io"
8
+ "net/http"
9
+ "net/url"
10
+ "os"
11
+ "path/filepath"
12
+ "syscall"
13
+ "time"
14
+ )
15
+
16
+ ///////////////////////////////////////////////////////////////////////////////
17
+ // CONSTANTS
18
+
19
+ const (
20
+ srcUrl = "https://huggingface.co/" // The location of the models
21
+ srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix
22
+ srcExt = ".bin" // Filename extension
23
+ bufSize = 1024 * 64 // Size of the buffer used for downloading the model
24
+ )
25
+
26
+ var (
27
+ // The models which will be downloaded, if no model is specified as an argument
28
+ modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"}
29
+ )
30
+
31
+ var (
32
+ // The output folder. When not set, use current working directory.
33
+ flagOut = flag.String("out", "", "Output folder")
34
+
35
+ // HTTP timeout parameter - will timeout if takes longer than this to download a model
36
+ flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout")
37
+
38
+ // Quiet parameter - will not print progress if set
39
+ flagQuiet = flag.Bool("quiet", false, "Quiet mode")
40
+ )
41
+
42
+ ///////////////////////////////////////////////////////////////////////////////
43
+ // MAIN
44
+
45
+ func main() {
46
+ flag.Usage = func() {
47
+ name := filepath.Base(flag.CommandLine.Name())
48
+ fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name)
49
+ flag.PrintDefaults()
50
+ }
51
+ flag.Parse()
52
+
53
+ // Get output path
54
+ out, err := GetOut()
55
+ if err != nil {
56
+ fmt.Fprintln(os.Stderr, "Error:", err)
57
+ os.Exit(-1)
58
+ }
59
+
60
+ // Create context which quits on SIGINT or SIGQUIT
61
+ ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT)
62
+
63
+ // Progress filehandle
64
+ progress := os.Stdout
65
+ if *flagQuiet {
66
+ progress, err = os.Open(os.DevNull)
67
+ if err != nil {
68
+ fmt.Fprintln(os.Stderr, "Error:", err)
69
+ os.Exit(-1)
70
+ }
71
+ defer progress.Close()
72
+ }
73
+
74
+ // Download models - exit on error or interrupt
75
+ for _, model := range GetModels() {
76
+ url, err := URLForModel(model)
77
+ if err != nil {
78
+ fmt.Fprintln(os.Stderr, "Error:", err)
79
+ continue
80
+ } else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF {
81
+ continue
82
+ } else if err == context.Canceled {
83
+ os.Remove(path)
84
+ fmt.Fprintln(progress, "\nInterrupted")
85
+ break
86
+ } else if err == context.DeadlineExceeded {
87
+ os.Remove(path)
88
+ fmt.Fprintln(progress, "Timeout downloading model")
89
+ continue
90
+ } else {
91
+ os.Remove(path)
92
+ fmt.Fprintln(os.Stderr, "Error:", err)
93
+ break
94
+ }
95
+ }
96
+ }
97
+
98
+ ///////////////////////////////////////////////////////////////////////////////
99
+ // PUBLIC METHODS
100
+
101
+ // GetOut returns the path to the output directory
102
+ func GetOut() (string, error) {
103
+ if *flagOut == "" {
104
+ return os.Getwd()
105
+ }
106
+ if info, err := os.Stat(*flagOut); err != nil {
107
+ return "", err
108
+ } else if !info.IsDir() {
109
+ return "", fmt.Errorf("not a directory: %s", info.Name())
110
+ } else {
111
+ return *flagOut, nil
112
+ }
113
+ }
114
+
115
+ // GetModels returns the list of models to download
116
+ func GetModels() []string {
117
+ if flag.NArg() == 0 {
118
+ return modelNames
119
+ } else {
120
+ return flag.Args()
121
+ }
122
+ }
123
+
124
+ // URLForModel returns the URL for the given model on huggingface.co
125
+ func URLForModel(model string) (string, error) {
126
+ url, err := url.Parse(srcUrl)
127
+ if err != nil {
128
+ return "", err
129
+ } else {
130
+ url.Path = srcPathPrefix + "-" + model + srcExt
131
+ }
132
+ return url.String(), nil
133
+ }
134
+
135
+ // Download downloads the model from the given URL to the given output directory
136
+ func Download(ctx context.Context, p io.Writer, model, out string) (string, error) {
137
+ // Create HTTP client
138
+ client := http.Client{
139
+ Timeout: *flagTimeout,
140
+ }
141
+
142
+ // Initiate the download
143
+ req, err := http.NewRequest("GET", model, nil)
144
+ if err != nil {
145
+ return "", err
146
+ }
147
+ resp, err := client.Do(req)
148
+ if err != nil {
149
+ return "", err
150
+ }
151
+ defer resp.Body.Close()
152
+ if resp.StatusCode != http.StatusOK {
153
+ return "", fmt.Errorf("%s: %s", model, resp.Status)
154
+ }
155
+
156
+ // If output file exists and is the same size as the model, skip
157
+ path := filepath.Join(out, filepath.Base(model))
158
+ if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength {
159
+ fmt.Fprintln(p, "Skipping", model, "as it already exists")
160
+ return "", nil
161
+ }
162
+
163
+ // Create file
164
+ w, err := os.Create(path)
165
+ if err != nil {
166
+ return "", err
167
+ }
168
+ defer w.Close()
169
+
170
+ // Report
171
+ fmt.Fprintln(p, "Downloading", model, "to", out)
172
+
173
+ // Progressively download the model
174
+ data := make([]byte, bufSize)
175
+ count, pct := int64(0), int64(0)
176
+ ticker := time.NewTicker(5 * time.Second)
177
+ for {
178
+ select {
179
+ case <-ctx.Done():
180
+ // Cancelled, return error
181
+ return path, ctx.Err()
182
+ case <-ticker.C:
183
+ pct = DownloadReport(p, pct, count, resp.ContentLength)
184
+ default:
185
+ // Read body
186
+ n, err := resp.Body.Read(data)
187
+ if err != nil {
188
+ DownloadReport(p, pct, count, resp.ContentLength)
189
+ return path, err
190
+ } else if m, err := w.Write(data[:n]); err != nil {
191
+ return path, err
192
+ } else {
193
+ count += int64(m)
194
+ }
195
+ }
196
+ }
197
+ }
198
+
199
+ // Report periodically reports the download progress when percentage changes
200
+ func DownloadReport(w io.Writer, pct, count, total int64) int64 {
201
+ pct_ := count * 100 / total
202
+ if pct_ > pct {
203
+ fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_)
204
+ }
205
+ return pct_
206
+ }
bindings/go/examples/go-whisper/flags.go ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "flag"
5
+ )
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // TYPES
9
+
10
+ type Flags struct {
11
+ *flag.FlagSet
12
+ }
13
+
14
+ ///////////////////////////////////////////////////////////////////////////////
15
+ // LIFECYCLE
16
+
17
+ func NewFlags(name string, args []string) (*Flags, error) {
18
+ flags := &Flags{
19
+ FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
20
+ }
21
+
22
+ // Register the command line arguments
23
+ registerFlags(flags)
24
+
25
+ // Parse command line
26
+ if err := flags.Parse(args); err != nil {
27
+ return nil, err
28
+ }
29
+
30
+ // Return success
31
+ return flags, nil
32
+ }
33
+
34
+ ///////////////////////////////////////////////////////////////////////////////
35
+ // PUBLIC METHODS
36
+
37
+ func (flags *Flags) GetModel() string {
38
+ return flags.Lookup("model").Value.String()
39
+ }
40
+
41
+ func (flags *Flags) GetLanguage() string {
42
+ return flags.Lookup("language").Value.String()
43
+ }
44
+
45
+ func (flags *Flags) IsSpeedup() bool {
46
+ return flags.Lookup("speedup").Value.String() == "true"
47
+ }
48
+
49
+ func (flags *Flags) IsTokens() bool {
50
+ return flags.Lookup("tokens").Value.String() == "true"
51
+ }
52
+
53
+ ///////////////////////////////////////////////////////////////////////////////
54
+ // PRIVATE METHODS
55
+
56
+ func registerFlags(flag *Flags) {
57
+ flag.String("model", "", "Path to the model file")
58
+ flag.String("language", "", "Language")
59
+ flag.Bool("speedup", false, "Enable speedup")
60
+ flag.Bool("tokens", false, "Display tokens")
61
+ }
bindings/go/examples/go-whisper/main.go ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "flag"
5
+ "fmt"
6
+ "os"
7
+ "path/filepath"
8
+
9
+ // Packages
10
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
11
+ )
12
+
13
+ func main() {
14
+ flags, err := NewFlags(filepath.Base(os.Args[0]), os.Args[1:])
15
+ if err == flag.ErrHelp {
16
+ os.Exit(0)
17
+ } else if err != nil {
18
+ fmt.Fprintln(os.Stderr, err)
19
+ os.Exit(1)
20
+ } else if flags.GetModel() == "" {
21
+ fmt.Fprintln(os.Stderr, "Use -model flag to specify which model file to use")
22
+ os.Exit(1)
23
+ } else if flags.NArg() == 0 {
24
+ fmt.Fprintln(os.Stderr, "No input files specified")
25
+ os.Exit(1)
26
+ }
27
+
28
+ // Load model
29
+ model, err := whisper.New(flags.GetModel())
30
+ if err != nil {
31
+ fmt.Fprintln(os.Stderr, err)
32
+ os.Exit(1)
33
+ }
34
+ defer model.Close()
35
+
36
+ // Process files
37
+ for _, filename := range flags.Args() {
38
+ fmt.Println("Processing", filename)
39
+ if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
40
+ fmt.Fprintln(os.Stderr, err)
41
+ continue
42
+ }
43
+ }
44
+ }
bindings/go/examples/go-whisper/process.go ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "fmt"
5
+ "io"
6
+ "os"
7
+ "time"
8
+
9
+ // Package imports
10
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
11
+ wav "github.com/go-audio/wav"
12
+ )
13
+
14
+ func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
15
+ var data []float32
16
+
17
+ // Create processing context
18
+ context, err := model.NewContext()
19
+ if err != nil {
20
+ return err
21
+ }
22
+
23
+ // Open the file
24
+ fh, err := os.Open(path)
25
+ if err != nil {
26
+ return err
27
+ }
28
+ defer fh.Close()
29
+
30
+ // Decode the WAV file
31
+ dec := wav.NewDecoder(fh)
32
+ if buf, err := dec.FullPCMBuffer(); err != nil {
33
+ return err
34
+ } else if dec.SampleRate != whisper.SampleRate {
35
+ return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
36
+ } else if dec.NumChans != 1 {
37
+ return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
38
+ } else {
39
+ data = buf.AsFloat32Buffer().Data
40
+ }
41
+
42
+ // Set the parameters
43
+ var cb whisper.SegmentCallback
44
+ if lang != "" {
45
+ if err := context.SetLanguage(lang); err != nil {
46
+ return err
47
+ }
48
+ }
49
+ if speedup {
50
+ context.SetSpeedup(true)
51
+ }
52
+ if tokens {
53
+ cb = func(segment whisper.Segment) {
54
+ fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
55
+ for _, token := range segment.Tokens {
56
+ fmt.Printf("%q ", token.Text)
57
+ }
58
+ fmt.Println("")
59
+ }
60
+ }
61
+
62
+ // Process the data
63
+ if err := context.Process(data, cb); err != nil {
64
+ return err
65
+ }
66
+
67
+ // Print out the results
68
+ for {
69
+ segment, err := context.NextSegment()
70
+ if err == io.EOF {
71
+ break
72
+ } else if err != nil {
73
+ return err
74
+ }
75
+ fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
76
+ }
77
+
78
+ // Return success
79
+ return nil
80
+ }
bindings/go/go.mod ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module github.com/ggerganov/whisper.cpp/bindings/go
2
+
3
+ go 1.19
4
+
5
+ require (
6
+ github.com/go-audio/wav v1.1.0
7
+ github.com/stretchr/testify v1.8.1
8
+ )
9
+
10
+ require (
11
+ github.com/davecgh/go-spew v1.1.1 // indirect
12
+ github.com/go-audio/audio v1.0.0 // indirect
13
+ github.com/go-audio/riff v1.0.0 // indirect
14
+ github.com/pmezard/go-difflib v1.0.0 // indirect
15
+ gopkg.in/yaml.v3 v3.0.1 // indirect
16
+ )
bindings/go/params.go ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper
2
+
3
+ // This file defines the whisper_token, whisper_token_data and whisper_full_params
4
+ // structures, which are used by the whisper_full() function.
5
+
6
+ import (
7
+ "fmt"
8
+ )
9
+
10
+ ///////////////////////////////////////////////////////////////////////////////
11
+ // CGO
12
+
13
+ /*
14
+ #include <whisper.h>
15
+ */
16
+ import "C"
17
+
18
+ ///////////////////////////////////////////////////////////////////////////////
19
+ // PUBLIC METHODS
20
+
21
+ func (p *Params) SetTranslate(v bool) {
22
+ p.translate = toBool(v)
23
+ }
24
+
25
+ func (p *Params) SetNoContext(v bool) {
26
+ p.no_context = toBool(v)
27
+ }
28
+
29
+ func (p *Params) SetSingleSegment(v bool) {
30
+ p.single_segment = toBool(v)
31
+ }
32
+
33
+ func (p *Params) SetPrintSpecial(v bool) {
34
+ p.print_special = toBool(v)
35
+ }
36
+
37
+ func (p *Params) SetPrintProgress(v bool) {
38
+ p.print_progress = toBool(v)
39
+ }
40
+
41
+ func (p *Params) SetPrintRealtime(v bool) {
42
+ p.print_realtime = toBool(v)
43
+ }
44
+
45
+ func (p *Params) SetPrintTimestamps(v bool) {
46
+ p.print_timestamps = toBool(v)
47
+ }
48
+
49
+ func (p *Params) SetSpeedup(v bool) {
50
+ p.speed_up = toBool(v)
51
+ }
52
+
53
+ func (p *Params) SetLanguage(lang int) error {
54
+ str := C.whisper_lang_str(C.int(lang))
55
+ if str == nil {
56
+ return ErrInvalidLanguage
57
+ } else {
58
+ p.language = str
59
+ }
60
+ return nil
61
+ }
62
+
63
+ func (p *Params) Language() int {
64
+ if p.language == nil {
65
+ return -1
66
+ }
67
+ return int(C.whisper_lang_id(p.language))
68
+ }
69
+
70
+ func (p *Params) SetThreads(threads int) {
71
+ p.n_threads = C.int(threads)
72
+ }
73
+
74
+ func (p *Params) SetOffset(offset_ms int) {
75
+ p.offset_ms = C.int(offset_ms)
76
+ }
77
+
78
+ func (p *Params) SetDuration(duration_ms int) {
79
+ p.duration_ms = C.int(duration_ms)
80
+ }
81
+
82
+ ///////////////////////////////////////////////////////////////////////////////
83
+ // PRIVATE METHODS
84
+
85
+ func toBool(v bool) C.bool {
86
+ if v {
87
+ return C.bool(true)
88
+ }
89
+ return C.bool(false)
90
+ }
91
+
92
+ ///////////////////////////////////////////////////////////////////////////////
93
+ // STRINGIFY
94
+
95
+ func (p *Params) String() string {
96
+ str := "<whisper.params"
97
+ str += fmt.Sprintf(" strategy=%v", p.strategy)
98
+ str += fmt.Sprintf(" n_threads=%d", p.n_threads)
99
+ if p.language != nil {
100
+ str += fmt.Sprintf(" language=%s", C.GoString(p.language))
101
+ }
102
+ str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
103
+ str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
104
+ str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
105
+ if p.translate {
106
+ str += " translate"
107
+ }
108
+ if p.no_context {
109
+ str += " no_context"
110
+ }
111
+ if p.single_segment {
112
+ str += " single_segment"
113
+ }
114
+ if p.print_special {
115
+ str += " print_special"
116
+ }
117
+ if p.print_progress {
118
+ str += " print_progress"
119
+ }
120
+ if p.print_realtime {
121
+ str += " print_realtime"
122
+ }
123
+ if p.print_timestamps {
124
+ str += " print_timestamps"
125
+ }
126
+ if p.token_timestamps {
127
+ str += " token_timestamps"
128
+ }
129
+ if p.speed_up {
130
+ str += " speed_up"
131
+ }
132
+
133
+ return str + ">"
134
+ }
bindings/go/pkg/whisper/consts.go ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper
2
+
3
+ import (
4
+ "errors"
5
+
6
+ // Bindings
7
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
8
+ )
9
+
10
+ ///////////////////////////////////////////////////////////////////////////////
11
+ // ERRORS
12
+
13
+ var (
14
+ ErrUnableToLoadModel = errors.New("unable to load model")
15
+ ErrInternalAppError = errors.New("internal application error")
16
+ ErrProcessingFailed = errors.New("processing failed")
17
+ ErrUnsupportedLanguage = errors.New("unsupported language")
18
+ )
19
+
20
+ ///////////////////////////////////////////////////////////////////////////////
21
+ // CONSTANTS
22
+
23
+ // SampleRate is the sample rate of the audio data.
24
+ const SampleRate = whisper.SampleRate
25
+
26
+ // SampleBits is the number of bytes per sample.
27
+ const SampleBits = whisper.SampleBits
bindings/go/pkg/whisper/context.go ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper
2
+
3
+ import (
4
+ "io"
5
+ "strings"
6
+ "time"
7
+
8
+ // Bindings
9
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
10
+ )
11
+
12
+ ///////////////////////////////////////////////////////////////////////////////
13
+ // TYPES
14
+
15
+ type context struct {
16
+ n int
17
+ model *model
18
+ params whisper.Params
19
+ }
20
+
21
+ // Make sure context adheres to the interface
22
+ var _ Context = (*context)(nil)
23
+
24
+ ///////////////////////////////////////////////////////////////////////////////
25
+ // LIFECYCLE
26
+
27
+ func NewContext(model *model, params whisper.Params) (Context, error) {
28
+ context := new(context)
29
+ context.model = model
30
+ context.params = params
31
+
32
+ // Return success
33
+ return context, nil
34
+ }
35
+
36
+ ///////////////////////////////////////////////////////////////////////////////
37
+ // PUBLIC METHODS
38
+
39
+ // Set the language to use for speech recognition.
40
+ func (context *context) SetLanguage(lang string) error {
41
+ if context.model.ctx == nil {
42
+ return ErrInternalAppError
43
+ }
44
+ if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
45
+ return ErrUnsupportedLanguage
46
+ } else if err := context.params.SetLanguage(id); err != nil {
47
+ return err
48
+ }
49
+ // Return success
50
+ return nil
51
+ }
52
+
53
+ // Get language
54
+ func (context *context) Language() string {
55
+ return whisper.Whisper_lang_str(context.params.Language())
56
+ }
57
+
58
+ // Set speedup flag
59
+ func (context *context) SetSpeedup(v bool) {
60
+ context.params.SetSpeedup(v)
61
+ }
62
+
63
+ // Process new sample data and return any errors
64
+ func (context *context) Process(data []float32, cb SegmentCallback) error {
65
+ if context.model.ctx == nil {
66
+ return ErrInternalAppError
67
+ }
68
+ // If the callback is defined then we force on single_segment mode
69
+ if cb != nil {
70
+ context.params.SetSingleSegment(true)
71
+ }
72
+
73
+ // We don't do parallel processing at the moment
74
+ processors := 0
75
+ if processors > 1 {
76
+ if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
77
+ if cb != nil {
78
+ num_segments := context.model.ctx.Whisper_full_n_segments()
79
+ s0 := num_segments - new
80
+ for i := s0; i < num_segments; i++ {
81
+ cb(toSegment(context.model.ctx, i))
82
+ }
83
+ }
84
+ }); err != nil {
85
+ return err
86
+ }
87
+ } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
88
+ if cb != nil {
89
+ num_segments := context.model.ctx.Whisper_full_n_segments()
90
+ s0 := num_segments - new
91
+ for i := s0; i < num_segments; i++ {
92
+ cb(toSegment(context.model.ctx, i))
93
+ }
94
+ }
95
+ }); err != nil {
96
+ return err
97
+ }
98
+
99
+ // Return success
100
+ return nil
101
+ }
102
+
103
+ // Return the next segment of tokens
104
+ func (context *context) NextSegment() (Segment, error) {
105
+ if context.model.ctx == nil {
106
+ return Segment{}, ErrInternalAppError
107
+ }
108
+ if context.n >= context.model.ctx.Whisper_full_n_segments() {
109
+ return Segment{}, io.EOF
110
+ }
111
+
112
+ // Populate result
113
+ result := toSegment(context.model.ctx, context.n)
114
+
115
+ // Increment the cursor
116
+ context.n++
117
+
118
+ // Return success
119
+ return result, nil
120
+ }
121
+
122
+ ///////////////////////////////////////////////////////////////////////////////
123
+ // PRIVATE METHODS
124
+
125
+ func toSegment(ctx *whisper.Context, n int) Segment {
126
+ return Segment{
127
+ Num: n,
128
+ Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
129
+ Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
130
+ End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
131
+ Tokens: toTokens(ctx, n),
132
+ }
133
+ }
134
+
135
+ func toTokens(ctx *whisper.Context, n int) []Token {
136
+ result := make([]Token, ctx.Whisper_full_n_tokens(n))
137
+ for i := 0; i < len(result); i++ {
138
+ result[i] = Token{
139
+ Id: int(ctx.Whisper_full_get_token_id(n, i)),
140
+ Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
141
+ P: ctx.Whisper_full_get_token_p(n, i),
142
+ }
143
+ }
144
+ return result
145
+ }
bindings/go/pkg/whisper/context_test.go ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper_test
2
+
3
+ import (
4
+ "os"
5
+ "testing"
6
+
7
+ // Packages
8
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
9
+ assert "github.com/stretchr/testify/assert"
10
+ )
11
+
12
+ const (
13
+ ModelPath = "../../models/ggml-tiny.bin"
14
+ SamplePath = "../../samples/jfk.wav"
15
+ )
16
+
17
+ func Test_Whisper_000(t *testing.T) {
18
+ assert := assert.New(t)
19
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
20
+ t.Skip("Skipping test, model not found:", ModelPath)
21
+ }
22
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
23
+ t.Skip("Skipping test, sample not found:", SamplePath)
24
+ }
25
+
26
+ // Load model
27
+ model, err := whisper.New(ModelPath)
28
+ assert.NoError(err)
29
+ assert.NotNil(model)
30
+ assert.NoError(model.Close())
31
+
32
+ t.Log("languages=", model.Languages())
33
+ }
34
+
35
+ func Test_Whisper_001(t *testing.T) {
36
+ assert := assert.New(t)
37
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
38
+ t.Skip("Skipping test, model not found:", ModelPath)
39
+ }
40
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
41
+ t.Skip("Skipping test, sample not found:", SamplePath)
42
+ }
43
+
44
+ // Load model
45
+ model, err := whisper.New(ModelPath)
46
+ assert.NoError(err)
47
+ assert.NotNil(model)
48
+ defer model.Close()
49
+
50
+ // Get context for decoding
51
+ ctx, err := model.NewContext()
52
+ assert.NoError(err)
53
+ assert.NotNil(ctx)
54
+
55
+ }
bindings/go/pkg/whisper/doc.go ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /*
2
+ This is the higher-level speech-to-text whisper.cpp API for go
3
+ */
4
+ package whisper
bindings/go/pkg/whisper/interface.go ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper
2
+
3
+ import (
4
+ "io"
5
+ "time"
6
+ )
7
+
8
+ ///////////////////////////////////////////////////////////////////////////////
9
+ // TYPES
10
+
11
+ // SegmentCallback is the callback function for processing segments in real
12
+ // time. It is called during the Process function
13
+ type SegmentCallback func(Segment)
14
+
15
+ // Model is the interface to a whisper model. Create a new model with the
16
+ // function whisper.New(string)
17
+ type Model interface {
18
+ io.Closer
19
+
20
+ // Return a new speech-to-text context.
21
+ NewContext() (Context, error)
22
+
23
+ // Return all languages supported.
24
+ Languages() []string
25
+ }
26
+
27
+ // Context is the speach recognition context.
28
+ type Context interface {
29
+ SetLanguage(string) error // Set the language to use for speech recognition.
30
+ Language() string // Get language
31
+ SetSpeedup(bool) // Set speedup flag
32
+
33
+ // Process mono audio data and return any errors.
34
+ // If defined, newly generated segments are passed to the
35
+ // callback function during processing.
36
+ Process([]float32, SegmentCallback) error
37
+
38
+ // After process is called, return segments until the end of the stream
39
+ // is reached, when io.EOF is returned.
40
+ NextSegment() (Segment, error)
41
+ }
42
+
43
+ // Segment is the text result of a speech recognition.
44
+ type Segment struct {
45
+ // Segment Number
46
+ Num int
47
+
48
+ // Time beginning and end timestamps for the segment.
49
+ Start, End time.Duration
50
+
51
+ // The text of the segment.
52
+ Text string
53
+
54
+ // The tokens of the segment.
55
+ Tokens []Token
56
+ }
57
+
58
+ // Token is a text or special token
59
+ type Token struct {
60
+ Id int
61
+ Text string
62
+ P float32
63
+ }
bindings/go/pkg/whisper/model.go ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper
2
+
3
+ import (
4
+ "fmt"
5
+ "os"
6
+ "runtime"
7
+
8
+ // Bindings
9
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
10
+ )
11
+
12
+ ///////////////////////////////////////////////////////////////////////////////
13
+ // TYPES
14
+
15
+ type model struct {
16
+ path string
17
+ ctx *whisper.Context
18
+ }
19
+
20
+ // Make sure model adheres to the interface
21
+ var _ Model = (*model)(nil)
22
+
23
+ ///////////////////////////////////////////////////////////////////////////////
24
+ // LIFECYCLE
25
+
26
+ func New(path string) (*model, error) {
27
+ model := new(model)
28
+ if _, err := os.Stat(path); err != nil {
29
+ return nil, err
30
+ } else if ctx := whisper.Whisper_init(path); ctx == nil {
31
+ return nil, ErrUnableToLoadModel
32
+ } else {
33
+ model.ctx = ctx
34
+ model.path = path
35
+ }
36
+
37
+ // Return success
38
+ return model, nil
39
+ }
40
+
41
+ func (model *model) Close() error {
42
+ if model.ctx != nil {
43
+ model.ctx.Whisper_free()
44
+ }
45
+
46
+ // Release resources
47
+ model.ctx = nil
48
+
49
+ // Return success
50
+ return nil
51
+ }
52
+
53
+ ///////////////////////////////////////////////////////////////////////////////
54
+ // STRINGIFY
55
+
56
+ func (model *model) String() string {
57
+ str := "<whisper.model"
58
+ if model.ctx != nil {
59
+ str += fmt.Sprintf(" model=%q", model.path)
60
+ }
61
+ return str + ">"
62
+ }
63
+
64
+ ///////////////////////////////////////////////////////////////////////////////
65
+ // PUBLIC METHODS
66
+
67
+ // Return all recognized languages. Initially it is set to auto-detect
68
+ func (model *model) Languages() []string {
69
+ result := make([]string, 0, whisper.Whisper_lang_max_id())
70
+ for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
71
+ str := whisper.Whisper_lang_str(i)
72
+ if model.ctx.Whisper_lang_id(str) >= 0 {
73
+ result = append(result, str)
74
+ }
75
+ }
76
+ return result
77
+ }
78
+
79
+ func (model *model) NewContext() (Context, error) {
80
+ if model.ctx == nil {
81
+ return nil, ErrInternalAppError
82
+ }
83
+
84
+ // Create new context
85
+ params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
86
+ params.SetTranslate(false)
87
+ params.SetPrintSpecial(false)
88
+ params.SetPrintProgress(false)
89
+ params.SetPrintRealtime(false)
90
+ params.SetPrintTimestamps(false)
91
+ params.SetThreads(runtime.NumCPU())
92
+
93
+ // Return new context
94
+ return NewContext(model, params)
95
+ }
bindings/go/samples/jfk.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59dfb9a4acb36fe2a2affc14bacbee2920ff435cb13cc314a08c13f66ba7860e
3
+ size 352078
bindings/go/whisper.go ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper
2
+
3
+ import (
4
+ "errors"
5
+ "unsafe"
6
+ )
7
+
8
+ ///////////////////////////////////////////////////////////////////////////////
9
+ // CGO
10
+
11
+ /*
12
+ #cgo CFLAGS: -I${SRCDIR}/../..
13
+ #cgo LDFLAGS: -L${SRCDIR}/build -lwhisper -lm -lstdc++
14
+ #cgo darwin LDFLAGS: -framework Accelerate
15
+ #include <whisper.h>
16
+ #include <stdlib.h>
17
+
18
+ extern void callNewSegment(void* user_data, int new);
19
+ extern bool callEncoderBegin(void* user_data);
20
+
21
+ // Text segment callback
22
+ // Called on every newly generated text segment
23
+ // Use the whisper_full_...() functions to obtain the text segments
24
+ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
25
+ if(user_data != NULL && ctx != NULL) {
26
+ callNewSegment(user_data, n_new);
27
+ }
28
+ }
29
+
30
+ // Encoder begin callback
31
+ // If not NULL, called before the encoder starts
32
+ // If it returns false, the computation is aborted
33
+ static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
34
+ if(user_data != NULL && ctx != NULL) {
35
+ return callEncoderBegin(user_data);
36
+ }
37
+ return false;
38
+ }
39
+
40
+ // Get default parameters and set callbacks
41
+ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_context* ctx, enum whisper_sampling_strategy strategy) {
42
+ struct whisper_full_params params = whisper_full_default_params(strategy);
43
+ params.new_segment_callback = whisper_new_segment_cb;
44
+ params.new_segment_callback_user_data = (void*)(ctx);
45
+ params.encoder_begin_callback = whisper_encoder_begin_cb;
46
+ params.encoder_begin_callback_user_data = (void*)(ctx);
47
+ return params;
48
+ }
49
+ */
50
+ import "C"
51
+
52
+ ///////////////////////////////////////////////////////////////////////////////
53
+ // TYPES
54
+
55
+ type (
56
+ Context C.struct_whisper_context
57
+ Token C.whisper_token
58
+ TokenData C.struct_whisper_token_data
59
+ SamplingStrategy C.enum_whisper_sampling_strategy
60
+ Params C.struct_whisper_full_params
61
+ )
62
+
63
+ ///////////////////////////////////////////////////////////////////////////////
64
+ // GLOBALS
65
+
66
+ const (
67
+ SAMPLING_GREEDY SamplingStrategy = C.WHISPER_SAMPLING_GREEDY
68
+ SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH
69
+ )
70
+
71
+ const (
72
+ SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second
73
+ SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits
74
+ NumFFT = C.WHISPER_N_FFT
75
+ NumMEL = C.WHISPER_N_MEL
76
+ HopLength = C.WHISPER_HOP_LENGTH
77
+ ChunkSize = C.WHISPER_CHUNK_SIZE
78
+ )
79
+
80
+ var (
81
+ ErrTokenizerFailed = errors.New("whisper_tokenize failed")
82
+ ErrAutoDetectFailed = errors.New("whisper_lang_auto_detect failed")
83
+ ErrConversionFailed = errors.New("whisper_convert failed")
84
+ ErrInvalidLanguage = errors.New("invalid language")
85
+ )
86
+
87
+ ///////////////////////////////////////////////////////////////////////////////
88
+ // PUBLIC METHODS
89
+
90
+ // Allocates all memory needed for the model and loads the model from the given file.
91
+ // Returns NULL on failure.
92
+ func Whisper_init(path string) *Context {
93
+ cPath := C.CString(path)
94
+ defer C.free(unsafe.Pointer(cPath))
95
+ if ctx := C.whisper_init(cPath); ctx != nil {
96
+ return (*Context)(ctx)
97
+ } else {
98
+ return nil
99
+ }
100
+ }
101
+
102
+ // Frees all memory allocated by the model.
103
+ func (ctx *Context) Whisper_free() {
104
+ C.whisper_free((*C.struct_whisper_context)(ctx))
105
+ }
106
+
107
+ // Convert RAW PCM audio to log mel spectrogram.
108
+ // The resulting spectrogram is stored inside the provided whisper context.
109
+ func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
110
+ if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
111
+ return nil
112
+ } else {
113
+ return ErrConversionFailed
114
+ }
115
+ }
116
+
117
+ // This can be used to set a custom log mel spectrogram inside the provided whisper context.
118
+ // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
119
+ // n_mel must be 80
120
+ func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
121
+ if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
122
+ return nil
123
+ } else {
124
+ return ErrConversionFailed
125
+ }
126
+ }
127
+
128
+ // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
129
+ // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
130
+ // offset can be used to specify the offset of the first frame in the spectrogram.
131
+ func (ctx *Context) Whisper_encode(offset, threads int) error {
132
+ if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
133
+ return nil
134
+ } else {
135
+ return ErrConversionFailed
136
+ }
137
+ }
138
+
139
+ // Run the Whisper decoder to obtain the logits and probabilities for the next token.
140
+ // Make sure to call whisper_encode() first.
141
+ // tokens + n_tokens is the provided context for the decoder.
142
+ // n_past is the number of tokens to use from previous decoder calls.
143
+ func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
144
+ if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
145
+ return nil
146
+ } else {
147
+ return ErrConversionFailed
148
+ }
149
+ }
150
+
151
+ // whisper_sample_best() returns the token with the highest probability
152
+ func (ctx *Context) Whisper_sample_best() TokenData {
153
+ return TokenData(C.whisper_sample_best((*C.struct_whisper_context)(ctx)))
154
+ }
155
+
156
+ // whisper_sample_timestamp() returns the most probable timestamp token
157
+ func (ctx *Context) Whisper_sample_timestamp(is_initial bool) TokenData {
158
+ return TokenData(C.whisper_sample_timestamp((*C.struct_whisper_context)(ctx), C.bool(is_initial)))
159
+ }
160
+
161
+ // Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
162
+ // Returns the number of tokens on success
163
+ func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
164
+ cText := C.CString(text)
165
+ defer C.free(unsafe.Pointer(cText))
166
+ if n := C.whisper_tokenize((*C.struct_whisper_context)(ctx), cText, (*C.whisper_token)(&tokens[0]), C.int(len(tokens))); n >= 0 {
167
+ return int(n), nil
168
+ } else {
169
+ return 0, ErrTokenizerFailed
170
+ }
171
+ }
172
+
173
+ // Return the id of the specified language, returns -1 if not found
174
+ func (ctx *Context) Whisper_lang_id(lang string) int {
175
+ return int(C.whisper_lang_id(C.CString(lang)))
176
+ }
177
+
178
+ // Largest language id (i.e. number of available languages - 1)
179
+ func Whisper_lang_max_id() int {
180
+ return int(C.whisper_lang_max_id())
181
+ }
182
+
183
+ // Return the short string of the specified language id (e.g. 2 -> "de"),
184
+ // returns empty string if not found
185
+ func Whisper_lang_str(id int) string {
186
+ return C.GoString(C.whisper_lang_str(C.int(id)))
187
+ }
188
+
189
+ // Use mel data at offset_ms to try and auto-detect the spoken language
190
+ // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
191
+ // Returns the probabilities of all languages.
192
+ // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
193
+ func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
194
+ probs := make([]float32, Whisper_lang_max_id()+1)
195
+ if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
196
+ return nil, ErrAutoDetectFailed
197
+ } else {
198
+ return probs, nil
199
+ }
200
+ }
201
+
202
+ func (ctx *Context) Whisper_n_len() int {
203
+ return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
204
+ }
205
+
206
+ func (ctx *Context) Whisper_n_vocab() int {
207
+ return int(C.whisper_n_vocab((*C.struct_whisper_context)(ctx)))
208
+ }
209
+
210
+ func (ctx *Context) Whisper_n_text_ctx() int {
211
+ return int(C.whisper_n_text_ctx((*C.struct_whisper_context)(ctx)))
212
+ }
213
+
214
+ func (ctx *Context) Whisper_is_multilingual() int {
215
+ return int(C.whisper_is_multilingual((*C.struct_whisper_context)(ctx)))
216
+ }
217
+
218
+ // The probabilities for the next token
219
+ //func (ctx *Whisper_context) Whisper_get_probs() []float32 {
220
+ // return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_probs((*C.struct_whisper_context)(ctx))))[:ctx.Whisper_n_vocab()]
221
+ //}
222
+
223
+ // Token Id -> String. Uses the vocabulary in the provided context
224
+ func (ctx *Context) Whisper_token_to_str(token Token) string {
225
+ return C.GoString(C.whisper_token_to_str((*C.struct_whisper_context)(ctx), C.whisper_token(token)))
226
+ }
227
+
228
+ // Special tokens
229
+ func (ctx *Context) Whisper_token_eot() Token {
230
+ return Token(C.whisper_token_eot((*C.struct_whisper_context)(ctx)))
231
+ }
232
+
233
+ // Special tokens
234
+ func (ctx *Context) Whisper_token_sot() Token {
235
+ return Token(C.whisper_token_sot((*C.struct_whisper_context)(ctx)))
236
+ }
237
+
238
+ // Special tokens
239
+ func (ctx *Context) Whisper_token_prev() Token {
240
+ return Token(C.whisper_token_prev((*C.struct_whisper_context)(ctx)))
241
+ }
242
+
243
+ // Special tokens
244
+ func (ctx *Context) Whisper_token_solm() Token {
245
+ return Token(C.whisper_token_solm((*C.struct_whisper_context)(ctx)))
246
+ }
247
+
248
+ // Special tokens
249
+ func (ctx *Context) Whisper_token_not() Token {
250
+ return Token(C.whisper_token_not((*C.struct_whisper_context)(ctx)))
251
+ }
252
+
253
+ // Special tokens
254
+ func (ctx *Context) Whisper_token_beg() Token {
255
+ return Token(C.whisper_token_beg((*C.struct_whisper_context)(ctx)))
256
+ }
257
+
258
+ // Special tokens
259
+ func (ctx *Context) Whisper_token_lang(lang_id int) Token {
260
+ return Token(C.whisper_token_lang((*C.struct_whisper_context)(ctx), C.int(lang_id)))
261
+ }
262
+
263
+ // Task tokens
264
+ func Whisper_token_translate() Token {
265
+ return Token(C.whisper_token_translate())
266
+ }
267
+
268
+ // Task tokens
269
+ func Whisper_token_transcribe() Token {
270
+ return Token(C.whisper_token_transcribe())
271
+ }
272
+
273
+ // Performance information
274
+ func (ctx *Context) Whisper_print_timings() {
275
+ C.whisper_print_timings((*C.struct_whisper_context)(ctx))
276
+ }
277
+
278
+ // Performance information
279
+ func (ctx *Context) Whisper_reset_timings() {
280
+ C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
281
+ }
282
+
283
+ // Print system information
284
+ func Whisper_print_system_info() string {
285
+ return C.GoString(C.whisper_print_system_info())
286
+ }
287
+
288
+ // Return default parameters for a strategy
289
+ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
290
+ // Get default parameters
291
+ return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
292
+ }
293
+
294
+ // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
295
+ // Uses the specified decoding strategy to obtain the text.
296
+ func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
297
+ registerEncoderBeginCallback(ctx, encoderBeginCallback)
298
+ registerNewSegmentCallback(ctx, newSegmentCallback)
299
+ defer registerEncoderBeginCallback(ctx, nil)
300
+ defer registerNewSegmentCallback(ctx, nil)
301
+ if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
302
+ return nil
303
+ } else {
304
+ return ErrConversionFailed
305
+ }
306
+ }
307
+
308
+ // Split the input audio in chunks and process each chunk separately using whisper_full()
309
+ // It seems this approach can offer some speedup in some cases.
310
+ // However, the transcription accuracy can be worse at the beginning and end of each chunk.
311
+ func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
312
+ registerEncoderBeginCallback(ctx, encoderBeginCallback)
313
+ registerNewSegmentCallback(ctx, newSegmentCallback)
314
+ defer registerEncoderBeginCallback(ctx, nil)
315
+ defer registerNewSegmentCallback(ctx, nil)
316
+
317
+ if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
318
+ return nil
319
+ } else {
320
+ return ErrConversionFailed
321
+ }
322
+ }
323
+
324
+ // Number of generated text segments.
325
+ // A segment can be a few words, a sentence, or even a paragraph.
326
+ func (ctx *Context) Whisper_full_n_segments() int {
327
+ return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
328
+ }
329
+
330
+ // Get the start and end time of the specified segment.
331
+ func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
332
+ return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
333
+ }
334
+
335
+ // Get the start and end time of the specified segment.
336
+ func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
337
+ return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
338
+ }
339
+
340
+ // Get the text of the specified segment.
341
+ func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
342
+ return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
343
+ }
344
+
345
+ // Get number of tokens in the specified segment.
346
+ func (ctx *Context) Whisper_full_n_tokens(segment int) int {
347
+ return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
348
+ }
349
+
350
+ // Get the token text of the specified token index in the specified segment.
351
+ func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
352
+ return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
353
+ }
354
+
355
+ // Get the token of the specified token index in the specified segment.
356
+ func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
357
+ return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
358
+ }
359
+
360
+ // Get token data for the specified token in the specified segment.
361
+ // This contains probabilities, timestamps, etc.
362
+ func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
363
+ return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
364
+ }
365
+
366
+ // Get the probability of the specified token in the specified segment.
367
+ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
368
+ return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
369
+ }
370
+
371
+ ///////////////////////////////////////////////////////////////////////////////
372
+ // CALLBACKS
373
+
374
+ var (
375
+ cbNewSegment = make(map[unsafe.Pointer]func(int))
376
+ cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
377
+ )
378
+
379
+ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
380
+ if fn == nil {
381
+ delete(cbNewSegment, unsafe.Pointer(ctx))
382
+ } else {
383
+ cbNewSegment[unsafe.Pointer(ctx)] = fn
384
+ }
385
+ }
386
+
387
+ func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
388
+ if fn == nil {
389
+ delete(cbEncoderBegin, unsafe.Pointer(ctx))
390
+ } else {
391
+ cbEncoderBegin[unsafe.Pointer(ctx)] = fn
392
+ }
393
+ }
394
+
395
+ //export callNewSegment
396
+ func callNewSegment(user_data unsafe.Pointer, new C.int) {
397
+ if fn, ok := cbNewSegment[user_data]; ok {
398
+ fn(int(new))
399
+ }
400
+ }
401
+
402
+ //export callEncoderBegin
403
+ func callEncoderBegin(user_data unsafe.Pointer) C.bool {
404
+ if fn, ok := cbEncoderBegin[user_data]; ok {
405
+ if fn() {
406
+ return C.bool(true)
407
+ } else {
408
+ return C.bool(false)
409
+ }
410
+ }
411
+ return true
412
+ }
bindings/go/whisper_test.go ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package whisper_test
2
+
3
+ import (
4
+ "os"
5
+ "runtime"
6
+ "testing"
7
+ "time"
8
+
9
+ // Packages
10
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go"
11
+ wav "github.com/go-audio/wav"
12
+ assert "github.com/stretchr/testify/assert"
13
+ )
14
+
15
+ const (
16
+ ModelPath = "models/ggml-small.en.bin"
17
+ SamplePath = "samples/jfk.wav"
18
+ )
19
+
20
+ func Test_Whisper_000(t *testing.T) {
21
+ assert := assert.New(t)
22
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
23
+ t.Skip("Skipping test, model not found:", ModelPath)
24
+ }
25
+ ctx := whisper.Whisper_init(ModelPath)
26
+ assert.NotNil(ctx)
27
+ ctx.Whisper_free()
28
+ }
29
+
30
+ func Test_Whisper_001(t *testing.T) {
31
+ assert := assert.New(t)
32
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
33
+ t.Skip("Skipping test, model not found:", ModelPath)
34
+ }
35
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
36
+ t.Skip("Skipping test, sample not found:", SamplePath)
37
+ }
38
+
39
+ // Open samples
40
+ fh, err := os.Open(SamplePath)
41
+ assert.NoError(err)
42
+ defer fh.Close()
43
+
44
+ // Read samples
45
+ d := wav.NewDecoder(fh)
46
+ buf, err := d.FullPCMBuffer()
47
+ assert.NoError(err)
48
+
49
+ // Run whisper
50
+ ctx := whisper.Whisper_init(ModelPath)
51
+ assert.NotNil(ctx)
52
+ defer ctx.Whisper_free()
53
+ assert.NoError(ctx.Whisper_full(ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY), buf.AsFloat32Buffer().Data, nil, nil))
54
+
55
+ // Print out tokens
56
+ num_segments := ctx.Whisper_full_n_segments()
57
+ assert.GreaterOrEqual(num_segments, 1)
58
+ for i := 0; i < num_segments; i++ {
59
+ str := ctx.Whisper_full_get_segment_text(i)
60
+ assert.NotEmpty(str)
61
+ t0 := time.Duration(ctx.Whisper_full_get_segment_t0(i)) * time.Millisecond
62
+ t1 := time.Duration(ctx.Whisper_full_get_segment_t1(i)) * time.Millisecond
63
+ t.Logf("[%6s->%-6s] %q", t0, t1, str)
64
+ }
65
+ }
66
+
67
+ func Test_Whisper_002(t *testing.T) {
68
+ assert := assert.New(t)
69
+ for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
70
+ str := whisper.Whisper_lang_str(i)
71
+ assert.NotEmpty(str)
72
+ t.Log(str)
73
+ }
74
+ }
75
+
76
+ func Test_Whisper_003(t *testing.T) {
77
+ threads := runtime.NumCPU()
78
+ assert := assert.New(t)
79
+ if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
80
+ t.Skip("Skipping test, model not found:", ModelPath)
81
+ }
82
+ if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
83
+ t.Skip("Skipping test, sample not found:", SamplePath)
84
+ }
85
+
86
+ // Open samples
87
+ fh, err := os.Open(SamplePath)
88
+ assert.NoError(err)
89
+ defer fh.Close()
90
+
91
+ // Read samples
92
+ d := wav.NewDecoder(fh)
93
+ buf, err := d.FullPCMBuffer()
94
+ assert.NoError(err)
95
+
96
+ // Make the model
97
+ ctx := whisper.Whisper_init(ModelPath)
98
+ assert.NotNil(ctx)
99
+ defer ctx.Whisper_free()
100
+
101
+ // Get MEL
102
+ assert.NoError(ctx.Whisper_pcm_to_mel(buf.AsFloat32Buffer().Data, threads))
103
+
104
+ // Get Languages
105
+ languages, err := ctx.Whisper_lang_auto_detect(0, threads)
106
+ assert.NoError(err)
107
+ for i, p := range languages {
108
+ t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
109
+ }
110
+ }