ggerganov commited on
Commit
b5e16ed
·
unverified ·
1 Parent(s): 64508b4

main : add stereo-channel-based diarization (#64)

Browse files
Files changed (1) hide show
  1. examples/main/main.cpp +71 -10
examples/main/main.cpp CHANGED
@@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) {
36
  return std::string(buf);
37
  }
38
 
 
 
 
 
39
  // helper function to replace substrings
40
  void replace_all(std::string & s, const std::string & search, const std::string & replace) {
41
  for (size_t pos = 0; ; pos += replace.length()) {
@@ -60,6 +64,7 @@ struct whisper_params {
60
 
61
  bool speed_up = false;
62
  bool translate = false;
 
63
  bool output_txt = false;
64
  bool output_vtt = false;
65
  bool output_srt = false;
@@ -99,6 +104,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
99
  else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
100
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
101
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
 
102
  else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
103
  else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
104
  else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
@@ -135,6 +141,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
135
  fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
136
  fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
137
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
 
138
  fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
139
  fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
140
  fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
@@ -148,8 +155,15 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
148
  fprintf(stderr, "\n");
149
  }
150
 
 
 
 
 
 
 
151
  void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
152
- const whisper_params & params = *(whisper_params *) user_data;
 
153
 
154
  const int n_segments = whisper_full_n_segments(ctx);
155
 
@@ -186,6 +200,33 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
186
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
187
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  if (params.print_colors) {
190
  printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
191
  for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
@@ -201,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
201
 
202
  const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
203
 
204
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
205
  }
206
  printf("\n");
207
  } else {
208
  const char * text = whisper_full_get_segment_text(ctx, i);
209
 
210
- printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
211
  }
212
  }
213
  }
@@ -235,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
235
  std::ofstream fout(fname);
236
  if (!fout.is_open()) {
237
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
238
- return 9;
239
  }
240
 
241
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@@ -425,6 +466,7 @@ int main(int argc, char ** argv) {
425
  const auto fname_inp = params.fname_inp[f];
426
 
427
  std::vector<float> pcmf32; // mono-channel F32 PCM
 
428
 
429
  // WAV input
430
  {
@@ -453,22 +495,27 @@ int main(int argc, char ** argv) {
453
  }
454
  else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
455
  fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
456
- return 4;
457
  }
458
 
459
  if (wav.channels != 1 && wav.channels != 2) {
460
  fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
461
- return 5;
 
 
 
 
 
462
  }
463
 
464
  if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
465
  fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
466
- return 6;
467
  }
468
 
469
  if (wav.bitsPerSample != 16) {
470
  fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
471
- return 7;
472
  }
473
 
474
  const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
@@ -489,6 +536,18 @@ int main(int argc, char ** argv) {
489
  pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
490
  }
491
  }
 
 
 
 
 
 
 
 
 
 
 
 
492
  }
493
 
494
  // print system information
@@ -540,15 +599,17 @@ int main(int argc, char ** argv) {
540
 
541
  wparams.speed_up = params.speed_up;
542
 
 
 
543
  // this callback is called on each new segment
544
  if (!wparams.print_realtime) {
545
  wparams.new_segment_callback = whisper_print_segment_callback;
546
- wparams.new_segment_callback_user_data = &params;
547
  }
548
 
549
  if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
550
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
551
- return 8;
552
  }
553
  }
554
 
 
36
  return std::string(buf);
37
  }
38
 
39
+ int timestamp_to_sample(int64_t t, int n_samples) {
40
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
41
+ }
42
+
43
  // helper function to replace substrings
44
  void replace_all(std::string & s, const std::string & search, const std::string & replace) {
45
  for (size_t pos = 0; ; pos += replace.length()) {
 
64
 
65
  bool speed_up = false;
66
  bool translate = false;
67
+ bool diarize = false;
68
  bool output_txt = false;
69
  bool output_vtt = false;
70
  bool output_srt = false;
 
104
  else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
105
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
106
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
107
+ else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
108
  else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
109
  else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
110
  else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
 
141
  fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
142
  fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
143
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
144
+ fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
145
  fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
146
  fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
147
  fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
 
155
  fprintf(stderr, "\n");
156
  }
157
 
158
+ struct whisper_print_user_data {
159
+ const whisper_params * params;
160
+
161
+ const std::vector<std::vector<float>> * pcmf32s;
162
+ };
163
+
164
  void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
165
+ const auto & params = *((whisper_print_user_data *) user_data)->params;
166
+ const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
167
 
168
  const int n_segments = whisper_full_n_segments(ctx);
169
 
 
200
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
201
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
202
 
203
+ std::string speaker = "";
204
+
205
+ if (params.diarize && pcmf32s.size() == 2) {
206
+ const int64_t n_samples = pcmf32s[0].size();
207
+
208
+ const int64_t is0 = timestamp_to_sample(t0, n_samples);
209
+ const int64_t is1 = timestamp_to_sample(t1, n_samples);
210
+
211
+ double energy0 = 0.0f;
212
+ double energy1 = 0.0f;
213
+
214
+ for (int64_t j = is0; j < is1; j++) {
215
+ energy0 += fabs(pcmf32s[0][j]);
216
+ energy1 += fabs(pcmf32s[1][j]);
217
+ }
218
+
219
+ if (energy0 > 1.1*energy1) {
220
+ speaker = "(speaker 0)";
221
+ } else if (energy1 > 1.1*energy0) {
222
+ speaker = "(speaker 1)";
223
+ } else {
224
+ speaker = "(speaker ?)";
225
+ }
226
+
227
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
228
+ }
229
+
230
  if (params.print_colors) {
231
  printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
232
  for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
 
242
 
243
  const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
244
 
245
+ printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
246
  }
247
  printf("\n");
248
  } else {
249
  const char * text = whisper_full_get_segment_text(ctx, i);
250
 
251
+ printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
252
  }
253
  }
254
  }
 
276
  std::ofstream fout(fname);
277
  if (!fout.is_open()) {
278
  fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
279
+ return false;
280
  }
281
 
282
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
 
466
  const auto fname_inp = params.fname_inp[f];
467
 
468
  std::vector<float> pcmf32; // mono-channel F32 PCM
469
+ std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
470
 
471
  // WAV input
472
  {
 
495
  }
496
  else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
497
  fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
498
+ return 5;
499
  }
500
 
501
  if (wav.channels != 1 && wav.channels != 2) {
502
  fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
503
+ return 6;
504
+ }
505
+
506
+ if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
507
+ fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
508
+ return 6;
509
  }
510
 
511
  if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
512
  fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
513
+ return 8;
514
  }
515
 
516
  if (wav.bitsPerSample != 16) {
517
  fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
518
+ return 9;
519
  }
520
 
521
  const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
 
536
  pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
537
  }
538
  }
539
+
540
+ if (params.diarize) {
541
+ // convert to stereo, float
542
+ pcmf32s.resize(2);
543
+
544
+ pcmf32s[0].resize(n);
545
+ pcmf32s[1].resize(n);
546
+ for (int i = 0; i < n; i++) {
547
+ pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
548
+ pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
549
+ }
550
+ }
551
  }
552
 
553
  // print system information
 
599
 
600
  wparams.speed_up = params.speed_up;
601
 
602
+ whisper_print_user_data user_data = { &params, &pcmf32s };
603
+
604
  // this callback is called on each new segment
605
  if (!wparams.print_realtime) {
606
  wparams.new_segment_callback = whisper_print_segment_callback;
607
+ wparams.new_segment_callback_user_data = &user_data;
608
  }
609
 
610
  if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
611
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
612
+ return 10;
613
  }
614
  }
615