ggerganov commited on
Commit
8177527
·
1 Parent(s): 80b373b

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama-sampling.cpp CHANGED
@@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
63
  }
64
  */
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
67
  GGML_ASSERT(cur_p->size > 0);
68
 
@@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
427
 
428
  static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
429
  auto * ctx = (llama_sampler_dist *) smpl->ctx;
 
 
 
430
  cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
431
  }
432
 
@@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
912
 
913
  static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
914
  const auto * ctx = (llama_sampler_temp *) smpl->ctx;
915
- for (size_t i = 0; i < cur_p->size; ++i) {
916
- cur_p->data[i].logit /= ctx->temp;
917
- }
918
  }
919
 
920
  static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
961
  if (ctx->delta > 0) {
962
  const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
963
  const float max_temp = ctx->temp + ctx->delta;
 
964
  float exponent_val = ctx->exponent;
965
 
966
  // no need to do anything if there is only one (or zero) candidates
@@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
998
  #endif
999
 
1000
  // Apply the dynamically calculated temperature scaling
1001
- for (size_t i = 0; i < cur_p->size; ++i) {
1002
- cur_p->data[i].logit /= dyn_temp;
1003
- }
1004
 
1005
  // Re-compute softmax probabilities after scaling logits with dynamic temperature
1006
  const double max_l_double = cur_p->data[0].logit;
@@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
1024
  }
1025
  #endif
1026
  } else {
1027
- for (size_t i = 0; i < cur_p->size; ++i) {
1028
- cur_p->data[i].logit /= ctx->temp;
1029
- }
1030
  }
1031
  }
1032
 
@@ -1059,6 +1082,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
1059
  };
1060
  }
1061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1062
  // mirostat
1063
 
1064
  struct llama_sampler_mirostat {
@@ -1565,6 +1683,397 @@ struct llama_sampler * llama_sampler_init_penalties(
1565
  };
1566
  }
1567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1568
  // logit-bias
1569
 
1570
  struct llama_sampler_logit_bias {
@@ -1644,6 +2153,229 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1644
  };
1645
  }
1646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1647
  // utils
1648
 
1649
  uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
 
63
  }
64
  */
65
 
66
+ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
67
+ if (temp <= 0.0f) {
68
+ // find the token with the highest logit and set the rest to -inf
69
+ size_t max_i = 0;
70
+ float max_l = cur_p->data[0].logit;
71
+
72
+ for (size_t i = 1; i < cur_p->size; ++i) {
73
+ if (cur_p->data[i ].logit > max_l) {
74
+ cur_p->data[max_i].logit = -INFINITY;
75
+ max_i = i;
76
+ max_l = cur_p->data[i].logit;
77
+ } else {
78
+ cur_p->data[i].logit = -INFINITY;
79
+ }
80
+ }
81
+
82
+ return;
83
+ }
84
+
85
+ for (size_t i = 0; i < cur_p->size; ++i) {
86
+ cur_p->data[i].logit /= temp;
87
+ }
88
+ }
89
+
90
  static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
91
  GGML_ASSERT(cur_p->size > 0);
92
 
 
451
 
452
  static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
453
  auto * ctx = (llama_sampler_dist *) smpl->ctx;
454
+
455
+ llama_sampler_softmax_impl(cur_p);
456
+
457
  cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
458
  }
459
 
 
939
 
940
  static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
941
  const auto * ctx = (llama_sampler_temp *) smpl->ctx;
942
+
943
+ llama_sampler_temp_impl(cur_p, ctx->temp);
 
944
  }
945
 
946
  static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
 
987
  if (ctx->delta > 0) {
988
  const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
989
  const float max_temp = ctx->temp + ctx->delta;
990
+
991
  float exponent_val = ctx->exponent;
992
 
993
  // no need to do anything if there is only one (or zero) candidates
 
1025
  #endif
1026
 
1027
  // Apply the dynamically calculated temperature scaling
1028
+ llama_sampler_temp_impl(cur_p, dyn_temp);
 
 
1029
 
1030
  // Re-compute softmax probabilities after scaling logits with dynamic temperature
1031
  const double max_l_double = cur_p->data[0].logit;
 
1049
  }
1050
  #endif
1051
  } else {
1052
+ llama_sampler_temp_impl(cur_p, ctx->temp);
 
 
1053
  }
1054
  }
1055
 
 
1082
  };
1083
  }
1084
 
1085
+ // xtc
1086
+
1087
+ struct llama_sampler_xtc {
1088
+ const float probability;
1089
+ const float threshold;
1090
+ const size_t min_keep;
1091
+
1092
+ const uint32_t seed;
1093
+ uint32_t seed_cur;
1094
+
1095
+ std::mt19937 rng;
1096
+ };
1097
+
1098
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1099
+ return "xtc";
1100
+ }
1101
+
1102
+ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1103
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1104
+
1105
+ if (ctx->probability <= 0.0f
1106
+ || ctx->threshold > 0.5f
1107
+ || cur_p->size < 2) {
1108
+ return;
1109
+ }
1110
+
1111
+ std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1112
+ float chance = distribution(ctx->rng);
1113
+ if (chance > ctx->probability) return;
1114
+
1115
+ // in case it's not sorted/recalculated yet
1116
+ llama_sampler_softmax_impl(cur_p);
1117
+
1118
+ int pos_last = 0;
1119
+
1120
+ for (size_t i = 0; i < cur_p->size; ++i) {
1121
+ if (cur_p->data[i].p >= ctx->threshold) {
1122
+ pos_last = i;
1123
+ } else break;
1124
+ }
1125
+
1126
+ if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1127
+ cur_p->data += pos_last;
1128
+ cur_p->size -= pos_last;
1129
+ }
1130
+ }
1131
+
1132
+ static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1133
+ const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1134
+ auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1135
+
1136
+ // copy the state
1137
+ {
1138
+ auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1139
+
1140
+ result_ctx->rng = ctx->rng;
1141
+ }
1142
+
1143
+ return result;
1144
+ }
1145
+
1146
+ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1147
+ delete (llama_sampler_xtc *) smpl->ctx;
1148
+ }
1149
+
1150
+ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1151
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1152
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1153
+ ctx->rng.seed(ctx->seed_cur);
1154
+ }
1155
+
1156
+ static struct llama_sampler_i llama_sampler_xtc_i = {
1157
+ /* .name = */ llama_sampler_xtc_name,
1158
+ /* .accept = */ nullptr,
1159
+ /* .apply = */ llama_sample_xtc_apply,
1160
+ /* .reset = */ llama_sampler_xtc_reset,
1161
+ /* .clone = */ llama_sampler_xtc_clone,
1162
+ /* .free = */ llama_sampler_xtc_free,
1163
+ };
1164
+
1165
+ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1166
+ auto seed_cur = get_rng_seed(seed);
1167
+ return new llama_sampler {
1168
+ /* .iface = */ &llama_sampler_xtc_i,
1169
+ /* .ctx = */ new llama_sampler_xtc {
1170
+ /* .probability = */ p,
1171
+ /* .threshold = */ t,
1172
+ /* .min_keep = */ min_keep,
1173
+ /* .seed = */ seed,
1174
+ /* .seed_cur = */ seed_cur,
1175
+ /* .rng = */ std::mt19937(seed_cur),
1176
+ },
1177
+ };
1178
+ }
1179
+
1180
  // mirostat
1181
 
1182
  struct llama_sampler_mirostat {
 
1683
  };
1684
  }
1685
 
1686
+ // DRY
1687
+
1688
+ struct llama_sampler_dry {
1689
+ int32_t total_context_size;
1690
+
1691
+ const float dry_multiplier;
1692
+ const float dry_base;
1693
+ const int32_t dry_allowed_length;
1694
+ const int32_t dry_penalty_last_n;
1695
+
1696
+ std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
1697
+ std::vector<int> dry_repeat_count;
1698
+ std::unordered_map<llama_token, int> dry_max_token_repeat;
1699
+ ring_buffer<llama_token> last_tokens;
1700
+ };
1701
+
1702
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1703
+ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1704
+ for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
1705
+ std::string word = llama_detokenize(vocab, {token_id}, true);
1706
+ if (word.find(str) != std::string::npos) {
1707
+ token_sequences.emplace(token_id, std::vector<llama_token>());
1708
+ } else {
1709
+ size_t word_len = word.size(), str_len = str.size();
1710
+ size_t pos = -1;
1711
+ while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
1712
+ bool match = true;
1713
+ size_t i;
1714
+ for (i = 1; i < str_len && i + pos < word_len; ++i) {
1715
+ if (word[pos + i] != str[i]) {
1716
+ match = false;
1717
+ break;
1718
+ }
1719
+ }
1720
+ if (match) {
1721
+ std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
1722
+ if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1723
+ tokenization.resize(max_tail_len);
1724
+ }
1725
+
1726
+ // Ensure we don't already have a duplicate matching tokenization
1727
+ auto its = token_sequences.equal_range(token_id);
1728
+ bool found = false;
1729
+ for (auto it = its.first; it != its.second; ++it) {
1730
+ if (tokenization == it->second) {
1731
+ found = true;
1732
+ break;
1733
+ }
1734
+ }
1735
+ if (!found) {
1736
+ token_sequences.emplace(token_id, tokenization);
1737
+ }
1738
+ }
1739
+ }
1740
+ }
1741
+ }
1742
+ }
1743
+
1744
+ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
1745
+ return "dry";
1746
+ }
1747
+
1748
+ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
1749
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1750
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1751
+ return;
1752
+ }
1753
+
1754
+ ctx->last_tokens.push_back(token);
1755
+ }
1756
+
1757
+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1758
+ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1759
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1760
+
1761
+ if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1762
+ return;
1763
+ }
1764
+
1765
+ int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
1766
+ int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
1767
+
1768
+ if (last_n_repeat <= ctx->dry_allowed_length) {
1769
+ return;
1770
+ }
1771
+
1772
+ ctx->dry_repeat_count.assign(last_n_repeat, 0);
1773
+ ctx->dry_max_token_repeat.clear();
1774
+
1775
+ // Step 1: Look for restart sequences to limit the maximum repetition length.
1776
+ // Work backwards through the context looking for any token that begins a restart sequence.
1777
+ //
1778
+ // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
1779
+ // sequences that together comprise a restart sequence. This allows us to quickly check
1780
+ // whether each token is the head of a complete sequence. Most restart sequences are actually
1781
+ // a single token, and for these the "tail" is an empty vector.
1782
+ //
1783
+ // If the token is a "head", test all restart sequences that begin with this token
1784
+ // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
1785
+ // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
1786
+ // longest matching sequence (if any) is used to limit the maximum repetition length.
1787
+ //
1788
+ // Note that in the case case of a short sequence contained in a longer one, this might fail to
1789
+ // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
1790
+ // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
1791
+ // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
1792
+ //
1793
+ // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
1794
+ // have already clamped the maximum tail sequence length when generating `restart_sequences`.
1795
+ // With clamping, this scan is O(N) in the context length.
1796
+
1797
+ int rep_limit = last_n_repeat;
1798
+ for (int i = 0; i < last_n_repeat; ++i) {
1799
+ llama_token token = ctx->last_tokens.rat(i);
1800
+ auto its = ctx->dry_processed_breakers.equal_range(token);
1801
+ if (its.first == ctx->dry_processed_breakers.end()) {
1802
+ continue;
1803
+ }
1804
+ int longest_match = -1;
1805
+ for (auto it = its.first; it != its.second; ++it) {
1806
+ // Note that (*it) does not contain the head character, so seq_len will be
1807
+ // the restart sequence length minus 1.
1808
+ // In the common case of a single-token restart sequence, (*it) will be empty
1809
+ // and we will trivially match.
1810
+ int seq_len = (int)it->second.size();
1811
+ if (seq_len > longest_match && seq_len <= (int)i) {
1812
+ bool match = true;
1813
+ for (int offset = 0; offset < seq_len; ++offset) {
1814
+ // The -1 when indexing `last_tokens` is because we already matched the head.
1815
+ if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
1816
+ match = false;
1817
+ break;
1818
+ }
1819
+ }
1820
+ if (match) {
1821
+ longest_match = seq_len;
1822
+ }
1823
+ }
1824
+ }
1825
+ if (longest_match >= 0) {
1826
+ // We found a restart sequence starting `i` tokens from the end and continuing for
1827
+ // `longest_match` tokens.
1828
+ rep_limit = i - longest_match;
1829
+ break;
1830
+ }
1831
+ }
1832
+ if (rep_limit < ctx->dry_allowed_length) {
1833
+ return;
1834
+ }
1835
+
1836
+ // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
1837
+ // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
1838
+ // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
1839
+ //
1840
+ // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
1841
+ // https://ivanyu.me/blog/2014/10/15/z-algorithm/
1842
+ //
1843
+ // The code below is adapted from the public domain implementation by the same author here:
1844
+ // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
1845
+ //
1846
+ // Example:
1847
+ // Last N tokens: a b c c b c y a b c
1848
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1849
+ // ^
1850
+ // This `3` means that the last three tokens of the context (a b c) also appear here.
1851
+ //
1852
+ // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
1853
+ // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
1854
+ // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
1855
+ // ensure that the inner while loops only examine each token in the context once as the outer
1856
+ // for loop iterates over the context.
1857
+
1858
+ {
1859
+ const int last = last_n_repeat - 1;
1860
+ int rt = 0, lt = 0;
1861
+
1862
+ for (int k = 1; k < last_n_repeat; ++k) {
1863
+ if (k > rt) {
1864
+ // If k is outside the current Z-box, do naive computation.
1865
+ int n = 0;
1866
+ while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
1867
+ ++n;
1868
+ }
1869
+ ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1870
+ if (n > 0) {
1871
+ lt = k;
1872
+ rt = k+n-1;
1873
+ }
1874
+ } else {
1875
+ // If k is inside the current Z-box, consider two cases.
1876
+
1877
+ int p = k - lt; // Pair index.
1878
+ int right_part_len = rt - k + 1;
1879
+
1880
+ if (ctx->dry_repeat_count[last - p] < right_part_len) {
1881
+ int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
1882
+ ctx->dry_repeat_count[last - k] = n;
1883
+ } else {
1884
+ int i = rt + 1;
1885
+ while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
1886
+ i += 1;
1887
+ }
1888
+
1889
+ int n = std::min(i - k, rep_limit);
1890
+ ctx->dry_repeat_count[last - k] = n;
1891
+ lt = k;
1892
+ rt = i - 1;
1893
+ }
1894
+ }
1895
+ }
1896
+ }
1897
+
1898
+ // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
1899
+ // that would be generated by emitting each new token that would extend a sequence.
1900
+ //
1901
+ // Following the same example as above:
1902
+ // Last N tokens: a b c c b c y a b c
1903
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1904
+ //
1905
+ // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
1906
+ // c: 3 -> 4 (from `a b c` to `a b c c`)
1907
+ // b: 1 -> 2 (from `c` to `c b`)
1908
+ // y: 2 -> 3 (from `b c` to `b c y`)
1909
+
1910
+ for (int i = 0; i < last_n_repeat - 1; ++i) {
1911
+ int repeat_len = ctx->dry_repeat_count[i];
1912
+ if (repeat_len >= ctx->dry_allowed_length) {
1913
+ // This token ends a repeat, so the next token would continue one.
1914
+ // By convention, the value of `repeat_len` only includes the tokens currently
1915
+ // in the context, not the new token that would be added.
1916
+ llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
1917
+ // Track the maximum sequence ending in this token.
1918
+ const auto& it = ctx->dry_max_token_repeat.find(token);
1919
+ if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
1920
+ ctx->dry_max_token_repeat[token] = repeat_len;
1921
+ }
1922
+ }
1923
+ }
1924
+
1925
+ // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
1926
+
1927
+ // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
1928
+ // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
1929
+ const float FLOAT_MAX_LOG = 88.7228391f;
1930
+ int max_exponent = 0;
1931
+ if (ctx->dry_base > 1.000001f) {
1932
+ max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
1933
+ }
1934
+
1935
+ for (size_t i = 0; i < cur_p->size; ++i) {
1936
+ const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
1937
+ if (af_kvp != ctx->dry_max_token_repeat.end()) {
1938
+ // Check all sequence breakers starting with this token
1939
+ auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
1940
+ bool is_single_token_breaker = false;
1941
+
1942
+ for (auto it = range.first; it != range.second; ++it) {
1943
+ if (it->second.empty()) {
1944
+ is_single_token_breaker = true;
1945
+ break;
1946
+ }
1947
+ }
1948
+
1949
+ // Apply penalty only if it's not a single-token sequence breaker
1950
+ if (!is_single_token_breaker) {
1951
+ int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
1952
+ if (max_exponent > 0 && repeat_exp > max_exponent) {
1953
+ repeat_exp = max_exponent;
1954
+ }
1955
+ float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
1956
+ cur_p->data[i].logit -= penalty;
1957
+ }
1958
+ }
1959
+ }
1960
+
1961
+ cur_p->sorted = false;
1962
+ }
1963
+
1964
+ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
1965
+ auto * ctx = (llama_sampler_dry *) smpl->ctx;
1966
+ ctx->last_tokens.clear();
1967
+ ctx->dry_repeat_count.clear();
1968
+ ctx->dry_max_token_repeat.clear();
1969
+ }
1970
+
1971
+ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
1972
+ const auto * ctx = (llama_sampler_dry *) smpl->ctx;
1973
+
1974
+ // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
1975
+ auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1976
+ // Copy the state, including the processed breakers
1977
+ {
1978
+ auto * result_ctx = (llama_sampler_dry *) result->ctx;
1979
+ result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
1980
+ result_ctx->dry_repeat_count = ctx->dry_repeat_count;
1981
+ result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
1982
+ result_ctx->last_tokens = ctx->last_tokens;
1983
+ }
1984
+
1985
+ return result;
1986
+ }
1987
+
1988
+ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
1989
+ delete (llama_sampler_dry *) smpl->ctx;
1990
+ }
1991
+
1992
+ static struct llama_sampler_i llama_sampler_dry_i = {
1993
+ /* .name = */ llama_sampler_dry_name,
1994
+ /* .accept = */ llama_sampler_dry_accept,
1995
+ /* .apply = */ llama_sampler_dry_apply,
1996
+ /* .reset = */ llama_sampler_dry_reset,
1997
+ /* .clone = */ llama_sampler_dry_clone,
1998
+ /* .free = */ llama_sampler_dry_free,
1999
+ };
2000
+
2001
+ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
2002
+ int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
2003
+ std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
2004
+ const int MAX_CHAR_LEN = 40;
2005
+ const int MAX_SEQ_LEN = 20;
2006
+
2007
+ const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2008
+
2009
+ if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2010
+ // Process sequence breakers
2011
+ for (size_t i = 0; i < num_breakers; ++i) {
2012
+ if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
2013
+ LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
2014
+ continue;
2015
+ }
2016
+
2017
+ std::string sequence_break(seq_breakers[i]);
2018
+ if (sequence_break.empty()) {
2019
+ LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
2020
+ continue;
2021
+ }
2022
+
2023
+ if (sequence_break.size() > MAX_CHAR_LEN) {
2024
+ LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
2025
+ sequence_break.resize(MAX_CHAR_LEN);
2026
+ }
2027
+
2028
+ get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
2029
+ }
2030
+ }
2031
+
2032
+ return new llama_sampler {
2033
+ /* .iface = */ &llama_sampler_dry_i,
2034
+ /* .ctx = */ new llama_sampler_dry {
2035
+ /* .total_context_size = */ context_size,
2036
+ /* .dry_multiplier = */ dry_multiplier,
2037
+ /* .dry_base = */ dry_base,
2038
+ /* .dry_allowed_length = */ dry_allowed_length,
2039
+ /* .dry_penalty_last_n = */ dry_penalty_last_n,
2040
+ /* .dry_processed_breakers = */ std::move(processed_breakers),
2041
+ /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
2042
+ /* .dry_max_token_repeat = */ {},
2043
+ /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2044
+ },
2045
+ };
2046
+ }
2047
+
2048
+ // wrapper for test-sampling.cpp
2049
+ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
2050
+ llama_vocab dummy_vocab;
2051
+ auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
2052
+ auto * ctx = (llama_sampler_dry *) result->ctx;
2053
+
2054
+ // Process the token-based sequence breakers
2055
+ ctx->dry_processed_breakers.clear();
2056
+ if (seq_breakers.empty()) {
2057
+ LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
2058
+ } else {
2059
+ for (const auto& breaker : seq_breakers) {
2060
+ if (breaker.empty()) {
2061
+ LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
2062
+ continue;
2063
+ }
2064
+ llama_token head_token = breaker[0];
2065
+ std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
2066
+ ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
2067
+ }
2068
+
2069
+ if (ctx->dry_processed_breakers.empty()) {
2070
+ LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
2071
+ }
2072
+ }
2073
+
2074
+ return result;
2075
+ }
2076
+
2077
  // logit-bias
2078
 
2079
  struct llama_sampler_logit_bias {
 
2153
  };
2154
  }
2155
 
2156
+ // infill
2157
+
2158
+ //#define GGML_DEBUG_SAMPLER_INFILL
2159
+
2160
+ struct llama_sampler_infill {
2161
+ const struct llama_vocab * vocab;
2162
+
2163
+ std::vector<char> buf0;
2164
+ std::vector<char> buf1;
2165
+ };
2166
+
2167
+ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
2168
+ return "infill";
2169
+ }
2170
+
2171
+ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2172
+ auto * ctx = (llama_sampler_infill *) smpl->ctx;
2173
+
2174
+ llama_sampler_softmax_impl(cur_p);
2175
+
2176
+ #if defined(GGML_DEBUG_SAMPLER_INFILL)
2177
+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
2178
+ #else
2179
+ #define LOG_DBG_CUR(...)
2180
+ #endif
2181
+
2182
+ for (size_t i = 0; i < cur_p->size; ++i) {
2183
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2184
+ }
2185
+
2186
+ float p_txt_sum = 0.0f;
2187
+ float p_eog_sum = 0.0f;
2188
+
2189
+ for (size_t i = 0; i < cur_p->size; ++i) {
2190
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2191
+ p_eog_sum += cur_p->data[i].p;
2192
+ } else {
2193
+ p_txt_sum += cur_p->data[i].p;
2194
+ }
2195
+ }
2196
+
2197
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
2198
+
2199
+ LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
2200
+
2201
+ if (3*p_eog_sum*cur_p->size > p_txt_sum) {
2202
+ LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
2203
+
2204
+ // keep just the EOG tokens
2205
+ const auto size_org = cur_p->size;
2206
+
2207
+ cur_p->size = 0;
2208
+
2209
+ float p_sum = 0.0f;
2210
+
2211
+ for (size_t i = 0; i < size_org; ++i) {
2212
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2213
+ p_sum += cur_p->data[i].p;
2214
+
2215
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2216
+ }
2217
+ }
2218
+
2219
+ // normalize probs
2220
+ for (size_t i = 0; i < cur_p->size; ++i) {
2221
+ cur_p->data[i].p /= p_sum;
2222
+ }
2223
+
2224
+ return;
2225
+ }
2226
+
2227
+ size_t n_combined = 0; GGML_UNUSED(n_combined);
2228
+
2229
+ // combine tokens with common prefix
2230
+ for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
2231
+ for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
2232
+ if (cur_p->data[i0].logit == -INFINITY) {
2233
+ break;
2234
+ }
2235
+
2236
+ if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
2237
+ continue;
2238
+ }
2239
+
2240
+ int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2241
+ if (len0 < 0) {
2242
+ ctx->buf0.resize(len0);
2243
+ len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2244
+ assert(len0 > 0);
2245
+ }
2246
+
2247
+ int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2248
+ if (len1 < 0) {
2249
+ ctx->buf1.resize(len1);
2250
+ len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2251
+ assert(len1 > 0);
2252
+ }
2253
+
2254
+ // token i0 is a prefix of token i1
2255
+ if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
2256
+ int dst = i0;
2257
+ int src = i1;
2258
+
2259
+ // merge into the token with higher probability
2260
+ if (cur_p->data[i1].p > cur_p->data[i0].p) {
2261
+ std::swap(dst, src);
2262
+ }
2263
+
2264
+ cur_p->data[dst].p += cur_p->data[src].p;
2265
+ cur_p->data[src].logit = -INFINITY;
2266
+ cur_p->data[src].p = 0.0f;
2267
+
2268
+ n_combined++;
2269
+ }
2270
+ }
2271
+ }
2272
+
2273
+ size_t n_non_eog = 0;
2274
+
2275
+ size_t size_org = cur_p->size;
2276
+
2277
+ float p_sum = 0.0f;
2278
+ float thold = 0.2f;
2279
+
2280
+ cur_p->size = 0;
2281
+
2282
+ LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2283
+
2284
+ for (size_t i = 0; i < size_org; ++i) {
2285
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2286
+
2287
+ if (cur_p->data[i].p < thold && !is_eog) {
2288
+ continue;
2289
+ }
2290
+
2291
+ if (!is_eog) {
2292
+ ++n_non_eog;
2293
+ }
2294
+
2295
+ p_sum += cur_p->data[i].p;
2296
+
2297
+ // keep this token
2298
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2299
+ }
2300
+
2301
+ LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2302
+
2303
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2304
+ if (n_non_eog == 0) {
2305
+ cur_p->size = 1;
2306
+ cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
2307
+ cur_p->data[0].logit = 1.0f;
2308
+
2309
+ return;
2310
+ }
2311
+
2312
+ // normalize probs
2313
+ for (size_t i = 0; i < cur_p->size; ++i) {
2314
+ cur_p->data[i].p /= p_sum;
2315
+
2316
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2317
+ }
2318
+
2319
+ size_org = cur_p->size;
2320
+ p_sum = 0.0f;
2321
+ thold = 1.0/(n_non_eog + 1);
2322
+
2323
+ cur_p->size = 0;
2324
+
2325
+ LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2326
+
2327
+ for (size_t i = 0; i < size_org; ++i) {
2328
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2329
+
2330
+ if (cur_p->data[i].p < thold && !is_eog) {
2331
+ continue;
2332
+ }
2333
+
2334
+ p_sum += cur_p->data[i].p;
2335
+
2336
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2337
+ }
2338
+
2339
+ // normalize probs
2340
+ for (size_t i = 0; i < cur_p->size; ++i) {
2341
+ cur_p->data[i].p /= p_sum;
2342
+
2343
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2344
+ }
2345
+
2346
+ #undef LOG_DBG_CUR
2347
+ }
2348
+
2349
+ static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2350
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2351
+ return llama_sampler_init_infill_impl(*ctx->vocab);
2352
+ }
2353
+
2354
+ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2355
+ delete (llama_sampler_infill *) smpl->ctx;
2356
+ }
2357
+
2358
+ static struct llama_sampler_i llama_sampler_infill_i = {
2359
+ /* .name = */ llama_sampler_infill_name,
2360
+ /* .accept = */ nullptr,
2361
+ /* .apply = */ llama_sampler_infill_apply,
2362
+ /* .reset = */ nullptr,
2363
+ /* .clone = */ llama_sampler_infill_clone,
2364
+ /* .free = */ llama_sampler_infill_free,
2365
+ };
2366
+
2367
+ struct llama_sampler * llama_sampler_init_infill_impl(
2368
+ const struct llama_vocab & vocab) {
2369
+ return new llama_sampler {
2370
+ /* .iface = */ &llama_sampler_infill_i,
2371
+ /* .ctx = */ new llama_sampler_infill {
2372
+ /* .vocab = */ &vocab,
2373
+ /* .buf0 = */ std::vector<char>(512),
2374
+ /* .buf1 = */ std::vector<char>(512),
2375
+ },
2376
+ };
2377
+ }
2378
+
2379
  // utils
2380
 
2381
  uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
examples/talk-llama/llama-sampling.h CHANGED
@@ -4,8 +4,6 @@
4
 
5
  #include "llama-grammar.h"
6
 
7
- #include <unordered_map>
8
-
9
  struct llama_vocab;
10
  struct llama_grammar;
11
 
@@ -27,3 +25,24 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
27
  const struct llama_vocab & vocab,
28
  const char * grammar_str,
29
  const char * grammar_root);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  #include "llama-grammar.h"
6
 
 
 
7
  struct llama_vocab;
8
  struct llama_grammar;
9
 
 
25
  const struct llama_vocab & vocab,
26
  const char * grammar_str,
27
  const char * grammar_root);
28
+
29
+ struct llama_sampler * llama_sampler_init_infill_impl(
30
+ const struct llama_vocab & vocab);
31
+
32
+ struct llama_sampler * llama_sampler_init_dry_impl(
33
+ const struct llama_vocab & vocab,
34
+ int32_t context_size,
35
+ float dry_multiplier,
36
+ float dry_base,
37
+ int32_t dry_allowed_length,
38
+ int32_t dry_penalty_last_n,
39
+ const char ** seq_breakers,
40
+ size_t num_breakers);
41
+
42
+ struct llama_sampler * llama_sampler_init_dry_testing(
43
+ int32_t context_size,
44
+ float dry_multiplier,
45
+ float dry_base,
46
+ int32_t dry_allowed_length,
47
+ int32_t dry_penalty_last_n,
48
+ const std::vector<std::vector<llama_token>>& seq_breakers);
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -221,7 +221,7 @@ struct llm_tokenizer_spm_session {
221
  }
222
 
223
  // seed the work queue with all possible 2-character tokens.
224
- for (size_t i = 1; i < symbols.size(); ++i) {
225
  try_add_bigram(i - 1, i);
226
  }
227
 
@@ -563,7 +563,7 @@ struct llm_tokenizer_bpe_session {
563
  index++;
564
  symbols.emplace_back(sym);
565
  }
566
- for (size_t i = 1; i < symbols.size(); ++i) {
567
  add_new_bigram(i - 1, i);
568
  }
569
 
@@ -1663,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
1663
  return vocab.special_eos_id;
1664
  }
1665
 
 
 
 
 
 
 
 
 
1666
  llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
1667
  return vocab.special_cls_id;
1668
  }
@@ -1688,23 +1696,39 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
1688
  }
1689
 
1690
  llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1691
- return vocab.special_prefix_id;
1692
  }
1693
 
1694
  llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1695
- return vocab.special_middle_id;
1696
  }
1697
 
1698
  llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1699
- return vocab.special_suffix_id;
1700
  }
1701
 
1702
- llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1703
- return vocab.special_eot_id;
1704
  }
1705
 
1706
- llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1707
- return vocab.special_eom_id;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1708
  }
1709
 
1710
  int32_t llama_tokenize_impl(
@@ -1942,3 +1966,19 @@ int32_t llama_detokenize_impl(
1942
 
1943
  return total <= text_len_max ? total : -total;
1944
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  }
222
 
223
  // seed the work queue with all possible 2-character tokens.
224
+ for (int i = 1; i < (int) symbols.size(); ++i) {
225
  try_add_bigram(i - 1, i);
226
  }
227
 
 
563
  index++;
564
  symbols.emplace_back(sym);
565
  }
566
+ for (int i = 1; i < (int) symbols.size(); ++i) {
567
  add_new_bigram(i - 1, i);
568
  }
569
 
 
1663
  return vocab.special_eos_id;
1664
  }
1665
 
1666
+ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1667
+ return vocab.special_eot_id;
1668
+ }
1669
+
1670
+ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1671
+ return vocab.special_eom_id;
1672
+ }
1673
+
1674
  llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
1675
  return vocab.special_cls_id;
1676
  }
 
1696
  }
1697
 
1698
  llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1699
+ return vocab.special_fim_pre_id;
1700
  }
1701
 
1702
  llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1703
+ return vocab.special_fim_mid_id;
1704
  }
1705
 
1706
  llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1707
+ return vocab.special_fim_suf_id;
1708
  }
1709
 
1710
+ llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
1711
+ return vocab.special_fim_pre_id;
1712
  }
1713
 
1714
+ llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
1715
+ return vocab.special_fim_suf_id;
1716
+ }
1717
+
1718
+ llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
1719
+ return vocab.special_fim_mid_id;
1720
+ }
1721
+
1722
+ llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
1723
+ return vocab.special_fim_pad_id;
1724
+ }
1725
+
1726
+ llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
1727
+ return vocab.special_fim_rep_id;
1728
+ }
1729
+
1730
+ llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
1731
+ return vocab.special_fim_sep_id;
1732
  }
1733
 
1734
  int32_t llama_tokenize_impl(
 
1966
 
1967
  return total <= text_len_max ? total : -total;
1968
  }
1969
+
1970
+ std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
1971
+ std::string text;
1972
+ text.resize(std::max(text.capacity(), tokens.size()));
1973
+ int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1974
+ if (n_chars < 0) {
1975
+ text.resize(-n_chars);
1976
+ n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1977
+ GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
1978
+ }
1979
+
1980
+ text.resize(n_chars);
1981
+
1982
+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
1983
+ return text;
1984
+ }
examples/talk-llama/llama-vocab.h CHANGED
@@ -37,20 +37,26 @@ struct llama_vocab {
37
  std::map<std::pair<std::string, std::string>, int> bpe_ranks;
38
 
39
  // default LLaMA special tokens
 
40
  id special_bos_id = 1;
41
  id special_eos_id = 2;
 
 
42
  id special_unk_id = 0;
43
- id special_sep_id = -1;
44
- id special_pad_id = -1;
45
- id special_cls_id = -1;
46
- id special_mask_id = -1;
47
-
48
- id linefeed_id = 13;
49
- id special_prefix_id = -1;
50
- id special_suffix_id = -1;
51
- id special_middle_id = -1;
52
- id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
53
- id special_eom_id = -1;
 
 
 
54
 
55
  // set of all tokens that cause "end of generation"
56
  std::set<id> special_eog_ids;
@@ -104,19 +110,26 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t
104
 
105
  llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
106
  llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
 
 
107
  llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
108
  llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
109
  llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
110
  llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
111
 
112
- bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
113
- bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
114
-
115
  llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
116
  llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
117
  llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
118
- llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
119
- llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
 
 
 
 
 
 
 
 
120
 
121
  int32_t llama_tokenize_impl(
122
  const struct llama_vocab & vocab,
@@ -136,6 +149,12 @@ int32_t llama_token_to_piece_impl(
136
  int32_t lstrip,
137
  bool special);
138
 
 
 
 
 
 
 
139
  int32_t llama_detokenize_impl(
140
  const struct llama_vocab & vocab,
141
  const llama_token * tokens,
@@ -144,3 +163,8 @@ int32_t llama_detokenize_impl(
144
  int32_t text_len_max,
145
  bool remove_special,
146
  bool unparse_special);
 
 
 
 
 
 
37
  std::map<std::pair<std::string, std::string>, int> bpe_ranks;
38
 
39
  // default LLaMA special tokens
40
+ // TODO: should we set all of these to LLAMA_TOKEN_NULL?
41
  id special_bos_id = 1;
42
  id special_eos_id = 2;
43
+ id special_eot_id = LLAMA_TOKEN_NULL;
44
+ id special_eom_id = LLAMA_TOKEN_NULL;
45
  id special_unk_id = 0;
46
+ id special_sep_id = LLAMA_TOKEN_NULL;
47
+ id special_pad_id = LLAMA_TOKEN_NULL;
48
+ id special_cls_id = LLAMA_TOKEN_NULL;
49
+ id special_mask_id = LLAMA_TOKEN_NULL;
50
+
51
+ id linefeed_id = 13;
52
+
53
+ // fim tokens
54
+ id special_fim_pre_id = LLAMA_TOKEN_NULL;
55
+ id special_fim_suf_id = LLAMA_TOKEN_NULL;
56
+ id special_fim_mid_id = LLAMA_TOKEN_NULL;
57
+ id special_fim_pad_id = LLAMA_TOKEN_NULL;
58
+ id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
59
+ id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
60
 
61
  // set of all tokens that cause "end of generation"
62
  std::set<id> special_eog_ids;
 
110
 
111
  llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
112
  llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
113
+ llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
114
+ llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
115
  llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
116
  llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
117
  llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
118
  llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
119
 
 
 
 
120
  llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
121
  llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
122
  llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
123
+
124
+ llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
125
+ llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
126
+ llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
127
+ llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
128
+ llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
129
+ llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
130
+
131
+ bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
132
+ bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
133
 
134
  int32_t llama_tokenize_impl(
135
  const struct llama_vocab & vocab,
 
149
  int32_t lstrip,
150
  bool special);
151
 
152
+ // check if token0 is contained as a prefix in token1
153
+ bool llama_token_is_prefix_impl(
154
+ const struct llama_vocab & vocab,
155
+ llama_token token0,
156
+ llama_token token1);
157
+
158
  int32_t llama_detokenize_impl(
159
  const struct llama_vocab & vocab,
160
  const llama_token * tokens,
 
163
  int32_t text_len_max,
164
  bool remove_special,
165
  bool unparse_special);
166
+
167
+ std::string llama_detokenize(
168
+ const struct llama_vocab & vocab,
169
+ const std::vector<llama_token> & tokens,
170
+ bool special);
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -217,6 +217,7 @@ extern "C" {
217
 
218
  typedef struct llama_token_data_array {
219
  // TODO: consider SoA
 
220
  llama_token_data * data;
221
  size_t size;
222
  int64_t selected; // this is the index in the data array (i.e. not the token id)
@@ -232,8 +233,11 @@ extern "C" {
232
  // - token : the token ids of the input (used when embd is NULL)
233
  // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
234
  // - pos : the positions of the respective token in the sequence
 
235
  // - seq_id : the sequence to which the respective token belongs
 
236
  // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
 
237
  //
238
  typedef struct llama_batch {
239
  int32_t n_tokens;
@@ -244,15 +248,6 @@ extern "C" {
244
  int32_t * n_seq_id;
245
  llama_seq_id ** seq_id;
246
  int8_t * logits; // TODO: rename this to "output"
247
-
248
- // NOTE: helpers for smooth API transition - can be deprecated in the future
249
- // for future-proof code, use the above fields instead and ignore everything below
250
- //
251
- // pos[i] = all_pos_0 + i*all_pos_1
252
- //
253
- llama_pos all_pos_0; // used if pos == NULL
254
- llama_pos all_pos_1; // used if pos == NULL
255
- llama_seq_id all_seq_id; // used if seq_id == NULL
256
  } llama_batch;
257
 
258
  enum llama_model_kv_override_type {
@@ -433,6 +428,7 @@ extern "C" {
433
  LLAMA_API bool llama_supports_mmap (void);
434
  LLAMA_API bool llama_supports_mlock (void);
435
  LLAMA_API bool llama_supports_gpu_offload(void);
 
436
 
437
  LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
438
  LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@@ -775,15 +771,15 @@ extern "C" {
775
  // Decoding
776
  //
777
 
778
- // Return batch for single sequence of tokens starting at pos_0
 
 
779
  //
780
  // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
781
  //
782
  LLAMA_API struct llama_batch llama_batch_get_one(
783
  llama_token * tokens,
784
- int32_t n_tokens,
785
- llama_pos pos_0,
786
- llama_seq_id seq_id);
787
 
788
  // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
789
  // Each token can be assigned up to n_seq_max sequence ids
@@ -896,6 +892,7 @@ extern "C" {
896
  // Special tokens
897
  LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
898
  LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
 
899
  LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
900
  LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
901
  LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@@ -904,11 +901,17 @@ extern "C" {
904
  LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
905
  LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
906
 
907
- // Codellama infill tokens
908
- LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
909
- LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
910
- LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
911
- LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
 
 
 
 
 
 
912
 
913
  //
914
  // Tokenization
@@ -1067,12 +1070,13 @@ extern "C" {
1067
 
1068
  // available samplers:
1069
 
1070
- LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
1071
- LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
1072
 
1073
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1074
  /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1075
- LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
 
1076
 
1077
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1078
  LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
@@ -1088,11 +1092,16 @@ extern "C" {
1088
 
1089
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1090
  LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
 
 
1091
  LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
1092
 
1093
  /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
1094
  LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
1095
 
 
 
 
1096
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1097
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1098
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1132,11 +1141,43 @@ extern "C" {
1132
  bool penalize_nl, // consider newlines as a repeatable token
1133
  bool ignore_eos); // ignore the end-of-sequence token
1134
 
 
 
 
 
 
 
 
 
 
 
1135
  LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
1136
  int32_t n_vocab,
1137
  int32_t n_logit_bias,
1138
  const llama_logit_bias * logit_bias);
1139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1140
 
1141
  // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
1142
  LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
 
217
 
218
  typedef struct llama_token_data_array {
219
  // TODO: consider SoA
220
+ // NOTE: this pointer can be modified by the samplers
221
  llama_token_data * data;
222
  size_t size;
223
  int64_t selected; // this is the index in the data array (i.e. not the token id)
 
233
  // - token : the token ids of the input (used when embd is NULL)
234
  // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
235
  // - pos : the positions of the respective token in the sequence
236
+ // (if set to NULL, the token position will be tracked automatically by llama_decode)
237
  // - seq_id : the sequence to which the respective token belongs
238
+ // (if set to NULL, the sequence ID will be assumed to be 0)
239
  // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
240
+ // (if set to NULL, only the logits for last token will be returned)
241
  //
242
  typedef struct llama_batch {
243
  int32_t n_tokens;
 
248
  int32_t * n_seq_id;
249
  llama_seq_id ** seq_id;
250
  int8_t * logits; // TODO: rename this to "output"
 
 
 
 
 
 
 
 
 
251
  } llama_batch;
252
 
253
  enum llama_model_kv_override_type {
 
428
  LLAMA_API bool llama_supports_mmap (void);
429
  LLAMA_API bool llama_supports_mlock (void);
430
  LLAMA_API bool llama_supports_gpu_offload(void);
431
+ LLAMA_API bool llama_supports_rpc (void);
432
 
433
  LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
434
  LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
 
771
  // Decoding
772
  //
773
 
774
+ // Return batch for single sequence of tokens
775
+ // The sequence ID will be fixed to 0
776
+ // The position of the tokens will be tracked automatically by llama_decode
777
  //
778
  // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
779
  //
780
  LLAMA_API struct llama_batch llama_batch_get_one(
781
  llama_token * tokens,
782
+ int32_t n_tokens);
 
 
783
 
784
  // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
785
  // Each token can be assigned up to n_seq_max sequence ids
 
892
  // Special tokens
893
  LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
894
  LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
895
+ LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
896
  LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
897
  LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
898
  LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
 
901
  LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
902
  LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
903
 
904
+ // infill tokens
905
+ DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
906
+ DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
907
+ DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
908
+
909
+ LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
910
+ LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
911
+ LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
912
+ LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
913
+ LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
914
+ LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
915
 
916
  //
917
  // Tokenization
 
1070
 
1071
  // available samplers:
1072
 
1073
+ LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
1074
+ LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
1075
 
1076
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1077
  /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1078
+ DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
1079
+ "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
1080
 
1081
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1082
  LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
 
1092
 
1093
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1094
  LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
1095
+
1096
+ /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
1097
  LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
1098
 
1099
  /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
1100
  LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
1101
 
1102
+ /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
1103
+ LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
1104
+
1105
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1106
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1107
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
 
1141
  bool penalize_nl, // consider newlines as a repeatable token
1142
  bool ignore_eos); // ignore the end-of-sequence token
1143
 
1144
+ /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
1145
+ LLAMA_API struct llama_sampler * llama_sampler_init_dry(
1146
+ const struct llama_model * model,
1147
+ float dry_multiplier,
1148
+ float dry_base,
1149
+ int32_t dry_allowed_length,
1150
+ int32_t dry_penalty_last_n,
1151
+ const char ** seq_breakers,
1152
+ size_t num_breakers);
1153
+
1154
  LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
1155
  int32_t n_vocab,
1156
  int32_t n_logit_bias,
1157
  const llama_logit_bias * logit_bias);
1158
 
1159
+ // this sampler is meant to be used for fill-in-the-middle infilling
1160
+ // it's supposed to be used after top_k + top_p sampling
1161
+ //
1162
+ // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
1163
+ // 2. combine probs of tokens that have the same prefix
1164
+ //
1165
+ // example:
1166
+ //
1167
+ // - before:
1168
+ // "hel": 0.5
1169
+ // "hell": 0.2
1170
+ // "hello": 0.1
1171
+ // "dummy": 0.1
1172
+ //
1173
+ // - after:
1174
+ // "hel": 0.8
1175
+ // "dummy": 0.1
1176
+ //
1177
+ // 3. discard non-EOG tokens with low prob
1178
+ // 4. if no tokens are left -> pick EOT
1179
+ //
1180
+ LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
1181
 
1182
  // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
1183
  LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
examples/talk-llama/unicode-data.cpp CHANGED
@@ -2311,7 +2311,7 @@ const std::unordered_set<uint32_t> unicode_set_whitespace = {
2311
  0x003000,
2312
  };
2313
 
2314
- // list is always in ascending order, to enable binary searh
2315
  const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase = {
2316
  {0x000041, 0x000061},
2317
  {0x000042, 0x000062},
@@ -3748,7 +3748,7 @@ const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase
3748
  {0x01E921, 0x01E943},
3749
  };
3750
 
3751
- // list is always in ascending order, to enable binary searh
3752
  const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_uppercase = {
3753
  {0x000061, 0x000041},
3754
  {0x000062, 0x000042},
 
2311
  0x003000,
2312
  };
2313
 
2314
+ // list is always in ascending order, to enable binary search
2315
  const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase = {
2316
  {0x000041, 0x000061},
2317
  {0x000042, 0x000062},
 
3748
  {0x01E921, 0x01E943},
3749
  };
3750
 
3751
+ // list is always in ascending order, to enable binary search
3752
  const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_uppercase = {
3753
  {0x000061, 0x000041},
3754
  {0x000062, 0x000042},