Spaces:
Running
Running
whisper : add whisper_state + default state on the whisper_context (#523)
Browse files* Added whisper state + default state on the whisper_context
* Fixed some examples and bindings
* Fixed whisper_n_len (which was used in some binding) and added whisper_n_len_from_state
* Fixed comments
* whisper : reuse kv_cache_free() and fix compiler warnings
* whisper : clean-up the API comments
---------
Co-authored-by: Sandro Hanea <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
- bindings/go/whisper.go +2 -2
- bindings/ruby/ext/ruby_whisper.cpp +1 -1
- examples/addon.node/addon.cpp +2 -2
- examples/main/main.cpp +2 -2
- whisper.cpp +576 -406
- whisper.h +118 -40
bindings/go/whisper.go
CHANGED
|
@@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data);
|
|
| 20 |
// Text segment callback
|
| 21 |
// Called on every newly generated text segment
|
| 22 |
// Use the whisper_full_...() functions to obtain the text segments
|
| 23 |
-
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
|
| 24 |
if(user_data != NULL && ctx != NULL) {
|
| 25 |
callNewSegment(user_data, n_new);
|
| 26 |
}
|
|
@@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void*
|
|
| 29 |
// Encoder begin callback
|
| 30 |
// If not NULL, called before the encoder starts
|
| 31 |
// If it returns false, the computation is aborted
|
| 32 |
-
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
|
| 33 |
if(user_data != NULL && ctx != NULL) {
|
| 34 |
return callEncoderBegin(user_data);
|
| 35 |
}
|
|
|
|
| 20 |
// Text segment callback
|
| 21 |
// Called on every newly generated text segment
|
| 22 |
// Use the whisper_full_...() functions to obtain the text segments
|
| 23 |
+
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
|
| 24 |
if(user_data != NULL && ctx != NULL) {
|
| 25 |
callNewSegment(user_data, n_new);
|
| 26 |
}
|
|
|
|
| 29 |
// Encoder begin callback
|
| 30 |
// If not NULL, called before the encoder starts
|
| 31 |
// If it returns false, the computation is aborted
|
| 32 |
+
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
|
| 33 |
if(user_data != NULL && ctx != NULL) {
|
| 34 |
return callEncoderBegin(user_data);
|
| 35 |
}
|
bindings/ruby/ext/ruby_whisper.cpp
CHANGED
|
@@ -199,7 +199,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
|
| 199 |
{
|
| 200 |
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
| 201 |
|
| 202 |
-
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
|
| 203 |
bool is_aborted = *(bool*)user_data;
|
| 204 |
return !is_aborted;
|
| 205 |
};
|
|
|
|
| 199 |
{
|
| 200 |
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
| 201 |
|
| 202 |
+
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
| 203 |
bool is_aborted = *(bool*)user_data;
|
| 204 |
return !is_aborted;
|
| 205 |
};
|
examples/addon.node/addon.cpp
CHANGED
|
@@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) {
|
|
| 72 |
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
| 73 |
}
|
| 74 |
|
| 75 |
-
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
| 76 |
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
| 77 |
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
| 78 |
|
|
@@ -260,7 +260,7 @@ int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
|
| 260 |
{
|
| 261 |
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
| 262 |
|
| 263 |
-
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
|
| 264 |
bool is_aborted = *(bool*)user_data;
|
| 265 |
return !is_aborted;
|
| 266 |
};
|
|
|
|
| 72 |
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
| 73 |
}
|
| 74 |
|
| 75 |
+
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
|
| 76 |
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
| 77 |
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
| 78 |
|
|
|
|
| 260 |
{
|
| 261 |
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
| 262 |
|
| 263 |
+
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
| 264 |
bool is_aborted = *(bool*)user_data;
|
| 265 |
return !is_aborted;
|
| 266 |
};
|
examples/main/main.cpp
CHANGED
|
@@ -193,7 +193,7 @@ struct whisper_print_user_data {
|
|
| 193 |
const std::vector<std::vector<float>> * pcmf32s;
|
| 194 |
};
|
| 195 |
|
| 196 |
-
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
| 197 |
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
| 198 |
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
| 199 |
|
|
@@ -608,7 +608,7 @@ int main(int argc, char ** argv) {
|
|
| 608 |
{
|
| 609 |
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
| 610 |
|
| 611 |
-
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
|
| 612 |
bool is_aborted = *(bool*)user_data;
|
| 613 |
return !is_aborted;
|
| 614 |
};
|
|
|
|
| 193 |
const std::vector<std::vector<float>> * pcmf32s;
|
| 194 |
};
|
| 195 |
|
| 196 |
+
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
|
| 197 |
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
| 198 |
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
| 199 |
|
|
|
|
| 608 |
{
|
| 609 |
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
| 610 |
|
| 611 |
+
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
| 612 |
bool is_aborted = *(bool*)user_data;
|
| 613 |
return !is_aborted;
|
| 614 |
};
|
whisper.cpp
CHANGED
|
@@ -547,13 +547,11 @@ struct whisper_decoder {
|
|
| 547 |
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
| 548 |
};
|
| 549 |
|
| 550 |
-
struct
|
| 551 |
-
int64_t t_load_us = 0;
|
| 552 |
-
int64_t t_mel_us = 0;
|
| 553 |
int64_t t_sample_us = 0;
|
| 554 |
int64_t t_encode_us = 0;
|
| 555 |
int64_t t_decode_us = 0;
|
| 556 |
-
int64_t
|
| 557 |
|
| 558 |
int32_t n_sample = 0; // number of tokens sampled
|
| 559 |
int32_t n_encode = 0; // number of encoder calls
|
|
@@ -561,16 +559,10 @@ struct whisper_context {
|
|
| 561 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 562 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 563 |
|
| 564 |
-
ggml_type wtype; // weight type (FP32 or FP16)
|
| 565 |
-
|
| 566 |
-
whisper_mel mel;
|
| 567 |
-
|
| 568 |
-
whisper_model model;
|
| 569 |
-
whisper_vocab vocab;
|
| 570 |
-
|
| 571 |
// cross-attention KV cache for the decoders
|
| 572 |
// shared between all decoders
|
| 573 |
whisper_kv_cache kv_cross;
|
|
|
|
| 574 |
|
| 575 |
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
| 576 |
|
|
@@ -635,6 +627,18 @@ struct whisper_context {
|
|
| 635 |
}
|
| 636 |
};
|
| 637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
template<typename T>
|
| 639 |
static void read_safe(whisper_model_loader * loader, T & dest) {
|
| 640 |
loader->read(loader->context, &dest, sizeof(T));
|
|
@@ -821,32 +825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 821 |
wctx.model.buf = new std::vector<uint8_t>();
|
| 822 |
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
| 823 |
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
return false;
|
| 827 |
-
}
|
| 828 |
-
|
| 829 |
-
{
|
| 830 |
-
const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
|
| 831 |
-
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
| 832 |
-
}
|
| 833 |
-
|
| 834 |
-
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
|
| 835 |
-
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 836 |
-
return false;
|
| 837 |
-
}
|
| 838 |
-
|
| 839 |
-
{
|
| 840 |
-
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
|
| 841 |
-
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
| 842 |
-
}
|
| 843 |
-
|
| 844 |
-
wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
|
| 845 |
-
|
| 846 |
-
wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
|
| 847 |
-
wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
|
| 848 |
-
wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
|
| 849 |
-
wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
|
| 850 |
}
|
| 851 |
|
| 852 |
// load mel filters
|
|
@@ -929,17 +909,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 929 |
vocab.id_to_token[i] = word;
|
| 930 |
}
|
| 931 |
}
|
| 932 |
-
|
| 933 |
-
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
| 934 |
-
|
| 935 |
-
wctx.logits_id.reserve(n_vocab);
|
| 936 |
-
|
| 937 |
-
// TAGS: WHISPER_DECODER_INIT
|
| 938 |
-
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
| 939 |
-
|
| 940 |
-
wctx.decoders[0].probs.reserve (vocab.n_vocab);
|
| 941 |
-
wctx.decoders[0].logits.reserve (vocab.n_vocab);
|
| 942 |
-
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
|
| 943 |
}
|
| 944 |
|
| 945 |
size_t ctx_size = 0;
|
|
@@ -1339,33 +1308,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1339 |
}
|
| 1340 |
}
|
| 1341 |
|
| 1342 |
-
wctx.rng = std::mt19937(0);
|
| 1343 |
-
|
| 1344 |
wctx.t_load_us = ggml_time_us() - t_start_us;
|
| 1345 |
|
| 1346 |
return true;
|
| 1347 |
}
|
| 1348 |
|
| 1349 |
-
// evaluate the encoder
|
| 1350 |
//
|
| 1351 |
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
| 1352 |
// part of the transformer model and returns the encoded features
|
| 1353 |
//
|
| 1354 |
-
// -
|
|
|
|
| 1355 |
// - n_threads: number of threads to use
|
| 1356 |
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
| 1357 |
//
|
| 1358 |
-
static bool
|
| 1359 |
whisper_context & wctx,
|
|
|
|
| 1360 |
const int mel_offset,
|
| 1361 |
-
const int n_threads)
|
|
|
|
| 1362 |
const int64_t t_start_us = ggml_time_us();
|
| 1363 |
|
| 1364 |
const auto & model = wctx.model;
|
| 1365 |
-
const auto & mel_inp =
|
| 1366 |
const auto & hparams = model.hparams;
|
| 1367 |
|
| 1368 |
-
const int n_ctx =
|
| 1369 |
const int n_state = hparams.n_audio_state;
|
| 1370 |
const int n_head = hparams.n_audio_head;
|
| 1371 |
const int n_layer = hparams.n_audio_layer;
|
|
@@ -1374,12 +1344,12 @@ static bool whisper_encode(
|
|
| 1374 |
assert(mel_inp.n_mel == n_mels);
|
| 1375 |
|
| 1376 |
struct ggml_init_params params;
|
| 1377 |
-
params.mem_size =
|
| 1378 |
-
params.mem_buffer =
|
| 1379 |
|
| 1380 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1381 |
|
| 1382 |
-
|
| 1383 |
|
| 1384 |
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
| 1385 |
assert(mel->type == GGML_TYPE_F32);
|
|
@@ -1401,30 +1371,30 @@ static bool whisper_encode(
|
|
| 1401 |
|
| 1402 |
// convolution + gelu
|
| 1403 |
{
|
| 1404 |
-
|
| 1405 |
|
| 1406 |
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
|
| 1407 |
cur = ggml_add(ctx0,
|
| 1408 |
-
|
| 1409 |
-
|
| 1410 |
-
|
| 1411 |
-
|
| 1412 |
|
| 1413 |
cur = ggml_gelu(ctx0, cur);
|
| 1414 |
|
| 1415 |
-
|
| 1416 |
|
| 1417 |
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
|
| 1418 |
cur = ggml_add(ctx0,
|
| 1419 |
-
|
| 1420 |
-
|
| 1421 |
-
|
| 1422 |
-
|
| 1423 |
|
| 1424 |
cur = ggml_gelu(ctx0, cur);
|
| 1425 |
}
|
| 1426 |
|
| 1427 |
-
|
| 1428 |
|
| 1429 |
// ===================================================================
|
| 1430 |
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
@@ -1439,7 +1409,7 @@ static bool whisper_encode(
|
|
| 1439 |
//}
|
| 1440 |
|
| 1441 |
static int iter = 0;
|
| 1442 |
-
|
| 1443 |
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
|
| 1444 |
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
| 1445 |
|
|
@@ -1459,54 +1429,54 @@ static bool whisper_encode(
|
|
| 1459 |
|
| 1460 |
// norm
|
| 1461 |
{
|
| 1462 |
-
|
| 1463 |
|
| 1464 |
cur = ggml_norm(ctx0, inpL);
|
| 1465 |
|
| 1466 |
// cur = ln_0_w*cur + ln_0_b
|
| 1467 |
cur = ggml_add(ctx0,
|
| 1468 |
-
|
| 1469 |
-
|
| 1470 |
-
|
| 1471 |
-
|
| 1472 |
}
|
| 1473 |
|
| 1474 |
// self-attention
|
| 1475 |
{
|
| 1476 |
-
|
| 1477 |
|
| 1478 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
|
| 1482 |
Qcur = ggml_add(ctx0,
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
|
| 1488 |
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
| 1489 |
|
| 1490 |
// note: no bias for Key
|
| 1491 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
|
| 1495 |
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
| 1496 |
|
| 1497 |
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
| 1498 |
-
|
| 1499 |
-
|
| 1500 |
|
| 1501 |
Vcur = ggml_add(ctx0,
|
| 1502 |
-
|
| 1503 |
-
|
| 1504 |
-
|
| 1505 |
-
|
| 1506 |
|
| 1507 |
// ------
|
| 1508 |
|
| 1509 |
-
|
| 1510 |
|
| 1511 |
#ifdef WHISPER_USE_FLASH_ATTN
|
| 1512 |
struct ggml_tensor * Q =
|
|
@@ -1583,29 +1553,29 @@ static bool whisper_encode(
|
|
| 1583 |
#endif
|
| 1584 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 1585 |
|
| 1586 |
-
|
| 1587 |
|
| 1588 |
cur = ggml_cpy(ctx0,
|
| 1589 |
-
|
| 1590 |
-
|
| 1591 |
}
|
| 1592 |
|
| 1593 |
// projection
|
| 1594 |
{
|
| 1595 |
-
|
| 1596 |
|
| 1597 |
cur = ggml_mul_mat(ctx0,
|
| 1598 |
-
|
| 1599 |
-
|
| 1600 |
|
| 1601 |
-
|
| 1602 |
|
| 1603 |
cur = ggml_add(ctx0,
|
| 1604 |
-
|
| 1605 |
-
|
| 1606 |
}
|
| 1607 |
|
| 1608 |
-
|
| 1609 |
|
| 1610 |
// add the input
|
| 1611 |
cur = ggml_add(ctx0, cur, inpL);
|
|
@@ -1616,61 +1586,61 @@ static bool whisper_encode(
|
|
| 1616 |
{
|
| 1617 |
// norm
|
| 1618 |
{
|
| 1619 |
-
|
| 1620 |
|
| 1621 |
cur = ggml_norm(ctx0, inpFF);
|
| 1622 |
|
| 1623 |
-
|
| 1624 |
|
| 1625 |
// cur = mlp_ln_w*cur + mlp_ln_b
|
| 1626 |
cur = ggml_add(ctx0,
|
| 1627 |
-
|
| 1628 |
-
|
| 1629 |
-
|
| 1630 |
-
|
| 1631 |
-
|
| 1632 |
|
| 1633 |
#ifdef WHISPER_USE_FLASH_FF
|
| 1634 |
-
|
| 1635 |
|
| 1636 |
cur = ggml_flash_ff(ctx0,
|
| 1637 |
-
|
| 1638 |
-
|
| 1639 |
#else
|
| 1640 |
-
|
| 1641 |
|
| 1642 |
// fully connected
|
| 1643 |
cur = ggml_mul_mat(ctx0,
|
| 1644 |
-
|
| 1645 |
-
|
| 1646 |
|
| 1647 |
-
|
| 1648 |
|
| 1649 |
cur = ggml_add(ctx0,
|
| 1650 |
-
|
| 1651 |
-
|
| 1652 |
|
| 1653 |
-
|
| 1654 |
|
| 1655 |
// GELU activation
|
| 1656 |
cur = ggml_gelu(ctx0, cur);
|
| 1657 |
|
| 1658 |
-
|
| 1659 |
|
| 1660 |
// projection
|
| 1661 |
cur = ggml_mul_mat(ctx0,
|
| 1662 |
-
|
| 1663 |
-
|
| 1664 |
|
| 1665 |
-
|
| 1666 |
|
| 1667 |
cur = ggml_add(ctx0,
|
| 1668 |
-
|
| 1669 |
-
|
| 1670 |
#endif
|
| 1671 |
-
|
| 1672 |
|
| 1673 |
-
|
| 1674 |
|
| 1675 |
inpL = ggml_add(ctx0, cur, inpFF);
|
| 1676 |
}
|
|
@@ -1679,21 +1649,21 @@ static bool whisper_encode(
|
|
| 1679 |
|
| 1680 |
// norm
|
| 1681 |
{
|
| 1682 |
-
|
| 1683 |
|
| 1684 |
cur = ggml_norm(ctx0, cur);
|
| 1685 |
|
| 1686 |
-
|
| 1687 |
|
| 1688 |
// cur = ln_f_g*cur + ln_f_b
|
| 1689 |
cur = ggml_add(ctx0,
|
| 1690 |
-
|
| 1691 |
-
|
| 1692 |
-
|
| 1693 |
-
|
| 1694 |
}
|
| 1695 |
|
| 1696 |
-
|
| 1697 |
|
| 1698 |
// run the computation
|
| 1699 |
{
|
|
@@ -1701,7 +1671,7 @@ static bool whisper_encode(
|
|
| 1701 |
gf.n_threads = n_threads;
|
| 1702 |
|
| 1703 |
ggml_build_forward_expand(&gf, cur);
|
| 1704 |
-
ggml_graph_compute
|
| 1705 |
|
| 1706 |
//ggml_graph_print(&gf);
|
| 1707 |
}
|
|
@@ -1731,34 +1701,34 @@ static bool whisper_encode(
|
|
| 1731 |
cur->src1 = nullptr;
|
| 1732 |
|
| 1733 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 1734 |
-
auto
|
| 1735 |
|
| 1736 |
-
|
| 1737 |
|
| 1738 |
-
struct ggml_tensor
|
| 1739 |
-
|
| 1740 |
-
|
| 1741 |
|
| 1742 |
-
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
| 1743 |
|
| 1744 |
-
|
| 1745 |
|
| 1746 |
-
struct ggml_tensor
|
| 1747 |
-
|
| 1748 |
-
|
| 1749 |
|
| 1750 |
Vcross = ggml_add(ctx0,
|
| 1751 |
-
|
| 1752 |
-
|
| 1753 |
-
|
| 1754 |
-
|
| 1755 |
|
| 1756 |
-
|
| 1757 |
|
| 1758 |
-
//struct ggml_tensor * k = ggml_view_1d(ctx0,
|
| 1759 |
-
//struct ggml_tensor * v = ggml_view_1d(ctx0,
|
| 1760 |
-
struct ggml_tensor
|
| 1761 |
-
struct ggml_tensor
|
| 1762 |
|
| 1763 |
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
| 1764 |
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
|
@@ -1779,8 +1749,8 @@ static bool whisper_encode(
|
|
| 1779 |
|
| 1780 |
ggml_free(ctx0);
|
| 1781 |
|
| 1782 |
-
|
| 1783 |
-
|
| 1784 |
|
| 1785 |
return true;
|
| 1786 |
}
|
|
@@ -1795,8 +1765,9 @@ static bool whisper_encode(
|
|
| 1795 |
// - n_tokens: number of tokens in the prompt
|
| 1796 |
// - n_past: number of past tokens to prefix the prompt with
|
| 1797 |
//
|
| 1798 |
-
static bool
|
| 1799 |
whisper_context & wctx,
|
|
|
|
| 1800 |
whisper_decoder & decoder,
|
| 1801 |
const whisper_token * tokens,
|
| 1802 |
const int n_tokens,
|
|
@@ -1811,7 +1782,7 @@ static bool whisper_decode(
|
|
| 1811 |
|
| 1812 |
WHISPER_ASSERT(!!kv_self.ctx);
|
| 1813 |
|
| 1814 |
-
auto & logits_out =
|
| 1815 |
|
| 1816 |
const int n_vocab = hparams.n_vocab;
|
| 1817 |
|
|
@@ -1821,13 +1792,13 @@ static bool whisper_decode(
|
|
| 1821 |
const int n_layer = hparams.n_text_layer;
|
| 1822 |
|
| 1823 |
const int N = n_tokens;
|
| 1824 |
-
const int M =
|
| 1825 |
|
| 1826 |
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
| 1827 |
|
| 1828 |
struct ggml_init_params params;
|
| 1829 |
-
params.mem_size =
|
| 1830 |
-
params.mem_buffer =
|
| 1831 |
|
| 1832 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1833 |
|
|
@@ -1842,7 +1813,7 @@ static bool whisper_decode(
|
|
| 1842 |
((int32_t *) position->data)[i] = n_past + i;
|
| 1843 |
}
|
| 1844 |
|
| 1845 |
-
|
| 1846 |
|
| 1847 |
// token encoding + position encoding
|
| 1848 |
struct ggml_tensor * cur =
|
|
@@ -1857,7 +1828,7 @@ static bool whisper_decode(
|
|
| 1857 |
|
| 1858 |
// norm
|
| 1859 |
{
|
| 1860 |
-
|
| 1861 |
|
| 1862 |
cur = ggml_norm(ctx0, inpL);
|
| 1863 |
|
|
@@ -1871,7 +1842,7 @@ static bool whisper_decode(
|
|
| 1871 |
|
| 1872 |
// self-attention
|
| 1873 |
{
|
| 1874 |
-
|
| 1875 |
|
| 1876 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1877 |
layer.attn_q_w,
|
|
@@ -1913,7 +1884,7 @@ static bool whisper_decode(
|
|
| 1913 |
|
| 1914 |
// ------
|
| 1915 |
|
| 1916 |
-
|
| 1917 |
|
| 1918 |
struct ggml_tensor * Q =
|
| 1919 |
ggml_permute(ctx0,
|
|
@@ -1929,12 +1900,12 @@ static bool whisper_decode(
|
|
| 1929 |
n_state/n_head, n_head, n_past + N),
|
| 1930 |
0, 2, 1, 3);
|
| 1931 |
|
| 1932 |
-
|
| 1933 |
|
| 1934 |
// K * Q
|
| 1935 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 1936 |
|
| 1937 |
-
|
| 1938 |
|
| 1939 |
//struct ggml_tensor * KQ_scaled =
|
| 1940 |
// ggml_scale(ctx0,
|
|
@@ -1944,11 +1915,11 @@ static bool whisper_decode(
|
|
| 1944 |
|
| 1945 |
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
| 1946 |
|
| 1947 |
-
|
| 1948 |
|
| 1949 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 1950 |
|
| 1951 |
-
|
| 1952 |
|
| 1953 |
struct ggml_tensor * V_trans =
|
| 1954 |
ggml_permute(ctx0,
|
|
@@ -1957,7 +1928,7 @@ static bool whisper_decode(
|
|
| 1957 |
n_state/n_head, n_head, n_past + N),
|
| 1958 |
1, 2, 0, 3);
|
| 1959 |
|
| 1960 |
-
|
| 1961 |
|
| 1962 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 1963 |
|
|
@@ -1970,31 +1941,31 @@ static bool whisper_decode(
|
|
| 1970 |
|
| 1971 |
// projection
|
| 1972 |
{
|
| 1973 |
-
|
| 1974 |
|
| 1975 |
cur = ggml_mul_mat(ctx0,
|
| 1976 |
layer.attn_ln_1_w,
|
| 1977 |
cur);
|
| 1978 |
|
| 1979 |
-
|
| 1980 |
|
| 1981 |
cur = ggml_add(ctx0,
|
| 1982 |
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
| 1983 |
cur);
|
| 1984 |
}
|
| 1985 |
|
| 1986 |
-
|
| 1987 |
|
| 1988 |
// add the input
|
| 1989 |
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
| 1990 |
|
| 1991 |
// norm
|
| 1992 |
{
|
| 1993 |
-
|
| 1994 |
|
| 1995 |
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
| 1996 |
|
| 1997 |
-
|
| 1998 |
|
| 1999 |
// cur = ln_0_w*cur + ln_0_b
|
| 2000 |
cur = ggml_add(ctx0,
|
|
@@ -2006,7 +1977,7 @@ static bool whisper_decode(
|
|
| 2006 |
|
| 2007 |
// cross-attention
|
| 2008 |
{
|
| 2009 |
-
|
| 2010 |
|
| 2011 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 2012 |
layer.cross_attn_q_w,
|
|
@@ -2023,19 +1994,19 @@ static bool whisper_decode(
|
|
| 2023 |
// Kcross is already scaled
|
| 2024 |
struct ggml_tensor * Kcross =
|
| 2025 |
ggml_reshape_3d(ctx0,
|
| 2026 |
-
ggml_view_1d(ctx0,
|
| 2027 |
n_state/n_head, n_head, M);
|
| 2028 |
|
| 2029 |
struct ggml_tensor * Vcross =
|
| 2030 |
ggml_reshape_3d(ctx0,
|
| 2031 |
-
ggml_view_1d(ctx0,
|
| 2032 |
n_state/n_head, n_head, M);
|
| 2033 |
|
| 2034 |
struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
|
| 2035 |
|
| 2036 |
// ------
|
| 2037 |
|
| 2038 |
-
|
| 2039 |
|
| 2040 |
struct ggml_tensor * Q =
|
| 2041 |
ggml_permute(ctx0,
|
|
@@ -2046,7 +2017,7 @@ static bool whisper_decode(
|
|
| 2046 |
|
| 2047 |
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
| 2048 |
|
| 2049 |
-
|
| 2050 |
|
| 2051 |
// K * Q
|
| 2052 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
@@ -2060,15 +2031,15 @@ static bool whisper_decode(
|
|
| 2060 |
// no masking for cross-attention
|
| 2061 |
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
| 2062 |
|
| 2063 |
-
|
| 2064 |
|
| 2065 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
| 2066 |
|
| 2067 |
-
|
| 2068 |
|
| 2069 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 2070 |
|
| 2071 |
-
|
| 2072 |
|
| 2073 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2074 |
|
|
@@ -2080,20 +2051,20 @@ static bool whisper_decode(
|
|
| 2080 |
|
| 2081 |
// projection
|
| 2082 |
{
|
| 2083 |
-
|
| 2084 |
|
| 2085 |
cur = ggml_mul_mat(ctx0,
|
| 2086 |
layer.cross_attn_ln_1_w,
|
| 2087 |
cur);
|
| 2088 |
|
| 2089 |
-
|
| 2090 |
|
| 2091 |
cur = ggml_add(ctx0,
|
| 2092 |
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
| 2093 |
cur);
|
| 2094 |
}
|
| 2095 |
|
| 2096 |
-
|
| 2097 |
|
| 2098 |
// add the input
|
| 2099 |
cur = ggml_add(ctx0, cur, inpCA);
|
|
@@ -2104,11 +2075,11 @@ static bool whisper_decode(
|
|
| 2104 |
{
|
| 2105 |
// norm
|
| 2106 |
{
|
| 2107 |
-
|
| 2108 |
|
| 2109 |
cur = ggml_norm(ctx0, inpFF);
|
| 2110 |
|
| 2111 |
-
|
| 2112 |
|
| 2113 |
// cur = mlp_ln_w*cur + mlp_ln_b
|
| 2114 |
cur = ggml_add(ctx0,
|
|
@@ -2118,39 +2089,39 @@ static bool whisper_decode(
|
|
| 2118 |
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
| 2119 |
}
|
| 2120 |
|
| 2121 |
-
|
| 2122 |
|
| 2123 |
// fully connected
|
| 2124 |
cur = ggml_mul_mat(ctx0,
|
| 2125 |
layer.mlp_0_w,
|
| 2126 |
cur);
|
| 2127 |
|
| 2128 |
-
|
| 2129 |
|
| 2130 |
cur = ggml_add(ctx0,
|
| 2131 |
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
| 2132 |
cur);
|
| 2133 |
|
| 2134 |
-
|
| 2135 |
|
| 2136 |
// GELU activation
|
| 2137 |
cur = ggml_gelu(ctx0, cur);
|
| 2138 |
|
| 2139 |
-
|
| 2140 |
|
| 2141 |
// projection
|
| 2142 |
cur = ggml_mul_mat(ctx0,
|
| 2143 |
layer.mlp_1_w,
|
| 2144 |
cur);
|
| 2145 |
|
| 2146 |
-
|
| 2147 |
|
| 2148 |
cur = ggml_add(ctx0,
|
| 2149 |
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
| 2150 |
cur);
|
| 2151 |
}
|
| 2152 |
|
| 2153 |
-
|
| 2154 |
|
| 2155 |
inpL = ggml_add(ctx0, cur, inpFF);
|
| 2156 |
}
|
|
@@ -2159,11 +2130,11 @@ static bool whisper_decode(
|
|
| 2159 |
|
| 2160 |
// norm
|
| 2161 |
{
|
| 2162 |
-
|
| 2163 |
|
| 2164 |
cur = ggml_norm(ctx0, cur);
|
| 2165 |
|
| 2166 |
-
|
| 2167 |
|
| 2168 |
cur = ggml_add(ctx0,
|
| 2169 |
ggml_mul(ctx0,
|
|
@@ -2172,7 +2143,7 @@ static bool whisper_decode(
|
|
| 2172 |
ggml_repeat(ctx0, model.d_ln_b, cur));
|
| 2173 |
}
|
| 2174 |
|
| 2175 |
-
|
| 2176 |
|
| 2177 |
// compute logits only for the last token
|
| 2178 |
// comment this line to compute logits for all N tokens
|
|
@@ -2181,7 +2152,7 @@ static bool whisper_decode(
|
|
| 2181 |
|
| 2182 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2183 |
|
| 2184 |
-
|
| 2185 |
|
| 2186 |
// run the computation
|
| 2187 |
{
|
|
@@ -2208,8 +2179,8 @@ static bool whisper_decode(
|
|
| 2208 |
|
| 2209 |
ggml_free(ctx0);
|
| 2210 |
|
| 2211 |
-
|
| 2212 |
-
|
| 2213 |
|
| 2214 |
return true;
|
| 2215 |
}
|
|
@@ -2313,7 +2284,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2313 |
|
| 2314 |
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
| 2315 |
static bool log_mel_spectrogram(
|
| 2316 |
-
|
| 2317 |
const float * samples,
|
| 2318 |
const int n_samples,
|
| 2319 |
const int /*sample_rate*/,
|
|
@@ -2433,7 +2404,7 @@ static bool log_mel_spectrogram(
|
|
| 2433 |
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
| 2434 |
}
|
| 2435 |
|
| 2436 |
-
|
| 2437 |
|
| 2438 |
return true;
|
| 2439 |
}
|
|
@@ -2507,7 +2478,56 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
|
|
| 2507 |
// interface implementation
|
| 2508 |
//
|
| 2509 |
|
| 2510 |
-
struct
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2511 |
whisper_model_loader loader = {};
|
| 2512 |
|
| 2513 |
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
|
|
@@ -2535,10 +2555,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
|
|
| 2535 |
fin->close();
|
| 2536 |
};
|
| 2537 |
|
| 2538 |
-
return
|
| 2539 |
}
|
| 2540 |
|
| 2541 |
-
struct whisper_context *
|
| 2542 |
struct buf_context {
|
| 2543 |
uint8_t* buffer;
|
| 2544 |
size_t size;
|
|
@@ -2571,10 +2591,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
|
|
| 2571 |
|
| 2572 |
loader.close = [](void * /*ctx*/) { };
|
| 2573 |
|
| 2574 |
-
return
|
| 2575 |
}
|
| 2576 |
|
| 2577 |
-
struct whisper_context *
|
| 2578 |
ggml_time_init();
|
| 2579 |
|
| 2580 |
whisper_context * ctx = new whisper_context;
|
|
@@ -2591,6 +2611,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
|
| 2591 |
return ctx;
|
| 2592 |
}
|
| 2593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2594 |
void whisper_free(struct whisper_context * ctx) {
|
| 2595 |
if (ctx) {
|
| 2596 |
if (ctx->model.ctx) {
|
|
@@ -2599,20 +2677,29 @@ void whisper_free(struct whisper_context * ctx) {
|
|
| 2599 |
if (ctx->model.buf) {
|
| 2600 |
delete ctx->model.buf;
|
| 2601 |
}
|
| 2602 |
-
|
| 2603 |
-
|
| 2604 |
-
|
| 2605 |
-
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
| 2606 |
-
if (ctx->decoders[i].kv_self.ctx) {
|
| 2607 |
-
ggml_free(ctx->decoders[i].kv_self.ctx);
|
| 2608 |
-
}
|
| 2609 |
-
}
|
| 2610 |
delete ctx;
|
| 2611 |
}
|
| 2612 |
}
|
| 2613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2614 |
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
| 2615 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2616 |
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
| 2617 |
return -1;
|
| 2618 |
}
|
|
@@ -2622,11 +2709,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
| 2622 |
|
| 2623 |
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
| 2624 |
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
| 2625 |
-
|
| 2626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2627 |
return -1;
|
| 2628 |
}
|
| 2629 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2630 |
return 0;
|
| 2631 |
}
|
| 2632 |
|
|
@@ -2635,22 +2737,20 @@ int whisper_set_mel(
|
|
| 2635 |
const float * data,
|
| 2636 |
int n_len,
|
| 2637 |
int n_mel) {
|
| 2638 |
-
|
| 2639 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2640 |
return -1;
|
| 2641 |
}
|
| 2642 |
|
| 2643 |
-
ctx->mel.n_len = n_len;
|
| 2644 |
-
ctx->mel.n_mel = n_mel;
|
| 2645 |
-
|
| 2646 |
-
ctx->mel.data.resize(n_len*n_mel);
|
| 2647 |
-
memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
| 2648 |
-
|
| 2649 |
return 0;
|
| 2650 |
}
|
| 2651 |
|
| 2652 |
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
| 2653 |
-
if (!
|
| 2654 |
fprintf(stderr, "%s: failed to eval\n", __func__);
|
| 2655 |
return -1;
|
| 2656 |
}
|
|
@@ -2658,11 +2758,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
| 2658 |
return 0;
|
| 2659 |
}
|
| 2660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2661 |
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 2662 |
-
// TODO: add selected_decoder_id to
|
| 2663 |
const int selected_decoder_id = 0;
|
| 2664 |
|
| 2665 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2666 |
fprintf(stderr, "%s: failed to eval\n", __func__);
|
| 2667 |
return 1;
|
| 2668 |
}
|
|
@@ -2720,11 +2837,12 @@ const char * whisper_lang_str(int id) {
|
|
| 2720 |
return nullptr;
|
| 2721 |
}
|
| 2722 |
|
| 2723 |
-
int
|
| 2724 |
struct whisper_context * ctx,
|
| 2725 |
-
|
| 2726 |
-
|
| 2727 |
-
|
|
|
|
| 2728 |
const int seek = offset_ms/10;
|
| 2729 |
|
| 2730 |
if (seek < 0) {
|
|
@@ -2732,8 +2850,8 @@ int whisper_lang_auto_detect(
|
|
| 2732 |
return -1;
|
| 2733 |
}
|
| 2734 |
|
| 2735 |
-
if (seek >=
|
| 2736 |
-
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms,
|
| 2737 |
return -2;
|
| 2738 |
}
|
| 2739 |
|
|
@@ -2745,17 +2863,17 @@ int whisper_lang_auto_detect(
|
|
| 2745 |
|
| 2746 |
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
| 2747 |
|
| 2748 |
-
if (
|
| 2749 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 2750 |
return -7;
|
| 2751 |
}
|
| 2752 |
|
| 2753 |
-
auto & logits_id =
|
| 2754 |
logits_id.clear();
|
| 2755 |
|
| 2756 |
for (const auto & kv : g_lang) {
|
| 2757 |
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
| 2758 |
-
logits_id.emplace_back(
|
| 2759 |
}
|
| 2760 |
|
| 2761 |
// sort descending
|
|
@@ -2794,8 +2912,20 @@ int whisper_lang_auto_detect(
|
|
| 2794 |
return logits_id[0].second;
|
| 2795 |
}
|
| 2796 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2797 |
int whisper_n_len(struct whisper_context * ctx) {
|
| 2798 |
-
return ctx->mel.n_len;
|
| 2799 |
}
|
| 2800 |
|
| 2801 |
int whisper_n_vocab(struct whisper_context * ctx) {
|
|
@@ -2815,7 +2945,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
|
|
| 2815 |
}
|
| 2816 |
|
| 2817 |
float * whisper_get_logits(struct whisper_context * ctx) {
|
| 2818 |
-
return ctx->logits.data();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2819 |
}
|
| 2820 |
|
| 2821 |
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
|
@@ -2861,24 +2996,29 @@ whisper_token whisper_token_transcribe(void) {
|
|
| 2861 |
void whisper_print_timings(struct whisper_context * ctx) {
|
| 2862 |
const int64_t t_end_us = ggml_time_us();
|
| 2863 |
|
| 2864 |
-
const int32_t n_sample = std::max(1, ctx->n_sample);
|
| 2865 |
-
const int32_t n_encode = std::max(1, ctx->n_encode);
|
| 2866 |
-
const int32_t n_decode = std::max(1, ctx->n_decode);
|
| 2867 |
-
|
| 2868 |
fprintf(stderr, "\n");
|
| 2869 |
-
fprintf(stderr, "%s:
|
| 2870 |
-
|
| 2871 |
-
|
| 2872 |
-
|
| 2873 |
-
|
| 2874 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2875 |
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
| 2876 |
}
|
| 2877 |
|
| 2878 |
void whisper_reset_timings(struct whisper_context * ctx) {
|
| 2879 |
-
ctx->
|
| 2880 |
-
|
| 2881 |
-
|
|
|
|
|
|
|
| 2882 |
}
|
| 2883 |
|
| 2884 |
const char * whisper_print_system_info(void) {
|
|
@@ -2991,6 +3131,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2991 |
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
| 2992 |
static void whisper_exp_compute_token_level_timestamps(
|
| 2993 |
struct whisper_context & ctx,
|
|
|
|
| 2994 |
int i_segment,
|
| 2995 |
float thold_pt,
|
| 2996 |
float thold_ptsum);
|
|
@@ -3023,8 +3164,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
|
|
| 3023 |
|
| 3024 |
// wrap the last segment to max_len characters
|
| 3025 |
// returns the number of new segments
|
| 3026 |
-
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
|
| 3027 |
-
auto segment =
|
| 3028 |
|
| 3029 |
int res = 1;
|
| 3030 |
int acc = 0;
|
|
@@ -3046,24 +3187,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
|
|
| 3046 |
trim(text);
|
| 3047 |
}
|
| 3048 |
|
| 3049 |
-
|
| 3050 |
-
|
| 3051 |
-
|
| 3052 |
|
| 3053 |
-
|
| 3054 |
-
|
| 3055 |
-
|
| 3056 |
|
| 3057 |
// add tokens [i, end] to the new segment
|
| 3058 |
-
|
| 3059 |
-
|
| 3060 |
segment.tokens.begin() + i,
|
| 3061 |
segment.tokens.end());
|
| 3062 |
|
| 3063 |
acc = 0;
|
| 3064 |
text = "";
|
| 3065 |
|
| 3066 |
-
segment =
|
| 3067 |
i = -1;
|
| 3068 |
|
| 3069 |
res++;
|
|
@@ -3076,7 +3217,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
|
|
| 3076 |
if (split_on_word) {
|
| 3077 |
trim(text);
|
| 3078 |
}
|
| 3079 |
-
|
| 3080 |
|
| 3081 |
return res;
|
| 3082 |
}
|
|
@@ -3093,6 +3234,7 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
| 3093 |
// - computes logprobs and probs
|
| 3094 |
static void whisper_process_logits(
|
| 3095 |
struct whisper_context & ctx,
|
|
|
|
| 3096 |
const struct whisper_full_params params,
|
| 3097 |
struct whisper_decoder & decoder,
|
| 3098 |
float temperature) {
|
|
@@ -3111,7 +3253,7 @@ static void whisper_process_logits(
|
|
| 3111 |
auto & logprobs = decoder.logprobs;
|
| 3112 |
{
|
| 3113 |
logits.resize(n_logits);
|
| 3114 |
-
memcpy(logits.data(),
|
| 3115 |
|
| 3116 |
if (temperature > 0.0f) {
|
| 3117 |
for (int i = 0; i < n_logits; i++) {
|
|
@@ -3149,7 +3291,7 @@ static void whisper_process_logits(
|
|
| 3149 |
logits[vocab.token_transcribe] = -INFINITY;
|
| 3150 |
|
| 3151 |
if (params.logits_filter_callback) {
|
| 3152 |
-
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
| 3153 |
}
|
| 3154 |
|
| 3155 |
// suppress non-speech tokens
|
|
@@ -3310,6 +3452,7 @@ static void whisper_process_logits(
|
|
| 3310 |
|
| 3311 |
static whisper_token_data whisper_sample_token(
|
| 3312 |
whisper_context & ctx,
|
|
|
|
| 3313 |
const whisper_decoder & decoder,
|
| 3314 |
bool best) {
|
| 3315 |
whisper_token_data result = {
|
|
@@ -3354,7 +3497,7 @@ static whisper_token_data whisper_sample_token(
|
|
| 3354 |
} else {
|
| 3355 |
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 3356 |
|
| 3357 |
-
result.id = dist(
|
| 3358 |
result.p = probs[result.id];
|
| 3359 |
result.plog = logprobs[result.id];
|
| 3360 |
}
|
|
@@ -3364,13 +3507,14 @@ static whisper_token_data whisper_sample_token(
|
|
| 3364 |
result.pt = result.p;
|
| 3365 |
}
|
| 3366 |
|
| 3367 |
-
|
| 3368 |
|
| 3369 |
return result;
|
| 3370 |
}
|
| 3371 |
|
| 3372 |
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
| 3373 |
whisper_context & ctx,
|
|
|
|
| 3374 |
const whisper_decoder & decoder,
|
| 3375 |
int k) {
|
| 3376 |
const auto & vocab = ctx.vocab;
|
|
@@ -3381,7 +3525,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
| 3381 |
|
| 3382 |
const int n_logits = vocab.n_vocab;
|
| 3383 |
|
| 3384 |
-
auto & logits_id =
|
| 3385 |
|
| 3386 |
logits_id.clear();
|
| 3387 |
for (int i = 0; i < n_logits; ++i) {
|
|
@@ -3434,7 +3578,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
| 3434 |
}
|
| 3435 |
}
|
| 3436 |
|
| 3437 |
-
|
| 3438 |
|
| 3439 |
return result;
|
| 3440 |
}
|
|
@@ -3488,24 +3632,25 @@ static void whisper_sequence_score(
|
|
| 3488 |
}
|
| 3489 |
}
|
| 3490 |
|
| 3491 |
-
int
|
| 3492 |
struct whisper_context * ctx,
|
| 3493 |
-
|
| 3494 |
-
|
| 3495 |
-
|
|
|
|
| 3496 |
// clear old results
|
| 3497 |
-
auto & result_all =
|
| 3498 |
|
| 3499 |
result_all.clear();
|
| 3500 |
|
| 3501 |
// compute log mel spectrogram
|
| 3502 |
if (params.speed_up) {
|
| 3503 |
-
if (
|
| 3504 |
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
| 3505 |
return -1;
|
| 3506 |
}
|
| 3507 |
} else {
|
| 3508 |
-
if (
|
| 3509 |
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
| 3510 |
return -2;
|
| 3511 |
}
|
|
@@ -3515,26 +3660,26 @@ int whisper_full(
|
|
| 3515 |
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
| 3516 |
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
| 3517 |
|
| 3518 |
-
const auto lang_id =
|
| 3519 |
if (lang_id < 0) {
|
| 3520 |
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
| 3521 |
return -3;
|
| 3522 |
}
|
| 3523 |
-
|
| 3524 |
params.language = whisper_lang_str(lang_id);
|
| 3525 |
|
| 3526 |
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
| 3527 |
}
|
| 3528 |
|
| 3529 |
if (params.token_timestamps) {
|
| 3530 |
-
|
| 3531 |
-
|
| 3532 |
-
|
| 3533 |
-
|
| 3534 |
}
|
| 3535 |
|
| 3536 |
const int seek_start = params.offset_ms/10;
|
| 3537 |
-
const int seek_end = seek_start + (params.duration_ms == 0 ?
|
| 3538 |
|
| 3539 |
// if length of spectrogram is less than 1s (100 samples), then return
|
| 3540 |
// basically don't process anything that is less than 1s
|
|
@@ -3572,10 +3717,10 @@ int whisper_full(
|
|
| 3572 |
|
| 3573 |
// TAGS: WHISPER_DECODER_INIT
|
| 3574 |
for (int j = 1; j < n_decoders; j++) {
|
| 3575 |
-
auto & decoder =
|
| 3576 |
|
| 3577 |
if (decoder.kv_self.ctx == nullptr) {
|
| 3578 |
-
decoder.kv_self =
|
| 3579 |
if (!kv_cache_reinit(decoder.kv_self)) {
|
| 3580 |
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
| 3581 |
return -4;
|
|
@@ -3583,7 +3728,7 @@ int whisper_full(
|
|
| 3583 |
|
| 3584 |
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
| 3585 |
|
| 3586 |
-
decoder.sequence.tokens.reserve(
|
| 3587 |
|
| 3588 |
decoder.probs.resize (ctx->vocab.n_vocab);
|
| 3589 |
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
@@ -3592,7 +3737,7 @@ int whisper_full(
|
|
| 3592 |
}
|
| 3593 |
|
| 3594 |
// the accumulated text context so far
|
| 3595 |
-
auto & prompt_past =
|
| 3596 |
if (params.no_context) {
|
| 3597 |
prompt_past.clear();
|
| 3598 |
}
|
|
@@ -3611,13 +3756,13 @@ int whisper_full(
|
|
| 3611 |
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
| 3612 |
return -5;
|
| 3613 |
}
|
| 3614 |
-
|
| 3615 |
|
| 3616 |
// these tokens determine the task that will be performed
|
| 3617 |
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
| 3618 |
if (whisper_is_multilingual(ctx)) {
|
| 3619 |
const int lang_id = whisper_lang_id(params.language);
|
| 3620 |
-
|
| 3621 |
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
| 3622 |
if (params.translate) {
|
| 3623 |
prompt_init.push_back(whisper_token_translate());
|
|
@@ -3669,14 +3814,14 @@ int whisper_full(
|
|
| 3669 |
}
|
| 3670 |
|
| 3671 |
if (params.encoder_begin_callback) {
|
| 3672 |
-
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
|
| 3673 |
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
| 3674 |
break;
|
| 3675 |
}
|
| 3676 |
}
|
| 3677 |
|
| 3678 |
// encode audio features starting at offset seek
|
| 3679 |
-
if (!
|
| 3680 |
fprintf(stderr, "%s: failed to encode\n", __func__);
|
| 3681 |
return -6;
|
| 3682 |
}
|
|
@@ -3717,7 +3862,7 @@ int whisper_full(
|
|
| 3717 |
|
| 3718 |
// TAGS: WHISPER_DECODER_INIT
|
| 3719 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3720 |
-
auto & decoder =
|
| 3721 |
|
| 3722 |
decoder.kv_self.n = 0;
|
| 3723 |
|
|
@@ -3759,7 +3904,7 @@ int whisper_full(
|
|
| 3759 |
}
|
| 3760 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 3761 |
|
| 3762 |
-
if (!
|
| 3763 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 3764 |
return -7;
|
| 3765 |
}
|
|
@@ -3767,24 +3912,24 @@ int whisper_full(
|
|
| 3767 |
{
|
| 3768 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 3769 |
|
| 3770 |
-
whisper_process_logits(*ctx, params,
|
| 3771 |
|
| 3772 |
-
|
| 3773 |
|
| 3774 |
for (int j = 1; j < n_decoders_cur; ++j) {
|
| 3775 |
-
auto & decoder =
|
| 3776 |
|
| 3777 |
-
memcpy(decoder.kv_self.k->data,
|
| 3778 |
-
memcpy(decoder.kv_self.v->data,
|
| 3779 |
|
| 3780 |
decoder.kv_self.n += prompt.size();
|
| 3781 |
|
| 3782 |
-
memcpy(decoder.probs.data(),
|
| 3783 |
-
memcpy(decoder.logits.data(),
|
| 3784 |
-
memcpy(decoder.logprobs.data(),
|
| 3785 |
}
|
| 3786 |
|
| 3787 |
-
|
| 3788 |
}
|
| 3789 |
}
|
| 3790 |
|
|
@@ -3795,7 +3940,7 @@ int whisper_full(
|
|
| 3795 |
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
| 3796 |
kv_bufs.resize(n_decoders_cur);
|
| 3797 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3798 |
-
auto & decoder =
|
| 3799 |
|
| 3800 |
if (decoder.completed || decoder.failed) {
|
| 3801 |
continue;
|
|
@@ -3813,7 +3958,7 @@ int whisper_full(
|
|
| 3813 |
|
| 3814 |
// generate new sequence candidates for each decoder
|
| 3815 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3816 |
-
auto & decoder =
|
| 3817 |
|
| 3818 |
if (decoder.completed || decoder.failed) {
|
| 3819 |
continue;
|
|
@@ -3823,16 +3968,16 @@ int whisper_full(
|
|
| 3823 |
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
| 3824 |
{
|
| 3825 |
if (t_cur < 1e-6f) {
|
| 3826 |
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
| 3827 |
} else {
|
| 3828 |
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
| 3829 |
}
|
| 3830 |
|
| 3831 |
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
| 3832 |
} break;
|
| 3833 |
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
| 3834 |
{
|
| 3835 |
-
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
| 3836 |
|
| 3837 |
for (const auto & token : tokens_new) {
|
| 3838 |
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
|
|
@@ -3857,7 +4002,7 @@ int whisper_full(
|
|
| 3857 |
uint32_t cur_c = 0;
|
| 3858 |
|
| 3859 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3860 |
-
auto & decoder =
|
| 3861 |
|
| 3862 |
if (decoder.completed || decoder.failed) {
|
| 3863 |
continue;
|
|
@@ -3886,7 +4031,7 @@ int whisper_full(
|
|
| 3886 |
// - check if the sequence is failed
|
| 3887 |
// - update sliding window based on timestamp tokens
|
| 3888 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3889 |
-
auto & decoder =
|
| 3890 |
|
| 3891 |
if (decoder.completed || decoder.failed) {
|
| 3892 |
continue;
|
|
@@ -3968,7 +4113,7 @@ int whisper_full(
|
|
| 3968 |
bool completed_all = true;
|
| 3969 |
|
| 3970 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3971 |
-
auto & decoder =
|
| 3972 |
|
| 3973 |
if (decoder.completed || decoder.failed) {
|
| 3974 |
continue;
|
|
@@ -3982,11 +4127,11 @@ int whisper_full(
|
|
| 3982 |
}
|
| 3983 |
}
|
| 3984 |
|
| 3985 |
-
|
| 3986 |
|
| 3987 |
// obtain logits for the next token
|
| 3988 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3989 |
-
auto & decoder =
|
| 3990 |
|
| 3991 |
if (decoder.failed || decoder.completed) {
|
| 3992 |
continue;
|
|
@@ -3997,7 +4142,7 @@ int whisper_full(
|
|
| 3997 |
|
| 3998 |
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
| 3999 |
|
| 4000 |
-
if (!
|
| 4001 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 4002 |
return -8;
|
| 4003 |
}
|
|
@@ -4005,11 +4150,11 @@ int whisper_full(
|
|
| 4005 |
{
|
| 4006 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 4007 |
|
| 4008 |
-
whisper_process_logits(*ctx, params, decoder, t_cur);
|
| 4009 |
|
| 4010 |
++decoder.kv_self.n;
|
| 4011 |
|
| 4012 |
-
|
| 4013 |
}
|
| 4014 |
}
|
| 4015 |
}
|
|
@@ -4019,7 +4164,7 @@ int whisper_full(
|
|
| 4019 |
double best_score = -INFINITY;
|
| 4020 |
|
| 4021 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 4022 |
-
auto & decoder =
|
| 4023 |
|
| 4024 |
if (decoder.failed) {
|
| 4025 |
continue;
|
|
@@ -4036,7 +4181,7 @@ int whisper_full(
|
|
| 4036 |
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
| 4037 |
|
| 4038 |
decoder.failed = true;
|
| 4039 |
-
|
| 4040 |
|
| 4041 |
continue;
|
| 4042 |
}
|
|
@@ -4054,11 +4199,11 @@ int whisper_full(
|
|
| 4054 |
{
|
| 4055 |
bool success = true;
|
| 4056 |
|
| 4057 |
-
const auto & decoder =
|
| 4058 |
|
| 4059 |
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
| 4060 |
success = false;
|
| 4061 |
-
|
| 4062 |
}
|
| 4063 |
|
| 4064 |
if (success) {
|
|
@@ -4075,7 +4220,7 @@ int whisper_full(
|
|
| 4075 |
|
| 4076 |
// output results through a user-provided callback
|
| 4077 |
{
|
| 4078 |
-
const auto & best_decoder =
|
| 4079 |
|
| 4080 |
const auto seek_delta = best_decoder.seek_delta;
|
| 4081 |
const auto result_len = best_decoder.sequence.result_len;
|
|
@@ -4138,14 +4283,14 @@ int whisper_full(
|
|
| 4138 |
|
| 4139 |
if (params.token_timestamps) {
|
| 4140 |
whisper_exp_compute_token_level_timestamps(
|
| 4141 |
-
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4142 |
|
| 4143 |
if (params.max_len > 0) {
|
| 4144 |
-
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
| 4145 |
}
|
| 4146 |
}
|
| 4147 |
if (params.new_segment_callback) {
|
| 4148 |
-
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
| 4149 |
}
|
| 4150 |
}
|
| 4151 |
text = "";
|
|
@@ -4182,14 +4327,14 @@ int whisper_full(
|
|
| 4182 |
|
| 4183 |
if (params.token_timestamps) {
|
| 4184 |
whisper_exp_compute_token_level_timestamps(
|
| 4185 |
-
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4186 |
|
| 4187 |
if (params.max_len > 0) {
|
| 4188 |
-
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
| 4189 |
}
|
| 4190 |
}
|
| 4191 |
if (params.new_segment_callback) {
|
| 4192 |
-
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
| 4193 |
}
|
| 4194 |
}
|
| 4195 |
}
|
|
@@ -4204,6 +4349,15 @@ int whisper_full(
|
|
| 4204 |
return 0;
|
| 4205 |
}
|
| 4206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4207 |
int whisper_full_parallel(
|
| 4208 |
struct whisper_context * ctx,
|
| 4209 |
struct whisper_full_params params,
|
|
@@ -4213,40 +4367,10 @@ int whisper_full_parallel(
|
|
| 4213 |
if (n_processors == 1) {
|
| 4214 |
return whisper_full(ctx, params, samples, n_samples);
|
| 4215 |
}
|
| 4216 |
-
|
| 4217 |
int ret = 0;
|
| 4218 |
|
| 4219 |
-
// prepare separate
|
| 4220 |
-
std::vector<
|
| 4221 |
-
|
| 4222 |
-
for (int i = 0; i < n_processors - 1; ++i) {
|
| 4223 |
-
auto & ctx_p = ctxs[i];
|
| 4224 |
-
|
| 4225 |
-
ctx_p = *ctx;
|
| 4226 |
-
|
| 4227 |
-
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
| 4228 |
-
|
| 4229 |
-
ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
|
| 4230 |
-
|
| 4231 |
-
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
| 4232 |
-
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
|
| 4233 |
-
return false;
|
| 4234 |
-
}
|
| 4235 |
-
|
| 4236 |
-
// TAGS: WHISPER_DECODER_INIT
|
| 4237 |
-
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
|
| 4238 |
-
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
|
| 4239 |
-
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
|
| 4240 |
-
return false;
|
| 4241 |
-
}
|
| 4242 |
-
|
| 4243 |
-
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
| 4244 |
-
|
| 4245 |
-
ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
|
| 4246 |
-
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
|
| 4247 |
-
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
|
| 4248 |
-
}
|
| 4249 |
-
}
|
| 4250 |
|
| 4251 |
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
|
| 4252 |
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
|
|
@@ -4256,6 +4380,9 @@ int whisper_full_parallel(
|
|
| 4256 |
|
| 4257 |
std::vector<std::thread> workers(n_processors - 1);
|
| 4258 |
for (int i = 0; i < n_processors - 1; ++i) {
|
|
|
|
|
|
|
|
|
|
| 4259 |
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
|
| 4260 |
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
|
| 4261 |
|
|
@@ -4268,13 +4395,17 @@ int whisper_full_parallel(
|
|
| 4268 |
params_cur.new_segment_callback = nullptr;
|
| 4269 |
params_cur.new_segment_callback_user_data = nullptr;
|
| 4270 |
|
| 4271 |
-
workers[i] = std::thread(
|
| 4272 |
}
|
| 4273 |
|
| 4274 |
{
|
| 4275 |
auto params_cur = params;
|
| 4276 |
|
| 4277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4278 |
}
|
| 4279 |
|
| 4280 |
for (int i = 0; i < n_processors - 1; ++i) {
|
|
@@ -4283,45 +4414,43 @@ int whisper_full_parallel(
|
|
| 4283 |
|
| 4284 |
const int64_t offset_t = (int64_t) params.offset_ms/10.0;
|
| 4285 |
|
| 4286 |
-
// combine results into
|
| 4287 |
for (int i = 0; i < n_processors - 1; ++i) {
|
| 4288 |
-
auto
|
| 4289 |
|
| 4290 |
-
for (auto
|
| 4291 |
// correct the segment timestamp taking into account the offset
|
| 4292 |
-
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
| 4293 |
-
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
|
|
|
| 4294 |
|
| 4295 |
// make sure that segments are not overlapping
|
| 4296 |
-
if (!ctx->result_all.empty()) {
|
| 4297 |
-
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
|
| 4298 |
}
|
| 4299 |
|
| 4300 |
-
ctx->result_all.push_back(std::move(result));
|
| 4301 |
|
| 4302 |
// call the new_segment_callback for each segment
|
| 4303 |
if (params.new_segment_callback) {
|
| 4304 |
-
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
|
| 4305 |
}
|
| 4306 |
}
|
| 4307 |
|
| 4308 |
-
ctx->t_mel_us
|
| 4309 |
-
ctx->t_sample_us += ctxs[i].t_sample_us;
|
| 4310 |
-
ctx->t_encode_us += ctxs[i].t_encode_us;
|
| 4311 |
-
ctx->t_decode_us += ctxs[i].t_decode_us;
|
| 4312 |
|
| 4313 |
-
|
|
|
|
|
|
|
| 4314 |
|
| 4315 |
-
|
| 4316 |
-
kv_cache_free(ctx->decoders[j].kv_self);
|
| 4317 |
-
}
|
| 4318 |
}
|
| 4319 |
|
| 4320 |
// average the timings
|
| 4321 |
-
ctx->t_mel_us /= n_processors;
|
| 4322 |
-
ctx->t_sample_us /= n_processors;
|
| 4323 |
-
ctx->t_encode_us /= n_processors;
|
| 4324 |
-
ctx->t_decode_us /= n_processors;
|
| 4325 |
|
| 4326 |
// print information about the audio boundaries
|
| 4327 |
fprintf(stderr, "\n");
|
|
@@ -4334,44 +4463,84 @@ int whisper_full_parallel(
|
|
| 4334 |
return ret;
|
| 4335 |
}
|
| 4336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4337 |
int whisper_full_n_segments(struct whisper_context * ctx) {
|
| 4338 |
-
return ctx->result_all.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4339 |
}
|
| 4340 |
|
| 4341 |
int whisper_full_lang_id(struct whisper_context * ctx) {
|
| 4342 |
-
return ctx->lang_id;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4343 |
}
|
| 4344 |
|
| 4345 |
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
| 4346 |
-
return ctx->result_all[i_segment].t0;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4347 |
}
|
| 4348 |
|
| 4349 |
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
| 4350 |
-
return ctx->result_all[i_segment].t1;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4351 |
}
|
| 4352 |
|
| 4353 |
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
|
| 4354 |
-
return ctx->result_all[i_segment].text.c_str();
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4355 |
}
|
| 4356 |
|
| 4357 |
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
|
| 4358 |
-
return ctx->result_all[i_segment].tokens.size();
|
| 4359 |
}
|
| 4360 |
|
| 4361 |
-
const char *
|
| 4362 |
-
return ctx->vocab.id_to_token[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4363 |
}
|
| 4364 |
|
| 4365 |
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 4366 |
-
return ctx->result_all[i_segment].tokens[i_token].id;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4367 |
}
|
| 4368 |
|
| 4369 |
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 4370 |
-
return ctx->result_all[i_segment].tokens[i_token];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4371 |
}
|
| 4372 |
|
| 4373 |
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 4374 |
-
return ctx->result_all[i_segment].tokens[i_token].p;
|
| 4375 |
}
|
| 4376 |
|
| 4377 |
// =================================================================================================
|
|
@@ -4583,13 +4752,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
|
|
| 4583 |
|
| 4584 |
static void whisper_exp_compute_token_level_timestamps(
|
| 4585 |
struct whisper_context & ctx,
|
|
|
|
| 4586 |
int i_segment,
|
| 4587 |
float thold_pt,
|
| 4588 |
float thold_ptsum) {
|
| 4589 |
-
auto & segment =
|
| 4590 |
auto & tokens = segment.tokens;
|
| 4591 |
|
| 4592 |
-
const int n_samples =
|
| 4593 |
|
| 4594 |
if (n_samples == 0) {
|
| 4595 |
fprintf(stderr, "%s: no signal data available\n", __func__);
|
|
@@ -4612,9 +4782,9 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 4612 |
return;
|
| 4613 |
}
|
| 4614 |
|
| 4615 |
-
auto & t_beg =
|
| 4616 |
-
auto & t_last =
|
| 4617 |
-
auto & tid_last =
|
| 4618 |
|
| 4619 |
for (int j = 0; j < n; ++j) {
|
| 4620 |
auto & token = tokens[j];
|
|
@@ -4737,15 +4907,15 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 4737 |
float sum = 0.0f;
|
| 4738 |
|
| 4739 |
for (int k = ss0; k < ss1; k++) {
|
| 4740 |
-
sum +=
|
| 4741 |
}
|
| 4742 |
|
| 4743 |
const float thold = 0.5*sum/ns;
|
| 4744 |
|
| 4745 |
{
|
| 4746 |
int k = s0;
|
| 4747 |
-
if (
|
| 4748 |
-
while (k > 0 &&
|
| 4749 |
k--;
|
| 4750 |
}
|
| 4751 |
tokens[j].t0 = sample_to_timestamp(k);
|
|
@@ -4755,7 +4925,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 4755 |
s0 = k;
|
| 4756 |
}
|
| 4757 |
} else {
|
| 4758 |
-
while (
|
| 4759 |
k++;
|
| 4760 |
}
|
| 4761 |
s0 = k;
|
|
@@ -4765,8 +4935,8 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 4765 |
|
| 4766 |
{
|
| 4767 |
int k = s1;
|
| 4768 |
-
if (
|
| 4769 |
-
while (k < n_samples - 1 &&
|
| 4770 |
k++;
|
| 4771 |
}
|
| 4772 |
tokens[j].t1 = sample_to_timestamp(k);
|
|
@@ -4776,7 +4946,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 4776 |
s1 = k;
|
| 4777 |
}
|
| 4778 |
} else {
|
| 4779 |
-
while (
|
| 4780 |
k--;
|
| 4781 |
}
|
| 4782 |
s1 = k;
|
|
|
|
| 547 |
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
| 548 |
};
|
| 549 |
|
| 550 |
+
struct whisper_state {
|
|
|
|
|
|
|
| 551 |
int64_t t_sample_us = 0;
|
| 552 |
int64_t t_encode_us = 0;
|
| 553 |
int64_t t_decode_us = 0;
|
| 554 |
+
int64_t t_mel_us = 0;
|
| 555 |
|
| 556 |
int32_t n_sample = 0; // number of tokens sampled
|
| 557 |
int32_t n_encode = 0; // number of encoder calls
|
|
|
|
| 559 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 560 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
// cross-attention KV cache for the decoders
|
| 563 |
// shared between all decoders
|
| 564 |
whisper_kv_cache kv_cross;
|
| 565 |
+
whisper_mel mel;
|
| 566 |
|
| 567 |
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
| 568 |
|
|
|
|
| 627 |
}
|
| 628 |
};
|
| 629 |
|
| 630 |
+
struct whisper_context {
|
| 631 |
+
int64_t t_load_us = 0;
|
| 632 |
+
int64_t t_start_us = 0;
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
|
| 636 |
+
|
| 637 |
+
whisper_model model;
|
| 638 |
+
whisper_vocab vocab;
|
| 639 |
+
whisper_state * state = nullptr;
|
| 640 |
+
};
|
| 641 |
+
|
| 642 |
template<typename T>
|
| 643 |
static void read_safe(whisper_model_loader * loader, T & dest) {
|
| 644 |
loader->read(loader->context, &dest, sizeof(T));
|
|
|
|
| 825 |
wctx.model.buf = new std::vector<uint8_t>();
|
| 826 |
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
| 827 |
|
| 828 |
+
// we skip initialization of the state until it is needed
|
| 829 |
+
// because it might be that state will always be provided externally.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
}
|
| 831 |
|
| 832 |
// load mel filters
|
|
|
|
| 909 |
vocab.id_to_token[i] = word;
|
| 910 |
}
|
| 911 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
}
|
| 913 |
|
| 914 |
size_t ctx_size = 0;
|
|
|
|
| 1308 |
}
|
| 1309 |
}
|
| 1310 |
|
|
|
|
|
|
|
| 1311 |
wctx.t_load_us = ggml_time_us() - t_start_us;
|
| 1312 |
|
| 1313 |
return true;
|
| 1314 |
}
|
| 1315 |
|
| 1316 |
+
// evaluate the encoder with the given state
|
| 1317 |
//
|
| 1318 |
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
| 1319 |
// part of the transformer model and returns the encoded features
|
| 1320 |
//
|
| 1321 |
+
// - wctx: the model
|
| 1322 |
+
// - wstate: the state of the encoder
|
| 1323 |
// - n_threads: number of threads to use
|
| 1324 |
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
| 1325 |
//
|
| 1326 |
+
static bool whisper_encode_internal(
|
| 1327 |
whisper_context & wctx,
|
| 1328 |
+
whisper_state & wstate,
|
| 1329 |
const int mel_offset,
|
| 1330 |
+
const int n_threads){
|
| 1331 |
+
|
| 1332 |
const int64_t t_start_us = ggml_time_us();
|
| 1333 |
|
| 1334 |
const auto & model = wctx.model;
|
| 1335 |
+
const auto & mel_inp = wstate.mel;
|
| 1336 |
const auto & hparams = model.hparams;
|
| 1337 |
|
| 1338 |
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 1339 |
const int n_state = hparams.n_audio_state;
|
| 1340 |
const int n_head = hparams.n_audio_head;
|
| 1341 |
const int n_layer = hparams.n_audio_layer;
|
|
|
|
| 1344 |
assert(mel_inp.n_mel == n_mels);
|
| 1345 |
|
| 1346 |
struct ggml_init_params params;
|
| 1347 |
+
params.mem_size = wstate.buf_compute.size();
|
| 1348 |
+
params.mem_buffer = wstate.buf_compute.data();
|
| 1349 |
|
| 1350 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1351 |
|
| 1352 |
+
wstate.use_buf(ctx0, 0);
|
| 1353 |
|
| 1354 |
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
| 1355 |
assert(mel->type == GGML_TYPE_F32);
|
|
|
|
| 1371 |
|
| 1372 |
// convolution + gelu
|
| 1373 |
{
|
| 1374 |
+
wstate.use_buf(ctx0, 1);
|
| 1375 |
|
| 1376 |
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
|
| 1377 |
cur = ggml_add(ctx0,
|
| 1378 |
+
ggml_repeat(ctx0,
|
| 1379 |
+
model.e_conv_1_b,
|
| 1380 |
+
cur),
|
| 1381 |
+
cur);
|
| 1382 |
|
| 1383 |
cur = ggml_gelu(ctx0, cur);
|
| 1384 |
|
| 1385 |
+
wstate.use_buf(ctx0, 0);
|
| 1386 |
|
| 1387 |
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
|
| 1388 |
cur = ggml_add(ctx0,
|
| 1389 |
+
ggml_repeat(ctx0,
|
| 1390 |
+
model.e_conv_2_b,
|
| 1391 |
+
cur),
|
| 1392 |
+
cur);
|
| 1393 |
|
| 1394 |
cur = ggml_gelu(ctx0, cur);
|
| 1395 |
}
|
| 1396 |
|
| 1397 |
+
wstate.use_buf(ctx0, 3);
|
| 1398 |
|
| 1399 |
// ===================================================================
|
| 1400 |
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
|
|
| 1409 |
//}
|
| 1410 |
|
| 1411 |
static int iter = 0;
|
| 1412 |
+
|
| 1413 |
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
|
| 1414 |
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
| 1415 |
|
|
|
|
| 1429 |
|
| 1430 |
// norm
|
| 1431 |
{
|
| 1432 |
+
wstate.use_buf(ctx0, 0);
|
| 1433 |
|
| 1434 |
cur = ggml_norm(ctx0, inpL);
|
| 1435 |
|
| 1436 |
// cur = ln_0_w*cur + ln_0_b
|
| 1437 |
cur = ggml_add(ctx0,
|
| 1438 |
+
ggml_mul(ctx0,
|
| 1439 |
+
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
| 1440 |
+
cur),
|
| 1441 |
+
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
| 1442 |
}
|
| 1443 |
|
| 1444 |
// self-attention
|
| 1445 |
{
|
| 1446 |
+
wstate.use_buf(ctx0, 1);
|
| 1447 |
|
| 1448 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1449 |
+
layer.attn_q_w,
|
| 1450 |
+
cur);
|
| 1451 |
|
| 1452 |
Qcur = ggml_add(ctx0,
|
| 1453 |
+
ggml_repeat(ctx0,
|
| 1454 |
+
layer.attn_q_b,
|
| 1455 |
+
Qcur),
|
| 1456 |
+
Qcur);
|
| 1457 |
|
| 1458 |
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
| 1459 |
|
| 1460 |
// note: no bias for Key
|
| 1461 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 1462 |
+
layer.attn_k_w,
|
| 1463 |
+
cur);
|
| 1464 |
|
| 1465 |
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
| 1466 |
|
| 1467 |
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
| 1468 |
+
layer.attn_v_w,
|
| 1469 |
+
cur);
|
| 1470 |
|
| 1471 |
Vcur = ggml_add(ctx0,
|
| 1472 |
+
ggml_repeat(ctx0,
|
| 1473 |
+
layer.attn_v_b,
|
| 1474 |
+
Vcur),
|
| 1475 |
+
Vcur);
|
| 1476 |
|
| 1477 |
// ------
|
| 1478 |
|
| 1479 |
+
wstate.use_buf(ctx0, 0);
|
| 1480 |
|
| 1481 |
#ifdef WHISPER_USE_FLASH_ATTN
|
| 1482 |
struct ggml_tensor * Q =
|
|
|
|
| 1553 |
#endif
|
| 1554 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 1555 |
|
| 1556 |
+
wstate.use_buf(ctx0, 1);
|
| 1557 |
|
| 1558 |
cur = ggml_cpy(ctx0,
|
| 1559 |
+
KQV_merged,
|
| 1560 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
| 1561 |
}
|
| 1562 |
|
| 1563 |
// projection
|
| 1564 |
{
|
| 1565 |
+
wstate.use_buf(ctx0, 0);
|
| 1566 |
|
| 1567 |
cur = ggml_mul_mat(ctx0,
|
| 1568 |
+
layer.attn_ln_1_w,
|
| 1569 |
+
cur);
|
| 1570 |
|
| 1571 |
+
wstate.use_buf(ctx0, 1);
|
| 1572 |
|
| 1573 |
cur = ggml_add(ctx0,
|
| 1574 |
+
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
| 1575 |
+
cur);
|
| 1576 |
}
|
| 1577 |
|
| 1578 |
+
wstate.use_buf(ctx0, 2);
|
| 1579 |
|
| 1580 |
// add the input
|
| 1581 |
cur = ggml_add(ctx0, cur, inpL);
|
|
|
|
| 1586 |
{
|
| 1587 |
// norm
|
| 1588 |
{
|
| 1589 |
+
wstate.use_buf(ctx0, 0);
|
| 1590 |
|
| 1591 |
cur = ggml_norm(ctx0, inpFF);
|
| 1592 |
|
| 1593 |
+
wstate.use_buf(ctx0, 1);
|
| 1594 |
|
| 1595 |
// cur = mlp_ln_w*cur + mlp_ln_b
|
| 1596 |
cur = ggml_add(ctx0,
|
| 1597 |
+
ggml_mul(ctx0,
|
| 1598 |
+
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
| 1599 |
+
cur),
|
| 1600 |
+
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
| 1601 |
+
}
|
| 1602 |
|
| 1603 |
#ifdef WHISPER_USE_FLASH_FF
|
| 1604 |
+
wstate.use_buf(ctx0, 0);
|
| 1605 |
|
| 1606 |
cur = ggml_flash_ff(ctx0,
|
| 1607 |
+
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
|
| 1608 |
+
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
| 1609 |
#else
|
| 1610 |
+
wstate.use_buf(ctx0, 0);
|
| 1611 |
|
| 1612 |
// fully connected
|
| 1613 |
cur = ggml_mul_mat(ctx0,
|
| 1614 |
+
layer.mlp_0_w,
|
| 1615 |
+
cur);
|
| 1616 |
|
| 1617 |
+
wstate.use_buf(ctx0, 1);
|
| 1618 |
|
| 1619 |
cur = ggml_add(ctx0,
|
| 1620 |
+
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
| 1621 |
+
cur);
|
| 1622 |
|
| 1623 |
+
wstate.use_buf(ctx0, 0);
|
| 1624 |
|
| 1625 |
// GELU activation
|
| 1626 |
cur = ggml_gelu(ctx0, cur);
|
| 1627 |
|
| 1628 |
+
wstate.use_buf(ctx0, 1);
|
| 1629 |
|
| 1630 |
// projection
|
| 1631 |
cur = ggml_mul_mat(ctx0,
|
| 1632 |
+
layer.mlp_1_w,
|
| 1633 |
+
cur);
|
| 1634 |
|
| 1635 |
+
wstate.use_buf(ctx0, 0);
|
| 1636 |
|
| 1637 |
cur = ggml_add(ctx0,
|
| 1638 |
+
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
| 1639 |
+
cur);
|
| 1640 |
#endif
|
| 1641 |
+
}
|
| 1642 |
|
| 1643 |
+
wstate.use_buf(ctx0, 3);
|
| 1644 |
|
| 1645 |
inpL = ggml_add(ctx0, cur, inpFF);
|
| 1646 |
}
|
|
|
|
| 1649 |
|
| 1650 |
// norm
|
| 1651 |
{
|
| 1652 |
+
wstate.use_buf(ctx0, 0);
|
| 1653 |
|
| 1654 |
cur = ggml_norm(ctx0, cur);
|
| 1655 |
|
| 1656 |
+
wstate.use_buf(ctx0, 1);
|
| 1657 |
|
| 1658 |
// cur = ln_f_g*cur + ln_f_b
|
| 1659 |
cur = ggml_add(ctx0,
|
| 1660 |
+
ggml_mul(ctx0,
|
| 1661 |
+
ggml_repeat(ctx0, model.e_ln_w, cur),
|
| 1662 |
+
cur),
|
| 1663 |
+
ggml_repeat(ctx0, model.e_ln_b, cur));
|
| 1664 |
}
|
| 1665 |
|
| 1666 |
+
wstate.use_buf(ctx0, -1);
|
| 1667 |
|
| 1668 |
// run the computation
|
| 1669 |
{
|
|
|
|
| 1671 |
gf.n_threads = n_threads;
|
| 1672 |
|
| 1673 |
ggml_build_forward_expand(&gf, cur);
|
| 1674 |
+
ggml_graph_compute(ctx0, &gf);
|
| 1675 |
|
| 1676 |
//ggml_graph_print(&gf);
|
| 1677 |
}
|
|
|
|
| 1701 |
cur->src1 = nullptr;
|
| 1702 |
|
| 1703 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 1704 |
+
auto& layer = model.layers_decoder[il];
|
| 1705 |
|
| 1706 |
+
wstate.use_buf(ctx0, 0);
|
| 1707 |
|
| 1708 |
+
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
| 1709 |
+
layer.cross_attn_k_w,
|
| 1710 |
+
cur);
|
| 1711 |
|
| 1712 |
+
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
|
| 1713 |
|
| 1714 |
+
wstate.use_buf(ctx0, 1);
|
| 1715 |
|
| 1716 |
+
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
| 1717 |
+
layer.cross_attn_v_w,
|
| 1718 |
+
cur);
|
| 1719 |
|
| 1720 |
Vcross = ggml_add(ctx0,
|
| 1721 |
+
ggml_repeat(ctx0,
|
| 1722 |
+
layer.cross_attn_v_b,
|
| 1723 |
+
Vcross),
|
| 1724 |
+
Vcross);
|
| 1725 |
|
| 1726 |
+
wstate.use_buf(ctx0, -1);
|
| 1727 |
|
| 1728 |
+
//struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
| 1729 |
+
//struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
| 1730 |
+
struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
| 1731 |
+
struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
|
| 1732 |
|
| 1733 |
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
| 1734 |
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
|
|
|
| 1749 |
|
| 1750 |
ggml_free(ctx0);
|
| 1751 |
|
| 1752 |
+
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
| 1753 |
+
wstate.n_encode++;
|
| 1754 |
|
| 1755 |
return true;
|
| 1756 |
}
|
|
|
|
| 1765 |
// - n_tokens: number of tokens in the prompt
|
| 1766 |
// - n_past: number of past tokens to prefix the prompt with
|
| 1767 |
//
|
| 1768 |
+
static bool whisper_decode_internal(
|
| 1769 |
whisper_context & wctx,
|
| 1770 |
+
whisper_state & wstate,
|
| 1771 |
whisper_decoder & decoder,
|
| 1772 |
const whisper_token * tokens,
|
| 1773 |
const int n_tokens,
|
|
|
|
| 1782 |
|
| 1783 |
WHISPER_ASSERT(!!kv_self.ctx);
|
| 1784 |
|
| 1785 |
+
auto & logits_out = wstate.logits;
|
| 1786 |
|
| 1787 |
const int n_vocab = hparams.n_vocab;
|
| 1788 |
|
|
|
|
| 1792 |
const int n_layer = hparams.n_text_layer;
|
| 1793 |
|
| 1794 |
const int N = n_tokens;
|
| 1795 |
+
const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 1796 |
|
| 1797 |
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
| 1798 |
|
| 1799 |
struct ggml_init_params params;
|
| 1800 |
+
params.mem_size = wstate.buf_compute.size();
|
| 1801 |
+
params.mem_buffer = wstate.buf_compute.data();
|
| 1802 |
|
| 1803 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1804 |
|
|
|
|
| 1813 |
((int32_t *) position->data)[i] = n_past + i;
|
| 1814 |
}
|
| 1815 |
|
| 1816 |
+
wstate.use_buf(ctx0, 3);
|
| 1817 |
|
| 1818 |
// token encoding + position encoding
|
| 1819 |
struct ggml_tensor * cur =
|
|
|
|
| 1828 |
|
| 1829 |
// norm
|
| 1830 |
{
|
| 1831 |
+
wstate.use_buf(ctx0, 0);
|
| 1832 |
|
| 1833 |
cur = ggml_norm(ctx0, inpL);
|
| 1834 |
|
|
|
|
| 1842 |
|
| 1843 |
// self-attention
|
| 1844 |
{
|
| 1845 |
+
wstate.use_buf(ctx0, 1);
|
| 1846 |
|
| 1847 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1848 |
layer.attn_q_w,
|
|
|
|
| 1884 |
|
| 1885 |
// ------
|
| 1886 |
|
| 1887 |
+
wstate.use_buf(ctx0, 0);
|
| 1888 |
|
| 1889 |
struct ggml_tensor * Q =
|
| 1890 |
ggml_permute(ctx0,
|
|
|
|
| 1900 |
n_state/n_head, n_head, n_past + N),
|
| 1901 |
0, 2, 1, 3);
|
| 1902 |
|
| 1903 |
+
wstate.use_buf(ctx0, 1);
|
| 1904 |
|
| 1905 |
// K * Q
|
| 1906 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 1907 |
|
| 1908 |
+
wstate.use_buf(ctx0, 0);
|
| 1909 |
|
| 1910 |
//struct ggml_tensor * KQ_scaled =
|
| 1911 |
// ggml_scale(ctx0,
|
|
|
|
| 1915 |
|
| 1916 |
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
| 1917 |
|
| 1918 |
+
wstate.use_buf(ctx0, 1);
|
| 1919 |
|
| 1920 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 1921 |
|
| 1922 |
+
wstate.use_buf(ctx0, 0);
|
| 1923 |
|
| 1924 |
struct ggml_tensor * V_trans =
|
| 1925 |
ggml_permute(ctx0,
|
|
|
|
| 1928 |
n_state/n_head, n_head, n_past + N),
|
| 1929 |
1, 2, 0, 3);
|
| 1930 |
|
| 1931 |
+
wstate.use_buf(ctx0, 1);
|
| 1932 |
|
| 1933 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 1934 |
|
|
|
|
| 1941 |
|
| 1942 |
// projection
|
| 1943 |
{
|
| 1944 |
+
wstate.use_buf(ctx0, 0);
|
| 1945 |
|
| 1946 |
cur = ggml_mul_mat(ctx0,
|
| 1947 |
layer.attn_ln_1_w,
|
| 1948 |
cur);
|
| 1949 |
|
| 1950 |
+
wstate.use_buf(ctx0, 1);
|
| 1951 |
|
| 1952 |
cur = ggml_add(ctx0,
|
| 1953 |
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
| 1954 |
cur);
|
| 1955 |
}
|
| 1956 |
|
| 1957 |
+
wstate.use_buf(ctx0, 2);
|
| 1958 |
|
| 1959 |
// add the input
|
| 1960 |
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
| 1961 |
|
| 1962 |
// norm
|
| 1963 |
{
|
| 1964 |
+
wstate.use_buf(ctx0, 0);
|
| 1965 |
|
| 1966 |
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
| 1967 |
|
| 1968 |
+
wstate.use_buf(ctx0, 1);
|
| 1969 |
|
| 1970 |
// cur = ln_0_w*cur + ln_0_b
|
| 1971 |
cur = ggml_add(ctx0,
|
|
|
|
| 1977 |
|
| 1978 |
// cross-attention
|
| 1979 |
{
|
| 1980 |
+
wstate.use_buf(ctx0, 0);
|
| 1981 |
|
| 1982 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1983 |
layer.cross_attn_q_w,
|
|
|
|
| 1994 |
// Kcross is already scaled
|
| 1995 |
struct ggml_tensor * Kcross =
|
| 1996 |
ggml_reshape_3d(ctx0,
|
| 1997 |
+
ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
|
| 1998 |
n_state/n_head, n_head, M);
|
| 1999 |
|
| 2000 |
struct ggml_tensor * Vcross =
|
| 2001 |
ggml_reshape_3d(ctx0,
|
| 2002 |
+
ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
|
| 2003 |
n_state/n_head, n_head, M);
|
| 2004 |
|
| 2005 |
struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
|
| 2006 |
|
| 2007 |
// ------
|
| 2008 |
|
| 2009 |
+
wstate.use_buf(ctx0, 1);
|
| 2010 |
|
| 2011 |
struct ggml_tensor * Q =
|
| 2012 |
ggml_permute(ctx0,
|
|
|
|
| 2017 |
|
| 2018 |
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
| 2019 |
|
| 2020 |
+
wstate.use_buf(ctx0, 0);
|
| 2021 |
|
| 2022 |
// K * Q
|
| 2023 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
|
|
| 2031 |
// no masking for cross-attention
|
| 2032 |
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
| 2033 |
|
| 2034 |
+
wstate.use_buf(ctx0, 1);
|
| 2035 |
|
| 2036 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
| 2037 |
|
| 2038 |
+
wstate.use_buf(ctx0, 0);
|
| 2039 |
|
| 2040 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 2041 |
|
| 2042 |
+
wstate.use_buf(ctx0, 1);
|
| 2043 |
|
| 2044 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2045 |
|
|
|
|
| 2051 |
|
| 2052 |
// projection
|
| 2053 |
{
|
| 2054 |
+
wstate.use_buf(ctx0, 0);
|
| 2055 |
|
| 2056 |
cur = ggml_mul_mat(ctx0,
|
| 2057 |
layer.cross_attn_ln_1_w,
|
| 2058 |
cur);
|
| 2059 |
|
| 2060 |
+
wstate.use_buf(ctx0, 1);
|
| 2061 |
|
| 2062 |
cur = ggml_add(ctx0,
|
| 2063 |
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
| 2064 |
cur);
|
| 2065 |
}
|
| 2066 |
|
| 2067 |
+
wstate.use_buf(ctx0, 2);
|
| 2068 |
|
| 2069 |
// add the input
|
| 2070 |
cur = ggml_add(ctx0, cur, inpCA);
|
|
|
|
| 2075 |
{
|
| 2076 |
// norm
|
| 2077 |
{
|
| 2078 |
+
wstate.use_buf(ctx0, 0);
|
| 2079 |
|
| 2080 |
cur = ggml_norm(ctx0, inpFF);
|
| 2081 |
|
| 2082 |
+
wstate.use_buf(ctx0, 1);
|
| 2083 |
|
| 2084 |
// cur = mlp_ln_w*cur + mlp_ln_b
|
| 2085 |
cur = ggml_add(ctx0,
|
|
|
|
| 2089 |
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
| 2090 |
}
|
| 2091 |
|
| 2092 |
+
wstate.use_buf(ctx0, 0);
|
| 2093 |
|
| 2094 |
// fully connected
|
| 2095 |
cur = ggml_mul_mat(ctx0,
|
| 2096 |
layer.mlp_0_w,
|
| 2097 |
cur);
|
| 2098 |
|
| 2099 |
+
wstate.use_buf(ctx0, 1);
|
| 2100 |
|
| 2101 |
cur = ggml_add(ctx0,
|
| 2102 |
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
| 2103 |
cur);
|
| 2104 |
|
| 2105 |
+
wstate.use_buf(ctx0, 0);
|
| 2106 |
|
| 2107 |
// GELU activation
|
| 2108 |
cur = ggml_gelu(ctx0, cur);
|
| 2109 |
|
| 2110 |
+
wstate.use_buf(ctx0, 1);
|
| 2111 |
|
| 2112 |
// projection
|
| 2113 |
cur = ggml_mul_mat(ctx0,
|
| 2114 |
layer.mlp_1_w,
|
| 2115 |
cur);
|
| 2116 |
|
| 2117 |
+
wstate.use_buf(ctx0, 0);
|
| 2118 |
|
| 2119 |
cur = ggml_add(ctx0,
|
| 2120 |
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
| 2121 |
cur);
|
| 2122 |
}
|
| 2123 |
|
| 2124 |
+
wstate.use_buf(ctx0, 3);
|
| 2125 |
|
| 2126 |
inpL = ggml_add(ctx0, cur, inpFF);
|
| 2127 |
}
|
|
|
|
| 2130 |
|
| 2131 |
// norm
|
| 2132 |
{
|
| 2133 |
+
wstate.use_buf(ctx0, 0);
|
| 2134 |
|
| 2135 |
cur = ggml_norm(ctx0, cur);
|
| 2136 |
|
| 2137 |
+
wstate.use_buf(ctx0, 1);
|
| 2138 |
|
| 2139 |
cur = ggml_add(ctx0,
|
| 2140 |
ggml_mul(ctx0,
|
|
|
|
| 2143 |
ggml_repeat(ctx0, model.d_ln_b, cur));
|
| 2144 |
}
|
| 2145 |
|
| 2146 |
+
wstate.use_buf(ctx0, 0);
|
| 2147 |
|
| 2148 |
// compute logits only for the last token
|
| 2149 |
// comment this line to compute logits for all N tokens
|
|
|
|
| 2152 |
|
| 2153 |
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
| 2154 |
|
| 2155 |
+
wstate.use_buf(ctx0, -1);
|
| 2156 |
|
| 2157 |
// run the computation
|
| 2158 |
{
|
|
|
|
| 2179 |
|
| 2180 |
ggml_free(ctx0);
|
| 2181 |
|
| 2182 |
+
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
| 2183 |
+
wstate.n_decode++;
|
| 2184 |
|
| 2185 |
return true;
|
| 2186 |
}
|
|
|
|
| 2284 |
|
| 2285 |
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
| 2286 |
static bool log_mel_spectrogram(
|
| 2287 |
+
whisper_state & wstate,
|
| 2288 |
const float * samples,
|
| 2289 |
const int n_samples,
|
| 2290 |
const int /*sample_rate*/,
|
|
|
|
| 2404 |
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
| 2405 |
}
|
| 2406 |
|
| 2407 |
+
wstate.t_mel_us += ggml_time_us() - t_start_us;
|
| 2408 |
|
| 2409 |
return true;
|
| 2410 |
}
|
|
|
|
| 2478 |
// interface implementation
|
| 2479 |
//
|
| 2480 |
|
| 2481 |
+
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
| 2482 |
+
whisper_state * state = new whisper_state;
|
| 2483 |
+
|
| 2484 |
+
const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
|
| 2485 |
+
|
| 2486 |
+
|
| 2487 |
+
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
|
| 2488 |
+
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 2489 |
+
return nullptr;
|
| 2490 |
+
}
|
| 2491 |
+
|
| 2492 |
+
{
|
| 2493 |
+
const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
|
| 2494 |
+
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2495 |
+
}
|
| 2496 |
+
|
| 2497 |
+
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
|
| 2498 |
+
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 2499 |
+
return nullptr;
|
| 2500 |
+
}
|
| 2501 |
+
|
| 2502 |
+
{
|
| 2503 |
+
const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
|
| 2504 |
+
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2505 |
+
}
|
| 2506 |
+
|
| 2507 |
+
|
| 2508 |
+
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
| 2509 |
+
|
| 2510 |
+
state->logits_id.reserve(ctx->model.hparams.n_vocab);
|
| 2511 |
+
|
| 2512 |
+
// TAGS: WHISPER_DECODER_INIT
|
| 2513 |
+
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
| 2514 |
+
|
| 2515 |
+
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
| 2516 |
+
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
| 2517 |
+
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
| 2518 |
+
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
|
| 2519 |
+
|
| 2520 |
+
state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
|
| 2521 |
+
state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
|
| 2522 |
+
state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
|
| 2523 |
+
state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
|
| 2524 |
+
|
| 2525 |
+
state->rng = std::mt19937(0);
|
| 2526 |
+
|
| 2527 |
+
return state;
|
| 2528 |
+
}
|
| 2529 |
+
|
| 2530 |
+
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
| 2531 |
whisper_model_loader loader = {};
|
| 2532 |
|
| 2533 |
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
|
|
|
|
| 2555 |
fin->close();
|
| 2556 |
};
|
| 2557 |
|
| 2558 |
+
return whisper_init_no_state(&loader);
|
| 2559 |
}
|
| 2560 |
|
| 2561 |
+
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
|
| 2562 |
struct buf_context {
|
| 2563 |
uint8_t* buffer;
|
| 2564 |
size_t size;
|
|
|
|
| 2591 |
|
| 2592 |
loader.close = [](void * /*ctx*/) { };
|
| 2593 |
|
| 2594 |
+
return whisper_init_no_state(&loader);
|
| 2595 |
}
|
| 2596 |
|
| 2597 |
+
struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
|
| 2598 |
ggml_time_init();
|
| 2599 |
|
| 2600 |
whisper_context * ctx = new whisper_context;
|
|
|
|
| 2611 |
return ctx;
|
| 2612 |
}
|
| 2613 |
|
| 2614 |
+
struct whisper_context * whisper_init_from_file(const char * path_model) {
|
| 2615 |
+
whisper_context * ctx = whisper_init_from_file_no_state(path_model);
|
| 2616 |
+
if (!ctx) {
|
| 2617 |
+
return nullptr;
|
| 2618 |
+
}
|
| 2619 |
+
|
| 2620 |
+
ctx->state = whisper_init_state(ctx);
|
| 2621 |
+
if (!ctx->state) {
|
| 2622 |
+
whisper_free(ctx);
|
| 2623 |
+
return nullptr;
|
| 2624 |
+
}
|
| 2625 |
+
|
| 2626 |
+
return ctx;
|
| 2627 |
+
}
|
| 2628 |
+
|
| 2629 |
+
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
|
| 2630 |
+
whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
|
| 2631 |
+
if (!ctx) {
|
| 2632 |
+
return nullptr;
|
| 2633 |
+
}
|
| 2634 |
+
|
| 2635 |
+
ctx->state = whisper_init_state(ctx);
|
| 2636 |
+
if (!ctx->state) {
|
| 2637 |
+
whisper_free(ctx);
|
| 2638 |
+
return nullptr;
|
| 2639 |
+
}
|
| 2640 |
+
|
| 2641 |
+
return ctx;
|
| 2642 |
+
}
|
| 2643 |
+
|
| 2644 |
+
struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
| 2645 |
+
whisper_context * ctx = whisper_init_no_state(loader);
|
| 2646 |
+
if (!ctx) {
|
| 2647 |
+
return nullptr;
|
| 2648 |
+
}
|
| 2649 |
+
|
| 2650 |
+
ctx->state = whisper_init_state(ctx);
|
| 2651 |
+
if (!ctx->state) {
|
| 2652 |
+
whisper_free(ctx);
|
| 2653 |
+
return nullptr;
|
| 2654 |
+
}
|
| 2655 |
+
|
| 2656 |
+
return ctx;
|
| 2657 |
+
}
|
| 2658 |
+
|
| 2659 |
+
void whisper_free_state(struct whisper_state * state)
|
| 2660 |
+
{
|
| 2661 |
+
if (state) {
|
| 2662 |
+
kv_cache_free(state->kv_cross);
|
| 2663 |
+
|
| 2664 |
+
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
| 2665 |
+
kv_cache_free(state->decoders[i].kv_self);
|
| 2666 |
+
}
|
| 2667 |
+
|
| 2668 |
+
delete state;
|
| 2669 |
+
}
|
| 2670 |
+
}
|
| 2671 |
+
|
| 2672 |
void whisper_free(struct whisper_context * ctx) {
|
| 2673 |
if (ctx) {
|
| 2674 |
if (ctx->model.ctx) {
|
|
|
|
| 2677 |
if (ctx->model.buf) {
|
| 2678 |
delete ctx->model.buf;
|
| 2679 |
}
|
| 2680 |
+
|
| 2681 |
+
whisper_free_state(ctx->state);
|
| 2682 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2683 |
delete ctx;
|
| 2684 |
}
|
| 2685 |
}
|
| 2686 |
|
| 2687 |
+
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 2688 |
+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
| 2689 |
+
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
| 2690 |
+
return -1;
|
| 2691 |
+
}
|
| 2692 |
+
|
| 2693 |
+
return 0;
|
| 2694 |
+
}
|
| 2695 |
+
|
| 2696 |
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
| 2697 |
+
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
| 2698 |
+
}
|
| 2699 |
+
|
| 2700 |
+
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
| 2701 |
+
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 2702 |
+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
|
| 2703 |
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
| 2704 |
return -1;
|
| 2705 |
}
|
|
|
|
| 2709 |
|
| 2710 |
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
| 2711 |
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
| 2712 |
+
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
| 2713 |
+
}
|
| 2714 |
+
|
| 2715 |
+
int whisper_set_mel_with_state(
|
| 2716 |
+
struct whisper_context * /*ctx*/,
|
| 2717 |
+
struct whisper_state * state,
|
| 2718 |
+
const float * data,
|
| 2719 |
+
int n_len,
|
| 2720 |
+
int n_mel) {
|
| 2721 |
+
if (n_mel != WHISPER_N_MEL) {
|
| 2722 |
+
fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
|
| 2723 |
return -1;
|
| 2724 |
}
|
| 2725 |
|
| 2726 |
+
state->mel.n_len = n_len;
|
| 2727 |
+
state->mel.n_mel = n_mel;
|
| 2728 |
+
|
| 2729 |
+
state->mel.data.resize(n_len*n_mel);
|
| 2730 |
+
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
| 2731 |
+
|
| 2732 |
return 0;
|
| 2733 |
}
|
| 2734 |
|
|
|
|
| 2737 |
const float * data,
|
| 2738 |
int n_len,
|
| 2739 |
int n_mel) {
|
| 2740 |
+
return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
|
| 2741 |
+
}
|
| 2742 |
+
|
| 2743 |
+
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
| 2744 |
+
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
|
| 2745 |
+
fprintf(stderr, "%s: failed to eval\n", __func__);
|
| 2746 |
return -1;
|
| 2747 |
}
|
| 2748 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2749 |
return 0;
|
| 2750 |
}
|
| 2751 |
|
| 2752 |
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
| 2753 |
+
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
|
| 2754 |
fprintf(stderr, "%s: failed to eval\n", __func__);
|
| 2755 |
return -1;
|
| 2756 |
}
|
|
|
|
| 2758 |
return 0;
|
| 2759 |
}
|
| 2760 |
|
| 2761 |
+
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 2762 |
+
const int selected_decoder_id = 0;
|
| 2763 |
+
|
| 2764 |
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
| 2765 |
+
fprintf(stderr, "%s: failed to eval\n", __func__);
|
| 2766 |
+
return 1;
|
| 2767 |
+
}
|
| 2768 |
+
|
| 2769 |
+
return 0;
|
| 2770 |
+
}
|
| 2771 |
+
|
| 2772 |
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
| 2773 |
+
// TODO: add selected_decoder_id to state
|
| 2774 |
const int selected_decoder_id = 0;
|
| 2775 |
|
| 2776 |
+
if (ctx->state == nullptr) {
|
| 2777 |
+
fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
|
| 2778 |
+
return false;
|
| 2779 |
+
}
|
| 2780 |
+
|
| 2781 |
+
|
| 2782 |
+
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
| 2783 |
fprintf(stderr, "%s: failed to eval\n", __func__);
|
| 2784 |
return 1;
|
| 2785 |
}
|
|
|
|
| 2837 |
return nullptr;
|
| 2838 |
}
|
| 2839 |
|
| 2840 |
+
int whisper_lang_auto_detect_with_state(
|
| 2841 |
struct whisper_context * ctx,
|
| 2842 |
+
struct whisper_state * state,
|
| 2843 |
+
int offset_ms,
|
| 2844 |
+
int n_threads,
|
| 2845 |
+
float * lang_probs) {
|
| 2846 |
const int seek = offset_ms/10;
|
| 2847 |
|
| 2848 |
if (seek < 0) {
|
|
|
|
| 2850 |
return -1;
|
| 2851 |
}
|
| 2852 |
|
| 2853 |
+
if (seek >= state->mel.n_len) {
|
| 2854 |
+
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
|
| 2855 |
return -2;
|
| 2856 |
}
|
| 2857 |
|
|
|
|
| 2863 |
|
| 2864 |
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
| 2865 |
|
| 2866 |
+
if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
| 2867 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 2868 |
return -7;
|
| 2869 |
}
|
| 2870 |
|
| 2871 |
+
auto & logits_id = state->logits_id;
|
| 2872 |
logits_id.clear();
|
| 2873 |
|
| 2874 |
for (const auto & kv : g_lang) {
|
| 2875 |
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
| 2876 |
+
logits_id.emplace_back(state->logits[token_lang], kv.second.first);
|
| 2877 |
}
|
| 2878 |
|
| 2879 |
// sort descending
|
|
|
|
| 2912 |
return logits_id[0].second;
|
| 2913 |
}
|
| 2914 |
|
| 2915 |
+
int whisper_lang_auto_detect(
|
| 2916 |
+
struct whisper_context * ctx,
|
| 2917 |
+
int offset_ms,
|
| 2918 |
+
int n_threads,
|
| 2919 |
+
float * lang_probs) {
|
| 2920 |
+
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
|
| 2921 |
+
}
|
| 2922 |
+
|
| 2923 |
+
int whisper_n_len_from_state(struct whisper_state * state) {
|
| 2924 |
+
return state->mel.n_len;
|
| 2925 |
+
}
|
| 2926 |
+
|
| 2927 |
int whisper_n_len(struct whisper_context * ctx) {
|
| 2928 |
+
return ctx->state->mel.n_len;
|
| 2929 |
}
|
| 2930 |
|
| 2931 |
int whisper_n_vocab(struct whisper_context * ctx) {
|
|
|
|
| 2945 |
}
|
| 2946 |
|
| 2947 |
float * whisper_get_logits(struct whisper_context * ctx) {
|
| 2948 |
+
return ctx->state->logits.data();
|
| 2949 |
+
}
|
| 2950 |
+
|
| 2951 |
+
|
| 2952 |
+
float * whisper_get_logits_from_state(struct whisper_state * state) {
|
| 2953 |
+
return state->logits.data();
|
| 2954 |
}
|
| 2955 |
|
| 2956 |
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
|
|
|
| 2996 |
void whisper_print_timings(struct whisper_context * ctx) {
|
| 2997 |
const int64_t t_end_us = ggml_time_us();
|
| 2998 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2999 |
fprintf(stderr, "\n");
|
| 3000 |
+
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
| 3001 |
+
if (ctx->state != nullptr) {
|
| 3002 |
+
|
| 3003 |
+
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
| 3004 |
+
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
| 3005 |
+
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
| 3006 |
+
|
| 3007 |
+
fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
| 3008 |
+
fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
| 3009 |
+
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
| 3010 |
+
fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
| 3011 |
+
fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
| 3012 |
+
}
|
| 3013 |
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
| 3014 |
}
|
| 3015 |
|
| 3016 |
void whisper_reset_timings(struct whisper_context * ctx) {
|
| 3017 |
+
if (ctx->state != nullptr) {
|
| 3018 |
+
ctx->state->t_sample_us = 0;
|
| 3019 |
+
ctx->state->t_encode_us = 0;
|
| 3020 |
+
ctx->state->t_decode_us = 0;
|
| 3021 |
+
}
|
| 3022 |
}
|
| 3023 |
|
| 3024 |
const char * whisper_print_system_info(void) {
|
|
|
|
| 3131 |
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
| 3132 |
static void whisper_exp_compute_token_level_timestamps(
|
| 3133 |
struct whisper_context & ctx,
|
| 3134 |
+
struct whisper_state & state,
|
| 3135 |
int i_segment,
|
| 3136 |
float thold_pt,
|
| 3137 |
float thold_ptsum);
|
|
|
|
| 3164 |
|
| 3165 |
// wrap the last segment to max_len characters
|
| 3166 |
// returns the number of new segments
|
| 3167 |
+
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
|
| 3168 |
+
auto segment = state.result_all.back();
|
| 3169 |
|
| 3170 |
int res = 1;
|
| 3171 |
int acc = 0;
|
|
|
|
| 3187 |
trim(text);
|
| 3188 |
}
|
| 3189 |
|
| 3190 |
+
state.result_all.back().text = std::move(text);
|
| 3191 |
+
state.result_all.back().t1 = token.t0;
|
| 3192 |
+
state.result_all.back().tokens.resize(i);
|
| 3193 |
|
| 3194 |
+
state.result_all.push_back({});
|
| 3195 |
+
state.result_all.back().t0 = token.t0;
|
| 3196 |
+
state.result_all.back().t1 = segment.t1;
|
| 3197 |
|
| 3198 |
// add tokens [i, end] to the new segment
|
| 3199 |
+
state.result_all.back().tokens.insert(
|
| 3200 |
+
state.result_all.back().tokens.end(),
|
| 3201 |
segment.tokens.begin() + i,
|
| 3202 |
segment.tokens.end());
|
| 3203 |
|
| 3204 |
acc = 0;
|
| 3205 |
text = "";
|
| 3206 |
|
| 3207 |
+
segment = state.result_all.back();
|
| 3208 |
i = -1;
|
| 3209 |
|
| 3210 |
res++;
|
|
|
|
| 3217 |
if (split_on_word) {
|
| 3218 |
trim(text);
|
| 3219 |
}
|
| 3220 |
+
state.result_all.back().text = std::move(text);
|
| 3221 |
|
| 3222 |
return res;
|
| 3223 |
}
|
|
|
|
| 3234 |
// - computes logprobs and probs
|
| 3235 |
static void whisper_process_logits(
|
| 3236 |
struct whisper_context & ctx,
|
| 3237 |
+
struct whisper_state & state,
|
| 3238 |
const struct whisper_full_params params,
|
| 3239 |
struct whisper_decoder & decoder,
|
| 3240 |
float temperature) {
|
|
|
|
| 3253 |
auto & logprobs = decoder.logprobs;
|
| 3254 |
{
|
| 3255 |
logits.resize(n_logits);
|
| 3256 |
+
memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
|
| 3257 |
|
| 3258 |
if (temperature > 0.0f) {
|
| 3259 |
for (int i = 0; i < n_logits; i++) {
|
|
|
|
| 3291 |
logits[vocab.token_transcribe] = -INFINITY;
|
| 3292 |
|
| 3293 |
if (params.logits_filter_callback) {
|
| 3294 |
+
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
| 3295 |
}
|
| 3296 |
|
| 3297 |
// suppress non-speech tokens
|
|
|
|
| 3452 |
|
| 3453 |
static whisper_token_data whisper_sample_token(
|
| 3454 |
whisper_context & ctx,
|
| 3455 |
+
whisper_state & state,
|
| 3456 |
const whisper_decoder & decoder,
|
| 3457 |
bool best) {
|
| 3458 |
whisper_token_data result = {
|
|
|
|
| 3497 |
} else {
|
| 3498 |
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 3499 |
|
| 3500 |
+
result.id = dist(state.rng);
|
| 3501 |
result.p = probs[result.id];
|
| 3502 |
result.plog = logprobs[result.id];
|
| 3503 |
}
|
|
|
|
| 3507 |
result.pt = result.p;
|
| 3508 |
}
|
| 3509 |
|
| 3510 |
+
state.n_sample++;
|
| 3511 |
|
| 3512 |
return result;
|
| 3513 |
}
|
| 3514 |
|
| 3515 |
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
| 3516 |
whisper_context & ctx,
|
| 3517 |
+
whisper_state & state,
|
| 3518 |
const whisper_decoder & decoder,
|
| 3519 |
int k) {
|
| 3520 |
const auto & vocab = ctx.vocab;
|
|
|
|
| 3525 |
|
| 3526 |
const int n_logits = vocab.n_vocab;
|
| 3527 |
|
| 3528 |
+
auto & logits_id = state.logits_id;
|
| 3529 |
|
| 3530 |
logits_id.clear();
|
| 3531 |
for (int i = 0; i < n_logits; ++i) {
|
|
|
|
| 3578 |
}
|
| 3579 |
}
|
| 3580 |
|
| 3581 |
+
state.n_sample++;
|
| 3582 |
|
| 3583 |
return result;
|
| 3584 |
}
|
|
|
|
| 3632 |
}
|
| 3633 |
}
|
| 3634 |
|
| 3635 |
+
int whisper_full_with_state(
|
| 3636 |
struct whisper_context * ctx,
|
| 3637 |
+
struct whisper_state * state,
|
| 3638 |
+
struct whisper_full_params params,
|
| 3639 |
+
const float * samples,
|
| 3640 |
+
int n_samples) {
|
| 3641 |
// clear old results
|
| 3642 |
+
auto & result_all = state->result_all;
|
| 3643 |
|
| 3644 |
result_all.clear();
|
| 3645 |
|
| 3646 |
// compute log mel spectrogram
|
| 3647 |
if (params.speed_up) {
|
| 3648 |
+
if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
| 3649 |
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
| 3650 |
return -1;
|
| 3651 |
}
|
| 3652 |
} else {
|
| 3653 |
+
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
| 3654 |
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
| 3655 |
return -2;
|
| 3656 |
}
|
|
|
|
| 3660 |
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
| 3661 |
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
| 3662 |
|
| 3663 |
+
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
| 3664 |
if (lang_id < 0) {
|
| 3665 |
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
| 3666 |
return -3;
|
| 3667 |
}
|
| 3668 |
+
state->lang_id = lang_id;
|
| 3669 |
params.language = whisper_lang_str(lang_id);
|
| 3670 |
|
| 3671 |
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
| 3672 |
}
|
| 3673 |
|
| 3674 |
if (params.token_timestamps) {
|
| 3675 |
+
state->t_beg = 0;
|
| 3676 |
+
state->t_last = 0;
|
| 3677 |
+
state->tid_last = 0;
|
| 3678 |
+
state->energy = get_signal_energy(samples, n_samples, 32);
|
| 3679 |
}
|
| 3680 |
|
| 3681 |
const int seek_start = params.offset_ms/10;
|
| 3682 |
+
const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10);
|
| 3683 |
|
| 3684 |
// if length of spectrogram is less than 1s (100 samples), then return
|
| 3685 |
// basically don't process anything that is less than 1s
|
|
|
|
| 3717 |
|
| 3718 |
// TAGS: WHISPER_DECODER_INIT
|
| 3719 |
for (int j = 1; j < n_decoders; j++) {
|
| 3720 |
+
auto & decoder = state->decoders[j];
|
| 3721 |
|
| 3722 |
if (decoder.kv_self.ctx == nullptr) {
|
| 3723 |
+
decoder.kv_self = state->decoders[0].kv_self;
|
| 3724 |
if (!kv_cache_reinit(decoder.kv_self)) {
|
| 3725 |
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
| 3726 |
return -4;
|
|
|
|
| 3728 |
|
| 3729 |
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
| 3730 |
|
| 3731 |
+
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
| 3732 |
|
| 3733 |
decoder.probs.resize (ctx->vocab.n_vocab);
|
| 3734 |
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
|
|
| 3737 |
}
|
| 3738 |
|
| 3739 |
// the accumulated text context so far
|
| 3740 |
+
auto & prompt_past = state->prompt_past;
|
| 3741 |
if (params.no_context) {
|
| 3742 |
prompt_past.clear();
|
| 3743 |
}
|
|
|
|
| 3756 |
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
| 3757 |
return -5;
|
| 3758 |
}
|
| 3759 |
+
state->exp_n_audio_ctx = params.audio_ctx;
|
| 3760 |
|
| 3761 |
// these tokens determine the task that will be performed
|
| 3762 |
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
| 3763 |
if (whisper_is_multilingual(ctx)) {
|
| 3764 |
const int lang_id = whisper_lang_id(params.language);
|
| 3765 |
+
state->lang_id = lang_id;
|
| 3766 |
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
| 3767 |
if (params.translate) {
|
| 3768 |
prompt_init.push_back(whisper_token_translate());
|
|
|
|
| 3814 |
}
|
| 3815 |
|
| 3816 |
if (params.encoder_begin_callback) {
|
| 3817 |
+
if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
|
| 3818 |
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
| 3819 |
break;
|
| 3820 |
}
|
| 3821 |
}
|
| 3822 |
|
| 3823 |
// encode audio features starting at offset seek
|
| 3824 |
+
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
|
| 3825 |
fprintf(stderr, "%s: failed to encode\n", __func__);
|
| 3826 |
return -6;
|
| 3827 |
}
|
|
|
|
| 3862 |
|
| 3863 |
// TAGS: WHISPER_DECODER_INIT
|
| 3864 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3865 |
+
auto & decoder = state->decoders[j];
|
| 3866 |
|
| 3867 |
decoder.kv_self.n = 0;
|
| 3868 |
|
|
|
|
| 3904 |
}
|
| 3905 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 3906 |
|
| 3907 |
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
| 3908 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 3909 |
return -7;
|
| 3910 |
}
|
|
|
|
| 3912 |
{
|
| 3913 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 3914 |
|
| 3915 |
+
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
|
| 3916 |
|
| 3917 |
+
state->decoders[0].kv_self.n += prompt.size();
|
| 3918 |
|
| 3919 |
for (int j = 1; j < n_decoders_cur; ++j) {
|
| 3920 |
+
auto & decoder = state->decoders[j];
|
| 3921 |
|
| 3922 |
+
memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
| 3923 |
+
memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
| 3924 |
|
| 3925 |
decoder.kv_self.n += prompt.size();
|
| 3926 |
|
| 3927 |
+
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
| 3928 |
+
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
| 3929 |
+
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
| 3930 |
}
|
| 3931 |
|
| 3932 |
+
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 3933 |
}
|
| 3934 |
}
|
| 3935 |
|
|
|
|
| 3940 |
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
| 3941 |
kv_bufs.resize(n_decoders_cur);
|
| 3942 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3943 |
+
auto & decoder = state->decoders[j];
|
| 3944 |
|
| 3945 |
if (decoder.completed || decoder.failed) {
|
| 3946 |
continue;
|
|
|
|
| 3958 |
|
| 3959 |
// generate new sequence candidates for each decoder
|
| 3960 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3961 |
+
auto & decoder = state->decoders[j];
|
| 3962 |
|
| 3963 |
if (decoder.completed || decoder.failed) {
|
| 3964 |
continue;
|
|
|
|
| 3968 |
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
| 3969 |
{
|
| 3970 |
if (t_cur < 1e-6f) {
|
| 3971 |
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
|
| 3972 |
} else {
|
| 3973 |
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
|
| 3974 |
}
|
| 3975 |
|
| 3976 |
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
| 3977 |
} break;
|
| 3978 |
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
| 3979 |
{
|
| 3980 |
+
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
|
| 3981 |
|
| 3982 |
for (const auto & token : tokens_new) {
|
| 3983 |
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
|
|
|
|
| 4002 |
uint32_t cur_c = 0;
|
| 4003 |
|
| 4004 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 4005 |
+
auto & decoder = state->decoders[j];
|
| 4006 |
|
| 4007 |
if (decoder.completed || decoder.failed) {
|
| 4008 |
continue;
|
|
|
|
| 4031 |
// - check if the sequence is failed
|
| 4032 |
// - update sliding window based on timestamp tokens
|
| 4033 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 4034 |
+
auto & decoder = state->decoders[j];
|
| 4035 |
|
| 4036 |
if (decoder.completed || decoder.failed) {
|
| 4037 |
continue;
|
|
|
|
| 4113 |
bool completed_all = true;
|
| 4114 |
|
| 4115 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 4116 |
+
auto & decoder = state->decoders[j];
|
| 4117 |
|
| 4118 |
if (decoder.completed || decoder.failed) {
|
| 4119 |
continue;
|
|
|
|
| 4127 |
}
|
| 4128 |
}
|
| 4129 |
|
| 4130 |
+
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 4131 |
|
| 4132 |
// obtain logits for the next token
|
| 4133 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 4134 |
+
auto & decoder = state->decoders[j];
|
| 4135 |
|
| 4136 |
if (decoder.failed || decoder.completed) {
|
| 4137 |
continue;
|
|
|
|
| 4142 |
|
| 4143 |
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
| 4144 |
|
| 4145 |
+
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
| 4146 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 4147 |
return -8;
|
| 4148 |
}
|
|
|
|
| 4150 |
{
|
| 4151 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 4152 |
|
| 4153 |
+
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
|
| 4154 |
|
| 4155 |
++decoder.kv_self.n;
|
| 4156 |
|
| 4157 |
+
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 4158 |
}
|
| 4159 |
}
|
| 4160 |
}
|
|
|
|
| 4164 |
double best_score = -INFINITY;
|
| 4165 |
|
| 4166 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 4167 |
+
auto & decoder = state->decoders[j];
|
| 4168 |
|
| 4169 |
if (decoder.failed) {
|
| 4170 |
continue;
|
|
|
|
| 4181 |
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
| 4182 |
|
| 4183 |
decoder.failed = true;
|
| 4184 |
+
state->n_fail_h++;
|
| 4185 |
|
| 4186 |
continue;
|
| 4187 |
}
|
|
|
|
| 4199 |
{
|
| 4200 |
bool success = true;
|
| 4201 |
|
| 4202 |
+
const auto & decoder = state->decoders[best_decoder_id];
|
| 4203 |
|
| 4204 |
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
| 4205 |
success = false;
|
| 4206 |
+
state->n_fail_p++;
|
| 4207 |
}
|
| 4208 |
|
| 4209 |
if (success) {
|
|
|
|
| 4220 |
|
| 4221 |
// output results through a user-provided callback
|
| 4222 |
{
|
| 4223 |
+
const auto & best_decoder = state->decoders[best_decoder_id];
|
| 4224 |
|
| 4225 |
const auto seek_delta = best_decoder.seek_delta;
|
| 4226 |
const auto result_len = best_decoder.sequence.result_len;
|
|
|
|
| 4283 |
|
| 4284 |
if (params.token_timestamps) {
|
| 4285 |
whisper_exp_compute_token_level_timestamps(
|
| 4286 |
+
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4287 |
|
| 4288 |
if (params.max_len > 0) {
|
| 4289 |
+
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
| 4290 |
}
|
| 4291 |
}
|
| 4292 |
if (params.new_segment_callback) {
|
| 4293 |
+
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
| 4294 |
}
|
| 4295 |
}
|
| 4296 |
text = "";
|
|
|
|
| 4327 |
|
| 4328 |
if (params.token_timestamps) {
|
| 4329 |
whisper_exp_compute_token_level_timestamps(
|
| 4330 |
+
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4331 |
|
| 4332 |
if (params.max_len > 0) {
|
| 4333 |
+
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
| 4334 |
}
|
| 4335 |
}
|
| 4336 |
if (params.new_segment_callback) {
|
| 4337 |
+
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
| 4338 |
}
|
| 4339 |
}
|
| 4340 |
}
|
|
|
|
| 4349 |
return 0;
|
| 4350 |
}
|
| 4351 |
|
| 4352 |
+
|
| 4353 |
+
int whisper_full(
|
| 4354 |
+
struct whisper_context * ctx,
|
| 4355 |
+
struct whisper_full_params params,
|
| 4356 |
+
const float * samples,
|
| 4357 |
+
int n_samples) {
|
| 4358 |
+
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
|
| 4359 |
+
}
|
| 4360 |
+
|
| 4361 |
int whisper_full_parallel(
|
| 4362 |
struct whisper_context * ctx,
|
| 4363 |
struct whisper_full_params params,
|
|
|
|
| 4367 |
if (n_processors == 1) {
|
| 4368 |
return whisper_full(ctx, params, samples, n_samples);
|
| 4369 |
}
|
|
|
|
| 4370 |
int ret = 0;
|
| 4371 |
|
| 4372 |
+
// prepare separate states for each thread
|
| 4373 |
+
std::vector<whisper_state*> states;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4374 |
|
| 4375 |
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
|
| 4376 |
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
|
|
|
|
| 4380 |
|
| 4381 |
std::vector<std::thread> workers(n_processors - 1);
|
| 4382 |
for (int i = 0; i < n_processors - 1; ++i) {
|
| 4383 |
+
// create a new state for each thread
|
| 4384 |
+
states.push_back(whisper_init_state(ctx));
|
| 4385 |
+
|
| 4386 |
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
|
| 4387 |
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
|
| 4388 |
|
|
|
|
| 4395 |
params_cur.new_segment_callback = nullptr;
|
| 4396 |
params_cur.new_segment_callback_user_data = nullptr;
|
| 4397 |
|
| 4398 |
+
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
|
| 4399 |
}
|
| 4400 |
|
| 4401 |
{
|
| 4402 |
auto params_cur = params;
|
| 4403 |
|
| 4404 |
+
// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
|
| 4405 |
+
params_cur.print_realtime = false;
|
| 4406 |
+
|
| 4407 |
+
// Run the first transformation using default state but only for the first chunk.
|
| 4408 |
+
ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
|
| 4409 |
}
|
| 4410 |
|
| 4411 |
for (int i = 0; i < n_processors - 1; ++i) {
|
|
|
|
| 4414 |
|
| 4415 |
const int64_t offset_t = (int64_t) params.offset_ms/10.0;
|
| 4416 |
|
| 4417 |
+
// combine results into result_state->result_all from all other states
|
| 4418 |
for (int i = 0; i < n_processors - 1; ++i) {
|
| 4419 |
+
auto& results_i = states[i]->result_all;
|
| 4420 |
|
| 4421 |
+
for (auto& result : results_i) {
|
| 4422 |
// correct the segment timestamp taking into account the offset
|
| 4423 |
+
result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
|
| 4424 |
+
result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
|
| 4425 |
+
|
| 4426 |
|
| 4427 |
// make sure that segments are not overlapping
|
| 4428 |
+
if (!ctx->state->result_all.empty()) {
|
| 4429 |
+
result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
|
| 4430 |
}
|
| 4431 |
|
| 4432 |
+
ctx->state->result_all.push_back(std::move(result));
|
| 4433 |
|
| 4434 |
// call the new_segment_callback for each segment
|
| 4435 |
if (params.new_segment_callback) {
|
| 4436 |
+
params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
|
| 4437 |
}
|
| 4438 |
}
|
| 4439 |
|
| 4440 |
+
ctx->state->t_mel_us += states[i]->t_mel_us;
|
|
|
|
|
|
|
|
|
|
| 4441 |
|
| 4442 |
+
ctx->state->t_sample_us += states[i]->t_sample_us;
|
| 4443 |
+
ctx->state->t_encode_us += states[i]->t_encode_us;
|
| 4444 |
+
ctx->state->t_decode_us += states[i]->t_decode_us;
|
| 4445 |
|
| 4446 |
+
whisper_free_state(states[i]);
|
|
|
|
|
|
|
| 4447 |
}
|
| 4448 |
|
| 4449 |
// average the timings
|
| 4450 |
+
ctx->state->t_mel_us /= n_processors;
|
| 4451 |
+
ctx->state->t_sample_us /= n_processors;
|
| 4452 |
+
ctx->state->t_encode_us /= n_processors;
|
| 4453 |
+
ctx->state->t_decode_us /= n_processors;
|
| 4454 |
|
| 4455 |
// print information about the audio boundaries
|
| 4456 |
fprintf(stderr, "\n");
|
|
|
|
| 4463 |
return ret;
|
| 4464 |
}
|
| 4465 |
|
| 4466 |
+
int whisper_full_n_segments_from_state(struct whisper_state * state) {
|
| 4467 |
+
return state->result_all.size();
|
| 4468 |
+
}
|
| 4469 |
+
|
| 4470 |
int whisper_full_n_segments(struct whisper_context * ctx) {
|
| 4471 |
+
return ctx->state->result_all.size();
|
| 4472 |
+
}
|
| 4473 |
+
|
| 4474 |
+
int whisper_full_lang_id_from_state(struct whisper_state * state) {
|
| 4475 |
+
return state->lang_id;
|
| 4476 |
}
|
| 4477 |
|
| 4478 |
int whisper_full_lang_id(struct whisper_context * ctx) {
|
| 4479 |
+
return ctx->state->lang_id;
|
| 4480 |
+
}
|
| 4481 |
+
|
| 4482 |
+
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
|
| 4483 |
+
return state->result_all[i_segment].t0;
|
| 4484 |
}
|
| 4485 |
|
| 4486 |
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
| 4487 |
+
return ctx->state->result_all[i_segment].t0;
|
| 4488 |
+
}
|
| 4489 |
+
|
| 4490 |
+
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
|
| 4491 |
+
return state->result_all[i_segment].t1;
|
| 4492 |
}
|
| 4493 |
|
| 4494 |
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
| 4495 |
+
return ctx->state->result_all[i_segment].t1;
|
| 4496 |
+
}
|
| 4497 |
+
|
| 4498 |
+
const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
|
| 4499 |
+
return state->result_all[i_segment].text.c_str();
|
| 4500 |
}
|
| 4501 |
|
| 4502 |
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
|
| 4503 |
+
return ctx->state->result_all[i_segment].text.c_str();
|
| 4504 |
+
}
|
| 4505 |
+
|
| 4506 |
+
int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
|
| 4507 |
+
return state->result_all[i_segment].tokens.size();
|
| 4508 |
}
|
| 4509 |
|
| 4510 |
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
|
| 4511 |
+
return ctx->state->result_all[i_segment].tokens.size();
|
| 4512 |
}
|
| 4513 |
|
| 4514 |
+
const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
|
| 4515 |
+
return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
|
| 4516 |
+
}
|
| 4517 |
+
|
| 4518 |
+
const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 4519 |
+
return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
|
| 4520 |
+
}
|
| 4521 |
+
|
| 4522 |
+
whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
|
| 4523 |
+
return state->result_all[i_segment].tokens[i_token].id;
|
| 4524 |
}
|
| 4525 |
|
| 4526 |
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 4527 |
+
return ctx->state->result_all[i_segment].tokens[i_token].id;
|
| 4528 |
+
}
|
| 4529 |
+
|
| 4530 |
+
struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
|
| 4531 |
+
return state->result_all[i_segment].tokens[i_token];
|
| 4532 |
}
|
| 4533 |
|
| 4534 |
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 4535 |
+
return ctx->state->result_all[i_segment].tokens[i_token];
|
| 4536 |
+
}
|
| 4537 |
+
|
| 4538 |
+
float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) {
|
| 4539 |
+
return state->result_all[i_segment].tokens[i_token].p;
|
| 4540 |
}
|
| 4541 |
|
| 4542 |
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 4543 |
+
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
| 4544 |
}
|
| 4545 |
|
| 4546 |
// =================================================================================================
|
|
|
|
| 4752 |
|
| 4753 |
static void whisper_exp_compute_token_level_timestamps(
|
| 4754 |
struct whisper_context & ctx,
|
| 4755 |
+
struct whisper_state & state,
|
| 4756 |
int i_segment,
|
| 4757 |
float thold_pt,
|
| 4758 |
float thold_ptsum) {
|
| 4759 |
+
auto & segment = state.result_all[i_segment];
|
| 4760 |
auto & tokens = segment.tokens;
|
| 4761 |
|
| 4762 |
+
const int n_samples = state.energy.size();
|
| 4763 |
|
| 4764 |
if (n_samples == 0) {
|
| 4765 |
fprintf(stderr, "%s: no signal data available\n", __func__);
|
|
|
|
| 4782 |
return;
|
| 4783 |
}
|
| 4784 |
|
| 4785 |
+
auto & t_beg = state.t_beg;
|
| 4786 |
+
auto & t_last = state.t_last;
|
| 4787 |
+
auto & tid_last = state.tid_last;
|
| 4788 |
|
| 4789 |
for (int j = 0; j < n; ++j) {
|
| 4790 |
auto & token = tokens[j];
|
|
|
|
| 4907 |
float sum = 0.0f;
|
| 4908 |
|
| 4909 |
for (int k = ss0; k < ss1; k++) {
|
| 4910 |
+
sum += state.energy[k];
|
| 4911 |
}
|
| 4912 |
|
| 4913 |
const float thold = 0.5*sum/ns;
|
| 4914 |
|
| 4915 |
{
|
| 4916 |
int k = s0;
|
| 4917 |
+
if (state.energy[k] > thold && j > 0) {
|
| 4918 |
+
while (k > 0 && state.energy[k] > thold) {
|
| 4919 |
k--;
|
| 4920 |
}
|
| 4921 |
tokens[j].t0 = sample_to_timestamp(k);
|
|
|
|
| 4925 |
s0 = k;
|
| 4926 |
}
|
| 4927 |
} else {
|
| 4928 |
+
while (state.energy[k] < thold && k < s1) {
|
| 4929 |
k++;
|
| 4930 |
}
|
| 4931 |
s0 = k;
|
|
|
|
| 4935 |
|
| 4936 |
{
|
| 4937 |
int k = s1;
|
| 4938 |
+
if (state.energy[k] > thold) {
|
| 4939 |
+
while (k < n_samples - 1 && state.energy[k] > thold) {
|
| 4940 |
k++;
|
| 4941 |
}
|
| 4942 |
tokens[j].t1 = sample_to_timestamp(k);
|
|
|
|
| 4946 |
s1 = k;
|
| 4947 |
}
|
| 4948 |
} else {
|
| 4949 |
+
while (state.energy[k] < thold && k > s0) {
|
| 4950 |
k--;
|
| 4951 |
}
|
| 4952 |
s1 = k;
|
whisper.h
CHANGED
|
@@ -66,6 +66,7 @@ extern "C" {
|
|
| 66 |
//
|
| 67 |
|
| 68 |
struct whisper_context;
|
|
|
|
| 69 |
|
| 70 |
typedef int whisper_token;
|
| 71 |
|
|
@@ -101,11 +102,20 @@ extern "C" {
|
|
| 101 |
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
|
| 102 |
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
|
| 103 |
|
| 104 |
-
//
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
// Convert RAW PCM audio to log mel spectrogram.
|
| 108 |
-
// The resulting spectrogram is stored inside the provided whisper context.
|
| 109 |
// Returns 0 on success
|
| 110 |
WHISPER_API int whisper_pcm_to_mel(
|
| 111 |
struct whisper_context * ctx,
|
|
@@ -113,17 +123,30 @@ extern "C" {
|
|
| 113 |
int n_samples,
|
| 114 |
int n_threads);
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
// Returns 0 on success
|
| 119 |
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
|
| 120 |
-
struct whisper_context* ctx,
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
| 128 |
// n_mel must be 80
|
| 129 |
// Returns 0 on success
|
|
@@ -133,7 +156,14 @@ extern "C" {
|
|
| 133 |
int n_len,
|
| 134 |
int n_mel);
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
| 138 |
// offset can be used to specify the offset of the first frame in the spectrogram.
|
| 139 |
// Returns 0 on success
|
|
@@ -142,6 +172,12 @@ extern "C" {
|
|
| 142 |
int offset,
|
| 143 |
int n_threads);
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
| 146 |
// Make sure to call whisper_encode() first.
|
| 147 |
// tokens + n_tokens is the provided context for the decoder.
|
|
@@ -155,6 +191,14 @@ extern "C" {
|
|
| 155 |
int n_past,
|
| 156 |
int n_threads);
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
// Convert the provided text into tokens.
|
| 159 |
// The tokens pointer must be large enough to hold the resulting tokens.
|
| 160 |
// Returns the number of tokens on success, no more than n_max_tokens
|
|
@@ -190,17 +234,26 @@ extern "C" {
|
|
| 190 |
int n_threads,
|
| 191 |
float * lang_probs);
|
| 192 |
|
| 193 |
-
WHISPER_API int
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
// Token logits obtained from the last call to whisper_decode()
|
| 200 |
// The logits for the last token are stored in the last row
|
| 201 |
// Rows: n_tokens
|
| 202 |
// Cols: n_vocab
|
| 203 |
-
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
|
|
|
|
| 204 |
|
| 205 |
// Token Id -> String. Uses the vocabulary in the provided context
|
| 206 |
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
|
|
@@ -218,7 +271,7 @@ extern "C" {
|
|
| 218 |
WHISPER_API whisper_token whisper_token_translate (void);
|
| 219 |
WHISPER_API whisper_token whisper_token_transcribe(void);
|
| 220 |
|
| 221 |
-
// Performance information
|
| 222 |
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
|
| 223 |
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
|
| 224 |
|
|
@@ -236,18 +289,19 @@ extern "C" {
|
|
| 236 |
// Text segment callback
|
| 237 |
// Called on every newly generated text segment
|
| 238 |
// Use the whisper_full_...() functions to obtain the text segments
|
| 239 |
-
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
|
| 240 |
|
| 241 |
// Encoder begin callback
|
| 242 |
// If not NULL, called before the encoder starts
|
| 243 |
// If it returns false, the computation is aborted
|
| 244 |
-
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
|
| 245 |
|
| 246 |
// Logits filter callback
|
| 247 |
// Can be used to modify the logits before sampling
|
| 248 |
// If not NULL, called after applying temperature to logits
|
| 249 |
typedef void (*whisper_logits_filter_callback)(
|
| 250 |
struct whisper_context * ctx,
|
|
|
|
| 251 |
const whisper_token_data * tokens,
|
| 252 |
int n_tokens,
|
| 253 |
float * logits,
|
|
@@ -334,6 +388,7 @@ extern "C" {
|
|
| 334 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
| 335 |
|
| 336 |
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
|
|
|
| 337 |
// Uses the specified decoding strategy to obtain the text.
|
| 338 |
WHISPER_API int whisper_full(
|
| 339 |
struct whisper_context * ctx,
|
|
@@ -341,7 +396,16 @@ extern "C" {
|
|
| 341 |
const float * samples,
|
| 342 |
int n_samples);
|
| 343 |
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
// It seems this approach can offer some speedup in some cases.
|
| 346 |
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
| 347 |
WHISPER_API int whisper_full_parallel(
|
|
@@ -351,33 +415,47 @@ extern "C" {
|
|
| 351 |
int n_samples,
|
| 352 |
int n_processors);
|
| 353 |
|
| 354 |
-
// Number of generated text segments
|
| 355 |
// A segment can be a few words, a sentence, or even a paragraph.
|
| 356 |
-
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
|
|
|
|
| 357 |
|
| 358 |
-
// Language id associated with the
|
| 359 |
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
|
| 360 |
|
| 361 |
-
//
|
| 362 |
-
WHISPER_API
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
// Get
|
| 366 |
-
WHISPER_API
|
|
|
|
| 367 |
|
| 368 |
-
// Get
|
| 369 |
-
WHISPER_API
|
|
|
|
| 370 |
|
| 371 |
-
|
| 372 |
-
WHISPER_API
|
| 373 |
-
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
|
| 374 |
|
| 375 |
-
// Get token data for the specified token in the specified segment
|
| 376 |
// This contains probabilities, timestamps, etc.
|
| 377 |
-
WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
|
|
|
|
| 378 |
|
| 379 |
-
// Get the probability of the specified token in the specified segment
|
| 380 |
-
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
|
|
|
| 381 |
|
| 382 |
////////////////////////////////////////////////////////////////////////////
|
| 383 |
|
|
|
|
| 66 |
//
|
| 67 |
|
| 68 |
struct whisper_context;
|
| 69 |
+
struct whisper_state;
|
| 70 |
|
| 71 |
typedef int whisper_token;
|
| 72 |
|
|
|
|
| 102 |
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
|
| 103 |
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
|
| 104 |
|
| 105 |
+
// These are the same as the above, but the internal state of the context is not allocated automatically
|
| 106 |
+
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
|
| 107 |
+
WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
|
| 108 |
+
WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
|
| 109 |
+
WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
|
| 110 |
+
|
| 111 |
+
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
|
| 112 |
+
|
| 113 |
+
// Frees all allocated memory
|
| 114 |
+
WHISPER_API void whisper_free (struct whisper_context * ctx);
|
| 115 |
+
WHISPER_API void whisper_free_state(struct whisper_state * state);
|
| 116 |
|
| 117 |
// Convert RAW PCM audio to log mel spectrogram.
|
| 118 |
+
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
| 119 |
// Returns 0 on success
|
| 120 |
WHISPER_API int whisper_pcm_to_mel(
|
| 121 |
struct whisper_context * ctx,
|
|
|
|
| 123 |
int n_samples,
|
| 124 |
int n_threads);
|
| 125 |
|
| 126 |
+
WHISPER_API int whisper_pcm_to_mel_with_state(
|
| 127 |
+
struct whisper_context * ctx,
|
| 128 |
+
struct whisper_state * state,
|
| 129 |
+
const float * samples,
|
| 130 |
+
int n_samples,
|
| 131 |
+
int n_threads);
|
| 132 |
+
|
| 133 |
+
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
|
| 134 |
+
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
| 135 |
// Returns 0 on success
|
| 136 |
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
|
| 137 |
+
struct whisper_context * ctx,
|
| 138 |
+
const float * samples,
|
| 139 |
+
int n_samples,
|
| 140 |
+
int n_threads);
|
| 141 |
+
|
| 142 |
+
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
|
| 143 |
+
struct whisper_context * ctx,
|
| 144 |
+
struct whisper_state * state,
|
| 145 |
+
const float * samples,
|
| 146 |
+
int n_samples,
|
| 147 |
+
int n_threads);
|
| 148 |
+
|
| 149 |
+
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
|
| 150 |
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
| 151 |
// n_mel must be 80
|
| 152 |
// Returns 0 on success
|
|
|
|
| 156 |
int n_len,
|
| 157 |
int n_mel);
|
| 158 |
|
| 159 |
+
WHISPER_API int whisper_set_mel_with_state(
|
| 160 |
+
struct whisper_context * ctx,
|
| 161 |
+
struct whisper_state * state,
|
| 162 |
+
const float * data,
|
| 163 |
+
int n_len,
|
| 164 |
+
int n_mel);
|
| 165 |
+
|
| 166 |
+
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
|
| 167 |
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
| 168 |
// offset can be used to specify the offset of the first frame in the spectrogram.
|
| 169 |
// Returns 0 on success
|
|
|
|
| 172 |
int offset,
|
| 173 |
int n_threads);
|
| 174 |
|
| 175 |
+
WHISPER_API int whisper_encode_with_state(
|
| 176 |
+
struct whisper_context * ctx,
|
| 177 |
+
struct whisper_state * state,
|
| 178 |
+
int offset,
|
| 179 |
+
int n_threads);
|
| 180 |
+
|
| 181 |
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
| 182 |
// Make sure to call whisper_encode() first.
|
| 183 |
// tokens + n_tokens is the provided context for the decoder.
|
|
|
|
| 191 |
int n_past,
|
| 192 |
int n_threads);
|
| 193 |
|
| 194 |
+
WHISPER_API int whisper_decode_with_state(
|
| 195 |
+
struct whisper_context * ctx,
|
| 196 |
+
struct whisper_state * state,
|
| 197 |
+
const whisper_token * tokens,
|
| 198 |
+
int n_tokens,
|
| 199 |
+
int n_past,
|
| 200 |
+
int n_threads);
|
| 201 |
+
|
| 202 |
// Convert the provided text into tokens.
|
| 203 |
// The tokens pointer must be large enough to hold the resulting tokens.
|
| 204 |
// Returns the number of tokens on success, no more than n_max_tokens
|
|
|
|
| 234 |
int n_threads,
|
| 235 |
float * lang_probs);
|
| 236 |
|
| 237 |
+
WHISPER_API int whisper_lang_auto_detect_with_state(
|
| 238 |
+
struct whisper_context * ctx,
|
| 239 |
+
struct whisper_state * state,
|
| 240 |
+
int offset_ms,
|
| 241 |
+
int n_threads,
|
| 242 |
+
float * lang_probs);
|
| 243 |
+
|
| 244 |
+
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
| 245 |
+
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
|
| 246 |
+
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
| 247 |
+
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
| 248 |
+
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
| 249 |
+
WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
|
| 250 |
|
| 251 |
// Token logits obtained from the last call to whisper_decode()
|
| 252 |
// The logits for the last token are stored in the last row
|
| 253 |
// Rows: n_tokens
|
| 254 |
// Cols: n_vocab
|
| 255 |
+
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
|
| 256 |
+
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
|
| 257 |
|
| 258 |
// Token Id -> String. Uses the vocabulary in the provided context
|
| 259 |
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
|
|
|
|
| 271 |
WHISPER_API whisper_token whisper_token_translate (void);
|
| 272 |
WHISPER_API whisper_token whisper_token_transcribe(void);
|
| 273 |
|
| 274 |
+
// Performance information from the default state.
|
| 275 |
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
|
| 276 |
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
|
| 277 |
|
|
|
|
| 289 |
// Text segment callback
|
| 290 |
// Called on every newly generated text segment
|
| 291 |
// Use the whisper_full_...() functions to obtain the text segments
|
| 292 |
+
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
|
| 293 |
|
| 294 |
// Encoder begin callback
|
| 295 |
// If not NULL, called before the encoder starts
|
| 296 |
// If it returns false, the computation is aborted
|
| 297 |
+
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
| 298 |
|
| 299 |
// Logits filter callback
|
| 300 |
// Can be used to modify the logits before sampling
|
| 301 |
// If not NULL, called after applying temperature to logits
|
| 302 |
typedef void (*whisper_logits_filter_callback)(
|
| 303 |
struct whisper_context * ctx,
|
| 304 |
+
struct whisper_state * state,
|
| 305 |
const whisper_token_data * tokens,
|
| 306 |
int n_tokens,
|
| 307 |
float * logits,
|
|
|
|
| 388 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
| 389 |
|
| 390 |
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
| 391 |
+
// Not thread safe for same context
|
| 392 |
// Uses the specified decoding strategy to obtain the text.
|
| 393 |
WHISPER_API int whisper_full(
|
| 394 |
struct whisper_context * ctx,
|
|
|
|
| 396 |
const float * samples,
|
| 397 |
int n_samples);
|
| 398 |
|
| 399 |
+
WHISPER_API int whisper_full_with_state(
|
| 400 |
+
struct whisper_context * ctx,
|
| 401 |
+
struct whisper_state * state,
|
| 402 |
+
struct whisper_full_params params,
|
| 403 |
+
const float * samples,
|
| 404 |
+
int n_samples);
|
| 405 |
+
|
| 406 |
+
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
| 407 |
+
// Result is stored in the default state of the context
|
| 408 |
+
// Not thread safe if executed in parallel on the same context.
|
| 409 |
// It seems this approach can offer some speedup in some cases.
|
| 410 |
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
| 411 |
WHISPER_API int whisper_full_parallel(
|
|
|
|
| 415 |
int n_samples,
|
| 416 |
int n_processors);
|
| 417 |
|
| 418 |
+
// Number of generated text segments
|
| 419 |
// A segment can be a few words, a sentence, or even a paragraph.
|
| 420 |
+
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
|
| 421 |
+
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
|
| 422 |
|
| 423 |
+
// Language id associated with the context's default state
|
| 424 |
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
|
| 425 |
|
| 426 |
+
// Language id associated with the provided state
|
| 427 |
+
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
|
| 428 |
+
|
| 429 |
+
// Get the start and end time of the specified segment
|
| 430 |
+
WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment);
|
| 431 |
+
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
|
| 432 |
+
|
| 433 |
+
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
|
| 434 |
+
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
|
| 435 |
+
|
| 436 |
+
// Get the text of the specified segment
|
| 437 |
+
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
|
| 438 |
+
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
|
| 439 |
|
| 440 |
+
// Get number of tokens in the specified segment
|
| 441 |
+
WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment);
|
| 442 |
+
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
|
| 443 |
|
| 444 |
+
// Get the token text of the specified token in the specified segment
|
| 445 |
+
WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token);
|
| 446 |
+
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
|
| 447 |
|
| 448 |
+
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
|
| 449 |
+
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
|
|
|
|
| 450 |
|
| 451 |
+
// Get token data for the specified token in the specified segment
|
| 452 |
// This contains probabilities, timestamps, etc.
|
| 453 |
+
WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token);
|
| 454 |
+
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
|
| 455 |
|
| 456 |
+
// Get the probability of the specified token in the specified segment
|
| 457 |
+
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
|
| 458 |
+
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
|
| 459 |
|
| 460 |
////////////////////////////////////////////////////////////////////////////
|
| 461 |
|