sandrohanea Sandro Hanea ggerganov commited on
Commit
fa946a3
·
unverified ·
1 Parent(s): d144017

whisper : add whisper_state + default state on the whisper_context (#523)

Browse files

* Added whisper state + default state on the whisper_context

* Fixed some examples and bindings

* Fixed whisper_n_len (which was used in some binding) and added whisper_n_len_from_state

* Fixed comments

* whisper : reuse kv_cache_free() and fix compiler warnings

* whisper : clean-up the API comments

---------

Co-authored-by: Sandro Hanea <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

bindings/go/whisper.go CHANGED
@@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data);
20
  // Text segment callback
21
  // Called on every newly generated text segment
22
  // Use the whisper_full_...() functions to obtain the text segments
23
- static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
24
  if(user_data != NULL && ctx != NULL) {
25
  callNewSegment(user_data, n_new);
26
  }
@@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void*
29
  // Encoder begin callback
30
  // If not NULL, called before the encoder starts
31
  // If it returns false, the computation is aborted
32
- static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
33
  if(user_data != NULL && ctx != NULL) {
34
  return callEncoderBegin(user_data);
35
  }
 
20
  // Text segment callback
21
  // Called on every newly generated text segment
22
  // Use the whisper_full_...() functions to obtain the text segments
23
+ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
24
  if(user_data != NULL && ctx != NULL) {
25
  callNewSegment(user_data, n_new);
26
  }
 
29
  // Encoder begin callback
30
  // If not NULL, called before the encoder starts
31
  // If it returns false, the computation is aborted
32
+ static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
33
  if(user_data != NULL && ctx != NULL) {
34
  return callEncoderBegin(user_data);
35
  }
bindings/ruby/ext/ruby_whisper.cpp CHANGED
@@ -199,7 +199,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
199
  {
200
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
201
 
202
- rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
203
  bool is_aborted = *(bool*)user_data;
204
  return !is_aborted;
205
  };
 
199
  {
200
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
201
 
202
+ rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
203
  bool is_aborted = *(bool*)user_data;
204
  return !is_aborted;
205
  };
examples/addon.node/addon.cpp CHANGED
@@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) {
72
  return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
73
  }
74
 
75
- void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
76
  const auto & params = *((whisper_print_user_data *) user_data)->params;
77
  const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
78
 
@@ -260,7 +260,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
260
  {
261
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
262
 
263
- wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
264
  bool is_aborted = *(bool*)user_data;
265
  return !is_aborted;
266
  };
 
72
  return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
73
  }
74
 
75
+ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
76
  const auto & params = *((whisper_print_user_data *) user_data)->params;
77
  const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
78
 
 
260
  {
261
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
262
 
263
+ wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
264
  bool is_aborted = *(bool*)user_data;
265
  return !is_aborted;
266
  };
examples/main/main.cpp CHANGED
@@ -193,7 +193,7 @@ struct whisper_print_user_data {
193
  const std::vector<std::vector<float>> * pcmf32s;
194
  };
195
 
196
- void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
197
  const auto & params = *((whisper_print_user_data *) user_data)->params;
198
  const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
199
 
@@ -608,7 +608,7 @@ int main(int argc, char ** argv) {
608
  {
609
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
610
 
611
- wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
612
  bool is_aborted = *(bool*)user_data;
613
  return !is_aborted;
614
  };
 
193
  const std::vector<std::vector<float>> * pcmf32s;
194
  };
195
 
196
+ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
197
  const auto & params = *((whisper_print_user_data *) user_data)->params;
198
  const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
199
 
 
608
  {
609
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
610
 
611
+ wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
612
  bool is_aborted = *(bool*)user_data;
613
  return !is_aborted;
614
  };
whisper.cpp CHANGED
@@ -547,13 +547,11 @@ struct whisper_decoder {
547
  std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
548
  };
549
 
550
- struct whisper_context {
551
- int64_t t_load_us = 0;
552
- int64_t t_mel_us = 0;
553
  int64_t t_sample_us = 0;
554
  int64_t t_encode_us = 0;
555
  int64_t t_decode_us = 0;
556
- int64_t t_start_us = 0;
557
 
558
  int32_t n_sample = 0; // number of tokens sampled
559
  int32_t n_encode = 0; // number of encoder calls
@@ -561,16 +559,10 @@ struct whisper_context {
561
  int32_t n_fail_p = 0; // number of logprob threshold failures
562
  int32_t n_fail_h = 0; // number of entropy threshold failures
563
 
564
- ggml_type wtype; // weight type (FP32 or FP16)
565
-
566
- whisper_mel mel;
567
-
568
- whisper_model model;
569
- whisper_vocab vocab;
570
-
571
  // cross-attention KV cache for the decoders
572
  // shared between all decoders
573
  whisper_kv_cache kv_cross;
 
574
 
575
  whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
576
 
@@ -635,6 +627,18 @@ struct whisper_context {
635
  }
636
  };
637
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  template<typename T>
639
  static void read_safe(whisper_model_loader * loader, T & dest) {
640
  loader->read(loader->context, &dest, sizeof(T));
@@ -821,32 +825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
821
  wctx.model.buf = new std::vector<uint8_t>();
822
  wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
823
 
824
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
825
- fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
826
- return false;
827
- }
828
-
829
- {
830
- const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
831
- fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
832
- }
833
-
834
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
835
- fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
836
- return false;
837
- }
838
-
839
- {
840
- const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
841
- fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
842
- }
843
-
844
- wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
845
-
846
- wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
847
- wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
848
- wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
849
- wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
850
  }
851
 
852
  // load mel filters
@@ -929,17 +909,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
929
  vocab.id_to_token[i] = word;
930
  }
931
  }
932
-
933
- wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
934
-
935
- wctx.logits_id.reserve(n_vocab);
936
-
937
- // TAGS: WHISPER_DECODER_INIT
938
- wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
939
-
940
- wctx.decoders[0].probs.reserve (vocab.n_vocab);
941
- wctx.decoders[0].logits.reserve (vocab.n_vocab);
942
- wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
943
  }
944
 
945
  size_t ctx_size = 0;
@@ -1339,33 +1308,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1339
  }
1340
  }
1341
 
1342
- wctx.rng = std::mt19937(0);
1343
-
1344
  wctx.t_load_us = ggml_time_us() - t_start_us;
1345
 
1346
  return true;
1347
  }
1348
 
1349
- // evaluate the encoder
1350
  //
1351
  // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1352
  // part of the transformer model and returns the encoded features
1353
  //
1354
- // - model: the model
 
1355
  // - n_threads: number of threads to use
1356
  // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1357
  //
1358
- static bool whisper_encode(
1359
  whisper_context & wctx,
 
1360
  const int mel_offset,
1361
- const int n_threads) {
 
1362
  const int64_t t_start_us = ggml_time_us();
1363
 
1364
  const auto & model = wctx.model;
1365
- const auto & mel_inp = wctx.mel;
1366
  const auto & hparams = model.hparams;
1367
 
1368
- const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1369
  const int n_state = hparams.n_audio_state;
1370
  const int n_head = hparams.n_audio_head;
1371
  const int n_layer = hparams.n_audio_layer;
@@ -1374,12 +1344,12 @@ static bool whisper_encode(
1374
  assert(mel_inp.n_mel == n_mels);
1375
 
1376
  struct ggml_init_params params;
1377
- params.mem_size = wctx.buf_compute.size();
1378
- params.mem_buffer = wctx.buf_compute.data();
1379
 
1380
  struct ggml_context * ctx0 = ggml_init(params);
1381
 
1382
- wctx.use_buf(ctx0, 0);
1383
 
1384
  struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1385
  assert(mel->type == GGML_TYPE_F32);
@@ -1401,30 +1371,30 @@ static bool whisper_encode(
1401
 
1402
  // convolution + gelu
1403
  {
1404
- wctx.use_buf(ctx0, 1);
1405
 
1406
  cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1407
  cur = ggml_add(ctx0,
1408
- ggml_repeat(ctx0,
1409
- model.e_conv_1_b,
1410
- cur),
1411
- cur);
1412
 
1413
  cur = ggml_gelu(ctx0, cur);
1414
 
1415
- wctx.use_buf(ctx0, 0);
1416
 
1417
  cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1418
  cur = ggml_add(ctx0,
1419
- ggml_repeat(ctx0,
1420
- model.e_conv_2_b,
1421
- cur),
1422
- cur);
1423
 
1424
  cur = ggml_gelu(ctx0, cur);
1425
  }
1426
 
1427
- wctx.use_buf(ctx0, 3);
1428
 
1429
  // ===================================================================
1430
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1439,7 +1409,7 @@ static bool whisper_encode(
1439
  //}
1440
 
1441
  static int iter = 0;
1442
-
1443
  const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1444
  const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1445
 
@@ -1459,54 +1429,54 @@ static bool whisper_encode(
1459
 
1460
  // norm
1461
  {
1462
- wctx.use_buf(ctx0, 0);
1463
 
1464
  cur = ggml_norm(ctx0, inpL);
1465
 
1466
  // cur = ln_0_w*cur + ln_0_b
1467
  cur = ggml_add(ctx0,
1468
- ggml_mul(ctx0,
1469
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1470
- cur),
1471
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1472
  }
1473
 
1474
  // self-attention
1475
  {
1476
- wctx.use_buf(ctx0, 1);
1477
 
1478
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1479
- layer.attn_q_w,
1480
- cur);
1481
 
1482
  Qcur = ggml_add(ctx0,
1483
- ggml_repeat(ctx0,
1484
- layer.attn_q_b,
1485
- Qcur),
1486
- Qcur);
1487
 
1488
  //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1489
 
1490
  // note: no bias for Key
1491
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1492
- layer.attn_k_w,
1493
- cur);
1494
 
1495
  //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1496
 
1497
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1498
- layer.attn_v_w,
1499
- cur);
1500
 
1501
  Vcur = ggml_add(ctx0,
1502
- ggml_repeat(ctx0,
1503
- layer.attn_v_b,
1504
- Vcur),
1505
- Vcur);
1506
 
1507
  // ------
1508
 
1509
- wctx.use_buf(ctx0, 0);
1510
 
1511
  #ifdef WHISPER_USE_FLASH_ATTN
1512
  struct ggml_tensor * Q =
@@ -1583,29 +1553,29 @@ static bool whisper_encode(
1583
  #endif
1584
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1585
 
1586
- wctx.use_buf(ctx0, 1);
1587
 
1588
  cur = ggml_cpy(ctx0,
1589
- KQV_merged,
1590
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1591
  }
1592
 
1593
  // projection
1594
  {
1595
- wctx.use_buf(ctx0, 0);
1596
 
1597
  cur = ggml_mul_mat(ctx0,
1598
- layer.attn_ln_1_w,
1599
- cur);
1600
 
1601
- wctx.use_buf(ctx0, 1);
1602
 
1603
  cur = ggml_add(ctx0,
1604
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1605
- cur);
1606
  }
1607
 
1608
- wctx.use_buf(ctx0, 2);
1609
 
1610
  // add the input
1611
  cur = ggml_add(ctx0, cur, inpL);
@@ -1616,61 +1586,61 @@ static bool whisper_encode(
1616
  {
1617
  // norm
1618
  {
1619
- wctx.use_buf(ctx0, 0);
1620
 
1621
  cur = ggml_norm(ctx0, inpFF);
1622
 
1623
- wctx.use_buf(ctx0, 1);
1624
 
1625
  // cur = mlp_ln_w*cur + mlp_ln_b
1626
  cur = ggml_add(ctx0,
1627
- ggml_mul(ctx0,
1628
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1629
- cur),
1630
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1631
- }
1632
 
1633
  #ifdef WHISPER_USE_FLASH_FF
1634
- wctx.use_buf(ctx0, 0);
1635
 
1636
  cur = ggml_flash_ff(ctx0,
1637
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
1638
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1639
  #else
1640
- wctx.use_buf(ctx0, 0);
1641
 
1642
  // fully connected
1643
  cur = ggml_mul_mat(ctx0,
1644
- layer.mlp_0_w,
1645
- cur);
1646
 
1647
- wctx.use_buf(ctx0, 1);
1648
 
1649
  cur = ggml_add(ctx0,
1650
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
1651
- cur);
1652
 
1653
- wctx.use_buf(ctx0, 0);
1654
 
1655
  // GELU activation
1656
  cur = ggml_gelu(ctx0, cur);
1657
 
1658
- wctx.use_buf(ctx0, 1);
1659
 
1660
  // projection
1661
  cur = ggml_mul_mat(ctx0,
1662
- layer.mlp_1_w,
1663
- cur);
1664
 
1665
- wctx.use_buf(ctx0, 0);
1666
 
1667
  cur = ggml_add(ctx0,
1668
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
1669
- cur);
1670
  #endif
1671
- }
1672
 
1673
- wctx.use_buf(ctx0, 3);
1674
 
1675
  inpL = ggml_add(ctx0, cur, inpFF);
1676
  }
@@ -1679,21 +1649,21 @@ static bool whisper_encode(
1679
 
1680
  // norm
1681
  {
1682
- wctx.use_buf(ctx0, 0);
1683
 
1684
  cur = ggml_norm(ctx0, cur);
1685
 
1686
- wctx.use_buf(ctx0, 1);
1687
 
1688
  // cur = ln_f_g*cur + ln_f_b
1689
  cur = ggml_add(ctx0,
1690
- ggml_mul(ctx0,
1691
- ggml_repeat(ctx0, model.e_ln_w, cur),
1692
- cur),
1693
- ggml_repeat(ctx0, model.e_ln_b, cur));
1694
  }
1695
 
1696
- wctx.use_buf(ctx0, -1);
1697
 
1698
  // run the computation
1699
  {
@@ -1701,7 +1671,7 @@ static bool whisper_encode(
1701
  gf.n_threads = n_threads;
1702
 
1703
  ggml_build_forward_expand(&gf, cur);
1704
- ggml_graph_compute (ctx0, &gf);
1705
 
1706
  //ggml_graph_print(&gf);
1707
  }
@@ -1731,34 +1701,34 @@ static bool whisper_encode(
1731
  cur->src1 = nullptr;
1732
 
1733
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1734
- auto & layer = model.layers_decoder[il];
1735
 
1736
- wctx.use_buf(ctx0, 0);
1737
 
1738
- struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1739
- layer.cross_attn_k_w,
1740
- cur);
1741
 
1742
- Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1743
 
1744
- wctx.use_buf(ctx0, 1);
1745
 
1746
- struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1747
- layer.cross_attn_v_w,
1748
- cur);
1749
 
1750
  Vcross = ggml_add(ctx0,
1751
- ggml_repeat(ctx0,
1752
- layer.cross_attn_v_b,
1753
- Vcross),
1754
- Vcross);
1755
 
1756
- wctx.use_buf(ctx0, -1);
1757
 
1758
- //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1759
- //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1760
- struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
1761
- struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
1762
 
1763
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1764
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1779,8 +1749,8 @@ static bool whisper_encode(
1779
 
1780
  ggml_free(ctx0);
1781
 
1782
- wctx.t_encode_us += ggml_time_us() - t_start_us;
1783
- wctx.n_encode++;
1784
 
1785
  return true;
1786
  }
@@ -1795,8 +1765,9 @@ static bool whisper_encode(
1795
  // - n_tokens: number of tokens in the prompt
1796
  // - n_past: number of past tokens to prefix the prompt with
1797
  //
1798
- static bool whisper_decode(
1799
  whisper_context & wctx,
 
1800
  whisper_decoder & decoder,
1801
  const whisper_token * tokens,
1802
  const int n_tokens,
@@ -1811,7 +1782,7 @@ static bool whisper_decode(
1811
 
1812
  WHISPER_ASSERT(!!kv_self.ctx);
1813
 
1814
- auto & logits_out = wctx.logits;
1815
 
1816
  const int n_vocab = hparams.n_vocab;
1817
 
@@ -1821,13 +1792,13 @@ static bool whisper_decode(
1821
  const int n_layer = hparams.n_text_layer;
1822
 
1823
  const int N = n_tokens;
1824
- const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1825
 
1826
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1827
 
1828
  struct ggml_init_params params;
1829
- params.mem_size = wctx.buf_compute.size();
1830
- params.mem_buffer = wctx.buf_compute.data();
1831
 
1832
  struct ggml_context * ctx0 = ggml_init(params);
1833
 
@@ -1842,7 +1813,7 @@ static bool whisper_decode(
1842
  ((int32_t *) position->data)[i] = n_past + i;
1843
  }
1844
 
1845
- wctx.use_buf(ctx0, 3);
1846
 
1847
  // token encoding + position encoding
1848
  struct ggml_tensor * cur =
@@ -1857,7 +1828,7 @@ static bool whisper_decode(
1857
 
1858
  // norm
1859
  {
1860
- wctx.use_buf(ctx0, 0);
1861
 
1862
  cur = ggml_norm(ctx0, inpL);
1863
 
@@ -1871,7 +1842,7 @@ static bool whisper_decode(
1871
 
1872
  // self-attention
1873
  {
1874
- wctx.use_buf(ctx0, 1);
1875
 
1876
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1877
  layer.attn_q_w,
@@ -1913,7 +1884,7 @@ static bool whisper_decode(
1913
 
1914
  // ------
1915
 
1916
- wctx.use_buf(ctx0, 0);
1917
 
1918
  struct ggml_tensor * Q =
1919
  ggml_permute(ctx0,
@@ -1929,12 +1900,12 @@ static bool whisper_decode(
1929
  n_state/n_head, n_head, n_past + N),
1930
  0, 2, 1, 3);
1931
 
1932
- wctx.use_buf(ctx0, 1);
1933
 
1934
  // K * Q
1935
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1936
 
1937
- wctx.use_buf(ctx0, 0);
1938
 
1939
  //struct ggml_tensor * KQ_scaled =
1940
  // ggml_scale(ctx0,
@@ -1944,11 +1915,11 @@ static bool whisper_decode(
1944
 
1945
  struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1946
 
1947
- wctx.use_buf(ctx0, 1);
1948
 
1949
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1950
 
1951
- wctx.use_buf(ctx0, 0);
1952
 
1953
  struct ggml_tensor * V_trans =
1954
  ggml_permute(ctx0,
@@ -1957,7 +1928,7 @@ static bool whisper_decode(
1957
  n_state/n_head, n_head, n_past + N),
1958
  1, 2, 0, 3);
1959
 
1960
- wctx.use_buf(ctx0, 1);
1961
 
1962
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1963
 
@@ -1970,31 +1941,31 @@ static bool whisper_decode(
1970
 
1971
  // projection
1972
  {
1973
- wctx.use_buf(ctx0, 0);
1974
 
1975
  cur = ggml_mul_mat(ctx0,
1976
  layer.attn_ln_1_w,
1977
  cur);
1978
 
1979
- wctx.use_buf(ctx0, 1);
1980
 
1981
  cur = ggml_add(ctx0,
1982
  ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1983
  cur);
1984
  }
1985
 
1986
- wctx.use_buf(ctx0, 2);
1987
 
1988
  // add the input
1989
  struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
1990
 
1991
  // norm
1992
  {
1993
- wctx.use_buf(ctx0, 0);
1994
 
1995
  cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1996
 
1997
- wctx.use_buf(ctx0, 1);
1998
 
1999
  // cur = ln_0_w*cur + ln_0_b
2000
  cur = ggml_add(ctx0,
@@ -2006,7 +1977,7 @@ static bool whisper_decode(
2006
 
2007
  // cross-attention
2008
  {
2009
- wctx.use_buf(ctx0, 0);
2010
 
2011
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
2012
  layer.cross_attn_q_w,
@@ -2023,19 +1994,19 @@ static bool whisper_decode(
2023
  // Kcross is already scaled
2024
  struct ggml_tensor * Kcross =
2025
  ggml_reshape_3d(ctx0,
2026
- ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
2027
  n_state/n_head, n_head, M);
2028
 
2029
  struct ggml_tensor * Vcross =
2030
  ggml_reshape_3d(ctx0,
2031
- ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
2032
  n_state/n_head, n_head, M);
2033
 
2034
  struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2035
 
2036
  // ------
2037
 
2038
- wctx.use_buf(ctx0, 1);
2039
 
2040
  struct ggml_tensor * Q =
2041
  ggml_permute(ctx0,
@@ -2046,7 +2017,7 @@ static bool whisper_decode(
2046
 
2047
  struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2048
 
2049
- wctx.use_buf(ctx0, 0);
2050
 
2051
  // K * Q
2052
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@@ -2060,15 +2031,15 @@ static bool whisper_decode(
2060
  // no masking for cross-attention
2061
  //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2062
 
2063
- wctx.use_buf(ctx0, 1);
2064
 
2065
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2066
 
2067
- wctx.use_buf(ctx0, 0);
2068
 
2069
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2070
 
2071
- wctx.use_buf(ctx0, 1);
2072
 
2073
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2074
 
@@ -2080,20 +2051,20 @@ static bool whisper_decode(
2080
 
2081
  // projection
2082
  {
2083
- wctx.use_buf(ctx0, 0);
2084
 
2085
  cur = ggml_mul_mat(ctx0,
2086
  layer.cross_attn_ln_1_w,
2087
  cur);
2088
 
2089
- wctx.use_buf(ctx0, 1);
2090
 
2091
  cur = ggml_add(ctx0,
2092
  ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2093
  cur);
2094
  }
2095
 
2096
- wctx.use_buf(ctx0, 2);
2097
 
2098
  // add the input
2099
  cur = ggml_add(ctx0, cur, inpCA);
@@ -2104,11 +2075,11 @@ static bool whisper_decode(
2104
  {
2105
  // norm
2106
  {
2107
- wctx.use_buf(ctx0, 0);
2108
 
2109
  cur = ggml_norm(ctx0, inpFF);
2110
 
2111
- wctx.use_buf(ctx0, 1);
2112
 
2113
  // cur = mlp_ln_w*cur + mlp_ln_b
2114
  cur = ggml_add(ctx0,
@@ -2118,39 +2089,39 @@ static bool whisper_decode(
2118
  ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2119
  }
2120
 
2121
- wctx.use_buf(ctx0, 0);
2122
 
2123
  // fully connected
2124
  cur = ggml_mul_mat(ctx0,
2125
  layer.mlp_0_w,
2126
  cur);
2127
 
2128
- wctx.use_buf(ctx0, 1);
2129
 
2130
  cur = ggml_add(ctx0,
2131
  ggml_repeat(ctx0, layer.mlp_0_b, cur),
2132
  cur);
2133
 
2134
- wctx.use_buf(ctx0, 0);
2135
 
2136
  // GELU activation
2137
  cur = ggml_gelu(ctx0, cur);
2138
 
2139
- wctx.use_buf(ctx0, 1);
2140
 
2141
  // projection
2142
  cur = ggml_mul_mat(ctx0,
2143
  layer.mlp_1_w,
2144
  cur);
2145
 
2146
- wctx.use_buf(ctx0, 0);
2147
 
2148
  cur = ggml_add(ctx0,
2149
  ggml_repeat(ctx0, layer.mlp_1_b, cur),
2150
  cur);
2151
  }
2152
 
2153
- wctx.use_buf(ctx0, 3);
2154
 
2155
  inpL = ggml_add(ctx0, cur, inpFF);
2156
  }
@@ -2159,11 +2130,11 @@ static bool whisper_decode(
2159
 
2160
  // norm
2161
  {
2162
- wctx.use_buf(ctx0, 0);
2163
 
2164
  cur = ggml_norm(ctx0, cur);
2165
 
2166
- wctx.use_buf(ctx0, 1);
2167
 
2168
  cur = ggml_add(ctx0,
2169
  ggml_mul(ctx0,
@@ -2172,7 +2143,7 @@ static bool whisper_decode(
2172
  ggml_repeat(ctx0, model.d_ln_b, cur));
2173
  }
2174
 
2175
- wctx.use_buf(ctx0, 0);
2176
 
2177
  // compute logits only for the last token
2178
  // comment this line to compute logits for all N tokens
@@ -2181,7 +2152,7 @@ static bool whisper_decode(
2181
 
2182
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2183
 
2184
- wctx.use_buf(ctx0, -1);
2185
 
2186
  // run the computation
2187
  {
@@ -2208,8 +2179,8 @@ static bool whisper_decode(
2208
 
2209
  ggml_free(ctx0);
2210
 
2211
- wctx.t_decode_us += ggml_time_us() - t_start_us;
2212
- wctx.n_decode++;
2213
 
2214
  return true;
2215
  }
@@ -2313,7 +2284,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2313
 
2314
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2315
  static bool log_mel_spectrogram(
2316
- whisper_context & wctx,
2317
  const float * samples,
2318
  const int n_samples,
2319
  const int /*sample_rate*/,
@@ -2433,7 +2404,7 @@ static bool log_mel_spectrogram(
2433
  mel.data[i] = (mel.data[i] + 4.0)/4.0;
2434
  }
2435
 
2436
- wctx.t_mel_us += ggml_time_us() - t_start_us;
2437
 
2438
  return true;
2439
  }
@@ -2507,7 +2478,56 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2507
  // interface implementation
2508
  //
2509
 
2510
- struct whisper_context * whisper_init_from_file(const char * path_model) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2511
  whisper_model_loader loader = {};
2512
 
2513
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
@@ -2535,10 +2555,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
2535
  fin->close();
2536
  };
2537
 
2538
- return whisper_init(&loader);
2539
  }
2540
 
2541
- struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2542
  struct buf_context {
2543
  uint8_t* buffer;
2544
  size_t size;
@@ -2571,10 +2591,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
2571
 
2572
  loader.close = [](void * /*ctx*/) { };
2573
 
2574
- return whisper_init(&loader);
2575
  }
2576
 
2577
- struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2578
  ggml_time_init();
2579
 
2580
  whisper_context * ctx = new whisper_context;
@@ -2591,6 +2611,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2591
  return ctx;
2592
  }
2593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2594
  void whisper_free(struct whisper_context * ctx) {
2595
  if (ctx) {
2596
  if (ctx->model.ctx) {
@@ -2599,20 +2677,29 @@ void whisper_free(struct whisper_context * ctx) {
2599
  if (ctx->model.buf) {
2600
  delete ctx->model.buf;
2601
  }
2602
- if (ctx->kv_cross.ctx) {
2603
- ggml_free(ctx->kv_cross.ctx);
2604
- }
2605
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
2606
- if (ctx->decoders[i].kv_self.ctx) {
2607
- ggml_free(ctx->decoders[i].kv_self.ctx);
2608
- }
2609
- }
2610
  delete ctx;
2611
  }
2612
  }
2613
 
 
 
 
 
 
 
 
 
 
2614
  int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2615
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
 
 
 
 
 
2616
  fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2617
  return -1;
2618
  }
@@ -2622,11 +2709,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
2622
 
2623
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2624
  int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2625
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
2626
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
 
 
 
 
 
 
 
 
 
2627
  return -1;
2628
  }
2629
 
 
 
 
 
 
 
2630
  return 0;
2631
  }
2632
 
@@ -2635,22 +2737,20 @@ int whisper_set_mel(
2635
  const float * data,
2636
  int n_len,
2637
  int n_mel) {
2638
- if (n_mel != WHISPER_N_MEL) {
2639
- fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
 
 
 
 
2640
  return -1;
2641
  }
2642
 
2643
- ctx->mel.n_len = n_len;
2644
- ctx->mel.n_mel = n_mel;
2645
-
2646
- ctx->mel.data.resize(n_len*n_mel);
2647
- memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
2648
-
2649
  return 0;
2650
  }
2651
 
2652
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2653
- if (!whisper_encode(*ctx, offset, n_threads)) {
2654
  fprintf(stderr, "%s: failed to eval\n", __func__);
2655
  return -1;
2656
  }
@@ -2658,11 +2758,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2658
  return 0;
2659
  }
2660
 
 
 
 
 
 
 
 
 
 
 
 
2661
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2662
- // TODO: add selected_decoder_id to context
2663
  const int selected_decoder_id = 0;
2664
 
2665
- if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
 
 
 
 
 
 
2666
  fprintf(stderr, "%s: failed to eval\n", __func__);
2667
  return 1;
2668
  }
@@ -2720,11 +2837,12 @@ const char * whisper_lang_str(int id) {
2720
  return nullptr;
2721
  }
2722
 
2723
- int whisper_lang_auto_detect(
2724
  struct whisper_context * ctx,
2725
- int offset_ms,
2726
- int n_threads,
2727
- float * lang_probs) {
 
2728
  const int seek = offset_ms/10;
2729
 
2730
  if (seek < 0) {
@@ -2732,8 +2850,8 @@ int whisper_lang_auto_detect(
2732
  return -1;
2733
  }
2734
 
2735
- if (seek >= ctx->mel.n_len) {
2736
- fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
2737
  return -2;
2738
  }
2739
 
@@ -2745,17 +2863,17 @@ int whisper_lang_auto_detect(
2745
 
2746
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
2747
 
2748
- if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2749
  fprintf(stderr, "%s: failed to decode\n", __func__);
2750
  return -7;
2751
  }
2752
 
2753
- auto & logits_id = ctx->logits_id;
2754
  logits_id.clear();
2755
 
2756
  for (const auto & kv : g_lang) {
2757
  const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2758
- logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
2759
  }
2760
 
2761
  // sort descending
@@ -2794,8 +2912,20 @@ int whisper_lang_auto_detect(
2794
  return logits_id[0].second;
2795
  }
2796
 
 
 
 
 
 
 
 
 
 
 
 
 
2797
  int whisper_n_len(struct whisper_context * ctx) {
2798
- return ctx->mel.n_len;
2799
  }
2800
 
2801
  int whisper_n_vocab(struct whisper_context * ctx) {
@@ -2815,7 +2945,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
2815
  }
2816
 
2817
  float * whisper_get_logits(struct whisper_context * ctx) {
2818
- return ctx->logits.data();
 
 
 
 
 
2819
  }
2820
 
2821
  const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
@@ -2861,24 +2996,29 @@ whisper_token whisper_token_transcribe(void) {
2861
  void whisper_print_timings(struct whisper_context * ctx) {
2862
  const int64_t t_end_us = ggml_time_us();
2863
 
2864
- const int32_t n_sample = std::max(1, ctx->n_sample);
2865
- const int32_t n_encode = std::max(1, ctx->n_encode);
2866
- const int32_t n_decode = std::max(1, ctx->n_decode);
2867
-
2868
  fprintf(stderr, "\n");
2869
- fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
2870
- fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
2871
- fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
2872
- fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
2873
- fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
2874
- fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
 
 
 
 
 
 
 
2875
  fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2876
  }
2877
 
2878
  void whisper_reset_timings(struct whisper_context * ctx) {
2879
- ctx->t_sample_us = 0;
2880
- ctx->t_encode_us = 0;
2881
- ctx->t_decode_us = 0;
 
 
2882
  }
2883
 
2884
  const char * whisper_print_system_info(void) {
@@ -2991,6 +3131,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2991
  static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
2992
  static void whisper_exp_compute_token_level_timestamps(
2993
  struct whisper_context & ctx,
 
2994
  int i_segment,
2995
  float thold_pt,
2996
  float thold_ptsum);
@@ -3023,8 +3164,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3023
 
3024
  // wrap the last segment to max_len characters
3025
  // returns the number of new segments
3026
- static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
3027
- auto segment = ctx.result_all.back();
3028
 
3029
  int res = 1;
3030
  int acc = 0;
@@ -3046,24 +3187,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
3046
  trim(text);
3047
  }
3048
 
3049
- ctx.result_all.back().text = std::move(text);
3050
- ctx.result_all.back().t1 = token.t0;
3051
- ctx.result_all.back().tokens.resize(i);
3052
 
3053
- ctx.result_all.push_back({});
3054
- ctx.result_all.back().t0 = token.t0;
3055
- ctx.result_all.back().t1 = segment.t1;
3056
 
3057
  // add tokens [i, end] to the new segment
3058
- ctx.result_all.back().tokens.insert(
3059
- ctx.result_all.back().tokens.end(),
3060
  segment.tokens.begin() + i,
3061
  segment.tokens.end());
3062
 
3063
  acc = 0;
3064
  text = "";
3065
 
3066
- segment = ctx.result_all.back();
3067
  i = -1;
3068
 
3069
  res++;
@@ -3076,7 +3217,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
3076
  if (split_on_word) {
3077
  trim(text);
3078
  }
3079
- ctx.result_all.back().text = std::move(text);
3080
 
3081
  return res;
3082
  }
@@ -3093,6 +3234,7 @@ static const std::vector<std::string> non_speech_tokens = {
3093
  // - computes logprobs and probs
3094
  static void whisper_process_logits(
3095
  struct whisper_context & ctx,
 
3096
  const struct whisper_full_params params,
3097
  struct whisper_decoder & decoder,
3098
  float temperature) {
@@ -3111,7 +3253,7 @@ static void whisper_process_logits(
3111
  auto & logprobs = decoder.logprobs;
3112
  {
3113
  logits.resize(n_logits);
3114
- memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
3115
 
3116
  if (temperature > 0.0f) {
3117
  for (int i = 0; i < n_logits; i++) {
@@ -3149,7 +3291,7 @@ static void whisper_process_logits(
3149
  logits[vocab.token_transcribe] = -INFINITY;
3150
 
3151
  if (params.logits_filter_callback) {
3152
- params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3153
  }
3154
 
3155
  // suppress non-speech tokens
@@ -3310,6 +3452,7 @@ static void whisper_process_logits(
3310
 
3311
  static whisper_token_data whisper_sample_token(
3312
  whisper_context & ctx,
 
3313
  const whisper_decoder & decoder,
3314
  bool best) {
3315
  whisper_token_data result = {
@@ -3354,7 +3497,7 @@ static whisper_token_data whisper_sample_token(
3354
  } else {
3355
  std::discrete_distribution<> dist(probs.begin(), probs.end());
3356
 
3357
- result.id = dist(ctx.rng);
3358
  result.p = probs[result.id];
3359
  result.plog = logprobs[result.id];
3360
  }
@@ -3364,13 +3507,14 @@ static whisper_token_data whisper_sample_token(
3364
  result.pt = result.p;
3365
  }
3366
 
3367
- ctx.n_sample++;
3368
 
3369
  return result;
3370
  }
3371
 
3372
  static std::vector<whisper_token_data> whisper_sample_token_topk(
3373
  whisper_context & ctx,
 
3374
  const whisper_decoder & decoder,
3375
  int k) {
3376
  const auto & vocab = ctx.vocab;
@@ -3381,7 +3525,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
3381
 
3382
  const int n_logits = vocab.n_vocab;
3383
 
3384
- auto & logits_id = ctx.logits_id;
3385
 
3386
  logits_id.clear();
3387
  for (int i = 0; i < n_logits; ++i) {
@@ -3434,7 +3578,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
3434
  }
3435
  }
3436
 
3437
- ctx.n_sample++;
3438
 
3439
  return result;
3440
  }
@@ -3488,24 +3632,25 @@ static void whisper_sequence_score(
3488
  }
3489
  }
3490
 
3491
- int whisper_full(
3492
  struct whisper_context * ctx,
3493
- struct whisper_full_params params,
3494
- const float * samples,
3495
- int n_samples) {
 
3496
  // clear old results
3497
- auto & result_all = ctx->result_all;
3498
 
3499
  result_all.clear();
3500
 
3501
  // compute log mel spectrogram
3502
  if (params.speed_up) {
3503
- if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
3504
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3505
  return -1;
3506
  }
3507
  } else {
3508
- if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
3509
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3510
  return -2;
3511
  }
@@ -3515,26 +3660,26 @@ int whisper_full(
3515
  if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3516
  std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3517
 
3518
- const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
3519
  if (lang_id < 0) {
3520
  fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
3521
  return -3;
3522
  }
3523
- ctx->lang_id = lang_id;
3524
  params.language = whisper_lang_str(lang_id);
3525
 
3526
  fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3527
  }
3528
 
3529
  if (params.token_timestamps) {
3530
- ctx->t_beg = 0;
3531
- ctx->t_last = 0;
3532
- ctx->tid_last = 0;
3533
- ctx->energy = get_signal_energy(samples, n_samples, 32);
3534
  }
3535
 
3536
  const int seek_start = params.offset_ms/10;
3537
- const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
3538
 
3539
  // if length of spectrogram is less than 1s (100 samples), then return
3540
  // basically don't process anything that is less than 1s
@@ -3572,10 +3717,10 @@ int whisper_full(
3572
 
3573
  // TAGS: WHISPER_DECODER_INIT
3574
  for (int j = 1; j < n_decoders; j++) {
3575
- auto & decoder = ctx->decoders[j];
3576
 
3577
  if (decoder.kv_self.ctx == nullptr) {
3578
- decoder.kv_self = ctx->decoders[0].kv_self;
3579
  if (!kv_cache_reinit(decoder.kv_self)) {
3580
  fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
3581
  return -4;
@@ -3583,7 +3728,7 @@ int whisper_full(
3583
 
3584
  WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
3585
 
3586
- decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
3587
 
3588
  decoder.probs.resize (ctx->vocab.n_vocab);
3589
  decoder.logits.resize (ctx->vocab.n_vocab);
@@ -3592,7 +3737,7 @@ int whisper_full(
3592
  }
3593
 
3594
  // the accumulated text context so far
3595
- auto & prompt_past = ctx->prompt_past;
3596
  if (params.no_context) {
3597
  prompt_past.clear();
3598
  }
@@ -3611,13 +3756,13 @@ int whisper_full(
3611
  fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
3612
  return -5;
3613
  }
3614
- ctx->exp_n_audio_ctx = params.audio_ctx;
3615
 
3616
  // these tokens determine the task that will be performed
3617
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
3618
  if (whisper_is_multilingual(ctx)) {
3619
  const int lang_id = whisper_lang_id(params.language);
3620
- ctx->lang_id = lang_id;
3621
  prompt_init.push_back(whisper_token_lang(ctx, lang_id));
3622
  if (params.translate) {
3623
  prompt_init.push_back(whisper_token_translate());
@@ -3669,14 +3814,14 @@ int whisper_full(
3669
  }
3670
 
3671
  if (params.encoder_begin_callback) {
3672
- if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
3673
  fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
3674
  break;
3675
  }
3676
  }
3677
 
3678
  // encode audio features starting at offset seek
3679
- if (!whisper_encode(*ctx, seek, params.n_threads)) {
3680
  fprintf(stderr, "%s: failed to encode\n", __func__);
3681
  return -6;
3682
  }
@@ -3717,7 +3862,7 @@ int whisper_full(
3717
 
3718
  // TAGS: WHISPER_DECODER_INIT
3719
  for (int j = 0; j < n_decoders_cur; ++j) {
3720
- auto & decoder = ctx->decoders[j];
3721
 
3722
  decoder.kv_self.n = 0;
3723
 
@@ -3759,7 +3904,7 @@ int whisper_full(
3759
  }
3760
  WHISPER_PRINT_DEBUG("\n\n");
3761
 
3762
- if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3763
  fprintf(stderr, "%s: failed to decode\n", __func__);
3764
  return -7;
3765
  }
@@ -3767,24 +3912,24 @@ int whisper_full(
3767
  {
3768
  const int64_t t_start_sample_us = ggml_time_us();
3769
 
3770
- whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
3771
 
3772
- ctx->decoders[0].kv_self.n += prompt.size();
3773
 
3774
  for (int j = 1; j < n_decoders_cur; ++j) {
3775
- auto & decoder = ctx->decoders[j];
3776
 
3777
- memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
3778
- memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
3779
 
3780
  decoder.kv_self.n += prompt.size();
3781
 
3782
- memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
3783
- memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
3784
- memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
3785
  }
3786
 
3787
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3788
  }
3789
  }
3790
 
@@ -3795,7 +3940,7 @@ int whisper_full(
3795
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3796
  kv_bufs.resize(n_decoders_cur);
3797
  for (int j = 0; j < n_decoders_cur; ++j) {
3798
- auto & decoder = ctx->decoders[j];
3799
 
3800
  if (decoder.completed || decoder.failed) {
3801
  continue;
@@ -3813,7 +3958,7 @@ int whisper_full(
3813
 
3814
  // generate new sequence candidates for each decoder
3815
  for (int j = 0; j < n_decoders_cur; ++j) {
3816
- auto & decoder = ctx->decoders[j];
3817
 
3818
  if (decoder.completed || decoder.failed) {
3819
  continue;
@@ -3823,16 +3968,16 @@ int whisper_full(
3823
  case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3824
  {
3825
  if (t_cur < 1e-6f) {
3826
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
3827
  } else {
3828
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
3829
  }
3830
 
3831
  decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
3832
  } break;
3833
  case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3834
  {
3835
- const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
3836
 
3837
  for (const auto & token : tokens_new) {
3838
  beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
@@ -3857,7 +4002,7 @@ int whisper_full(
3857
  uint32_t cur_c = 0;
3858
 
3859
  for (int j = 0; j < n_decoders_cur; ++j) {
3860
- auto & decoder = ctx->decoders[j];
3861
 
3862
  if (decoder.completed || decoder.failed) {
3863
  continue;
@@ -3886,7 +4031,7 @@ int whisper_full(
3886
  // - check if the sequence is failed
3887
  // - update sliding window based on timestamp tokens
3888
  for (int j = 0; j < n_decoders_cur; ++j) {
3889
- auto & decoder = ctx->decoders[j];
3890
 
3891
  if (decoder.completed || decoder.failed) {
3892
  continue;
@@ -3968,7 +4113,7 @@ int whisper_full(
3968
  bool completed_all = true;
3969
 
3970
  for (int j = 0; j < n_decoders_cur; ++j) {
3971
- auto & decoder = ctx->decoders[j];
3972
 
3973
  if (decoder.completed || decoder.failed) {
3974
  continue;
@@ -3982,11 +4127,11 @@ int whisper_full(
3982
  }
3983
  }
3984
 
3985
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3986
 
3987
  // obtain logits for the next token
3988
  for (int j = 0; j < n_decoders_cur; ++j) {
3989
- auto & decoder = ctx->decoders[j];
3990
 
3991
  if (decoder.failed || decoder.completed) {
3992
  continue;
@@ -3997,7 +4142,7 @@ int whisper_full(
3997
 
3998
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
3999
 
4000
- if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4001
  fprintf(stderr, "%s: failed to decode\n", __func__);
4002
  return -8;
4003
  }
@@ -4005,11 +4150,11 @@ int whisper_full(
4005
  {
4006
  const int64_t t_start_sample_us = ggml_time_us();
4007
 
4008
- whisper_process_logits(*ctx, params, decoder, t_cur);
4009
 
4010
  ++decoder.kv_self.n;
4011
 
4012
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
4013
  }
4014
  }
4015
  }
@@ -4019,7 +4164,7 @@ int whisper_full(
4019
  double best_score = -INFINITY;
4020
 
4021
  for (int j = 0; j < n_decoders_cur; ++j) {
4022
- auto & decoder = ctx->decoders[j];
4023
 
4024
  if (decoder.failed) {
4025
  continue;
@@ -4036,7 +4181,7 @@ int whisper_full(
4036
  __func__, j, decoder.sequence.entropy, params.entropy_thold);
4037
 
4038
  decoder.failed = true;
4039
- ctx->n_fail_h++;
4040
 
4041
  continue;
4042
  }
@@ -4054,11 +4199,11 @@ int whisper_full(
4054
  {
4055
  bool success = true;
4056
 
4057
- const auto & decoder = ctx->decoders[best_decoder_id];
4058
 
4059
  if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
4060
  success = false;
4061
- ctx->n_fail_p++;
4062
  }
4063
 
4064
  if (success) {
@@ -4075,7 +4220,7 @@ int whisper_full(
4075
 
4076
  // output results through a user-provided callback
4077
  {
4078
- const auto & best_decoder = ctx->decoders[best_decoder_id];
4079
 
4080
  const auto seek_delta = best_decoder.seek_delta;
4081
  const auto result_len = best_decoder.sequence.result_len;
@@ -4138,14 +4283,14 @@ int whisper_full(
4138
 
4139
  if (params.token_timestamps) {
4140
  whisper_exp_compute_token_level_timestamps(
4141
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4142
 
4143
  if (params.max_len > 0) {
4144
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4145
  }
4146
  }
4147
  if (params.new_segment_callback) {
4148
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4149
  }
4150
  }
4151
  text = "";
@@ -4182,14 +4327,14 @@ int whisper_full(
4182
 
4183
  if (params.token_timestamps) {
4184
  whisper_exp_compute_token_level_timestamps(
4185
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4186
 
4187
  if (params.max_len > 0) {
4188
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4189
  }
4190
  }
4191
  if (params.new_segment_callback) {
4192
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4193
  }
4194
  }
4195
  }
@@ -4204,6 +4349,15 @@ int whisper_full(
4204
  return 0;
4205
  }
4206
 
 
 
 
 
 
 
 
 
 
4207
  int whisper_full_parallel(
4208
  struct whisper_context * ctx,
4209
  struct whisper_full_params params,
@@ -4213,40 +4367,10 @@ int whisper_full_parallel(
4213
  if (n_processors == 1) {
4214
  return whisper_full(ctx, params, samples, n_samples);
4215
  }
4216
-
4217
  int ret = 0;
4218
 
4219
- // prepare separate contexts for each thread
4220
- std::vector<struct whisper_context> ctxs(n_processors - 1);
4221
-
4222
- for (int i = 0; i < n_processors - 1; ++i) {
4223
- auto & ctx_p = ctxs[i];
4224
-
4225
- ctx_p = *ctx;
4226
-
4227
- ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
4228
-
4229
- ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
4230
-
4231
- if (!kv_cache_reinit(ctx_p.kv_cross)) {
4232
- fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
4233
- return false;
4234
- }
4235
-
4236
- // TAGS: WHISPER_DECODER_INIT
4237
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4238
- if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
4239
- fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
4240
- return false;
4241
- }
4242
-
4243
- ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
4244
-
4245
- ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
4246
- ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
4247
- ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
4248
- }
4249
- }
4250
 
4251
  const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
4252
  const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
@@ -4256,6 +4380,9 @@ int whisper_full_parallel(
4256
 
4257
  std::vector<std::thread> workers(n_processors - 1);
4258
  for (int i = 0; i < n_processors - 1; ++i) {
 
 
 
4259
  const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
4260
  const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
4261
 
@@ -4268,13 +4395,17 @@ int whisper_full_parallel(
4268
  params_cur.new_segment_callback = nullptr;
4269
  params_cur.new_segment_callback_user_data = nullptr;
4270
 
4271
- workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4272
  }
4273
 
4274
  {
4275
  auto params_cur = params;
4276
 
4277
- ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
 
 
 
 
4278
  }
4279
 
4280
  for (int i = 0; i < n_processors - 1; ++i) {
@@ -4283,45 +4414,43 @@ int whisper_full_parallel(
4283
 
4284
  const int64_t offset_t = (int64_t) params.offset_ms/10.0;
4285
 
4286
- // combine results into ctx->result_all
4287
  for (int i = 0; i < n_processors - 1; ++i) {
4288
- auto & results_i = ctxs[i].result_all;
4289
 
4290
- for (auto & result : results_i) {
4291
  // correct the segment timestamp taking into account the offset
4292
- result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4293
- result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
 
4294
 
4295
  // make sure that segments are not overlapping
4296
- if (!ctx->result_all.empty()) {
4297
- result.t0 = std::max(result.t0, ctx->result_all.back().t1);
4298
  }
4299
 
4300
- ctx->result_all.push_back(std::move(result));
4301
 
4302
  // call the new_segment_callback for each segment
4303
  if (params.new_segment_callback) {
4304
- params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
4305
  }
4306
  }
4307
 
4308
- ctx->t_mel_us += ctxs[i].t_mel_us;
4309
- ctx->t_sample_us += ctxs[i].t_sample_us;
4310
- ctx->t_encode_us += ctxs[i].t_encode_us;
4311
- ctx->t_decode_us += ctxs[i].t_decode_us;
4312
 
4313
- kv_cache_free(ctx->kv_cross);
 
 
4314
 
4315
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4316
- kv_cache_free(ctx->decoders[j].kv_self);
4317
- }
4318
  }
4319
 
4320
  // average the timings
4321
- ctx->t_mel_us /= n_processors;
4322
- ctx->t_sample_us /= n_processors;
4323
- ctx->t_encode_us /= n_processors;
4324
- ctx->t_decode_us /= n_processors;
4325
 
4326
  // print information about the audio boundaries
4327
  fprintf(stderr, "\n");
@@ -4334,44 +4463,84 @@ int whisper_full_parallel(
4334
  return ret;
4335
  }
4336
 
 
 
 
 
4337
  int whisper_full_n_segments(struct whisper_context * ctx) {
4338
- return ctx->result_all.size();
 
 
 
 
4339
  }
4340
 
4341
  int whisper_full_lang_id(struct whisper_context * ctx) {
4342
- return ctx->lang_id;
 
 
 
 
4343
  }
4344
 
4345
  int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
4346
- return ctx->result_all[i_segment].t0;
 
 
 
 
4347
  }
4348
 
4349
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
4350
- return ctx->result_all[i_segment].t1;
 
 
 
 
4351
  }
4352
 
4353
  const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
4354
- return ctx->result_all[i_segment].text.c_str();
 
 
 
 
4355
  }
4356
 
4357
  int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
4358
- return ctx->result_all[i_segment].tokens.size();
4359
  }
4360
 
4361
- const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
4362
- return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
 
 
 
 
 
 
 
 
4363
  }
4364
 
4365
  whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
4366
- return ctx->result_all[i_segment].tokens[i_token].id;
 
 
 
 
4367
  }
4368
 
4369
  struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
4370
- return ctx->result_all[i_segment].tokens[i_token];
 
 
 
 
4371
  }
4372
 
4373
  float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
4374
- return ctx->result_all[i_segment].tokens[i_token].p;
4375
  }
4376
 
4377
  // =================================================================================================
@@ -4583,13 +4752,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
4583
 
4584
  static void whisper_exp_compute_token_level_timestamps(
4585
  struct whisper_context & ctx,
 
4586
  int i_segment,
4587
  float thold_pt,
4588
  float thold_ptsum) {
4589
- auto & segment = ctx.result_all[i_segment];
4590
  auto & tokens = segment.tokens;
4591
 
4592
- const int n_samples = ctx.energy.size();
4593
 
4594
  if (n_samples == 0) {
4595
  fprintf(stderr, "%s: no signal data available\n", __func__);
@@ -4612,9 +4782,9 @@ static void whisper_exp_compute_token_level_timestamps(
4612
  return;
4613
  }
4614
 
4615
- auto & t_beg = ctx.t_beg;
4616
- auto & t_last = ctx.t_last;
4617
- auto & tid_last = ctx.tid_last;
4618
 
4619
  for (int j = 0; j < n; ++j) {
4620
  auto & token = tokens[j];
@@ -4737,15 +4907,15 @@ static void whisper_exp_compute_token_level_timestamps(
4737
  float sum = 0.0f;
4738
 
4739
  for (int k = ss0; k < ss1; k++) {
4740
- sum += ctx.energy[k];
4741
  }
4742
 
4743
  const float thold = 0.5*sum/ns;
4744
 
4745
  {
4746
  int k = s0;
4747
- if (ctx.energy[k] > thold && j > 0) {
4748
- while (k > 0 && ctx.energy[k] > thold) {
4749
  k--;
4750
  }
4751
  tokens[j].t0 = sample_to_timestamp(k);
@@ -4755,7 +4925,7 @@ static void whisper_exp_compute_token_level_timestamps(
4755
  s0 = k;
4756
  }
4757
  } else {
4758
- while (ctx.energy[k] < thold && k < s1) {
4759
  k++;
4760
  }
4761
  s0 = k;
@@ -4765,8 +4935,8 @@ static void whisper_exp_compute_token_level_timestamps(
4765
 
4766
  {
4767
  int k = s1;
4768
- if (ctx.energy[k] > thold) {
4769
- while (k < n_samples - 1 && ctx.energy[k] > thold) {
4770
  k++;
4771
  }
4772
  tokens[j].t1 = sample_to_timestamp(k);
@@ -4776,7 +4946,7 @@ static void whisper_exp_compute_token_level_timestamps(
4776
  s1 = k;
4777
  }
4778
  } else {
4779
- while (ctx.energy[k] < thold && k > s0) {
4780
  k--;
4781
  }
4782
  s1 = k;
 
547
  std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
548
  };
549
 
550
+ struct whisper_state {
 
 
551
  int64_t t_sample_us = 0;
552
  int64_t t_encode_us = 0;
553
  int64_t t_decode_us = 0;
554
+ int64_t t_mel_us = 0;
555
 
556
  int32_t n_sample = 0; // number of tokens sampled
557
  int32_t n_encode = 0; // number of encoder calls
 
559
  int32_t n_fail_p = 0; // number of logprob threshold failures
560
  int32_t n_fail_h = 0; // number of entropy threshold failures
561
 
 
 
 
 
 
 
 
562
  // cross-attention KV cache for the decoders
563
  // shared between all decoders
564
  whisper_kv_cache kv_cross;
565
+ whisper_mel mel;
566
 
567
  whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
568
 
 
627
  }
628
  };
629
 
630
+ struct whisper_context {
631
+ int64_t t_load_us = 0;
632
+ int64_t t_start_us = 0;
633
+
634
+
635
+ ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
636
+
637
+ whisper_model model;
638
+ whisper_vocab vocab;
639
+ whisper_state * state = nullptr;
640
+ };
641
+
642
  template<typename T>
643
  static void read_safe(whisper_model_loader * loader, T & dest) {
644
  loader->read(loader->context, &dest, sizeof(T));
 
825
  wctx.model.buf = new std::vector<uint8_t>();
826
  wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
827
 
828
+ // we skip initialization of the state until it is needed
829
+ // because it might be that state will always be provided externally.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830
  }
831
 
832
  // load mel filters
 
909
  vocab.id_to_token[i] = word;
910
  }
911
  }
 
 
 
 
 
 
 
 
 
 
 
912
  }
913
 
914
  size_t ctx_size = 0;
 
1308
  }
1309
  }
1310
 
 
 
1311
  wctx.t_load_us = ggml_time_us() - t_start_us;
1312
 
1313
  return true;
1314
  }
1315
 
1316
+ // evaluate the encoder with the given state
1317
  //
1318
  // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1319
  // part of the transformer model and returns the encoded features
1320
  //
1321
+ // - wctx: the model
1322
+ // - wstate: the state of the encoder
1323
  // - n_threads: number of threads to use
1324
  // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1325
  //
1326
+ static bool whisper_encode_internal(
1327
  whisper_context & wctx,
1328
+ whisper_state & wstate,
1329
  const int mel_offset,
1330
+ const int n_threads){
1331
+
1332
  const int64_t t_start_us = ggml_time_us();
1333
 
1334
  const auto & model = wctx.model;
1335
+ const auto & mel_inp = wstate.mel;
1336
  const auto & hparams = model.hparams;
1337
 
1338
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1339
  const int n_state = hparams.n_audio_state;
1340
  const int n_head = hparams.n_audio_head;
1341
  const int n_layer = hparams.n_audio_layer;
 
1344
  assert(mel_inp.n_mel == n_mels);
1345
 
1346
  struct ggml_init_params params;
1347
+ params.mem_size = wstate.buf_compute.size();
1348
+ params.mem_buffer = wstate.buf_compute.data();
1349
 
1350
  struct ggml_context * ctx0 = ggml_init(params);
1351
 
1352
+ wstate.use_buf(ctx0, 0);
1353
 
1354
  struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1355
  assert(mel->type == GGML_TYPE_F32);
 
1371
 
1372
  // convolution + gelu
1373
  {
1374
+ wstate.use_buf(ctx0, 1);
1375
 
1376
  cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1377
  cur = ggml_add(ctx0,
1378
+ ggml_repeat(ctx0,
1379
+ model.e_conv_1_b,
1380
+ cur),
1381
+ cur);
1382
 
1383
  cur = ggml_gelu(ctx0, cur);
1384
 
1385
+ wstate.use_buf(ctx0, 0);
1386
 
1387
  cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1388
  cur = ggml_add(ctx0,
1389
+ ggml_repeat(ctx0,
1390
+ model.e_conv_2_b,
1391
+ cur),
1392
+ cur);
1393
 
1394
  cur = ggml_gelu(ctx0, cur);
1395
  }
1396
 
1397
+ wstate.use_buf(ctx0, 3);
1398
 
1399
  // ===================================================================
1400
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
 
1409
  //}
1410
 
1411
  static int iter = 0;
1412
+
1413
  const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1414
  const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1415
 
 
1429
 
1430
  // norm
1431
  {
1432
+ wstate.use_buf(ctx0, 0);
1433
 
1434
  cur = ggml_norm(ctx0, inpL);
1435
 
1436
  // cur = ln_0_w*cur + ln_0_b
1437
  cur = ggml_add(ctx0,
1438
+ ggml_mul(ctx0,
1439
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1440
+ cur),
1441
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1442
  }
1443
 
1444
  // self-attention
1445
  {
1446
+ wstate.use_buf(ctx0, 1);
1447
 
1448
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1449
+ layer.attn_q_w,
1450
+ cur);
1451
 
1452
  Qcur = ggml_add(ctx0,
1453
+ ggml_repeat(ctx0,
1454
+ layer.attn_q_b,
1455
+ Qcur),
1456
+ Qcur);
1457
 
1458
  //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1459
 
1460
  // note: no bias for Key
1461
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1462
+ layer.attn_k_w,
1463
+ cur);
1464
 
1465
  //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1466
 
1467
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1468
+ layer.attn_v_w,
1469
+ cur);
1470
 
1471
  Vcur = ggml_add(ctx0,
1472
+ ggml_repeat(ctx0,
1473
+ layer.attn_v_b,
1474
+ Vcur),
1475
+ Vcur);
1476
 
1477
  // ------
1478
 
1479
+ wstate.use_buf(ctx0, 0);
1480
 
1481
  #ifdef WHISPER_USE_FLASH_ATTN
1482
  struct ggml_tensor * Q =
 
1553
  #endif
1554
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1555
 
1556
+ wstate.use_buf(ctx0, 1);
1557
 
1558
  cur = ggml_cpy(ctx0,
1559
+ KQV_merged,
1560
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1561
  }
1562
 
1563
  // projection
1564
  {
1565
+ wstate.use_buf(ctx0, 0);
1566
 
1567
  cur = ggml_mul_mat(ctx0,
1568
+ layer.attn_ln_1_w,
1569
+ cur);
1570
 
1571
+ wstate.use_buf(ctx0, 1);
1572
 
1573
  cur = ggml_add(ctx0,
1574
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1575
+ cur);
1576
  }
1577
 
1578
+ wstate.use_buf(ctx0, 2);
1579
 
1580
  // add the input
1581
  cur = ggml_add(ctx0, cur, inpL);
 
1586
  {
1587
  // norm
1588
  {
1589
+ wstate.use_buf(ctx0, 0);
1590
 
1591
  cur = ggml_norm(ctx0, inpFF);
1592
 
1593
+ wstate.use_buf(ctx0, 1);
1594
 
1595
  // cur = mlp_ln_w*cur + mlp_ln_b
1596
  cur = ggml_add(ctx0,
1597
+ ggml_mul(ctx0,
1598
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1599
+ cur),
1600
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1601
+ }
1602
 
1603
  #ifdef WHISPER_USE_FLASH_FF
1604
+ wstate.use_buf(ctx0, 0);
1605
 
1606
  cur = ggml_flash_ff(ctx0,
1607
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
1608
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1609
  #else
1610
+ wstate.use_buf(ctx0, 0);
1611
 
1612
  // fully connected
1613
  cur = ggml_mul_mat(ctx0,
1614
+ layer.mlp_0_w,
1615
+ cur);
1616
 
1617
+ wstate.use_buf(ctx0, 1);
1618
 
1619
  cur = ggml_add(ctx0,
1620
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
1621
+ cur);
1622
 
1623
+ wstate.use_buf(ctx0, 0);
1624
 
1625
  // GELU activation
1626
  cur = ggml_gelu(ctx0, cur);
1627
 
1628
+ wstate.use_buf(ctx0, 1);
1629
 
1630
  // projection
1631
  cur = ggml_mul_mat(ctx0,
1632
+ layer.mlp_1_w,
1633
+ cur);
1634
 
1635
+ wstate.use_buf(ctx0, 0);
1636
 
1637
  cur = ggml_add(ctx0,
1638
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
1639
+ cur);
1640
  #endif
1641
+ }
1642
 
1643
+ wstate.use_buf(ctx0, 3);
1644
 
1645
  inpL = ggml_add(ctx0, cur, inpFF);
1646
  }
 
1649
 
1650
  // norm
1651
  {
1652
+ wstate.use_buf(ctx0, 0);
1653
 
1654
  cur = ggml_norm(ctx0, cur);
1655
 
1656
+ wstate.use_buf(ctx0, 1);
1657
 
1658
  // cur = ln_f_g*cur + ln_f_b
1659
  cur = ggml_add(ctx0,
1660
+ ggml_mul(ctx0,
1661
+ ggml_repeat(ctx0, model.e_ln_w, cur),
1662
+ cur),
1663
+ ggml_repeat(ctx0, model.e_ln_b, cur));
1664
  }
1665
 
1666
+ wstate.use_buf(ctx0, -1);
1667
 
1668
  // run the computation
1669
  {
 
1671
  gf.n_threads = n_threads;
1672
 
1673
  ggml_build_forward_expand(&gf, cur);
1674
+ ggml_graph_compute(ctx0, &gf);
1675
 
1676
  //ggml_graph_print(&gf);
1677
  }
 
1701
  cur->src1 = nullptr;
1702
 
1703
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1704
+ auto& layer = model.layers_decoder[il];
1705
 
1706
+ wstate.use_buf(ctx0, 0);
1707
 
1708
+ struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
1709
+ layer.cross_attn_k_w,
1710
+ cur);
1711
 
1712
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
1713
 
1714
+ wstate.use_buf(ctx0, 1);
1715
 
1716
+ struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
1717
+ layer.cross_attn_v_w,
1718
+ cur);
1719
 
1720
  Vcross = ggml_add(ctx0,
1721
+ ggml_repeat(ctx0,
1722
+ layer.cross_attn_v_b,
1723
+ Vcross),
1724
+ Vcross);
1725
 
1726
+ wstate.use_buf(ctx0, -1);
1727
 
1728
+ //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1729
+ //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1730
+ struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1731
+ struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
1732
 
1733
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1734
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
 
1749
 
1750
  ggml_free(ctx0);
1751
 
1752
+ wstate.t_encode_us += ggml_time_us() - t_start_us;
1753
+ wstate.n_encode++;
1754
 
1755
  return true;
1756
  }
 
1765
  // - n_tokens: number of tokens in the prompt
1766
  // - n_past: number of past tokens to prefix the prompt with
1767
  //
1768
+ static bool whisper_decode_internal(
1769
  whisper_context & wctx,
1770
+ whisper_state & wstate,
1771
  whisper_decoder & decoder,
1772
  const whisper_token * tokens,
1773
  const int n_tokens,
 
1782
 
1783
  WHISPER_ASSERT(!!kv_self.ctx);
1784
 
1785
+ auto & logits_out = wstate.logits;
1786
 
1787
  const int n_vocab = hparams.n_vocab;
1788
 
 
1792
  const int n_layer = hparams.n_text_layer;
1793
 
1794
  const int N = n_tokens;
1795
+ const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1796
 
1797
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1798
 
1799
  struct ggml_init_params params;
1800
+ params.mem_size = wstate.buf_compute.size();
1801
+ params.mem_buffer = wstate.buf_compute.data();
1802
 
1803
  struct ggml_context * ctx0 = ggml_init(params);
1804
 
 
1813
  ((int32_t *) position->data)[i] = n_past + i;
1814
  }
1815
 
1816
+ wstate.use_buf(ctx0, 3);
1817
 
1818
  // token encoding + position encoding
1819
  struct ggml_tensor * cur =
 
1828
 
1829
  // norm
1830
  {
1831
+ wstate.use_buf(ctx0, 0);
1832
 
1833
  cur = ggml_norm(ctx0, inpL);
1834
 
 
1842
 
1843
  // self-attention
1844
  {
1845
+ wstate.use_buf(ctx0, 1);
1846
 
1847
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1848
  layer.attn_q_w,
 
1884
 
1885
  // ------
1886
 
1887
+ wstate.use_buf(ctx0, 0);
1888
 
1889
  struct ggml_tensor * Q =
1890
  ggml_permute(ctx0,
 
1900
  n_state/n_head, n_head, n_past + N),
1901
  0, 2, 1, 3);
1902
 
1903
+ wstate.use_buf(ctx0, 1);
1904
 
1905
  // K * Q
1906
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1907
 
1908
+ wstate.use_buf(ctx0, 0);
1909
 
1910
  //struct ggml_tensor * KQ_scaled =
1911
  // ggml_scale(ctx0,
 
1915
 
1916
  struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1917
 
1918
+ wstate.use_buf(ctx0, 1);
1919
 
1920
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1921
 
1922
+ wstate.use_buf(ctx0, 0);
1923
 
1924
  struct ggml_tensor * V_trans =
1925
  ggml_permute(ctx0,
 
1928
  n_state/n_head, n_head, n_past + N),
1929
  1, 2, 0, 3);
1930
 
1931
+ wstate.use_buf(ctx0, 1);
1932
 
1933
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1934
 
 
1941
 
1942
  // projection
1943
  {
1944
+ wstate.use_buf(ctx0, 0);
1945
 
1946
  cur = ggml_mul_mat(ctx0,
1947
  layer.attn_ln_1_w,
1948
  cur);
1949
 
1950
+ wstate.use_buf(ctx0, 1);
1951
 
1952
  cur = ggml_add(ctx0,
1953
  ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1954
  cur);
1955
  }
1956
 
1957
+ wstate.use_buf(ctx0, 2);
1958
 
1959
  // add the input
1960
  struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
1961
 
1962
  // norm
1963
  {
1964
+ wstate.use_buf(ctx0, 0);
1965
 
1966
  cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1967
 
1968
+ wstate.use_buf(ctx0, 1);
1969
 
1970
  // cur = ln_0_w*cur + ln_0_b
1971
  cur = ggml_add(ctx0,
 
1977
 
1978
  // cross-attention
1979
  {
1980
+ wstate.use_buf(ctx0, 0);
1981
 
1982
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1983
  layer.cross_attn_q_w,
 
1994
  // Kcross is already scaled
1995
  struct ggml_tensor * Kcross =
1996
  ggml_reshape_3d(ctx0,
1997
+ ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
1998
  n_state/n_head, n_head, M);
1999
 
2000
  struct ggml_tensor * Vcross =
2001
  ggml_reshape_3d(ctx0,
2002
+ ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
2003
  n_state/n_head, n_head, M);
2004
 
2005
  struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2006
 
2007
  // ------
2008
 
2009
+ wstate.use_buf(ctx0, 1);
2010
 
2011
  struct ggml_tensor * Q =
2012
  ggml_permute(ctx0,
 
2017
 
2018
  struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2019
 
2020
+ wstate.use_buf(ctx0, 0);
2021
 
2022
  // K * Q
2023
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
2031
  // no masking for cross-attention
2032
  //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2033
 
2034
+ wstate.use_buf(ctx0, 1);
2035
 
2036
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2037
 
2038
+ wstate.use_buf(ctx0, 0);
2039
 
2040
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2041
 
2042
+ wstate.use_buf(ctx0, 1);
2043
 
2044
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2045
 
 
2051
 
2052
  // projection
2053
  {
2054
+ wstate.use_buf(ctx0, 0);
2055
 
2056
  cur = ggml_mul_mat(ctx0,
2057
  layer.cross_attn_ln_1_w,
2058
  cur);
2059
 
2060
+ wstate.use_buf(ctx0, 1);
2061
 
2062
  cur = ggml_add(ctx0,
2063
  ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2064
  cur);
2065
  }
2066
 
2067
+ wstate.use_buf(ctx0, 2);
2068
 
2069
  // add the input
2070
  cur = ggml_add(ctx0, cur, inpCA);
 
2075
  {
2076
  // norm
2077
  {
2078
+ wstate.use_buf(ctx0, 0);
2079
 
2080
  cur = ggml_norm(ctx0, inpFF);
2081
 
2082
+ wstate.use_buf(ctx0, 1);
2083
 
2084
  // cur = mlp_ln_w*cur + mlp_ln_b
2085
  cur = ggml_add(ctx0,
 
2089
  ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2090
  }
2091
 
2092
+ wstate.use_buf(ctx0, 0);
2093
 
2094
  // fully connected
2095
  cur = ggml_mul_mat(ctx0,
2096
  layer.mlp_0_w,
2097
  cur);
2098
 
2099
+ wstate.use_buf(ctx0, 1);
2100
 
2101
  cur = ggml_add(ctx0,
2102
  ggml_repeat(ctx0, layer.mlp_0_b, cur),
2103
  cur);
2104
 
2105
+ wstate.use_buf(ctx0, 0);
2106
 
2107
  // GELU activation
2108
  cur = ggml_gelu(ctx0, cur);
2109
 
2110
+ wstate.use_buf(ctx0, 1);
2111
 
2112
  // projection
2113
  cur = ggml_mul_mat(ctx0,
2114
  layer.mlp_1_w,
2115
  cur);
2116
 
2117
+ wstate.use_buf(ctx0, 0);
2118
 
2119
  cur = ggml_add(ctx0,
2120
  ggml_repeat(ctx0, layer.mlp_1_b, cur),
2121
  cur);
2122
  }
2123
 
2124
+ wstate.use_buf(ctx0, 3);
2125
 
2126
  inpL = ggml_add(ctx0, cur, inpFF);
2127
  }
 
2130
 
2131
  // norm
2132
  {
2133
+ wstate.use_buf(ctx0, 0);
2134
 
2135
  cur = ggml_norm(ctx0, cur);
2136
 
2137
+ wstate.use_buf(ctx0, 1);
2138
 
2139
  cur = ggml_add(ctx0,
2140
  ggml_mul(ctx0,
 
2143
  ggml_repeat(ctx0, model.d_ln_b, cur));
2144
  }
2145
 
2146
+ wstate.use_buf(ctx0, 0);
2147
 
2148
  // compute logits only for the last token
2149
  // comment this line to compute logits for all N tokens
 
2152
 
2153
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2154
 
2155
+ wstate.use_buf(ctx0, -1);
2156
 
2157
  // run the computation
2158
  {
 
2179
 
2180
  ggml_free(ctx0);
2181
 
2182
+ wstate.t_decode_us += ggml_time_us() - t_start_us;
2183
+ wstate.n_decode++;
2184
 
2185
  return true;
2186
  }
 
2284
 
2285
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2286
  static bool log_mel_spectrogram(
2287
+ whisper_state & wstate,
2288
  const float * samples,
2289
  const int n_samples,
2290
  const int /*sample_rate*/,
 
2404
  mel.data[i] = (mel.data[i] + 4.0)/4.0;
2405
  }
2406
 
2407
+ wstate.t_mel_us += ggml_time_us() - t_start_us;
2408
 
2409
  return true;
2410
  }
 
2478
  // interface implementation
2479
  //
2480
 
2481
+ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2482
+ whisper_state * state = new whisper_state;
2483
+
2484
+ const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
2485
+
2486
+
2487
+ if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
2488
+ fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
2489
+ return nullptr;
2490
+ }
2491
+
2492
+ {
2493
+ const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
2494
+ fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2495
+ }
2496
+
2497
+ if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
2498
+ fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2499
+ return nullptr;
2500
+ }
2501
+
2502
+ {
2503
+ const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
2504
+ fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2505
+ }
2506
+
2507
+
2508
+ state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2509
+
2510
+ state->logits_id.reserve(ctx->model.hparams.n_vocab);
2511
+
2512
+ // TAGS: WHISPER_DECODER_INIT
2513
+ state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2514
+
2515
+ state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
2516
+ state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
2517
+ state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
2518
+ state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
2519
+
2520
+ state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
2521
+ state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
2522
+ state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
2523
+ state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
2524
+
2525
+ state->rng = std::mt19937(0);
2526
+
2527
+ return state;
2528
+ }
2529
+
2530
+ struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
2531
  whisper_model_loader loader = {};
2532
 
2533
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
 
2555
  fin->close();
2556
  };
2557
 
2558
+ return whisper_init_no_state(&loader);
2559
  }
2560
 
2561
+ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
2562
  struct buf_context {
2563
  uint8_t* buffer;
2564
  size_t size;
 
2591
 
2592
  loader.close = [](void * /*ctx*/) { };
2593
 
2594
+ return whisper_init_no_state(&loader);
2595
  }
2596
 
2597
+ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
2598
  ggml_time_init();
2599
 
2600
  whisper_context * ctx = new whisper_context;
 
2611
  return ctx;
2612
  }
2613
 
2614
+ struct whisper_context * whisper_init_from_file(const char * path_model) {
2615
+ whisper_context * ctx = whisper_init_from_file_no_state(path_model);
2616
+ if (!ctx) {
2617
+ return nullptr;
2618
+ }
2619
+
2620
+ ctx->state = whisper_init_state(ctx);
2621
+ if (!ctx->state) {
2622
+ whisper_free(ctx);
2623
+ return nullptr;
2624
+ }
2625
+
2626
+ return ctx;
2627
+ }
2628
+
2629
+ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2630
+ whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
2631
+ if (!ctx) {
2632
+ return nullptr;
2633
+ }
2634
+
2635
+ ctx->state = whisper_init_state(ctx);
2636
+ if (!ctx->state) {
2637
+ whisper_free(ctx);
2638
+ return nullptr;
2639
+ }
2640
+
2641
+ return ctx;
2642
+ }
2643
+
2644
+ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2645
+ whisper_context * ctx = whisper_init_no_state(loader);
2646
+ if (!ctx) {
2647
+ return nullptr;
2648
+ }
2649
+
2650
+ ctx->state = whisper_init_state(ctx);
2651
+ if (!ctx->state) {
2652
+ whisper_free(ctx);
2653
+ return nullptr;
2654
+ }
2655
+
2656
+ return ctx;
2657
+ }
2658
+
2659
+ void whisper_free_state(struct whisper_state * state)
2660
+ {
2661
+ if (state) {
2662
+ kv_cache_free(state->kv_cross);
2663
+
2664
+ for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
2665
+ kv_cache_free(state->decoders[i].kv_self);
2666
+ }
2667
+
2668
+ delete state;
2669
+ }
2670
+ }
2671
+
2672
  void whisper_free(struct whisper_context * ctx) {
2673
  if (ctx) {
2674
  if (ctx->model.ctx) {
 
2677
  if (ctx->model.buf) {
2678
  delete ctx->model.buf;
2679
  }
2680
+
2681
+ whisper_free_state(ctx->state);
2682
+
 
 
 
 
 
2683
  delete ctx;
2684
  }
2685
  }
2686
 
2687
+ int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2688
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
2689
+ fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2690
+ return -1;
2691
+ }
2692
+
2693
+ return 0;
2694
+ }
2695
+
2696
  int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2697
+ return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
2698
+ }
2699
+
2700
+ // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2701
+ int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2702
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
2703
  fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2704
  return -1;
2705
  }
 
2709
 
2710
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2711
  int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2712
+ return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
2713
+ }
2714
+
2715
+ int whisper_set_mel_with_state(
2716
+ struct whisper_context * /*ctx*/,
2717
+ struct whisper_state * state,
2718
+ const float * data,
2719
+ int n_len,
2720
+ int n_mel) {
2721
+ if (n_mel != WHISPER_N_MEL) {
2722
+ fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2723
  return -1;
2724
  }
2725
 
2726
+ state->mel.n_len = n_len;
2727
+ state->mel.n_mel = n_mel;
2728
+
2729
+ state->mel.data.resize(n_len*n_mel);
2730
+ memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
2731
+
2732
  return 0;
2733
  }
2734
 
 
2737
  const float * data,
2738
  int n_len,
2739
  int n_mel) {
2740
+ return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
2741
+ }
2742
+
2743
+ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
2744
+ if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
2745
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2746
  return -1;
2747
  }
2748
 
 
 
 
 
 
 
2749
  return 0;
2750
  }
2751
 
2752
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2753
+ if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
2754
  fprintf(stderr, "%s: failed to eval\n", __func__);
2755
  return -1;
2756
  }
 
2758
  return 0;
2759
  }
2760
 
2761
+ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2762
+ const int selected_decoder_id = 0;
2763
+
2764
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2765
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2766
+ return 1;
2767
+ }
2768
+
2769
+ return 0;
2770
+ }
2771
+
2772
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2773
+ // TODO: add selected_decoder_id to state
2774
  const int selected_decoder_id = 0;
2775
 
2776
+ if (ctx->state == nullptr) {
2777
+ fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
2778
+ return false;
2779
+ }
2780
+
2781
+
2782
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2783
  fprintf(stderr, "%s: failed to eval\n", __func__);
2784
  return 1;
2785
  }
 
2837
  return nullptr;
2838
  }
2839
 
2840
+ int whisper_lang_auto_detect_with_state(
2841
  struct whisper_context * ctx,
2842
+ struct whisper_state * state,
2843
+ int offset_ms,
2844
+ int n_threads,
2845
+ float * lang_probs) {
2846
  const int seek = offset_ms/10;
2847
 
2848
  if (seek < 0) {
 
2850
  return -1;
2851
  }
2852
 
2853
+ if (seek >= state->mel.n_len) {
2854
+ fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
2855
  return -2;
2856
  }
2857
 
 
2863
 
2864
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
2865
 
2866
+ if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2867
  fprintf(stderr, "%s: failed to decode\n", __func__);
2868
  return -7;
2869
  }
2870
 
2871
+ auto & logits_id = state->logits_id;
2872
  logits_id.clear();
2873
 
2874
  for (const auto & kv : g_lang) {
2875
  const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2876
+ logits_id.emplace_back(state->logits[token_lang], kv.second.first);
2877
  }
2878
 
2879
  // sort descending
 
2912
  return logits_id[0].second;
2913
  }
2914
 
2915
+ int whisper_lang_auto_detect(
2916
+ struct whisper_context * ctx,
2917
+ int offset_ms,
2918
+ int n_threads,
2919
+ float * lang_probs) {
2920
+ return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
2921
+ }
2922
+
2923
+ int whisper_n_len_from_state(struct whisper_state * state) {
2924
+ return state->mel.n_len;
2925
+ }
2926
+
2927
  int whisper_n_len(struct whisper_context * ctx) {
2928
+ return ctx->state->mel.n_len;
2929
  }
2930
 
2931
  int whisper_n_vocab(struct whisper_context * ctx) {
 
2945
  }
2946
 
2947
  float * whisper_get_logits(struct whisper_context * ctx) {
2948
+ return ctx->state->logits.data();
2949
+ }
2950
+
2951
+
2952
+ float * whisper_get_logits_from_state(struct whisper_state * state) {
2953
+ return state->logits.data();
2954
  }
2955
 
2956
  const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
 
2996
  void whisper_print_timings(struct whisper_context * ctx) {
2997
  const int64_t t_end_us = ggml_time_us();
2998
 
 
 
 
 
2999
  fprintf(stderr, "\n");
3000
+ fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3001
+ if (ctx->state != nullptr) {
3002
+
3003
+ const int32_t n_sample = std::max(1, ctx->state->n_sample);
3004
+ const int32_t n_encode = std::max(1, ctx->state->n_encode);
3005
+ const int32_t n_decode = std::max(1, ctx->state->n_decode);
3006
+
3007
+ fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3008
+ fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3009
+ fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3010
+ fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3011
+ fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3012
+ }
3013
  fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3014
  }
3015
 
3016
  void whisper_reset_timings(struct whisper_context * ctx) {
3017
+ if (ctx->state != nullptr) {
3018
+ ctx->state->t_sample_us = 0;
3019
+ ctx->state->t_encode_us = 0;
3020
+ ctx->state->t_decode_us = 0;
3021
+ }
3022
  }
3023
 
3024
  const char * whisper_print_system_info(void) {
 
3131
  static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
3132
  static void whisper_exp_compute_token_level_timestamps(
3133
  struct whisper_context & ctx,
3134
+ struct whisper_state & state,
3135
  int i_segment,
3136
  float thold_pt,
3137
  float thold_ptsum);
 
3164
 
3165
  // wrap the last segment to max_len characters
3166
  // returns the number of new segments
3167
+ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
3168
+ auto segment = state.result_all.back();
3169
 
3170
  int res = 1;
3171
  int acc = 0;
 
3187
  trim(text);
3188
  }
3189
 
3190
+ state.result_all.back().text = std::move(text);
3191
+ state.result_all.back().t1 = token.t0;
3192
+ state.result_all.back().tokens.resize(i);
3193
 
3194
+ state.result_all.push_back({});
3195
+ state.result_all.back().t0 = token.t0;
3196
+ state.result_all.back().t1 = segment.t1;
3197
 
3198
  // add tokens [i, end] to the new segment
3199
+ state.result_all.back().tokens.insert(
3200
+ state.result_all.back().tokens.end(),
3201
  segment.tokens.begin() + i,
3202
  segment.tokens.end());
3203
 
3204
  acc = 0;
3205
  text = "";
3206
 
3207
+ segment = state.result_all.back();
3208
  i = -1;
3209
 
3210
  res++;
 
3217
  if (split_on_word) {
3218
  trim(text);
3219
  }
3220
+ state.result_all.back().text = std::move(text);
3221
 
3222
  return res;
3223
  }
 
3234
  // - computes logprobs and probs
3235
  static void whisper_process_logits(
3236
  struct whisper_context & ctx,
3237
+ struct whisper_state & state,
3238
  const struct whisper_full_params params,
3239
  struct whisper_decoder & decoder,
3240
  float temperature) {
 
3253
  auto & logprobs = decoder.logprobs;
3254
  {
3255
  logits.resize(n_logits);
3256
+ memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
3257
 
3258
  if (temperature > 0.0f) {
3259
  for (int i = 0; i < n_logits; i++) {
 
3291
  logits[vocab.token_transcribe] = -INFINITY;
3292
 
3293
  if (params.logits_filter_callback) {
3294
+ params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3295
  }
3296
 
3297
  // suppress non-speech tokens
 
3452
 
3453
  static whisper_token_data whisper_sample_token(
3454
  whisper_context & ctx,
3455
+ whisper_state & state,
3456
  const whisper_decoder & decoder,
3457
  bool best) {
3458
  whisper_token_data result = {
 
3497
  } else {
3498
  std::discrete_distribution<> dist(probs.begin(), probs.end());
3499
 
3500
+ result.id = dist(state.rng);
3501
  result.p = probs[result.id];
3502
  result.plog = logprobs[result.id];
3503
  }
 
3507
  result.pt = result.p;
3508
  }
3509
 
3510
+ state.n_sample++;
3511
 
3512
  return result;
3513
  }
3514
 
3515
  static std::vector<whisper_token_data> whisper_sample_token_topk(
3516
  whisper_context & ctx,
3517
+ whisper_state & state,
3518
  const whisper_decoder & decoder,
3519
  int k) {
3520
  const auto & vocab = ctx.vocab;
 
3525
 
3526
  const int n_logits = vocab.n_vocab;
3527
 
3528
+ auto & logits_id = state.logits_id;
3529
 
3530
  logits_id.clear();
3531
  for (int i = 0; i < n_logits; ++i) {
 
3578
  }
3579
  }
3580
 
3581
+ state.n_sample++;
3582
 
3583
  return result;
3584
  }
 
3632
  }
3633
  }
3634
 
3635
+ int whisper_full_with_state(
3636
  struct whisper_context * ctx,
3637
+ struct whisper_state * state,
3638
+ struct whisper_full_params params,
3639
+ const float * samples,
3640
+ int n_samples) {
3641
  // clear old results
3642
+ auto & result_all = state->result_all;
3643
 
3644
  result_all.clear();
3645
 
3646
  // compute log mel spectrogram
3647
  if (params.speed_up) {
3648
+ if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3649
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3650
  return -1;
3651
  }
3652
  } else {
3653
+ if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3654
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3655
  return -2;
3656
  }
 
3660
  if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3661
  std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3662
 
3663
+ const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
3664
  if (lang_id < 0) {
3665
  fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
3666
  return -3;
3667
  }
3668
+ state->lang_id = lang_id;
3669
  params.language = whisper_lang_str(lang_id);
3670
 
3671
  fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3672
  }
3673
 
3674
  if (params.token_timestamps) {
3675
+ state->t_beg = 0;
3676
+ state->t_last = 0;
3677
+ state->tid_last = 0;
3678
+ state->energy = get_signal_energy(samples, n_samples, 32);
3679
  }
3680
 
3681
  const int seek_start = params.offset_ms/10;
3682
+ const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10);
3683
 
3684
  // if length of spectrogram is less than 1s (100 samples), then return
3685
  // basically don't process anything that is less than 1s
 
3717
 
3718
  // TAGS: WHISPER_DECODER_INIT
3719
  for (int j = 1; j < n_decoders; j++) {
3720
+ auto & decoder = state->decoders[j];
3721
 
3722
  if (decoder.kv_self.ctx == nullptr) {
3723
+ decoder.kv_self = state->decoders[0].kv_self;
3724
  if (!kv_cache_reinit(decoder.kv_self)) {
3725
  fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
3726
  return -4;
 
3728
 
3729
  WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
3730
 
3731
+ decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
3732
 
3733
  decoder.probs.resize (ctx->vocab.n_vocab);
3734
  decoder.logits.resize (ctx->vocab.n_vocab);
 
3737
  }
3738
 
3739
  // the accumulated text context so far
3740
+ auto & prompt_past = state->prompt_past;
3741
  if (params.no_context) {
3742
  prompt_past.clear();
3743
  }
 
3756
  fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
3757
  return -5;
3758
  }
3759
+ state->exp_n_audio_ctx = params.audio_ctx;
3760
 
3761
  // these tokens determine the task that will be performed
3762
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
3763
  if (whisper_is_multilingual(ctx)) {
3764
  const int lang_id = whisper_lang_id(params.language);
3765
+ state->lang_id = lang_id;
3766
  prompt_init.push_back(whisper_token_lang(ctx, lang_id));
3767
  if (params.translate) {
3768
  prompt_init.push_back(whisper_token_translate());
 
3814
  }
3815
 
3816
  if (params.encoder_begin_callback) {
3817
+ if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
3818
  fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
3819
  break;
3820
  }
3821
  }
3822
 
3823
  // encode audio features starting at offset seek
3824
+ if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
3825
  fprintf(stderr, "%s: failed to encode\n", __func__);
3826
  return -6;
3827
  }
 
3862
 
3863
  // TAGS: WHISPER_DECODER_INIT
3864
  for (int j = 0; j < n_decoders_cur; ++j) {
3865
+ auto & decoder = state->decoders[j];
3866
 
3867
  decoder.kv_self.n = 0;
3868
 
 
3904
  }
3905
  WHISPER_PRINT_DEBUG("\n\n");
3906
 
3907
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3908
  fprintf(stderr, "%s: failed to decode\n", __func__);
3909
  return -7;
3910
  }
 
3912
  {
3913
  const int64_t t_start_sample_us = ggml_time_us();
3914
 
3915
+ whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
3916
 
3917
+ state->decoders[0].kv_self.n += prompt.size();
3918
 
3919
  for (int j = 1; j < n_decoders_cur; ++j) {
3920
+ auto & decoder = state->decoders[j];
3921
 
3922
+ memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
3923
+ memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
3924
 
3925
  decoder.kv_self.n += prompt.size();
3926
 
3927
+ memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
3928
+ memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
3929
+ memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
3930
  }
3931
 
3932
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
3933
  }
3934
  }
3935
 
 
3940
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3941
  kv_bufs.resize(n_decoders_cur);
3942
  for (int j = 0; j < n_decoders_cur; ++j) {
3943
+ auto & decoder = state->decoders[j];
3944
 
3945
  if (decoder.completed || decoder.failed) {
3946
  continue;
 
3958
 
3959
  // generate new sequence candidates for each decoder
3960
  for (int j = 0; j < n_decoders_cur; ++j) {
3961
+ auto & decoder = state->decoders[j];
3962
 
3963
  if (decoder.completed || decoder.failed) {
3964
  continue;
 
3968
  case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3969
  {
3970
  if (t_cur < 1e-6f) {
3971
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
3972
  } else {
3973
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
3974
  }
3975
 
3976
  decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
3977
  } break;
3978
  case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3979
  {
3980
+ const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
3981
 
3982
  for (const auto & token : tokens_new) {
3983
  beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
 
4002
  uint32_t cur_c = 0;
4003
 
4004
  for (int j = 0; j < n_decoders_cur; ++j) {
4005
+ auto & decoder = state->decoders[j];
4006
 
4007
  if (decoder.completed || decoder.failed) {
4008
  continue;
 
4031
  // - check if the sequence is failed
4032
  // - update sliding window based on timestamp tokens
4033
  for (int j = 0; j < n_decoders_cur; ++j) {
4034
+ auto & decoder = state->decoders[j];
4035
 
4036
  if (decoder.completed || decoder.failed) {
4037
  continue;
 
4113
  bool completed_all = true;
4114
 
4115
  for (int j = 0; j < n_decoders_cur; ++j) {
4116
+ auto & decoder = state->decoders[j];
4117
 
4118
  if (decoder.completed || decoder.failed) {
4119
  continue;
 
4127
  }
4128
  }
4129
 
4130
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
4131
 
4132
  // obtain logits for the next token
4133
  for (int j = 0; j < n_decoders_cur; ++j) {
4134
+ auto & decoder = state->decoders[j];
4135
 
4136
  if (decoder.failed || decoder.completed) {
4137
  continue;
 
4142
 
4143
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4144
 
4145
+ if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4146
  fprintf(stderr, "%s: failed to decode\n", __func__);
4147
  return -8;
4148
  }
 
4150
  {
4151
  const int64_t t_start_sample_us = ggml_time_us();
4152
 
4153
+ whisper_process_logits(*ctx, *state, params, decoder, t_cur);
4154
 
4155
  ++decoder.kv_self.n;
4156
 
4157
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
4158
  }
4159
  }
4160
  }
 
4164
  double best_score = -INFINITY;
4165
 
4166
  for (int j = 0; j < n_decoders_cur; ++j) {
4167
+ auto & decoder = state->decoders[j];
4168
 
4169
  if (decoder.failed) {
4170
  continue;
 
4181
  __func__, j, decoder.sequence.entropy, params.entropy_thold);
4182
 
4183
  decoder.failed = true;
4184
+ state->n_fail_h++;
4185
 
4186
  continue;
4187
  }
 
4199
  {
4200
  bool success = true;
4201
 
4202
+ const auto & decoder = state->decoders[best_decoder_id];
4203
 
4204
  if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
4205
  success = false;
4206
+ state->n_fail_p++;
4207
  }
4208
 
4209
  if (success) {
 
4220
 
4221
  // output results through a user-provided callback
4222
  {
4223
+ const auto & best_decoder = state->decoders[best_decoder_id];
4224
 
4225
  const auto seek_delta = best_decoder.seek_delta;
4226
  const auto result_len = best_decoder.sequence.result_len;
 
4283
 
4284
  if (params.token_timestamps) {
4285
  whisper_exp_compute_token_level_timestamps(
4286
+ *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4287
 
4288
  if (params.max_len > 0) {
4289
+ n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
4290
  }
4291
  }
4292
  if (params.new_segment_callback) {
4293
+ params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
4294
  }
4295
  }
4296
  text = "";
 
4327
 
4328
  if (params.token_timestamps) {
4329
  whisper_exp_compute_token_level_timestamps(
4330
+ *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4331
 
4332
  if (params.max_len > 0) {
4333
+ n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
4334
  }
4335
  }
4336
  if (params.new_segment_callback) {
4337
+ params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
4338
  }
4339
  }
4340
  }
 
4349
  return 0;
4350
  }
4351
 
4352
+
4353
+ int whisper_full(
4354
+ struct whisper_context * ctx,
4355
+ struct whisper_full_params params,
4356
+ const float * samples,
4357
+ int n_samples) {
4358
+ return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
4359
+ }
4360
+
4361
  int whisper_full_parallel(
4362
  struct whisper_context * ctx,
4363
  struct whisper_full_params params,
 
4367
  if (n_processors == 1) {
4368
  return whisper_full(ctx, params, samples, n_samples);
4369
  }
 
4370
  int ret = 0;
4371
 
4372
+ // prepare separate states for each thread
4373
+ std::vector<whisper_state*> states;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4374
 
4375
  const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
4376
  const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
 
4380
 
4381
  std::vector<std::thread> workers(n_processors - 1);
4382
  for (int i = 0; i < n_processors - 1; ++i) {
4383
+ // create a new state for each thread
4384
+ states.push_back(whisper_init_state(ctx));
4385
+
4386
  const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
4387
  const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
4388
 
 
4395
  params_cur.new_segment_callback = nullptr;
4396
  params_cur.new_segment_callback_user_data = nullptr;
4397
 
4398
+ workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4399
  }
4400
 
4401
  {
4402
  auto params_cur = params;
4403
 
4404
+ // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
4405
+ params_cur.print_realtime = false;
4406
+
4407
+ // Run the first transformation using default state but only for the first chunk.
4408
+ ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
4409
  }
4410
 
4411
  for (int i = 0; i < n_processors - 1; ++i) {
 
4414
 
4415
  const int64_t offset_t = (int64_t) params.offset_ms/10.0;
4416
 
4417
+ // combine results into result_state->result_all from all other states
4418
  for (int i = 0; i < n_processors - 1; ++i) {
4419
+ auto& results_i = states[i]->result_all;
4420
 
4421
+ for (auto& result : results_i) {
4422
  // correct the segment timestamp taking into account the offset
4423
+ result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
4424
+ result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
4425
+
4426
 
4427
  // make sure that segments are not overlapping
4428
+ if (!ctx->state->result_all.empty()) {
4429
+ result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
4430
  }
4431
 
4432
+ ctx->state->result_all.push_back(std::move(result));
4433
 
4434
  // call the new_segment_callback for each segment
4435
  if (params.new_segment_callback) {
4436
+ params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
4437
  }
4438
  }
4439
 
4440
+ ctx->state->t_mel_us += states[i]->t_mel_us;
 
 
 
4441
 
4442
+ ctx->state->t_sample_us += states[i]->t_sample_us;
4443
+ ctx->state->t_encode_us += states[i]->t_encode_us;
4444
+ ctx->state->t_decode_us += states[i]->t_decode_us;
4445
 
4446
+ whisper_free_state(states[i]);
 
 
4447
  }
4448
 
4449
  // average the timings
4450
+ ctx->state->t_mel_us /= n_processors;
4451
+ ctx->state->t_sample_us /= n_processors;
4452
+ ctx->state->t_encode_us /= n_processors;
4453
+ ctx->state->t_decode_us /= n_processors;
4454
 
4455
  // print information about the audio boundaries
4456
  fprintf(stderr, "\n");
 
4463
  return ret;
4464
  }
4465
 
4466
+ int whisper_full_n_segments_from_state(struct whisper_state * state) {
4467
+ return state->result_all.size();
4468
+ }
4469
+
4470
  int whisper_full_n_segments(struct whisper_context * ctx) {
4471
+ return ctx->state->result_all.size();
4472
+ }
4473
+
4474
+ int whisper_full_lang_id_from_state(struct whisper_state * state) {
4475
+ return state->lang_id;
4476
  }
4477
 
4478
  int whisper_full_lang_id(struct whisper_context * ctx) {
4479
+ return ctx->state->lang_id;
4480
+ }
4481
+
4482
+ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
4483
+ return state->result_all[i_segment].t0;
4484
  }
4485
 
4486
  int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
4487
+ return ctx->state->result_all[i_segment].t0;
4488
+ }
4489
+
4490
+ int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
4491
+ return state->result_all[i_segment].t1;
4492
  }
4493
 
4494
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
4495
+ return ctx->state->result_all[i_segment].t1;
4496
+ }
4497
+
4498
+ const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
4499
+ return state->result_all[i_segment].text.c_str();
4500
  }
4501
 
4502
  const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
4503
+ return ctx->state->result_all[i_segment].text.c_str();
4504
+ }
4505
+
4506
+ int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
4507
+ return state->result_all[i_segment].tokens.size();
4508
  }
4509
 
4510
  int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
4511
+ return ctx->state->result_all[i_segment].tokens.size();
4512
  }
4513
 
4514
+ const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
4515
+ return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
4516
+ }
4517
+
4518
+ const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
4519
+ return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
4520
+ }
4521
+
4522
+ whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
4523
+ return state->result_all[i_segment].tokens[i_token].id;
4524
  }
4525
 
4526
  whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
4527
+ return ctx->state->result_all[i_segment].tokens[i_token].id;
4528
+ }
4529
+
4530
+ struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
4531
+ return state->result_all[i_segment].tokens[i_token];
4532
  }
4533
 
4534
  struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
4535
+ return ctx->state->result_all[i_segment].tokens[i_token];
4536
+ }
4537
+
4538
+ float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) {
4539
+ return state->result_all[i_segment].tokens[i_token].p;
4540
  }
4541
 
4542
  float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
4543
+ return ctx->state->result_all[i_segment].tokens[i_token].p;
4544
  }
4545
 
4546
  // =================================================================================================
 
4752
 
4753
  static void whisper_exp_compute_token_level_timestamps(
4754
  struct whisper_context & ctx,
4755
+ struct whisper_state & state,
4756
  int i_segment,
4757
  float thold_pt,
4758
  float thold_ptsum) {
4759
+ auto & segment = state.result_all[i_segment];
4760
  auto & tokens = segment.tokens;
4761
 
4762
+ const int n_samples = state.energy.size();
4763
 
4764
  if (n_samples == 0) {
4765
  fprintf(stderr, "%s: no signal data available\n", __func__);
 
4782
  return;
4783
  }
4784
 
4785
+ auto & t_beg = state.t_beg;
4786
+ auto & t_last = state.t_last;
4787
+ auto & tid_last = state.tid_last;
4788
 
4789
  for (int j = 0; j < n; ++j) {
4790
  auto & token = tokens[j];
 
4907
  float sum = 0.0f;
4908
 
4909
  for (int k = ss0; k < ss1; k++) {
4910
+ sum += state.energy[k];
4911
  }
4912
 
4913
  const float thold = 0.5*sum/ns;
4914
 
4915
  {
4916
  int k = s0;
4917
+ if (state.energy[k] > thold && j > 0) {
4918
+ while (k > 0 && state.energy[k] > thold) {
4919
  k--;
4920
  }
4921
  tokens[j].t0 = sample_to_timestamp(k);
 
4925
  s0 = k;
4926
  }
4927
  } else {
4928
+ while (state.energy[k] < thold && k < s1) {
4929
  k++;
4930
  }
4931
  s0 = k;
 
4935
 
4936
  {
4937
  int k = s1;
4938
+ if (state.energy[k] > thold) {
4939
+ while (k < n_samples - 1 && state.energy[k] > thold) {
4940
  k++;
4941
  }
4942
  tokens[j].t1 = sample_to_timestamp(k);
 
4946
  s1 = k;
4947
  }
4948
  } else {
4949
+ while (state.energy[k] < thold && k > s0) {
4950
  k--;
4951
  }
4952
  s1 = k;
whisper.h CHANGED
@@ -66,6 +66,7 @@ extern "C" {
66
  //
67
 
68
  struct whisper_context;
 
69
 
70
  typedef int whisper_token;
71
 
@@ -101,11 +102,20 @@ extern "C" {
101
  WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
102
  WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
103
 
104
- // Frees all memory allocated by the model.
105
- WHISPER_API void whisper_free(struct whisper_context * ctx);
 
 
 
 
 
 
 
 
 
106
 
107
  // Convert RAW PCM audio to log mel spectrogram.
108
- // The resulting spectrogram is stored inside the provided whisper context.
109
  // Returns 0 on success
110
  WHISPER_API int whisper_pcm_to_mel(
111
  struct whisper_context * ctx,
@@ -113,17 +123,30 @@ extern "C" {
113
  int n_samples,
114
  int n_threads);
115
 
116
- // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
117
- // The resulting spectrogram is stored inside the provided whisper context.
 
 
 
 
 
 
 
118
  // Returns 0 on success
119
  WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
120
- struct whisper_context* ctx,
121
- const float* samples,
122
- int n_samples,
123
- int n_threads);
124
-
125
-
126
- // This can be used to set a custom log mel spectrogram inside the provided whisper context.
 
 
 
 
 
 
127
  // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
128
  // n_mel must be 80
129
  // Returns 0 on success
@@ -133,7 +156,14 @@ extern "C" {
133
  int n_len,
134
  int n_mel);
135
 
136
- // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
 
 
 
 
 
 
 
137
  // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
138
  // offset can be used to specify the offset of the first frame in the spectrogram.
139
  // Returns 0 on success
@@ -142,6 +172,12 @@ extern "C" {
142
  int offset,
143
  int n_threads);
144
 
 
 
 
 
 
 
145
  // Run the Whisper decoder to obtain the logits and probabilities for the next token.
146
  // Make sure to call whisper_encode() first.
147
  // tokens + n_tokens is the provided context for the decoder.
@@ -155,6 +191,14 @@ extern "C" {
155
  int n_past,
156
  int n_threads);
157
 
 
 
 
 
 
 
 
 
158
  // Convert the provided text into tokens.
159
  // The tokens pointer must be large enough to hold the resulting tokens.
160
  // Returns the number of tokens on success, no more than n_max_tokens
@@ -190,17 +234,26 @@ extern "C" {
190
  int n_threads,
191
  float * lang_probs);
192
 
193
- WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
194
- WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
195
- WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
196
- WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
197
- WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
 
 
 
 
 
 
 
 
198
 
199
  // Token logits obtained from the last call to whisper_decode()
200
  // The logits for the last token are stored in the last row
201
  // Rows: n_tokens
202
  // Cols: n_vocab
203
- WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
 
204
 
205
  // Token Id -> String. Uses the vocabulary in the provided context
206
  WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@@ -218,7 +271,7 @@ extern "C" {
218
  WHISPER_API whisper_token whisper_token_translate (void);
219
  WHISPER_API whisper_token whisper_token_transcribe(void);
220
 
221
- // Performance information
222
  WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
223
  WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
224
 
@@ -236,18 +289,19 @@ extern "C" {
236
  // Text segment callback
237
  // Called on every newly generated text segment
238
  // Use the whisper_full_...() functions to obtain the text segments
239
- typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
240
 
241
  // Encoder begin callback
242
  // If not NULL, called before the encoder starts
243
  // If it returns false, the computation is aborted
244
- typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
245
 
246
  // Logits filter callback
247
  // Can be used to modify the logits before sampling
248
  // If not NULL, called after applying temperature to logits
249
  typedef void (*whisper_logits_filter_callback)(
250
  struct whisper_context * ctx,
 
251
  const whisper_token_data * tokens,
252
  int n_tokens,
253
  float * logits,
@@ -334,6 +388,7 @@ extern "C" {
334
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
335
 
336
  // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
 
337
  // Uses the specified decoding strategy to obtain the text.
338
  WHISPER_API int whisper_full(
339
  struct whisper_context * ctx,
@@ -341,7 +396,16 @@ extern "C" {
341
  const float * samples,
342
  int n_samples);
343
 
344
- // Split the input audio in chunks and process each chunk separately using whisper_full()
 
 
 
 
 
 
 
 
 
345
  // It seems this approach can offer some speedup in some cases.
346
  // However, the transcription accuracy can be worse at the beginning and end of each chunk.
347
  WHISPER_API int whisper_full_parallel(
@@ -351,33 +415,47 @@ extern "C" {
351
  int n_samples,
352
  int n_processors);
353
 
354
- // Number of generated text segments.
355
  // A segment can be a few words, a sentence, or even a paragraph.
356
- WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
 
357
 
358
- // Language id associated with the current context
359
  WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
360
 
361
- // Get the start and end time of the specified segment.
362
- WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
363
- WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
 
 
 
 
 
 
 
 
 
 
364
 
365
- // Get the text of the specified segment.
366
- WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
 
367
 
368
- // Get number of tokens in the specified segment.
369
- WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
 
370
 
371
- // Get the token text of the specified token in the specified segment.
372
- WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
373
- WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
374
 
375
- // Get token data for the specified token in the specified segment.
376
  // This contains probabilities, timestamps, etc.
377
- WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
 
378
 
379
- // Get the probability of the specified token in the specified segment.
380
- WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
 
381
 
382
  ////////////////////////////////////////////////////////////////////////////
383
 
 
66
  //
67
 
68
  struct whisper_context;
69
+ struct whisper_state;
70
 
71
  typedef int whisper_token;
72
 
 
102
  WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
103
  WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
104
 
105
+ // These are the same as the above, but the internal state of the context is not allocated automatically
106
+ // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
107
+ WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
108
+ WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
109
+ WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
110
+
111
+ WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
112
+
113
+ // Frees all allocated memory
114
+ WHISPER_API void whisper_free (struct whisper_context * ctx);
115
+ WHISPER_API void whisper_free_state(struct whisper_state * state);
116
 
117
  // Convert RAW PCM audio to log mel spectrogram.
118
+ // The resulting spectrogram is stored inside the default state of the provided whisper context.
119
  // Returns 0 on success
120
  WHISPER_API int whisper_pcm_to_mel(
121
  struct whisper_context * ctx,
 
123
  int n_samples,
124
  int n_threads);
125
 
126
+ WHISPER_API int whisper_pcm_to_mel_with_state(
127
+ struct whisper_context * ctx,
128
+ struct whisper_state * state,
129
+ const float * samples,
130
+ int n_samples,
131
+ int n_threads);
132
+
133
+ // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
134
+ // The resulting spectrogram is stored inside the default state of the provided whisper context.
135
  // Returns 0 on success
136
  WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
137
+ struct whisper_context * ctx,
138
+ const float * samples,
139
+ int n_samples,
140
+ int n_threads);
141
+
142
+ WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
143
+ struct whisper_context * ctx,
144
+ struct whisper_state * state,
145
+ const float * samples,
146
+ int n_samples,
147
+ int n_threads);
148
+
149
+ // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
150
  // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
151
  // n_mel must be 80
152
  // Returns 0 on success
 
156
  int n_len,
157
  int n_mel);
158
 
159
+ WHISPER_API int whisper_set_mel_with_state(
160
+ struct whisper_context * ctx,
161
+ struct whisper_state * state,
162
+ const float * data,
163
+ int n_len,
164
+ int n_mel);
165
+
166
+ // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
167
  // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
168
  // offset can be used to specify the offset of the first frame in the spectrogram.
169
  // Returns 0 on success
 
172
  int offset,
173
  int n_threads);
174
 
175
+ WHISPER_API int whisper_encode_with_state(
176
+ struct whisper_context * ctx,
177
+ struct whisper_state * state,
178
+ int offset,
179
+ int n_threads);
180
+
181
  // Run the Whisper decoder to obtain the logits and probabilities for the next token.
182
  // Make sure to call whisper_encode() first.
183
  // tokens + n_tokens is the provided context for the decoder.
 
191
  int n_past,
192
  int n_threads);
193
 
194
+ WHISPER_API int whisper_decode_with_state(
195
+ struct whisper_context * ctx,
196
+ struct whisper_state * state,
197
+ const whisper_token * tokens,
198
+ int n_tokens,
199
+ int n_past,
200
+ int n_threads);
201
+
202
  // Convert the provided text into tokens.
203
  // The tokens pointer must be large enough to hold the resulting tokens.
204
  // Returns the number of tokens on success, no more than n_max_tokens
 
234
  int n_threads,
235
  float * lang_probs);
236
 
237
+ WHISPER_API int whisper_lang_auto_detect_with_state(
238
+ struct whisper_context * ctx,
239
+ struct whisper_state * state,
240
+ int offset_ms,
241
+ int n_threads,
242
+ float * lang_probs);
243
+
244
+ WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
245
+ WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
246
+ WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
247
+ WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
248
+ WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
249
+ WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
250
 
251
  // Token logits obtained from the last call to whisper_decode()
252
  // The logits for the last token are stored in the last row
253
  // Rows: n_tokens
254
  // Cols: n_vocab
255
+ WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
256
+ WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
257
 
258
  // Token Id -> String. Uses the vocabulary in the provided context
259
  WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
 
271
  WHISPER_API whisper_token whisper_token_translate (void);
272
  WHISPER_API whisper_token whisper_token_transcribe(void);
273
 
274
+ // Performance information from the default state.
275
  WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
276
  WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
277
 
 
289
  // Text segment callback
290
  // Called on every newly generated text segment
291
  // Use the whisper_full_...() functions to obtain the text segments
292
+ typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
293
 
294
  // Encoder begin callback
295
  // If not NULL, called before the encoder starts
296
  // If it returns false, the computation is aborted
297
+ typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
298
 
299
  // Logits filter callback
300
  // Can be used to modify the logits before sampling
301
  // If not NULL, called after applying temperature to logits
302
  typedef void (*whisper_logits_filter_callback)(
303
  struct whisper_context * ctx,
304
+ struct whisper_state * state,
305
  const whisper_token_data * tokens,
306
  int n_tokens,
307
  float * logits,
 
388
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
389
 
390
  // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
391
+ // Not thread safe for same context
392
  // Uses the specified decoding strategy to obtain the text.
393
  WHISPER_API int whisper_full(
394
  struct whisper_context * ctx,
 
396
  const float * samples,
397
  int n_samples);
398
 
399
+ WHISPER_API int whisper_full_with_state(
400
+ struct whisper_context * ctx,
401
+ struct whisper_state * state,
402
+ struct whisper_full_params params,
403
+ const float * samples,
404
+ int n_samples);
405
+
406
+ // Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
407
+ // Result is stored in the default state of the context
408
+ // Not thread safe if executed in parallel on the same context.
409
  // It seems this approach can offer some speedup in some cases.
410
  // However, the transcription accuracy can be worse at the beginning and end of each chunk.
411
  WHISPER_API int whisper_full_parallel(
 
415
  int n_samples,
416
  int n_processors);
417
 
418
+ // Number of generated text segments
419
  // A segment can be a few words, a sentence, or even a paragraph.
420
+ WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
421
+ WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
422
 
423
+ // Language id associated with the context's default state
424
  WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
425
 
426
+ // Language id associated with the provided state
427
+ WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
428
+
429
+ // Get the start and end time of the specified segment
430
+ WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment);
431
+ WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
432
+
433
+ WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
434
+ WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
435
+
436
+ // Get the text of the specified segment
437
+ WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
438
+ WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
439
 
440
+ // Get number of tokens in the specified segment
441
+ WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment);
442
+ WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
443
 
444
+ // Get the token text of the specified token in the specified segment
445
+ WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token);
446
+ WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
447
 
448
+ WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
449
+ WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
 
450
 
451
+ // Get token data for the specified token in the specified segment
452
  // This contains probabilities, timestamps, etc.
453
+ WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token);
454
+ WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
455
 
456
+ // Get the probability of the specified token in the specified segment
457
+ WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
458
+ WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
459
 
460
  ////////////////////////////////////////////////////////////////////////////
461