Spaces:
Sleeping
Sleeping
main : add stereo-channel-based diarization (#64)
Browse files- examples/main/main.cpp +71 -10
examples/main/main.cpp
CHANGED
|
@@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) {
|
|
| 36 |
return std::string(buf);
|
| 37 |
}
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
// helper function to replace substrings
|
| 40 |
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
| 41 |
for (size_t pos = 0; ; pos += replace.length()) {
|
|
@@ -60,6 +64,7 @@ struct whisper_params {
|
|
| 60 |
|
| 61 |
bool speed_up = false;
|
| 62 |
bool translate = false;
|
|
|
|
| 63 |
bool output_txt = false;
|
| 64 |
bool output_vtt = false;
|
| 65 |
bool output_srt = false;
|
|
@@ -99,6 +104,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 99 |
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
| 100 |
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
| 101 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
|
|
|
| 102 |
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
| 103 |
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
| 104 |
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
|
@@ -135,6 +141,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 135 |
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
| 136 |
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
| 137 |
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
|
|
|
| 138 |
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
| 139 |
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
| 140 |
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
|
@@ -148,8 +155,15 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 148 |
fprintf(stderr, "\n");
|
| 149 |
}
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
| 152 |
-
const
|
|
|
|
| 153 |
|
| 154 |
const int n_segments = whisper_full_n_segments(ctx);
|
| 155 |
|
|
@@ -186,6 +200,33 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
|
| 186 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 187 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
if (params.print_colors) {
|
| 190 |
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
| 191 |
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
|
@@ -201,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
|
| 201 |
|
| 202 |
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
| 203 |
|
| 204 |
-
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
| 205 |
}
|
| 206 |
printf("\n");
|
| 207 |
} else {
|
| 208 |
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 209 |
|
| 210 |
-
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
| 211 |
}
|
| 212 |
}
|
| 213 |
}
|
|
@@ -235,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
|
|
| 235 |
std::ofstream fout(fname);
|
| 236 |
if (!fout.is_open()) {
|
| 237 |
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
| 238 |
-
return
|
| 239 |
}
|
| 240 |
|
| 241 |
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
@@ -425,6 +466,7 @@ int main(int argc, char ** argv) {
|
|
| 425 |
const auto fname_inp = params.fname_inp[f];
|
| 426 |
|
| 427 |
std::vector<float> pcmf32; // mono-channel F32 PCM
|
|
|
|
| 428 |
|
| 429 |
// WAV input
|
| 430 |
{
|
|
@@ -453,22 +495,27 @@ int main(int argc, char ** argv) {
|
|
| 453 |
}
|
| 454 |
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
|
| 455 |
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
| 456 |
-
return
|
| 457 |
}
|
| 458 |
|
| 459 |
if (wav.channels != 1 && wav.channels != 2) {
|
| 460 |
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
|
| 461 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
}
|
| 463 |
|
| 464 |
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
| 465 |
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
|
| 466 |
-
return
|
| 467 |
}
|
| 468 |
|
| 469 |
if (wav.bitsPerSample != 16) {
|
| 470 |
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
|
| 471 |
-
return
|
| 472 |
}
|
| 473 |
|
| 474 |
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
|
@@ -489,6 +536,18 @@ int main(int argc, char ** argv) {
|
|
| 489 |
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
| 490 |
}
|
| 491 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
}
|
| 493 |
|
| 494 |
// print system information
|
|
@@ -540,15 +599,17 @@ int main(int argc, char ** argv) {
|
|
| 540 |
|
| 541 |
wparams.speed_up = params.speed_up;
|
| 542 |
|
|
|
|
|
|
|
| 543 |
// this callback is called on each new segment
|
| 544 |
if (!wparams.print_realtime) {
|
| 545 |
wparams.new_segment_callback = whisper_print_segment_callback;
|
| 546 |
-
wparams.new_segment_callback_user_data = &
|
| 547 |
}
|
| 548 |
|
| 549 |
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
| 550 |
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
| 551 |
-
return
|
| 552 |
}
|
| 553 |
}
|
| 554 |
|
|
|
|
| 36 |
return std::string(buf);
|
| 37 |
}
|
| 38 |
|
| 39 |
+
int timestamp_to_sample(int64_t t, int n_samples) {
|
| 40 |
+
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
// helper function to replace substrings
|
| 44 |
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
| 45 |
for (size_t pos = 0; ; pos += replace.length()) {
|
|
|
|
| 64 |
|
| 65 |
bool speed_up = false;
|
| 66 |
bool translate = false;
|
| 67 |
+
bool diarize = false;
|
| 68 |
bool output_txt = false;
|
| 69 |
bool output_vtt = false;
|
| 70 |
bool output_srt = false;
|
|
|
|
| 104 |
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
| 105 |
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
| 106 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
| 107 |
+
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
| 108 |
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
| 109 |
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
| 110 |
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
|
|
|
| 141 |
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
| 142 |
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
| 143 |
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
| 144 |
+
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
| 145 |
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
| 146 |
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
| 147 |
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
|
|
|
| 155 |
fprintf(stderr, "\n");
|
| 156 |
}
|
| 157 |
|
| 158 |
+
struct whisper_print_user_data {
|
| 159 |
+
const whisper_params * params;
|
| 160 |
+
|
| 161 |
+
const std::vector<std::vector<float>> * pcmf32s;
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
| 165 |
+
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
| 166 |
+
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
| 167 |
|
| 168 |
const int n_segments = whisper_full_n_segments(ctx);
|
| 169 |
|
|
|
|
| 200 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 201 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 202 |
|
| 203 |
+
std::string speaker = "";
|
| 204 |
+
|
| 205 |
+
if (params.diarize && pcmf32s.size() == 2) {
|
| 206 |
+
const int64_t n_samples = pcmf32s[0].size();
|
| 207 |
+
|
| 208 |
+
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
| 209 |
+
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
| 210 |
+
|
| 211 |
+
double energy0 = 0.0f;
|
| 212 |
+
double energy1 = 0.0f;
|
| 213 |
+
|
| 214 |
+
for (int64_t j = is0; j < is1; j++) {
|
| 215 |
+
energy0 += fabs(pcmf32s[0][j]);
|
| 216 |
+
energy1 += fabs(pcmf32s[1][j]);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
if (energy0 > 1.1*energy1) {
|
| 220 |
+
speaker = "(speaker 0)";
|
| 221 |
+
} else if (energy1 > 1.1*energy0) {
|
| 222 |
+
speaker = "(speaker 1)";
|
| 223 |
+
} else {
|
| 224 |
+
speaker = "(speaker ?)";
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
if (params.print_colors) {
|
| 231 |
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
| 232 |
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
|
|
|
| 242 |
|
| 243 |
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
| 244 |
|
| 245 |
+
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
| 246 |
}
|
| 247 |
printf("\n");
|
| 248 |
} else {
|
| 249 |
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 250 |
|
| 251 |
+
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
|
| 252 |
}
|
| 253 |
}
|
| 254 |
}
|
|
|
|
| 276 |
std::ofstream fout(fname);
|
| 277 |
if (!fout.is_open()) {
|
| 278 |
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
| 279 |
+
return false;
|
| 280 |
}
|
| 281 |
|
| 282 |
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
|
|
| 466 |
const auto fname_inp = params.fname_inp[f];
|
| 467 |
|
| 468 |
std::vector<float> pcmf32; // mono-channel F32 PCM
|
| 469 |
+
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
| 470 |
|
| 471 |
// WAV input
|
| 472 |
{
|
|
|
|
| 495 |
}
|
| 496 |
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
|
| 497 |
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
| 498 |
+
return 5;
|
| 499 |
}
|
| 500 |
|
| 501 |
if (wav.channels != 1 && wav.channels != 2) {
|
| 502 |
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
|
| 503 |
+
return 6;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
|
| 507 |
+
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
|
| 508 |
+
return 6;
|
| 509 |
}
|
| 510 |
|
| 511 |
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
| 512 |
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
|
| 513 |
+
return 8;
|
| 514 |
}
|
| 515 |
|
| 516 |
if (wav.bitsPerSample != 16) {
|
| 517 |
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
|
| 518 |
+
return 9;
|
| 519 |
}
|
| 520 |
|
| 521 |
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
|
|
|
| 536 |
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
| 537 |
}
|
| 538 |
}
|
| 539 |
+
|
| 540 |
+
if (params.diarize) {
|
| 541 |
+
// convert to stereo, float
|
| 542 |
+
pcmf32s.resize(2);
|
| 543 |
+
|
| 544 |
+
pcmf32s[0].resize(n);
|
| 545 |
+
pcmf32s[1].resize(n);
|
| 546 |
+
for (int i = 0; i < n; i++) {
|
| 547 |
+
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
| 548 |
+
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
| 549 |
+
}
|
| 550 |
+
}
|
| 551 |
}
|
| 552 |
|
| 553 |
// print system information
|
|
|
|
| 599 |
|
| 600 |
wparams.speed_up = params.speed_up;
|
| 601 |
|
| 602 |
+
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
| 603 |
+
|
| 604 |
// this callback is called on each new segment
|
| 605 |
if (!wparams.print_realtime) {
|
| 606 |
wparams.new_segment_callback = whisper_print_segment_callback;
|
| 607 |
+
wparams.new_segment_callback_user_data = &user_data;
|
| 608 |
}
|
| 609 |
|
| 610 |
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
| 611 |
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
| 612 |
+
return 10;
|
| 613 |
}
|
| 614 |
}
|
| 615 |
|