KitaitiMakoto commited on
Commit
ae07b89
·
unverified ·
1 Parent(s): bb2a01d

ruby : support new-segment callback (#2506)

Browse files

* Add Params#new_segment_callback= method

* Add tests for Params#new_segment_callback=

* Group tests for #transcribe

* Don't use static for thread-safety

* Set new_segment_callback only when necessary

* Remove redundant check

* [skip ci] Add Ruby version README

* Revert "Group tests for #transcribe"

This reverts commit 71b65b00ccf1816c9ea8a247fb30f71bc09707d3.

* Revert "Add tests for Params#new_segment_callback="

This reverts commit 81e6df3bab7662da5379db51f28a989db7408c02.

* Add test for Context#full_n_segments

* Add Context#full_n_segments

* Add tests for lang API

* Add lang API

* Add tests for Context#full_lang_id API

* Add Context#full_lang_id

* Add abnormal test cases for lang

* Raise appropriate errors from lang APIs

* Add tests for Context#full_get_segment_t{0,1} API

* Add Context#full_get_segment_t{0,1}

* Add tests for Context#full_get_segment_speaker_turn_next API

* Add Context#full_get_segment_speaker_turn_next

* Add tests for Context#full_get_segment_text

* Add Context#full_get_setgment_text

* Add tests for Params#new_segment_callback=

* Run new segment callback

* Split tests to multiple files

* Use container struct for new segment callback

* Add tests for Params#new_segment_callback_user_data=

* Add Whisper::Params#new_user_callback_user_data=

* Add GC-related test for new segment callback

* Protect new segment callback related structs from GC

* Add meaningful test for build

* Rename: new_segment_callback_user_data -> new_segment_callback_container

* Add tests for Whisper::Segment

* Add Whisper::Segment and Whisper::Context#each_segment

* Extract c_ruby_whisper_callback_container_allocate()

* Add test for Whisper::Params#on_new_segment

* Add Whisper::Params#on_new_egment

* Assign symbol IDs to variables

* Make extsources.yaml simpler

* Update README

* Add document comments

* Add test for calling Whisper::Params#on_new_segment multiple times

* Add file dependencies to GitHub actions config and .gitignore

* Add more files to ext/.gitignore

.github/workflows/bindings-ruby.yml CHANGED
@@ -16,6 +16,9 @@ on:
16
  - ggml/src/ggml-quants.h
17
  - ggml/src/ggml-quants.c
18
  - ggml/src/ggml-cpu-impl.h
 
 
 
19
  - ggml/include/ggml.h
20
  - ggml/include/ggml-alloc.h
21
  - ggml/include/ggml-backend.h
@@ -24,6 +27,8 @@ on:
24
  - ggml/include/ggml-metal.h
25
  - ggml/include/ggml-sycl.h
26
  - ggml/include/ggml-vulkan.h
 
 
27
  - examples/dr_wav.h
28
  pull_request:
29
  paths:
@@ -41,6 +46,9 @@ on:
41
  - ggml/src/ggml-quants.h
42
  - ggml/src/ggml-quants.c
43
  - ggml/src/ggml-cpu-impl.h
 
 
 
44
  - ggml/include/ggml.h
45
  - ggml/include/ggml-alloc.h
46
  - ggml/include/ggml-backend.h
@@ -49,6 +57,8 @@ on:
49
  - ggml/include/ggml-metal.h
50
  - ggml/include/ggml-sycl.h
51
  - ggml/include/ggml-vulkan.h
 
 
52
  - examples/dr_wav.h
53
 
54
  jobs:
 
16
  - ggml/src/ggml-quants.h
17
  - ggml/src/ggml-quants.c
18
  - ggml/src/ggml-cpu-impl.h
19
+ - ggml/src/ggml-metal.m
20
+ - ggml/src/ggml-metal.metal
21
+ - ggml/src/ggml-blas.cpp
22
  - ggml/include/ggml.h
23
  - ggml/include/ggml-alloc.h
24
  - ggml/include/ggml-backend.h
 
27
  - ggml/include/ggml-metal.h
28
  - ggml/include/ggml-sycl.h
29
  - ggml/include/ggml-vulkan.h
30
+ - ggml/include/ggml-blas.h
31
+ - scripts/get-flags.mk
32
  - examples/dr_wav.h
33
  pull_request:
34
  paths:
 
46
  - ggml/src/ggml-quants.h
47
  - ggml/src/ggml-quants.c
48
  - ggml/src/ggml-cpu-impl.h
49
+ - ggml/src/ggml-metal.m
50
+ - ggml/src/ggml-metal.metal
51
+ - ggml/src/ggml-blas.cpp
52
  - ggml/include/ggml.h
53
  - ggml/include/ggml-alloc.h
54
  - ggml/include/ggml-backend.h
 
57
  - ggml/include/ggml-metal.h
58
  - ggml/include/ggml-sycl.h
59
  - ggml/include/ggml-vulkan.h
60
+ - ggml/include/ggml-blas.h
61
+ - scripts/get-flags.mk
62
  - examples/dr_wav.h
63
 
64
  jobs:
bindings/ruby/.gitignore CHANGED
@@ -1,4 +1,3 @@
1
- README.md
2
  LICENSE
3
  pkg/
4
  lib/whisper.*
 
 
1
  LICENSE
2
  pkg/
3
  lib/whisper.*
bindings/ruby/README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ whispercpp
2
+ ==========
3
+
4
+ ![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)
5
+
6
+ Ruby bindings for [whisper.cpp][], an interface of automatic speech recognition model.
7
+
8
+ Installation
9
+ ------------
10
+
11
+ Install the gem and add to the application's Gemfile by executing:
12
+
13
+ $ bundle add whispercpp
14
+
15
+ If bundler is not being used to manage dependencies, install the gem by executing:
16
+
17
+ $ gem install whispercpp
18
+
19
+ Usage
20
+ -----
21
+
22
+ ```ruby
23
+ require "whisper"
24
+
25
+ whisper = Whisper::Context.new("path/to/model.bin")
26
+
27
+ params = Whisper::Params.new
28
+ params.language = "en"
29
+ params.offset = 10_000
30
+ params.duration = 60_000
31
+ params.max_text_tokens = 300
32
+ params.translate = true
33
+ params.print_timestamps = false
34
+
35
+ whisper.transcribe("path/to/audio.wav", params) do |whole_text|
36
+ puts whole_text
37
+ end
38
+
39
+ ```
40
+
41
+ ### Preparing model ###
42
+
43
+ Use script to download model file(s):
44
+
45
+ ```bash
46
+ git clone https://github.com/ggerganov/whisper.cpp.git
47
+ cd whisper.cpp
48
+ sh ./models/download-ggml-model.sh base.en
49
+ ```
50
+
51
+ There are some types of models. See [models][] page for details.
52
+
53
+ ### Preparing audio file ###
54
+
55
+ Currently, whisper.cpp accepts only 16-bit WAV files.
56
+
57
+ ### API ###
58
+
59
+ Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
60
+
61
+ ```ruby
62
+ def format_time(time_ms)
63
+ sec, decimal_part = time_ms.divmod(1000)
64
+ min, sec = sec.divmod(60)
65
+ hour, min = min.divmod(60)
66
+ "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
67
+ end
68
+
69
+ whisper.transcribe("path/to/audio.wav", params)
70
+
71
+ whisper.each_segment.with_index do |segment, index|
72
+ line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
73
+ nth: index + 1,
74
+ st: format_time(segment.start_time),
75
+ ed: format_time(segment.end_time),
76
+ text: segment.text
77
+ }
78
+ line << " (speaker turned)" if segment.speaker_next_turn?
79
+ puts line
80
+ end
81
+
82
+ ```
83
+
84
+ You can also add hook to params called on new segment:
85
+
86
+ ```ruby
87
+ def format_time(time_ms)
88
+ sec, decimal_part = time_ms.divmod(1000)
89
+ min, sec = sec.divmod(60)
90
+ hour, min = min.divmod(60)
91
+ "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
92
+ end
93
+
94
+ # Add hook before calling #transcribe
95
+ params.on_new_segment do |segment|
96
+ line = "[%{st} --> %{ed}] %{text}" % {
97
+ st: format_time(segment.start_time),
98
+ ed: format_time(segment.end_time),
99
+ text: segment.text
100
+ }
101
+ line << " (speaker turned)" if segment.speaker_next_turn?
102
+ puts line
103
+ end
104
+
105
+ whisper.transcribe("path/to/audio.wav", params)
106
+
107
+ ```
108
+
109
+ [whisper.cpp]: https://github.com/ggerganov/whisper.cpp
110
+ [models]: https://github.com/ggerganov/whisper.cpp/tree/master/models
bindings/ruby/Rakefile CHANGED
@@ -5,17 +5,16 @@ require "yaml"
5
  require "rake/testtask"
6
 
7
  extsources = YAML.load_file("extsources.yaml")
8
- extsources.each_pair do |src_dir, dests|
9
- dests.each do |dest|
10
- src = Pathname(src_dir)/File.basename(dest)
11
-
12
- file src
13
- file dest => src do |t|
14
- cp t.source, t.name
15
- end
16
  end
 
17
  end
18
- SOURCES = extsources.values.flatten
19
  CLEAN.include SOURCES
20
  CLEAN.include FileList[
21
  "ext/*.o",
 
5
  require "rake/testtask"
6
 
7
  extsources = YAML.load_file("extsources.yaml")
8
+ SOURCES = FileList[]
9
+ extsources.each do |src|
10
+ basename = src.pathmap("%f")
11
+ dest = basename == "LICENSE" ? basename : basename.pathmap("ext/%f")
12
+ file src
13
+ file dest => src do |t|
14
+ cp t.source, t.name
 
15
  end
16
+ SOURCES.include dest
17
  end
 
18
  CLEAN.include SOURCES
19
  CLEAN.include FileList[
20
  "ext/*.o",
bindings/ruby/ext/.gitignore CHANGED
@@ -11,6 +11,10 @@ ggml-backend.c
11
  ggml-backend.h
12
  ggml-common.h
13
  ggml-cpu-impl.h
 
 
 
 
14
  ggml-cuda.h
15
  ggml-impl.h
16
  ggml-kompute.h
@@ -20,9 +24,12 @@ ggml-quants.c
20
  ggml-quants.h
21
  ggml-sycl.h
22
  ggml-vulkan.h
 
 
23
  whisper.cpp
24
  whisper.h
25
  dr_wav.h
 
26
  whisper.bundle
27
  whisper.so
28
  whisper.dll
 
11
  ggml-backend.h
12
  ggml-common.h
13
  ggml-cpu-impl.h
14
+ ggml-metal.m
15
+ ggml-metal.metal
16
+ ggml-metal-embed.metal
17
+ ggml-blas.cpp
18
  ggml-cuda.h
19
  ggml-impl.h
20
  ggml-kompute.h
 
24
  ggml-quants.h
25
  ggml-sycl.h
26
  ggml-vulkan.h
27
+ ggml-blas.h
28
+ get-flags.mk
29
  whisper.cpp
30
  whisper.h
31
  dr_wav.h
32
+ depend
33
  whisper.bundle
34
  whisper.so
35
  whisper.dll
bindings/ruby/ext/ruby_whisper.cpp CHANGED
@@ -36,12 +36,65 @@ VALUE mWhisper;
36
  VALUE cContext;
37
  VALUE cParams;
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  static void ruby_whisper_free(ruby_whisper *rw) {
40
  if (rw->context) {
41
  whisper_free(rw->context);
42
  rw->context = NULL;
43
  }
44
  }
 
45
  static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
46
  }
47
 
@@ -55,9 +108,13 @@ void rb_whisper_free(ruby_whisper *rw) {
55
  }
56
 
57
  void rb_whisper_params_mark(ruby_whisper_params *rwp) {
 
 
 
58
  }
59
 
60
  void rb_whisper_params_free(ruby_whisper_params *rwp) {
 
61
  ruby_whisper_params_free(rwp);
62
  free(rwp);
63
  }
@@ -69,13 +126,28 @@ static VALUE ruby_whisper_allocate(VALUE klass) {
69
  return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
70
  }
71
 
 
 
 
 
 
 
 
 
 
 
72
  static VALUE ruby_whisper_params_allocate(VALUE klass) {
73
  ruby_whisper_params *rwp;
74
  rwp = ALLOC(ruby_whisper_params);
75
  rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
76
  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
77
  }
78
 
 
 
 
 
79
  static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
80
  ruby_whisper *rw;
81
  VALUE whisper_model_file_path;
@@ -84,7 +156,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
84
  rb_scan_args(argc, argv, "01", &whisper_model_file_path);
85
  Data_Get_Struct(self, ruby_whisper, rw);
86
 
87
- if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
88
  rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
89
  }
90
  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
@@ -94,10 +166,21 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
94
  return self;
95
  }
96
 
 
 
 
97
  /*
98
  * transcribe a single file
99
  * can emit to a block results
100
  *
 
 
 
 
 
 
 
 
101
  **/
102
  static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
103
  ruby_whisper *rw;
@@ -108,7 +191,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
108
  Data_Get_Struct(self, ruby_whisper, rw);
109
  Data_Get_Struct(params, ruby_whisper_params, rwp);
110
 
111
- if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) {
112
  rb_raise(rb_eRuntimeError, "Expected file path to wave file");
113
  }
114
 
@@ -206,6 +289,33 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
206
  rwp->params.encoder_begin_callback_user_data = &is_aborted;
207
  }
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
210
  fprintf(stderr, "failed to process audio\n");
211
  return self;
@@ -216,15 +326,114 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
216
  const char * text = whisper_full_get_segment_text(rw->context, i);
217
  output = rb_str_concat(output, rb_str_new2(text));
218
  }
219
- VALUE idCall = rb_intern("call");
220
  if (blk != Qnil) {
221
  rb_funcall(blk, idCall, 1, output);
222
  }
223
  return self;
224
  }
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  /*
227
  * params.language = "auto" | "en", etc...
 
 
 
228
  */
229
  static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
230
  ruby_whisper_params *rwp;
@@ -236,6 +445,10 @@ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
236
  }
237
  return value;
238
  }
 
 
 
 
239
  static VALUE ruby_whisper_params_get_language(VALUE self) {
240
  ruby_whisper_params *rwp;
241
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -245,72 +458,185 @@ static VALUE ruby_whisper_params_get_language(VALUE self) {
245
  return rb_str_new2("auto");
246
  }
247
  }
 
 
 
 
248
  static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
249
  BOOL_PARAMS_SETTER(self, translate, value)
250
  }
 
 
 
 
251
  static VALUE ruby_whisper_params_get_translate(VALUE self) {
252
  BOOL_PARAMS_GETTER(self, translate)
253
  }
 
 
 
 
254
  static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
255
  BOOL_PARAMS_SETTER(self, no_context, value)
256
  }
 
 
 
 
 
 
257
  static VALUE ruby_whisper_params_get_no_context(VALUE self) {
258
  BOOL_PARAMS_GETTER(self, no_context)
259
  }
 
 
 
 
260
  static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
261
  BOOL_PARAMS_SETTER(self, single_segment, value)
262
  }
 
 
 
 
 
 
263
  static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
264
  BOOL_PARAMS_GETTER(self, single_segment)
265
  }
 
 
 
 
266
  static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
267
  BOOL_PARAMS_SETTER(self, print_special, value)
268
  }
 
 
 
 
 
 
269
  static VALUE ruby_whisper_params_get_print_special(VALUE self) {
270
  BOOL_PARAMS_GETTER(self, print_special)
271
  }
 
 
 
 
272
  static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
273
  BOOL_PARAMS_SETTER(self, print_progress, value)
274
  }
 
 
 
 
 
 
275
  static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
276
  BOOL_PARAMS_GETTER(self, print_progress)
277
  }
 
 
 
 
278
  static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
279
  BOOL_PARAMS_SETTER(self, print_realtime, value)
280
  }
 
 
 
 
 
281
  static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
282
  BOOL_PARAMS_GETTER(self, print_realtime)
283
  }
 
 
 
 
284
  static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
285
  BOOL_PARAMS_SETTER(self, print_timestamps, value)
286
  }
 
 
 
 
 
 
287
  static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
288
  BOOL_PARAMS_GETTER(self, print_timestamps)
289
  }
 
 
 
 
290
  static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
291
  BOOL_PARAMS_SETTER(self, suppress_blank, value)
292
  }
 
 
 
 
 
 
293
  static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
294
  BOOL_PARAMS_GETTER(self, suppress_blank)
295
  }
 
 
 
 
296
  static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
297
  BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
298
  }
 
 
 
 
 
 
299
  static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
300
  BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
301
  }
 
 
 
 
 
 
302
  static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
303
  BOOL_PARAMS_GETTER(self, token_timestamps)
304
  }
 
 
 
 
305
  static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
306
  BOOL_PARAMS_SETTER(self, token_timestamps, value)
307
  }
 
 
 
 
 
 
308
  static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
309
  BOOL_PARAMS_GETTER(self, split_on_word)
310
  }
 
 
 
 
311
  static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
312
  BOOL_PARAMS_SETTER(self, split_on_word, value)
313
  }
 
 
 
 
 
 
314
  static VALUE ruby_whisper_params_get_diarize(VALUE self) {
315
  ruby_whisper_params *rwp;
316
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -320,6 +646,10 @@ static VALUE ruby_whisper_params_get_diarize(VALUE self) {
320
  return Qfalse;
321
  }
322
  }
 
 
 
 
323
  static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
324
  ruby_whisper_params *rwp;
325
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -331,22 +661,42 @@ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
331
  return value;
332
  }
333
 
 
 
 
 
 
 
334
  static VALUE ruby_whisper_params_get_offset(VALUE self) {
335
  ruby_whisper_params *rwp;
336
  Data_Get_Struct(self, ruby_whisper_params, rwp);
337
  return INT2NUM(rwp->params.offset_ms);
338
  }
 
 
 
 
339
  static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
340
  ruby_whisper_params *rwp;
341
  Data_Get_Struct(self, ruby_whisper_params, rwp);
342
  rwp->params.offset_ms = NUM2INT(value);
343
  return value;
344
  }
 
 
 
 
 
 
345
  static VALUE ruby_whisper_params_get_duration(VALUE self) {
346
  ruby_whisper_params *rwp;
347
  Data_Get_Struct(self, ruby_whisper_params, rwp);
348
  return INT2NUM(rwp->params.duration_ms);
349
  }
 
 
 
 
350
  static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
351
  ruby_whisper_params *rwp;
352
  Data_Get_Struct(self, ruby_whisper_params, rwp);
@@ -354,27 +704,221 @@ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
354
  return value;
355
  }
356
 
 
 
 
 
 
 
357
  static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
358
  ruby_whisper_params *rwp;
359
  Data_Get_Struct(self, ruby_whisper_params, rwp);
360
  return INT2NUM(rwp->params.n_max_text_ctx);
361
  }
 
 
 
 
362
  static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
363
  ruby_whisper_params *rwp;
364
  Data_Get_Struct(self, ruby_whisper_params, rwp);
365
  rwp->params.n_max_text_ctx = NUM2INT(value);
366
  return value;
367
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  void Init_whisper() {
 
 
 
 
 
370
  mWhisper = rb_define_module("Whisper");
371
  cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
372
  cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
373
 
 
 
 
 
 
374
  rb_define_alloc_func(cContext, ruby_whisper_allocate);
375
  rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
376
 
377
  rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
 
 
 
 
 
 
378
 
379
  rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
380
 
@@ -412,6 +956,20 @@ void Init_whisper() {
412
 
413
  rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
414
  rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  }
416
  #ifdef __cplusplus
417
  }
 
36
  VALUE cContext;
37
  VALUE cParams;
38
 
39
+ static ID id_to_s;
40
+ static ID id_call;
41
+ static ID id___method__;
42
+ static ID id_to_enum;
43
+
44
+ /*
45
+ * call-seq:
46
+ * lang_max_id -> Integer
47
+ */
48
+ static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
49
+ return INT2NUM(whisper_lang_max_id());
50
+ }
51
+
52
+ /*
53
+ * call-seq:
54
+ * lang_id(lang_name) -> Integer
55
+ */
56
+ static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
57
+ const char * lang_str = StringValueCStr(lang);
58
+ const int id = whisper_lang_id(lang_str);
59
+ if (-1 == id) {
60
+ rb_raise(rb_eArgError, "language not found: %s", lang_str);
61
+ }
62
+ return INT2NUM(id);
63
+ }
64
+
65
+ /*
66
+ * call-seq:
67
+ * lang_str(lang_id) -> String
68
+ */
69
+ static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
70
+ const int lang_id = NUM2INT(id);
71
+ const char * str = whisper_lang_str(lang_id);
72
+ if (nullptr == str) {
73
+ rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
74
+ }
75
+ return rb_str_new2(str);
76
+ }
77
+
78
+ /*
79
+ * call-seq:
80
+ * lang_str(lang_id) -> String
81
+ */
82
+ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
83
+ const int lang_id = NUM2INT(id);
84
+ const char * str_full = whisper_lang_str_full(lang_id);
85
+ if (nullptr == str_full) {
86
+ rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
87
+ }
88
+ return rb_str_new2(str_full);
89
+ }
90
+
91
  static void ruby_whisper_free(ruby_whisper *rw) {
92
  if (rw->context) {
93
  whisper_free(rw->context);
94
  rw->context = NULL;
95
  }
96
  }
97
+
98
  static void ruby_whisper_params_free(ruby_whisper_params *rwp) {
99
  }
100
 
 
108
  }
109
 
110
  void rb_whisper_params_mark(ruby_whisper_params *rwp) {
111
+ rb_gc_mark(rwp->new_segment_callback_container->user_data);
112
+ rb_gc_mark(rwp->new_segment_callback_container->callback);
113
+ rb_gc_mark(rwp->new_segment_callback_container->callbacks);
114
  }
115
 
116
  void rb_whisper_params_free(ruby_whisper_params *rwp) {
117
+ // How to free user_data and callback only when not referred to by others?
118
  ruby_whisper_params_free(rwp);
119
  free(rwp);
120
  }
 
126
  return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
127
  }
128
 
129
+ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() {
130
+ ruby_whisper_callback_container *container;
131
+ container = ALLOC(ruby_whisper_callback_container);
132
+ container->context = nullptr;
133
+ container->user_data = Qnil;
134
+ container->callback = Qnil;
135
+ container->callbacks = rb_ary_new();
136
+ return container;
137
+ }
138
+
139
  static VALUE ruby_whisper_params_allocate(VALUE klass) {
140
  ruby_whisper_params *rwp;
141
  rwp = ALLOC(ruby_whisper_params);
142
  rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
143
+ rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
144
  return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
145
  }
146
 
147
+ /*
148
+ * call-seq:
149
+ * new("path/to/model.bin") -> Whisper::Context
150
+ */
151
  static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
152
  ruby_whisper *rw;
153
  VALUE whisper_model_file_path;
 
156
  rb_scan_args(argc, argv, "01", &whisper_model_file_path);
157
  Data_Get_Struct(self, ruby_whisper, rw);
158
 
159
+ if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
160
  rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
161
  }
162
  rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
 
166
  return self;
167
  }
168
 
169
+ // High level API
170
+ static VALUE rb_whisper_segment_initialize(VALUE context, int index);
171
+
172
  /*
173
  * transcribe a single file
174
  * can emit to a block results
175
  *
176
+ * params = Whisper::Params.new
177
+ * params.duration = 60_000
178
+ * whisper.transcribe "path/to/audio.wav", params do |text|
179
+ * puts text
180
+ * end
181
+ *
182
+ * call-seq:
183
+ * transcribe(path_to_audio, params) {|text| ...}
184
  **/
185
  static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
186
  ruby_whisper *rw;
 
191
  Data_Get_Struct(self, ruby_whisper, rw);
192
  Data_Get_Struct(params, ruby_whisper_params, rwp);
193
 
194
+ if (!rb_respond_to(wave_file_path, id_to_s)) {
195
  rb_raise(rb_eRuntimeError, "Expected file path to wave file");
196
  }
197
 
 
289
  rwp->params.encoder_begin_callback_user_data = &is_aborted;
290
  }
291
 
292
+ if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
293
+ rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
294
+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
295
+
296
+ // Currently, doesn't support state because
297
+ // those require to resolve GC-related problems.
298
+ if (!NIL_P(container->callback)) {
299
+ rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
300
+ }
301
+ const long callbacks_len = RARRAY_LEN(container->callbacks);
302
+ if (0 == callbacks_len) {
303
+ return;
304
+ }
305
+ const int n_segments = whisper_full_n_segments_from_state(state);
306
+ for (int i = n_new; i > 0; i--) {
307
+ int i_segment = n_segments - i;
308
+ VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
309
+ for (int j = 0; j < callbacks_len; j++) {
310
+ VALUE cb = rb_ary_entry(container->callbacks, j);
311
+ rb_funcall(cb, id_call, 1, segment);
312
+ }
313
+ }
314
+ };
315
+ rwp->new_segment_callback_container->context = &self;
316
+ rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
317
+ }
318
+
319
  if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
320
  fprintf(stderr, "failed to process audio\n");
321
  return self;
 
326
  const char * text = whisper_full_get_segment_text(rw->context, i);
327
  output = rb_str_concat(output, rb_str_new2(text));
328
  }
329
+ VALUE idCall = id_call;
330
  if (blk != Qnil) {
331
  rb_funcall(blk, idCall, 1, output);
332
  }
333
  return self;
334
  }
335
 
336
+ /*
337
+ * Number of segments.
338
+ *
339
+ * call-seq:
340
+ * full_n_segments -> Integer
341
+ */
342
+ static VALUE ruby_whisper_full_n_segments(VALUE self) {
343
+ ruby_whisper *rw;
344
+ Data_Get_Struct(self, ruby_whisper, rw);
345
+ return INT2NUM(whisper_full_n_segments(rw->context));
346
+ }
347
+
348
+ /*
349
+ * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
350
+ *
351
+ * call-seq:
352
+ * full_lang_id -> Integer
353
+ */
354
+ static VALUE ruby_whisper_full_lang_id(VALUE self) {
355
+ ruby_whisper *rw;
356
+ Data_Get_Struct(self, ruby_whisper, rw);
357
+ return INT2NUM(whisper_full_lang_id(rw->context));
358
+ }
359
+
360
+ static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) {
361
+ const int c_i_segment = NUM2INT(i_segment);
362
+ if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
363
+ rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
364
+ }
365
+ return c_i_segment;
366
+ }
367
+
368
+ /*
369
+ * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
370
+ *
371
+ * full_get_segment_t0(3) # => 1668 (16680 ms)
372
+ *
373
+ * call-seq:
374
+ * full_get_segment_t0(segment_index) -> Integer
375
+ */
376
+ static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) {
377
+ ruby_whisper *rw;
378
+ Data_Get_Struct(self, ruby_whisper, rw);
379
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
380
+ const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
381
+ return INT2NUM(t0);
382
+ }
383
+
384
+ /*
385
+ * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
386
+ *
387
+ * full_get_segment_t1(3) # => 1668 (16680 ms)
388
+ *
389
+ * call-seq:
390
+ * full_get_segment_t1(segment_index) -> Integer
391
+ */
392
+ static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) {
393
+ ruby_whisper *rw;
394
+ Data_Get_Struct(self, ruby_whisper, rw);
395
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
396
+ const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
397
+ return INT2NUM(t1);
398
+ }
399
+
400
+ /*
401
+ * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
402
+ *
403
+ * full_get_segment_speacker_turn_next(3) # => true
404
+ *
405
+ * call-seq:
406
+ * full_get_segment_speacker_turn_next(segment_index) -> bool
407
+ */
408
+ static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) {
409
+ ruby_whisper *rw;
410
+ Data_Get_Struct(self, ruby_whisper, rw);
411
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
412
+ const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
413
+ return speaker_turn_next ? Qtrue : Qfalse;
414
+ }
415
+
416
+ /*
417
+ * Text of a segment indexed by +segment_index+.
418
+ *
419
+ * full_get_segment_text(3) # => "ask not what your country can do for you, ..."
420
+ *
421
+ * call-seq:
422
+ * full_get_segment_text(segment_index) -> String
423
+ */
424
+ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
425
+ ruby_whisper *rw;
426
+ Data_Get_Struct(self, ruby_whisper, rw);
427
+ const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
428
+ const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
429
+ return rb_str_new2(text);
430
+ }
431
+
432
  /*
433
  * params.language = "auto" | "en", etc...
434
+ *
435
+ * call-seq:
436
+ * language = lang_name -> lang_name
437
  */
438
  static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) {
439
  ruby_whisper_params *rwp;
 
445
  }
446
  return value;
447
  }
448
+ /*
449
+ * call-seq:
450
+ * language -> String
451
+ */
452
  static VALUE ruby_whisper_params_get_language(VALUE self) {
453
  ruby_whisper_params *rwp;
454
  Data_Get_Struct(self, ruby_whisper_params, rwp);
 
458
  return rb_str_new2("auto");
459
  }
460
  }
461
+ /*
462
+ * call-seq:
463
+ * translate = do_translate -> do_translate
464
+ */
465
  static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) {
466
  BOOL_PARAMS_SETTER(self, translate, value)
467
  }
468
+ /*
469
+ * call-seq:
470
+ * translate -> bool
471
+ */
472
  static VALUE ruby_whisper_params_get_translate(VALUE self) {
473
  BOOL_PARAMS_GETTER(self, translate)
474
  }
475
+ /*
476
+ * call-seq:
477
+ * no_context = dont_use_context -> dont_use_context
478
+ */
479
  static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) {
480
  BOOL_PARAMS_SETTER(self, no_context, value)
481
  }
482
+ /*
483
+ * If true, does not use past transcription (if any) as initial prompt for the decoder.
484
+ *
485
+ * call-seq:
486
+ * no_context -> bool
487
+ */
488
  static VALUE ruby_whisper_params_get_no_context(VALUE self) {
489
  BOOL_PARAMS_GETTER(self, no_context)
490
  }
491
+ /*
492
+ * call-seq:
493
+ * single_segment = force_single -> force_single
494
+ */
495
  static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) {
496
  BOOL_PARAMS_SETTER(self, single_segment, value)
497
  }
498
+ /*
499
+ * If true, forces single segment output (useful for streaming).
500
+ *
501
+ * call-seq:
502
+ * single_segment -> bool
503
+ */
504
  static VALUE ruby_whisper_params_get_single_segment(VALUE self) {
505
  BOOL_PARAMS_GETTER(self, single_segment)
506
  }
507
+ /*
508
+ * call-seq:
509
+ * print_special = force_print -> force_print
510
+ */
511
  static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) {
512
  BOOL_PARAMS_SETTER(self, print_special, value)
513
  }
514
+ /*
515
+ * If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
516
+ *
517
+ * call-seq:
518
+ * print_special -> bool
519
+ */
520
  static VALUE ruby_whisper_params_get_print_special(VALUE self) {
521
  BOOL_PARAMS_GETTER(self, print_special)
522
  }
523
+ /*
524
+ * call-seq:
525
+ * print_progress = force_print -> force_print
526
+ */
527
  static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) {
528
  BOOL_PARAMS_SETTER(self, print_progress, value)
529
  }
530
+ /*
531
+ * If true, prints progress information.
532
+ *
533
+ * call-seq:
534
+ * print_progress -> bool
535
+ */
536
  static VALUE ruby_whisper_params_get_print_progress(VALUE self) {
537
  BOOL_PARAMS_GETTER(self, print_progress)
538
  }
539
+ /*
540
+ * call-seq:
541
+ * print_realtime = force_print -> force_print
542
+ */
543
  static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) {
544
  BOOL_PARAMS_SETTER(self, print_realtime, value)
545
  }
546
+ /*
547
+ * If true, prints results from within whisper.cpp. (avoid it, use callback instead)
548
+ * call-seq:
549
+ * print_realtime -> bool
550
+ */
551
  static VALUE ruby_whisper_params_get_print_realtime(VALUE self) {
552
  BOOL_PARAMS_GETTER(self, print_realtime)
553
  }
554
+ /*
555
+ * call-seq:
556
+ * print_timestamps = force_print -> force_print
557
+ */
558
  static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) {
559
  BOOL_PARAMS_SETTER(self, print_timestamps, value)
560
  }
561
+ /*
562
+ * If true, prints timestamps for each text segment when printing realtime.
563
+ *
564
+ * call-seq:
565
+ * print_timestamps -> bool
566
+ */
567
  static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) {
568
  BOOL_PARAMS_GETTER(self, print_timestamps)
569
  }
570
+ /*
571
+ * call-seq:
572
+ * suppress_blank = force_suppress -> force_suppress
573
+ */
574
  static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) {
575
  BOOL_PARAMS_SETTER(self, suppress_blank, value)
576
  }
577
+ /*
578
+ * If true, suppresses blank outputs.
579
+ *
580
+ * call-seq:
581
+ * suppress_blank -> bool
582
+ */
583
  static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
584
  BOOL_PARAMS_GETTER(self, suppress_blank)
585
  }
586
+ /*
587
+ * call-seq:
588
+ * suppress_non_speech_tokens = force_suppress -> force_suppress
589
+ */
590
  static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) {
591
  BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value)
592
  }
593
+ /*
594
+ * If true, suppresses non-speech-tokens.
595
+ *
596
+ * call-seq:
597
+ * suppress_non_speech_tokens -> bool
598
+ */
599
  static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) {
600
  BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens)
601
  }
602
+ /*
603
+ * If true, enables token-level timestamps.
604
+ *
605
+ * call-seq:
606
+ * token_timestamps -> bool
607
+ */
608
  static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) {
609
  BOOL_PARAMS_GETTER(self, token_timestamps)
610
  }
611
+ /*
612
+ * call-seq:
613
+ * token_timestamps = force_timestamps -> force_timestamps
614
+ */
615
  static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) {
616
  BOOL_PARAMS_SETTER(self, token_timestamps, value)
617
  }
618
+ /*
619
+ * If true, split on word rather than on token (when used with max_len).
620
+ *
621
+ * call-seq:
622
+ * translate -> bool
623
+ */
624
  static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
625
  BOOL_PARAMS_GETTER(self, split_on_word)
626
  }
627
+ /*
628
+ * call-seq:
629
+ * split_on_word = force_split -> force_split
630
+ */
631
  static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
632
  BOOL_PARAMS_SETTER(self, split_on_word, value)
633
  }
634
+ /*
635
+ * If true, enables diarization.
636
+ *
637
+ * call-seq:
638
+ * diarize -> bool
639
+ */
640
  static VALUE ruby_whisper_params_get_diarize(VALUE self) {
641
  ruby_whisper_params *rwp;
642
  Data_Get_Struct(self, ruby_whisper_params, rwp);
 
646
  return Qfalse;
647
  }
648
  }
649
+ /*
650
+ * call-seq:
651
+ * diarize = force_diarize -> force_diarize
652
+ */
653
  static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) {
654
  ruby_whisper_params *rwp;
655
  Data_Get_Struct(self, ruby_whisper_params, rwp);
 
661
  return value;
662
  }
663
 
664
+ /*
665
+ * Start offset in ms.
666
+ *
667
+ * call-seq:
668
+ * offset -> Integer
669
+ */
670
  static VALUE ruby_whisper_params_get_offset(VALUE self) {
671
  ruby_whisper_params *rwp;
672
  Data_Get_Struct(self, ruby_whisper_params, rwp);
673
  return INT2NUM(rwp->params.offset_ms);
674
  }
675
+ /*
676
+ * call-seq:
677
+ * offset = offset_ms -> offset_ms
678
+ */
679
  static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) {
680
  ruby_whisper_params *rwp;
681
  Data_Get_Struct(self, ruby_whisper_params, rwp);
682
  rwp->params.offset_ms = NUM2INT(value);
683
  return value;
684
  }
685
+ /*
686
+ * Audio duration to process in ms.
687
+ *
688
+ * call-seq:
689
+ * duration -> Integer
690
+ */
691
  static VALUE ruby_whisper_params_get_duration(VALUE self) {
692
  ruby_whisper_params *rwp;
693
  Data_Get_Struct(self, ruby_whisper_params, rwp);
694
  return INT2NUM(rwp->params.duration_ms);
695
  }
696
+ /*
697
+ * call-seq:
698
+ * duration = duration_ms -> duration_ms
699
+ */
700
  static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) {
701
  ruby_whisper_params *rwp;
702
  Data_Get_Struct(self, ruby_whisper_params, rwp);
 
704
  return value;
705
  }
706
 
707
+ /*
708
+ * Max tokens to use from past text as prompt for the decoder.
709
+ *
710
+ * call-seq:
711
+ * max_text_tokens -> Integer
712
+ */
713
  static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) {
714
  ruby_whisper_params *rwp;
715
  Data_Get_Struct(self, ruby_whisper_params, rwp);
716
  return INT2NUM(rwp->params.n_max_text_ctx);
717
  }
718
+ /*
719
+ * call-seq:
720
+ * max_text_tokens = n_tokens -> n_tokens
721
+ */
722
  static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
723
  ruby_whisper_params *rwp;
724
  Data_Get_Struct(self, ruby_whisper_params, rwp);
725
  rwp->params.n_max_text_ctx = NUM2INT(value);
726
  return value;
727
  }
728
+ /*
729
+ * Sets new segment callback, called for every newly generated text segment.
730
+ *
731
+ * params.new_segment_callback = ->(context, _, n_new, user_data) {
732
+ * # ...
733
+ * }
734
+ *
735
+ * call-seq:
736
+ * new_segment_callback = callback -> callback
737
+ */
738
+ static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) {
739
+ ruby_whisper_params *rwp;
740
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
741
+ rwp->new_segment_callback_container->callback = value;
742
+ return value;
743
+ }
744
+ /*
745
+ * Sets user data passed to the last argument of new segment callback.
746
+ *
747
+ * call-seq:
748
+ * new_segment_callback_user_data = user_data -> use_data
749
+ */
750
+ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) {
751
+ ruby_whisper_params *rwp;
752
+ Data_Get_Struct(self, ruby_whisper_params, rwp);
753
+ rwp->new_segment_callback_container->user_data = value;
754
+ return value;
755
+ }
756
+
757
+ // High level API
758
+
759
+ typedef struct {
760
+ VALUE context;
761
+ int index;
762
+ } ruby_whisper_segment;
763
+
764
+ VALUE cSegment;
765
+
766
+ static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
767
+ rb_gc_mark(rws->context);
768
+ }
769
+
770
+ static VALUE ruby_whisper_segment_allocate(VALUE klass) {
771
+ ruby_whisper_segment *rws;
772
+ rws = ALLOC(ruby_whisper_segment);
773
+ return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
774
+ }
775
+
776
+ static VALUE rb_whisper_segment_initialize(VALUE context, int index) {
777
+ ruby_whisper_segment *rws;
778
+ const VALUE segment = ruby_whisper_segment_allocate(cSegment);
779
+ Data_Get_Struct(segment, ruby_whisper_segment, rws);
780
+ rws->context = context;
781
+ rws->index = index;
782
+ return segment;
783
+ };
784
+
785
+ /*
786
+ * Yields each Whisper::Segment:
787
+ *
788
+ * whisper.transcribe("path/to/audio.wav", params)
789
+ * whisper.each_segment do |segment|
790
+ * puts segment.text
791
+ * end
792
+ *
793
+ * Returns an Enumerator if no block given:
794
+ *
795
+ * whisper.transcribe("path/to/audio.wav", params)
796
+ * enum = whisper.each_segment
797
+ * enum.to_a # => [#<Whisper::Segment>, ...]
798
+ *
799
+ * call-seq:
800
+ * each_segment {|segment| ... }
801
+ * each_segment -> Enumerator
802
+ */
803
+ static VALUE ruby_whisper_each_segment(VALUE self) {
804
+ if (!rb_block_given_p()) {
805
+ const VALUE method_name = rb_funcall(self, id___method__, 0);
806
+ return rb_funcall(self, id_to_enum, 1, method_name);
807
+ }
808
+
809
+ ruby_whisper *rw;
810
+ Data_Get_Struct(self, ruby_whisper, rw);
811
+
812
+ const int n_segments = whisper_full_n_segments(rw->context);
813
+ for (int i = 0; i < n_segments; ++i) {
814
+ rb_yield(rb_whisper_segment_initialize(self, i));
815
+ }
816
+
817
+ return self;
818
+ }
819
+
820
+ /*
821
+ * Hook called on new segment. Yields each Whisper::Segment.
822
+ *
823
+ * whisper.on_new_segment do |segment|
824
+ * # ...
825
+ * end
826
+ *
827
+ * call-seq:
828
+ * on_new_segment {|segment| ... }
829
+ */
830
+ static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
831
+ ruby_whisper_params *rws;
832
+ Data_Get_Struct(self, ruby_whisper_params, rws);
833
+ const VALUE blk = rb_block_proc();
834
+ rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
835
+ return Qnil;
836
+ }
837
+
838
+ /*
839
+ * Start time in milliseconds.
840
+ *
841
+ * call-seq:
842
+ * start_time -> Integer
843
+ */
844
+ static VALUE ruby_whisper_segment_get_start_time(VALUE self) {
845
+ ruby_whisper_segment *rws;
846
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
847
+ ruby_whisper *rw;
848
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
849
+ const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
850
+ // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
851
+ return INT2NUM(t0 * 10);
852
+ }
853
+
854
+ /*
855
+ * End time in milliseconds.
856
+ *
857
+ * call-seq:
858
+ * end_time -> Integer
859
+ */
860
+ static VALUE ruby_whisper_segment_get_end_time(VALUE self) {
861
+ ruby_whisper_segment *rws;
862
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
863
+ ruby_whisper *rw;
864
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
865
+ const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
866
+ // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
867
+ return INT2NUM(t1 * 10);
868
+ }
869
+
870
+ /*
871
+ * Whether the next segment is predicted as a speaker turn.
872
+ *
873
+ * call-seq:
874
+ * speaker_turn_next? -> bool
875
+ */
876
+ static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) {
877
+ ruby_whisper_segment *rws;
878
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
879
+ ruby_whisper *rw;
880
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
881
+ return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
882
+ }
883
+
884
+ /*
885
+ * call-seq:
886
+ * text -> String
887
+ */
888
+ static VALUE ruby_whisper_segment_get_text(VALUE self) {
889
+ ruby_whisper_segment *rws;
890
+ Data_Get_Struct(self, ruby_whisper_segment, rws);
891
+ ruby_whisper *rw;
892
+ Data_Get_Struct(rws->context, ruby_whisper, rw);
893
+ const char * text = whisper_full_get_segment_text(rw->context, rws->index);
894
+ return rb_str_new2(text);
895
+ }
896
 
897
  void Init_whisper() {
898
+ id_to_s = rb_intern("to_s");
899
+ id_call = rb_intern("call");
900
+ id___method__ = rb_intern("__method__");
901
+ id_to_enum = rb_intern("to_enum");
902
+
903
  mWhisper = rb_define_module("Whisper");
904
  cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
905
  cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
906
 
907
+ rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
908
+ rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
909
+ rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
910
+ rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
911
+
912
  rb_define_alloc_func(cContext, ruby_whisper_allocate);
913
  rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
914
 
915
  rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
916
+ rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
917
+ rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
918
+ rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
919
+ rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
920
+ rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
921
+ rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
922
 
923
  rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
924
 
 
956
 
957
  rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
958
  rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);
959
+
960
+ rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
961
+ rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
962
+
963
+ // High leve
964
+ cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
965
+
966
+ rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
967
+ rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
968
+ rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
969
+ rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
970
+ rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
971
+ rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
972
+ rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
973
  }
974
  #ifdef __cplusplus
975
  }
bindings/ruby/ext/ruby_whisper.h CHANGED
@@ -3,6 +3,13 @@
3
 
4
  #include "whisper.h"
5
 
 
 
 
 
 
 
 
6
  typedef struct {
7
  struct whisper_context *context;
8
  } ruby_whisper;
@@ -10,6 +17,7 @@ typedef struct {
10
  typedef struct {
11
  struct whisper_full_params params;
12
  bool diarize;
 
13
  } ruby_whisper_params;
14
 
15
  #endif
 
3
 
4
  #include "whisper.h"
5
 
6
+ typedef struct {
7
+ VALUE *context;
8
+ VALUE user_data;
9
+ VALUE callback;
10
+ VALUE callbacks;
11
+ } ruby_whisper_callback_container;
12
+
13
  typedef struct {
14
  struct whisper_context *context;
15
  } ruby_whisper;
 
17
  typedef struct {
18
  struct whisper_full_params params;
19
  bool diarize;
20
+ ruby_whisper_callback_container *new_segment_callback_container;
21
  } ruby_whisper_params;
22
 
23
  #endif
bindings/ruby/extsources.yaml CHANGED
@@ -1,37 +1,29 @@
1
  ---
2
- ../../src:
3
- - ext/whisper.cpp
4
- ../../include:
5
- - ext/whisper.h
6
- ../../ggml/src:
7
- - ext/ggml.c
8
- - ext/ggml-impl.h
9
- - ext/ggml-aarch64.h
10
- - ext/ggml-aarch64.c
11
- - ext/ggml-alloc.c
12
- - ext/ggml-backend-impl.h
13
- - ext/ggml-backend.cpp
14
- - ext/ggml-common.h
15
- - ext/ggml-quants.h
16
- - ext/ggml-quants.c
17
- - ext/ggml-cpu-impl.h
18
- - ext/ggml-metal.m
19
- - ext/ggml-metal.metal
20
- - ext/ggml-blas.cpp
21
- ../../ggml/include:
22
- - ext/ggml.h
23
- - ext/ggml-alloc.h
24
- - ext/ggml-backend.h
25
- - ext/ggml-cuda.h
26
- - ext/ggml-kompute.h
27
- - ext/ggml-metal.h
28
- - ext/ggml-sycl.h
29
- - ext/ggml-vulkan.h
30
- - ext/ggml-blas.h
31
- ../../scripts:
32
- - ext/get-flags.mk
33
- ../../examples:
34
- - ext/dr_wav.h
35
- ../..:
36
- - README.md
37
- - LICENSE
 
1
  ---
2
+ - ../../src/whisper.cpp
3
+ - ../../include/whisper.h
4
+ - ../../ggml/src/ggml.c
5
+ - ../../ggml/src/ggml-impl.h
6
+ - ../../ggml/src/ggml-aarch64.h
7
+ - ../../ggml/src/ggml-aarch64.c
8
+ - ../../ggml/src/ggml-alloc.c
9
+ - ../../ggml/src/ggml-backend-impl.h
10
+ - ../../ggml/src/ggml-backend.cpp
11
+ - ../../ggml/src/ggml-common.h
12
+ - ../../ggml/src/ggml-quants.h
13
+ - ../../ggml/src/ggml-quants.c
14
+ - ../../ggml/src/ggml-cpu-impl.h
15
+ - ../../ggml/src/ggml-metal.m
16
+ - ../../ggml/src/ggml-metal.metal
17
+ - ../../ggml/src/ggml-blas.cpp
18
+ - ../../ggml/include/ggml.h
19
+ - ../../ggml/include/ggml-alloc.h
20
+ - ../../ggml/include/ggml-backend.h
21
+ - ../../ggml/include/ggml-cuda.h
22
+ - ../../ggml/include/ggml-kompute.h
23
+ - ../../ggml/include/ggml-metal.h
24
+ - ../../ggml/include/ggml-sycl.h
25
+ - ../../ggml/include/ggml-vulkan.h
26
+ - ../../ggml/include/ggml-blas.h
27
+ - ../../scripts/get-flags.mk
28
+ - ../../examples/dr_wav.h
29
+ - ../../LICENSE
 
 
 
 
 
 
 
 
bindings/ruby/tests/test_callback.rb ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ require "test/unit"
2
+ require "whisper"
3
+
4
+ class TestCallback < Test::Unit::TestCase
5
+ TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
6
+
7
+ def setup
8
+ @params = Whisper::Params.new
9
+ @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
10
+ @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
11
+ end
12
+
13
+ def test_new_segment_callback
14
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
15
+ assert_kind_of Integer, n_new
16
+ assert n_new > 0
17
+ assert_same @whisper, context
18
+
19
+ n_segments = context.full_n_segments
20
+ n_new.times do |i|
21
+ i_segment = n_segments - 1 + i
22
+ start_time = context.full_get_segment_t0(i_segment) * 10
23
+ end_time = context.full_get_segment_t1(i_segment) * 10
24
+ text = context.full_get_segment_text(i_segment)
25
+
26
+ assert_kind_of Integer, start_time
27
+ assert start_time >= 0
28
+ assert_kind_of Integer, end_time
29
+ assert end_time > 0
30
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
31
+ end
32
+ }
33
+
34
+ @whisper.transcribe(@audio, @params)
35
+ end
36
+
37
+ def test_new_segment_callback_closure
38
+ search_word = "what"
39
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
40
+ n_segments = context.full_n_segments
41
+ n_new.times do |i|
42
+ i_segment = n_segments - 1 + i
43
+ text = context.full_get_segment_text(i_segment)
44
+ if text.include?(search_word)
45
+ t0 = context.full_get_segment_t0(i_segment)
46
+ t1 = context.full_get_segment_t1(i_segment)
47
+ raise "search word '#{search_word}' found at between #{t0} and #{t1}"
48
+ end
49
+ end
50
+ }
51
+
52
+ assert_raise RuntimeError do
53
+ @whisper.transcribe(@audio, @params)
54
+ end
55
+ end
56
+
57
+ def test_new_segment_callback_user_data
58
+ udata = Object.new
59
+ @params.new_segment_callback_user_data = udata
60
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
61
+ assert_same udata, user_data
62
+ }
63
+
64
+ @whisper.transcribe(@audio, @params)
65
+ end
66
+
67
+ def test_new_segment_callback_user_data_gc
68
+ @params.new_segment_callback_user_data = "My user data"
69
+ @params.new_segment_callback = ->(context, state, n_new, user_data) {
70
+ assert_equal "My user data", user_data
71
+ }
72
+ GC.start
73
+
74
+ assert_same @whisper, @whisper.transcribe(@audio, @params)
75
+ end
76
+ end
bindings/ruby/tests/test_package.rb ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ require 'test/unit'
2
+ require 'tempfile'
3
+ require 'tmpdir'
4
+ require 'shellwords'
5
+
6
+ class TestPackage < Test::Unit::TestCase
7
+ def test_build
8
+ Tempfile.create do |file|
9
+ assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
10
+ assert file.size > 0
11
+ end
12
+ end
13
+
14
+ sub_test_case "Building binary on installation" do
15
+ def setup
16
+ system "rake", "build", exception: true
17
+ end
18
+
19
+ def test_install
20
+ filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1]
21
+ basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
22
+ Dir.mktmpdir do |dir|
23
+ system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
24
+ assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename)
25
+ end
26
+ end
27
+ end
28
+ end
bindings/ruby/tests/test_params.rb ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ require 'whisper'
2
+
3
+ class TestParams < Test::Unit::TestCase
4
+ def setup
5
+ @params = Whisper::Params.new
6
+ end
7
+
8
+ def test_language
9
+ @params.language = "en"
10
+ assert_equal @params.language, "en"
11
+ @params.language = "auto"
12
+ assert_equal @params.language, "auto"
13
+ end
14
+
15
+ def test_offset
16
+ @params.offset = 10_000
17
+ assert_equal @params.offset, 10_000
18
+ @params.offset = 0
19
+ assert_equal @params.offset, 0
20
+ end
21
+
22
+ def test_duration
23
+ @params.duration = 60_000
24
+ assert_equal @params.duration, 60_000
25
+ @params.duration = 0
26
+ assert_equal @params.duration, 0
27
+ end
28
+
29
+ def test_max_text_tokens
30
+ @params.max_text_tokens = 300
31
+ assert_equal @params.max_text_tokens, 300
32
+ @params.max_text_tokens = 0
33
+ assert_equal @params.max_text_tokens, 0
34
+ end
35
+
36
+ def test_translate
37
+ @params.translate = true
38
+ assert @params.translate
39
+ @params.translate = false
40
41
+ end
42
+
43
+ def test_no_context
44
+ @params.no_context = true
45
+ assert @params.no_context
46
+ @params.no_context = false
47
+ assert [email protected]_context
48
+ end
49
+
50
+ def test_single_segment
51
+ @params.single_segment = true
52
+ assert @params.single_segment
53
+ @params.single_segment = false
54
+ assert [email protected]_segment
55
+ end
56
+
57
+ def test_print_special
58
+ @params.print_special = true
59
+ assert @params.print_special
60
+ @params.print_special = false
61
+ assert [email protected]_special
62
+ end
63
+
64
+ def test_print_progress
65
+ @params.print_progress = true
66
+ assert @params.print_progress
67
+ @params.print_progress = false
68
+ assert [email protected]_progress
69
+ end
70
+
71
+ def test_print_realtime
72
+ @params.print_realtime = true
73
+ assert @params.print_realtime
74
+ @params.print_realtime = false
75
+ assert [email protected]_realtime
76
+ end
77
+
78
+ def test_print_timestamps
79
+ @params.print_timestamps = true
80
+ assert @params.print_timestamps
81
+ @params.print_timestamps = false
82
+ assert [email protected]_timestamps
83
+ end
84
+
85
+ def test_suppress_blank
86
+ @params.suppress_blank = true
87
+ assert @params.suppress_blank
88
+ @params.suppress_blank = false
89
+ assert [email protected]_blank
90
+ end
91
+
92
+ def test_suppress_non_speech_tokens
93
+ @params.suppress_non_speech_tokens = true
94
+ assert @params.suppress_non_speech_tokens
95
+ @params.suppress_non_speech_tokens = false
96
+ assert [email protected]_non_speech_tokens
97
+ end
98
+
99
+ def test_token_timestamps
100
+ @params.token_timestamps = true
101
+ assert @params.token_timestamps
102
+ @params.token_timestamps = false
103
+ assert [email protected]_timestamps
104
+ end
105
+
106
+ def test_split_on_word
107
+ @params.split_on_word = true
108
+ assert @params.split_on_word
109
+ @params.split_on_word = false
110
+ assert [email protected]_on_word
111
+ end
112
+ end
bindings/ruby/tests/test_segment.rb ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ require "test/unit"
2
+ require "whisper"
3
+
4
+ class TestSegment < Test::Unit::TestCase
5
+ TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
6
+
7
+ class << self
8
+ attr_reader :whisper
9
+
10
+ def startup
11
+ @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
12
+ params = Whisper::Params.new
13
+ params.print_timestamps = false
14
+ jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
15
+ @whisper.transcribe(jfk, params)
16
+ end
17
+ end
18
+
19
+ def test_iteration
20
+ whisper.each_segment do |segment|
21
+ assert_instance_of Whisper::Segment, segment
22
+ end
23
+ end
24
+
25
+ def test_enumerator
26
+ enum = whisper.each_segment
27
+ assert_instance_of Enumerator, enum
28
+ enum.to_a.each_with_index do |segment, index|
29
+ assert_instance_of Whisper::Segment, segment
30
+ assert_kind_of Integer, index
31
+ end
32
+ end
33
+
34
+ def test_start_time
35
+ i = 0
36
+ whisper.each_segment do |segment|
37
+ assert_equal 0, segment.start_time if i == 0
38
+ i += 1
39
+ end
40
+ end
41
+
42
+ def test_end_time
43
+ i = 0
44
+ whisper.each_segment do |segment|
45
+ assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
46
+ i += 1
47
+ end
48
+ end
49
+
50
+ def test_on_new_segment
51
+ params = Whisper::Params.new
52
+ seg = nil
53
+ index = 0
54
+ params.on_new_segment do |segment|
55
+ assert_instance_of Whisper::Segment, segment
56
+ if index == 0
57
+ seg = segment
58
+ assert_equal 0, segment.start_time
59
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
60
+ end
61
+ index += 1
62
+ end
63
+ whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
64
+ assert_equal 0, seg.start_time
65
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text
66
+ end
67
+
68
+ def test_on_new_segment_twice
69
+ params = Whisper::Params.new
70
+ seg = nil
71
+ params.on_new_segment do |segment|
72
+ seg = segment
73
+ return
74
+ end
75
+ params.on_new_segment do |segment|
76
+ assert_same seg, segment
77
+ return
78
+ end
79
+ whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
80
+ end
81
+
82
+ private
83
+
84
+ def whisper
85
+ self.class.whisper
86
+ end
87
+ end
bindings/ruby/tests/test_whisper.rb CHANGED
@@ -1,151 +1,99 @@
1
- TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
2
-
3
  require 'whisper'
4
  require 'test/unit'
5
- require 'tempfile'
6
- require 'tmpdir'
7
- require 'shellwords'
8
 
9
  class TestWhisper < Test::Unit::TestCase
 
 
10
  def setup
11
  @params = Whisper::Params.new
12
  end
13
 
14
- def test_language
15
- @params.language = "en"
16
- assert_equal @params.language, "en"
17
- @params.language = "auto"
18
- assert_equal @params.language, "auto"
19
- end
20
-
21
- def test_offset
22
- @params.offset = 10_000
23
- assert_equal @params.offset, 10_000
24
- @params.offset = 0
25
- assert_equal @params.offset, 0
26
- end
27
-
28
- def test_duration
29
- @params.duration = 60_000
30
- assert_equal @params.duration, 60_000
31
- @params.duration = 0
32
- assert_equal @params.duration, 0
33
- end
34
-
35
- def test_max_text_tokens
36
- @params.max_text_tokens = 300
37
- assert_equal @params.max_text_tokens, 300
38
- @params.max_text_tokens = 0
39
- assert_equal @params.max_text_tokens, 0
40
- end
41
-
42
- def test_translate
43
- @params.translate = true
44
- assert @params.translate
45
- @params.translate = false
46
47
- end
48
 
49
- def test_no_context
50
- @params.no_context = true
51
- assert @params.no_context
52
- @params.no_context = false
53
- assert [email protected]_context
54
  end
55
 
56
- def test_single_segment
57
- @params.single_segment = true
58
- assert @params.single_segment
59
- @params.single_segment = false
60
- assert [email protected]_segment
61
- end
62
 
63
- def test_print_special
64
- @params.print_special = true
65
- assert @params.print_special
66
- @params.print_special = false
67
- assert !@params.print_special
68
- end
 
 
69
 
70
- def test_print_progress
71
- @params.print_progress = true
72
- assert @params.print_progress
73
- @params.print_progress = false
74
- assert [email protected]_progress
75
- end
76
 
77
- def test_print_realtime
78
- @params.print_realtime = true
79
- assert @params.print_realtime
80
- @params.print_realtime = false
81
- assert [email protected]_realtime
82
- end
83
 
84
- def test_print_timestamps
85
- @params.print_timestamps = true
86
- assert @params.print_timestamps
87
- @params.print_timestamps = false
88
- assert [email protected]_timestamps
89
- end
90
 
91
- def test_suppress_blank
92
- @params.suppress_blank = true
93
- assert @params.suppress_blank
94
- @params.suppress_blank = false
95
- assert [email protected]_blank
96
- end
 
 
 
97
 
98
- def test_suppress_non_speech_tokens
99
- @params.suppress_non_speech_tokens = true
100
- assert @params.suppress_non_speech_tokens
101
- @params.suppress_non_speech_tokens = false
102
- assert [email protected]_non_speech_tokens
103
- end
 
 
104
 
105
- def test_token_timestamps
106
- @params.token_timestamps = true
107
- assert @params.token_timestamps
108
- @params.token_timestamps = false
109
- assert [email protected]_timestamps
110
- end
111
 
112
- def test_split_on_word
113
- @params.split_on_word = true
114
- assert @params.split_on_word
115
- @params.split_on_word = false
116
- assert [email protected]_on_word
117
  end
118
 
119
- def test_whisper
120
- @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
121
- params = Whisper::Params.new
122
- params.print_timestamps = false
123
-
124
- jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
125
- @whisper.transcribe(jfk, params) {|text|
126
- assert_match /ask not what your country can do for you, ask what you can do for your country/, text
127
- }
128
  end
129
 
130
- def test_build
131
- Tempfile.create do |file|
132
- assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
133
- assert_path_exist file.to_path
134
  end
135
  end
136
 
137
- sub_test_case "Building binary on installation" do
138
- def setup
139
- system "rake", "build", exception: true
 
140
  end
 
141
 
142
- def test_install
143
- filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1]
144
- basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
145
- Dir.mktmpdir do |dir|
146
- system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
147
- assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename)
148
- end
149
  end
150
  end
151
  end
 
 
 
1
  require 'whisper'
2
  require 'test/unit'
 
 
 
3
 
4
  class TestWhisper < Test::Unit::TestCase
5
+ TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
6
+
7
  def setup
8
  @params = Whisper::Params.new
9
  end
10
 
11
+ def test_whisper
12
+ @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
13
+ params = Whisper::Params.new
14
+ params.print_timestamps = false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
17
+ @whisper.transcribe(jfk, params) {|text|
18
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, text
19
+ }
 
20
  end
21
 
22
+ sub_test_case "After transcription" do
23
+ class << self
24
+ attr_reader :whisper
 
 
 
25
 
26
+ def startup
27
+ @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
28
+ params = Whisper::Params.new
29
+ params.print_timestamps = false
30
+ jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
31
+ @whisper.transcribe(jfk, params)
32
+ end
33
+ end
34
 
35
+ def whisper
36
+ self.class.whisper
37
+ end
 
 
 
38
 
39
+ def test_full_n_segments
40
+ assert_equal 1, whisper.full_n_segments
41
+ end
 
 
 
42
 
43
+ def test_full_lang_id
44
+ assert_equal 0, whisper.full_lang_id
45
+ end
 
 
 
46
 
47
+ def test_full_get_segment_t0
48
+ assert_equal 0, whisper.full_get_segment_t0(0)
49
+ assert_raise IndexError do
50
+ whisper.full_get_segment_t0(whisper.full_n_segments)
51
+ end
52
+ assert_raise IndexError do
53
+ whisper.full_get_segment_t0(-1)
54
+ end
55
+ end
56
 
57
+ def test_full_get_segment_t1
58
+ t1 = whisper.full_get_segment_t1(0)
59
+ assert_kind_of Integer, t1
60
+ assert t1 > 0
61
+ assert_raise IndexError do
62
+ whisper.full_get_segment_t1(whisper.full_n_segments)
63
+ end
64
+ end
65
 
66
+ def test_full_get_segment_speaker_turn_next
67
+ assert_false whisper.full_get_segment_speaker_turn_next(0)
68
+ end
 
 
 
69
 
70
+ def test_full_get_segment_text
71
+ assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
72
+ end
 
 
73
  end
74
 
75
+ def test_lang_max_id
76
+ assert_kind_of Integer, Whisper.lang_max_id
 
 
 
 
 
 
 
77
  end
78
 
79
+ def test_lang_id
80
+ assert_equal 0, Whisper.lang_id("en")
81
+ assert_raise ArgumentError do
82
+ Whisper.lang_id("non existing language")
83
  end
84
  end
85
 
86
+ def test_lang_str
87
+ assert_equal "en", Whisper.lang_str(0)
88
+ assert_raise IndexError do
89
+ Whisper.lang_str(Whisper.lang_max_id + 1)
90
  end
91
+ end
92
 
93
+ def test_lang_str_full
94
+ assert_equal "english", Whisper.lang_str_full(0)
95
+ assert_raise IndexError do
96
+ Whisper.lang_str_full(Whisper.lang_max_id + 1)
 
 
 
97
  end
98
  end
99
  end
bindings/ruby/whispercpp.gemspec CHANGED
@@ -9,7 +9,15 @@ Gem::Specification.new do |s|
9
  s.email = '[email protected]'
10
  s.extra_rdoc_files = ['LICENSE', 'README.md']
11
 
12
- s.files = `git ls-files . -z`.split("\x0") + YAML.load_file("extsources.yaml").values.flatten
 
 
 
 
 
 
 
 
13
 
14
  s.summary = %q{Ruby whisper.cpp bindings}
15
  s.test_files = ["tests/test_whisper.rb"]
 
9
  s.email = '[email protected]'
10
  s.extra_rdoc_files = ['LICENSE', 'README.md']
11
 
12
+ s.files = `git ls-files . -z`.split("\x0") +
13
+ YAML.load_file("extsources.yaml").collect {|file|
14
+ basename = File.basename(file)
15
+ if s.extra_rdoc_files.include?(basename)
16
+ basename
17
+ else
18
+ File.join("ext", basename)
19
+ end
20
+ }
21
 
22
  s.summary = %q{Ruby whisper.cpp bindings}
23
  s.test_files = ["tests/test_whisper.rb"]