ggerganov commited on
Commit
a40d0a7
·
1 Parent(s): 96e8b15

talk-llama : sync llama.cpp

Browse files
Makefile CHANGED
@@ -785,7 +785,8 @@ OBJ_GGML += \
785
  ggml/src/ggml.o \
786
  ggml/src/ggml-alloc.o \
787
  ggml/src/ggml-backend.o \
788
- ggml/src/ggml-quants.o
 
789
 
790
  OBJ_WHISPER += \
791
  src/whisper.o
@@ -916,6 +917,13 @@ ggml/src/ggml-quants.o: \
916
  ggml/src/ggml-common.h
917
  $(CC) $(CFLAGS) -c $< -o $@
918
 
 
 
 
 
 
 
 
919
  ggml/src/ggml-blas.o: \
920
  ggml/src/ggml-blas.cpp \
921
  ggml/include/ggml-blas.h
@@ -1076,7 +1084,7 @@ talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
1076
  $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1077
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
1078
 
1079
- talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
1080
  $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
1081
  $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1082
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
 
785
  ggml/src/ggml.o \
786
  ggml/src/ggml-alloc.o \
787
  ggml/src/ggml-backend.o \
788
+ ggml/src/ggml-quants.o \
789
+ ggml/src/ggml-aarch64.o
790
 
791
  OBJ_WHISPER += \
792
  src/whisper.o
 
917
  ggml/src/ggml-common.h
918
  $(CC) $(CFLAGS) -c $< -o $@
919
 
920
+ ggml/src/ggml-aarch64.o: \
921
+ ggml/src/ggml-aarch64.c \
922
+ ggml/include/ggml.h \
923
+ ggml/src/ggml-aarch64.h \
924
+ ggml/src/ggml-common.h
925
+ $(CC) $(CFLAGS) -c $< -o $@
926
+
927
  ggml/src/ggml-blas.o: \
928
  ggml/src/ggml-blas.cpp \
929
  ggml/include/ggml-blas.h
 
1084
  $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1085
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
1086
 
1087
+ talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
1088
  $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
1089
  $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1090
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
examples/talk-llama/CMakeLists.txt CHANGED
@@ -1,7 +1,13 @@
1
  if (WHISPER_SDL2)
2
  # talk-llama
3
  set(TARGET talk-llama)
4
- add_executable(${TARGET} talk-llama.cpp llama.cpp unicode.cpp unicode-data.cpp)
 
 
 
 
 
 
5
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
6
 
7
  if (WHISPER_CLBLAST)
 
1
  if (WHISPER_SDL2)
2
  # talk-llama
3
  set(TARGET talk-llama)
4
+ add_executable(${TARGET} talk-llama.cpp
5
+ llama.cpp
6
+ llama-vocab.cpp
7
+ llama-grammar.cpp
8
+ llama-sampling.cpp
9
+ unicode.cpp
10
+ unicode-data.cpp)
11
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
12
 
13
  if (WHISPER_CLBLAST)
examples/talk-llama/llama-grammar.cpp ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-grammar.h"
2
+
3
+ #include "llama-vocab.h"
4
+ #include "llama-sampling.h"
5
+
6
+ #include <algorithm>
7
+
8
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
9
+ // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
10
+ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
11
+ const std::string & src,
12
+ llama_partial_utf8 partial_start) {
13
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
14
+ const char * pos = src.c_str();
15
+ std::vector<uint32_t> code_points;
16
+
17
+ // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
18
+ code_points.reserve(src.size() + 1);
19
+ uint32_t value = partial_start.value;
20
+ int n_remain = partial_start.n_remain;
21
+
22
+ // continue previous decode, if applicable
23
+ while (*pos != 0 && n_remain > 0) {
24
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
25
+ if ((next_byte >> 6) != 2) {
26
+ // invalid sequence, abort
27
+ code_points.push_back(0);
28
+ return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
29
+ }
30
+ value = (value << 6) + (next_byte & 0x3F);
31
+ ++pos;
32
+ --n_remain;
33
+ }
34
+
35
+ if (partial_start.n_remain > 0 && n_remain == 0) {
36
+ code_points.push_back(value);
37
+ }
38
+
39
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
40
+ while (*pos != 0) {
41
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
42
+ uint8_t highbits = first_byte >> 4;
43
+ n_remain = lookup[highbits] - 1;
44
+
45
+ if (n_remain < 0) {
46
+ // invalid sequence, abort
47
+ code_points.clear();
48
+ code_points.push_back(0);
49
+ return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
50
+ }
51
+
52
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
53
+ value = first_byte & mask;
54
+
55
+ ++pos;
56
+ while (*pos != 0 && n_remain > 0) {
57
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
58
+ ++pos;
59
+ --n_remain;
60
+ }
61
+ if (n_remain == 0) {
62
+ code_points.push_back(value);
63
+ }
64
+ }
65
+ code_points.push_back(0);
66
+
67
+ return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
68
+ }
69
+
70
+ const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
71
+ return grammar->rules;
72
+ }
73
+
74
+ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
75
+ return grammar->stacks;
76
+ }
77
+
78
+ // returns true iff pos points to the end of one of the definitions of a rule
79
+ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
80
+ switch (pos->type) {
81
+ case LLAMA_GRETYPE_END: return true; // NOLINT
82
+ case LLAMA_GRETYPE_ALT: return true; // NOLINT
83
+ default: return false;
84
+ }
85
+ }
86
+
87
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
88
+ // asserts that pos is pointing to a char range element
89
+ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
90
+ const llama_grammar_element * pos,
91
+ const uint32_t chr) {
92
+
93
+ bool found = false;
94
+ bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
95
+
96
+ GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
97
+
98
+ do {
99
+ if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
100
+ // inclusive range, e.g. [a-z]
101
+ found = found || (pos->value <= chr && chr <= pos[1].value);
102
+ pos += 2;
103
+ } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
104
+ // Any character matches "."
105
+ found = true;
106
+ pos += 1;
107
+ } else {
108
+ // exact char match, e.g. [a] or "a"
109
+ found = found || pos->value == chr;
110
+ pos += 1;
111
+ }
112
+ } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
113
+
114
+ return std::make_pair(found == is_positive_char, pos);
115
+ }
116
+
117
+ // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
118
+ // range at pos (regular or inverse range)
119
+ // asserts that pos is pointing to a char range element
120
+ static bool llama_grammar_match_partial_char(
121
+ const llama_grammar_element * pos,
122
+ const llama_partial_utf8 partial_utf8) {
123
+ bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
124
+ GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
125
+
126
+ uint32_t partial_value = partial_utf8.value;
127
+ int n_remain = partial_utf8.n_remain;
128
+
129
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
130
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
131
+ return false;
132
+ }
133
+
134
+ // range of possible code points this partial UTF-8 sequence could complete to
135
+ uint32_t low = partial_value << (n_remain * 6);
136
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
137
+
138
+ if (low == 0) {
139
+ if (n_remain == 2) {
140
+ low = 1 << 11;
141
+ } else if (n_remain == 3) {
142
+ low = 1 << 16;
143
+ }
144
+ }
145
+
146
+ do {
147
+ if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
148
+ // inclusive range, e.g. [a-z]
149
+ if (pos->value <= high && low <= pos[1].value) {
150
+ return is_positive_char;
151
+ }
152
+ pos += 2;
153
+ } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
154
+ // Any character matches "."
155
+ return true;
156
+ } else {
157
+ // exact char match, e.g. [a] or "a"
158
+ if (low <= pos->value && pos->value <= high) {
159
+ return is_positive_char;
160
+ }
161
+ pos += 1;
162
+ }
163
+ } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
164
+
165
+ return !is_positive_char;
166
+ }
167
+
168
+ // transforms a grammar pushdown stack into N possible stacks, all ending
169
+ // at a character range (terminal element)
170
+ static void llama_grammar_advance_stack(
171
+ const llama_grammar_rules & rules,
172
+ const llama_grammar_stack & stack,
173
+ llama_grammar_stacks & new_stacks) {
174
+ if (stack.empty()) {
175
+ if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
176
+ new_stacks.emplace_back(stack);
177
+ }
178
+ return;
179
+ }
180
+
181
+ const llama_grammar_element * pos = stack.back();
182
+
183
+ switch (pos->type) {
184
+ case LLAMA_GRETYPE_RULE_REF: {
185
+ const size_t rule_id = static_cast<size_t>(pos->value);
186
+ const llama_grammar_element * subpos = rules[rule_id].data();
187
+ do {
188
+ // init new stack without the top (pos)
189
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
190
+ if (!llama_grammar_is_end_of_sequence(pos + 1)) {
191
+ // if this rule ref is followed by another element, add that to stack
192
+ new_stack.push_back(pos + 1);
193
+ }
194
+ if (!llama_grammar_is_end_of_sequence(subpos)) {
195
+ // if alternate is nonempty, add to stack
196
+ new_stack.push_back(subpos);
197
+ }
198
+ llama_grammar_advance_stack(rules, new_stack, new_stacks);
199
+ while (!llama_grammar_is_end_of_sequence(subpos)) {
200
+ // scan to end of alternate def
201
+ subpos++;
202
+ }
203
+ if (subpos->type == LLAMA_GRETYPE_ALT) {
204
+ // there's another alternate def of this rule to process
205
+ subpos++;
206
+ } else {
207
+ break;
208
+ }
209
+ } while (true);
210
+ break;
211
+ }
212
+ case LLAMA_GRETYPE_CHAR:
213
+ case LLAMA_GRETYPE_CHAR_NOT:
214
+ case LLAMA_GRETYPE_CHAR_ANY:
215
+ if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
216
+ // only add the stack if it's not a duplicate of one we already have
217
+ new_stacks.emplace_back(stack);
218
+ }
219
+ break;
220
+ default:
221
+ // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
222
+ // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
223
+ // those
224
+ GGML_ABORT("fatal error");
225
+ }
226
+ }
227
+
228
+ // takes a set of possible pushdown stacks on a grammar, which are required to
229
+ // be positioned at a character range (see `llama_grammar_advance_stack`), and
230
+ // produces the N possible stacks if the given char is accepted at those
231
+ // positions
232
+ void llama_grammar_accept(
233
+ const llama_grammar_rules & rules,
234
+ const llama_grammar_stacks & stacks,
235
+ const uint32_t chr,
236
+ llama_grammar_stacks & new_stacks) {
237
+ new_stacks.clear();
238
+
239
+ for (const auto & stack : stacks) {
240
+ if (stack.empty()) {
241
+ continue;
242
+ }
243
+
244
+ auto match = llama_grammar_match_char(stack.back(), chr);
245
+ if (match.first) {
246
+ const llama_grammar_element * pos = match.second;
247
+
248
+ // update top of stack to next element, if any
249
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
250
+ if (!llama_grammar_is_end_of_sequence(pos)) {
251
+ new_stack.push_back(pos);
252
+ }
253
+ llama_grammar_advance_stack(rules, new_stack, new_stacks);
254
+ }
255
+ }
256
+ }
257
+
258
+ static llama_grammar_candidates llama_grammar_reject_candidates(
259
+ const llama_grammar_rules & rules,
260
+ const llama_grammar_stacks & stacks,
261
+ const llama_grammar_candidates & candidates) {
262
+ GGML_ASSERT(!stacks.empty()); // REVIEW
263
+
264
+ if (candidates.empty()) {
265
+ return {};
266
+ }
267
+
268
+ auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
269
+
270
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
271
+ rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
272
+ }
273
+ return rejects;
274
+ }
275
+
276
+ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
277
+ const llama_grammar_rules & rules,
278
+ const llama_grammar_stack & stack,
279
+ const llama_grammar_candidates & candidates) {
280
+
281
+ llama_grammar_candidates rejects;
282
+ rejects.reserve(candidates.size());
283
+
284
+ if (stack.empty()) {
285
+ for (const auto & tok : candidates) {
286
+ if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
287
+ rejects.push_back(tok);
288
+ }
289
+ }
290
+ return rejects;
291
+ }
292
+
293
+ const llama_grammar_element * stack_pos = stack.back();
294
+
295
+ llama_grammar_candidates next_candidates;
296
+ next_candidates.reserve(candidates.size());
297
+
298
+ for (const auto & tok : candidates) {
299
+ if (*tok.code_points == 0) {
300
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
301
+ // that cannot satisfy this position in grammar
302
+ if (tok.partial_utf8.n_remain != 0 &&
303
+ !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
304
+ rejects.push_back(tok);
305
+ }
306
+ } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
307
+ next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
308
+ } else {
309
+ rejects.push_back(tok);
310
+ }
311
+ }
312
+
313
+ const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
314
+
315
+ // update top of stack to next element, if any
316
+ llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
317
+ if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
318
+ stack_after.push_back(stack_pos_after);
319
+ }
320
+ llama_grammar_stacks next_stacks;
321
+ llama_grammar_advance_stack(rules, stack_after, next_stacks);
322
+
323
+ auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
324
+ for (const auto & tok : next_rejects) {
325
+ rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
326
+ }
327
+
328
+ return rejects;
329
+ }
330
+
331
+ static bool llama_grammar_detect_left_recursion(
332
+ const llama_grammar_rules & rules,
333
+ size_t rule_index,
334
+ std::vector<bool> * rules_visited,
335
+ std::vector<bool> * rules_in_progress,
336
+ std::vector<bool> * rules_may_be_empty) {
337
+ if ((*rules_in_progress)[rule_index]) {
338
+ return true;
339
+ }
340
+
341
+ (*rules_in_progress)[rule_index] = true;
342
+
343
+ const llama_grammar_rule & rule = rules[rule_index];
344
+
345
+ // First check if the rule might produce the empty string. This could be done combined with the second
346
+ // step but it's more readable as two steps.
347
+ bool at_rule_start = true;
348
+ for (size_t i = 0; i < rule.size(); i++) {
349
+ if (llama_grammar_is_end_of_sequence(&rule[i])) {
350
+ if (at_rule_start) {
351
+ (*rules_may_be_empty)[rule_index] = true;
352
+ break;
353
+ }
354
+ at_rule_start = true;
355
+ } else {
356
+ at_rule_start = false;
357
+ }
358
+ }
359
+
360
+ // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
361
+ // be empty)
362
+ bool recurse_into_nonterminal = true;
363
+ for (size_t i = 0; i < rule.size(); i++) {
364
+ if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
365
+ if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
366
+ return true;
367
+ }
368
+ if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
369
+ recurse_into_nonterminal = false;
370
+ }
371
+ } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
372
+ recurse_into_nonterminal = true;
373
+ } else {
374
+ recurse_into_nonterminal = false;
375
+ }
376
+ }
377
+
378
+ (*rules_in_progress)[rule_index] = false;
379
+ (*rules_visited)[rule_index] = true;
380
+ return false;
381
+ }
382
+
383
+ //
384
+ // grammar - external
385
+ //
386
+
387
+ struct llama_grammar * llama_grammar_init_impl(
388
+ const llama_grammar_element ** rules,
389
+ size_t n_rules,
390
+ size_t start_rule_index) {
391
+ const llama_grammar_element * pos;
392
+
393
+ // copy rule definitions into vectors
394
+ llama_grammar_rules vec_rules(n_rules);
395
+ for (size_t i = 0; i < n_rules; i++) {
396
+ for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
397
+ vec_rules[i].push_back(*pos);
398
+ }
399
+ vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
400
+ }
401
+
402
+ // Check for left recursion
403
+ std::vector<bool> rules_visited(n_rules);
404
+ std::vector<bool> rules_in_progress(n_rules);
405
+ std::vector<bool> rules_may_be_empty(n_rules);
406
+ for (size_t i = 0; i < n_rules; i++) {
407
+ if (rules_visited[i]) {
408
+ continue;
409
+ }
410
+ if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
411
+ LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
412
+ return nullptr;
413
+ }
414
+ }
415
+
416
+ // loop over alternates of start rule to build initial stacks
417
+ llama_grammar_stacks stacks;
418
+ pos = vec_rules[start_rule_index].data();
419
+ do {
420
+ llama_grammar_stack stack;
421
+ if (!llama_grammar_is_end_of_sequence(pos)) {
422
+ // if alternate is nonempty, add to stack
423
+ stack.push_back(pos);
424
+ }
425
+ llama_grammar_advance_stack(vec_rules, stack, stacks);
426
+ while (!llama_grammar_is_end_of_sequence(pos)) {
427
+ // scan to end of alternate def
428
+ pos++;
429
+ }
430
+ if (pos->type == LLAMA_GRETYPE_ALT) {
431
+ // there's another alternate def of this rule to process
432
+ pos++;
433
+ } else {
434
+ break;
435
+ }
436
+ } while (true);
437
+
438
+ // Important: vec_rules has to be moved here, not copied, because stacks contains
439
+ // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
440
+ // then the pointers would be invalidated when the local vec_rules goes out of scope.
441
+ return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
442
+ }
443
+
444
+ void llama_grammar_free_impl(struct llama_grammar * grammar) {
445
+ delete grammar;
446
+ }
447
+
448
+ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
449
+ llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
450
+
451
+ // redirect elements in stacks to point to new rules
452
+ for (size_t is = 0; is < result->stacks.size(); is++) {
453
+ for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
454
+ for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
455
+ for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
456
+ if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
457
+ result->stacks[is][ie] = &result->rules[ir0][ir1];
458
+ }
459
+ }
460
+ }
461
+ }
462
+ }
463
+
464
+ return result;
465
+ }
466
+
467
+ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468
+ GGML_ASSERT(grammar);
469
+ GGML_ASSERT(vocab);
470
+
471
+ int64_t t_start_sample_us = ggml_time_us();
472
+
473
+ bool allow_eog = false;
474
+ for (const auto & stack : grammar->stacks) {
475
+ if (stack.empty()) {
476
+ allow_eog = true;
477
+ break;
478
+ }
479
+ }
480
+
481
+ std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
482
+ candidates_decoded.reserve(candidates->size);
483
+
484
+ llama_grammar_candidates candidates_grammar;
485
+ candidates_grammar.reserve(candidates->size);
486
+
487
+ for (size_t i = 0; i < candidates->size; ++i) {
488
+ const llama_token id = candidates->data[i].id;
489
+ const std::string & piece = vocab->cache_token_to_piece.at(id);
490
+
491
+ if (llama_token_is_eog_impl(*vocab, id)) {
492
+ if (!allow_eog) {
493
+ candidates->data[i].logit = -INFINITY;
494
+ }
495
+ } else if (piece.empty() || piece[0] == 0) {
496
+ candidates->data[i].logit = -INFINITY;
497
+ } else {
498
+ candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
499
+ candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
500
+ }
501
+ }
502
+
503
+ const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
504
+ for (const auto & reject : rejects) {
505
+ candidates->data[reject.index].logit = -INFINITY;
506
+ }
507
+
508
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
509
+ }
510
+
511
+ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
512
+ const int64_t t_start_sample_us = ggml_time_us();
513
+
514
+ if (llama_token_is_eog_impl(*vocab, token)) {
515
+ for (const auto & stack : grammar->stacks) {
516
+ if (stack.empty()) {
517
+ return;
518
+ }
519
+ }
520
+ GGML_ABORT("fatal error");
521
+ }
522
+
523
+ const std::string & piece = vocab->cache_token_to_piece.at(token);
524
+
525
+ // Note terminating 0 in decoded string
526
+ const auto decoded = decode_utf8(piece, grammar->partial_utf8);
527
+ const auto & code_points = decoded.first;
528
+
529
+ llama_grammar_stacks tmp_new_stacks;
530
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
531
+ llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
532
+ grammar->stacks = tmp_new_stacks;
533
+ }
534
+
535
+ grammar->partial_utf8 = decoded.second;
536
+ GGML_ASSERT(!grammar->stacks.empty());
537
+
538
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
539
+ }
examples/talk-llama/llama-grammar.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-impl.h"
4
+
5
+ struct llama_vocab;
6
+ struct llama_sampling;
7
+
8
+ struct llama_grammar {
9
+ const llama_grammar_rules rules;
10
+ llama_grammar_stacks stacks;
11
+
12
+ // buffer for partially generated UTF-8 sequence from accepted tokens
13
+ llama_partial_utf8 partial_utf8;
14
+ };
15
+
16
+ //
17
+ // internal API
18
+ //
19
+
20
+ struct llama_grammar * llama_grammar_init_impl(
21
+ const llama_grammar_element ** rules,
22
+ size_t n_rules,
23
+ size_t start_rule_index);
24
+
25
+ void llama_grammar_free_impl(struct llama_grammar * grammar);
26
+
27
+ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
28
+
29
+ void llama_grammar_sample_impl(
30
+ const struct llama_grammar * grammar,
31
+ const struct llama_vocab * vocab,
32
+ const struct llama_sampling * smpl,
33
+ llama_token_data_array * candidates);
34
+
35
+ void llama_grammar_accept_token_impl(
36
+ struct llama_grammar * grammar,
37
+ const struct llama_vocab * vocab,
38
+ const struct llama_sampling * smpl,
39
+ llama_token token);
examples/talk-llama/llama-impl.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #define LLAMA_API_INTERNAL
4
+ #include "llama.h"
5
+
6
+ #ifdef __GNUC__
7
+ #ifdef __MINGW32__
8
+ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
9
+ #else
10
+ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
11
+ #endif
12
+ #else
13
+ #define LLAMA_ATTRIBUTE_FORMAT(...)
14
+ #endif
15
+
16
+ //
17
+ // logging
18
+ //
19
+
20
+ LLAMA_ATTRIBUTE_FORMAT(2, 3)
21
+ void llama_log_internal (ggml_log_level level, const char * format, ...);
22
+ void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
23
+
24
+ #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
25
+ #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
26
+ #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
examples/talk-llama/llama-sampling.cpp ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-sampling.h"
2
+
3
+ #include <algorithm>
4
+ #include <cstring>
5
+ #include <ctime>
6
+ #include <cfloat>
7
+ #include <numeric>
8
+ #include <unordered_map>
9
+
10
+ static void llama_log_softmax(float * array, size_t size) {
11
+ float max_l = *std::max_element(array, array + size);
12
+ float sum = 0.f;
13
+ for (size_t i = 0; i < size; ++i) {
14
+ float p = expf(array[i] - max_l);
15
+ sum += p;
16
+ array[i] = p;
17
+ }
18
+
19
+ for (size_t i = 0; i < size; ++i) {
20
+ array[i] = logf(array[i] / sum);
21
+ }
22
+ }
23
+
24
+ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
25
+ if (seed == LLAMA_DEFAULT_SEED) {
26
+ seed = time(NULL);
27
+ }
28
+
29
+ smpl->rng.seed(seed);
30
+ }
31
+
32
+ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
33
+ GGML_ASSERT(candidates->size > 0);
34
+
35
+ const int64_t t_start_sample_us = ggml_time_us();
36
+
37
+ // Sort the logits in descending order
38
+ if (!candidates->sorted) {
39
+ std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
40
+ return a.logit > b.logit;
41
+ });
42
+ candidates->sorted = true;
43
+ }
44
+
45
+ float max_l = candidates->data[0].logit;
46
+ float cum_sum = 0.0f;
47
+ for (size_t i = 0; i < candidates->size; ++i) {
48
+ float p = expf(candidates->data[i].logit - max_l);
49
+ candidates->data[i].p = p;
50
+ cum_sum += p;
51
+ }
52
+ for (size_t i = 0; i < candidates->size; ++i) {
53
+ candidates->data[i].p /= cum_sum;
54
+ }
55
+
56
+ if (smpl) {
57
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
58
+ }
59
+ }
60
+
61
+ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
62
+ // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
63
+ // if (k >= (int32_t)candidates->size) {
64
+ // return;
65
+ // }
66
+
67
+ const int64_t t_start_sample_us = ggml_time_us();
68
+
69
+ if (k <= 0) {
70
+ k = candidates->size;
71
+ }
72
+
73
+ k = std::max(k, (int) min_keep);
74
+ k = std::min(k, (int) candidates->size);
75
+
76
+ // Sort scores in descending order
77
+ if (!candidates->sorted) {
78
+ auto comp = [](const llama_token_data & a, const llama_token_data & b) {
79
+ return a.logit > b.logit;
80
+ };
81
+ if (k <= 128) {
82
+ std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
83
+ } else {
84
+ constexpr int nbuckets = 128;
85
+ constexpr float bucket_low = -10.0f;
86
+ constexpr float bucket_high = 10.0f;
87
+ constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
88
+ constexpr float bucker_inter = -bucket_low * bucket_scale;
89
+
90
+ std::vector<int> bucket_idx(candidates->size);
91
+ std::vector<int> histo(nbuckets, 0);
92
+
93
+ for (int i = 0; i < (int)candidates->size; ++i) {
94
+ const float val = candidates->data[i].logit;
95
+ int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
96
+ ib = std::max(0, std::min(nbuckets-1, ib));
97
+ bucket_idx[i] = ib;
98
+ ++histo[ib];
99
+ }
100
+ int nhave = 0;
101
+ int ib = nbuckets - 1;
102
+ for ( ; ib >= 0; --ib) {
103
+ nhave += histo[ib];
104
+ if (nhave >= k) break;
105
+ }
106
+ std::vector<llama_token_data> tmp_tokens(nhave);
107
+ auto ptr = tmp_tokens.data();
108
+ std::vector<llama_token_data*> bucket_ptrs;
109
+ bucket_ptrs.reserve(nbuckets - ib);
110
+ for (int j = nbuckets - 1; j >= ib; --j) {
111
+ bucket_ptrs.push_back(ptr);
112
+ ptr += histo[j];
113
+ }
114
+ for (int i = 0; i < (int)candidates->size; ++i) {
115
+ int j = bucket_idx[i];
116
+ if (j >= ib) {
117
+ *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
118
+ }
119
+ }
120
+
121
+ ptr = tmp_tokens.data();
122
+ int ndone = 0;
123
+ for (int j = nbuckets-1; j > ib; --j) {
124
+ std::sort(ptr, ptr + histo[j], comp);
125
+ ptr += histo[j];
126
+ ndone += histo[j];
127
+ }
128
+ std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
129
+
130
+ std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
131
+
132
+ }
133
+ candidates->sorted = true;
134
+ }
135
+ candidates->size = k;
136
+
137
+ if (smpl) {
138
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
139
+ }
140
+ }
141
+
142
+ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143
+ if (p >= 1.0f) {
144
+ return;
145
+ }
146
+
147
+ llama_sample_softmax_impl(smpl, candidates);
148
+
149
+ const int64_t t_start_sample_us = ggml_time_us();
150
+
151
+ // Compute the cumulative probabilities
152
+ float cum_sum = 0.0f;
153
+ size_t last_idx = candidates->size;
154
+
155
+ for (size_t i = 0; i < candidates->size; ++i) {
156
+ cum_sum += candidates->data[i].p;
157
+
158
+ // Check if the running sum is at least p or if we have kept at least min_keep tokens
159
+ // we set the last index to i+1 to indicate that the current iterate should be included in the set
160
+ if (cum_sum >= p && i + 1 >= min_keep) {
161
+ last_idx = i + 1;
162
+ break;
163
+ }
164
+ }
165
+
166
+ // Resize the output vector to keep only the top-p tokens
167
+ candidates->size = last_idx;
168
+
169
+ if (smpl) {
170
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
171
+ }
172
+ }
173
+
174
+ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
175
+ if (p <= 0.0f || !candidates->size) {
176
+ return;
177
+ }
178
+
179
+ const int64_t t_start_sample_us = ggml_time_us();
180
+
181
+ bool min_p_applied = false;
182
+
183
+ // if the candidates aren't sorted, try the unsorted implementation first
184
+ if (!candidates->sorted) {
185
+ std::vector<llama_token_data> filtered_tokens;
186
+
187
+ float max_logit = -FLT_MAX;
188
+ for (size_t i = 0; i < candidates->size; ++i) {
189
+ max_logit = std::max(max_logit, candidates->data[i].logit);
190
+ }
191
+ const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
192
+
193
+ for (size_t i = 0; i < candidates->size; ++i) {
194
+ if (candidates->data[i].logit >= min_logit) {
195
+ filtered_tokens.push_back(candidates->data[i]);
196
+ }
197
+ }
198
+
199
+ // if we have enough values the operation was a success
200
+ if (filtered_tokens.size() >= min_keep) {
201
+ memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
202
+ candidates->size = filtered_tokens.size();
203
+ min_p_applied = true;
204
+ }
205
+ }
206
+
207
+ // if the candidates are sorted or the unsorted implementation failed, use this implementation
208
+ if (!min_p_applied) {
209
+ // Sort the logits in descending order
210
+ if (!candidates->sorted) {
211
+ std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
212
+ return a.logit > b.logit;
213
+ });
214
+ candidates->sorted = true;
215
+ }
216
+
217
+ const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
218
+ size_t i = 1; // first token always matches
219
+
220
+ for (; i < candidates->size; ++i) {
221
+ if (candidates->data[i].logit < min_logit && i >= min_keep) {
222
+ break; // prob too small
223
+ }
224
+ }
225
+
226
+ // Resize the output vector to keep only the matching tokens
227
+ candidates->size = i;
228
+ }
229
+
230
+ if (smpl) {
231
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
232
+ }
233
+ }
234
+
235
+ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
236
+ if (z >= 1.0f || candidates->size <= 2) {
237
+ return;
238
+ }
239
+
240
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
241
+ const int64_t t_start_sample_us = ggml_time_us();
242
+
243
+ // Compute the first and second derivatives
244
+ std::vector<float> first_derivatives(candidates->size - 1);
245
+ std::vector<float> second_derivatives(candidates->size - 2);
246
+
247
+ for (size_t i = 0; i < first_derivatives.size(); ++i) {
248
+ first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
249
+ }
250
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
251
+ second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
252
+ }
253
+
254
+ // Calculate absolute value of second derivatives
255
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
256
+ second_derivatives[i] = std::abs(second_derivatives[i]);
257
+ }
258
+
259
+ // Normalize the second derivatives
260
+ {
261
+ const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
262
+
263
+ if (second_derivatives_sum > 1e-6f) {
264
+ for (float & value : second_derivatives) {
265
+ value /= second_derivatives_sum;
266
+ }
267
+ } else {
268
+ for (float & value : second_derivatives) {
269
+ value = 1.0f / second_derivatives.size();
270
+ }
271
+ }
272
+ }
273
+
274
+ float cum_sum = 0.0f;
275
+ size_t last_idx = candidates->size;
276
+ for (size_t i = 0; i < second_derivatives.size(); ++i) {
277
+ cum_sum += second_derivatives[i];
278
+
279
+ // Check if the running sum is greater than z or if we have kept at least min_keep tokens
280
+ if (cum_sum > z && i >= min_keep) {
281
+ last_idx = i;
282
+ break;
283
+ }
284
+ }
285
+
286
+ // Resize the output vector to keep only the tokens above the tail location
287
+ candidates->size = last_idx;
288
+
289
+ if (smpl) {
290
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
291
+ }
292
+ }
293
+
294
+ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
295
+ // Reference implementation:
296
+ // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
297
+ if (p >= 1.0f) {
298
+ return;
299
+ }
300
+
301
+ // Compute the softmax of logits and calculate entropy
302
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
303
+
304
+ const int64_t t_start_sample_us = ggml_time_us();
305
+
306
+ float entropy = 0.0f;
307
+ for (size_t i = 0; i < candidates->size; ++i) {
308
+ entropy += -candidates->data[i].p * logf(candidates->data[i].p);
309
+ }
310
+
311
+ // Compute the absolute difference between negative log probability and entropy for each candidate
312
+ std::vector<float> shifted_scores;
313
+ for (size_t i = 0; i < candidates->size; ++i) {
314
+ float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
315
+ shifted_scores.push_back(shifted_score);
316
+ }
317
+
318
+ // Sort tokens based on the shifted_scores and their corresponding indices
319
+ std::vector<size_t> indices(candidates->size);
320
+ std::iota(indices.begin(), indices.end(), 0);
321
+
322
+ std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
323
+ return shifted_scores[a] < shifted_scores[b];
324
+ });
325
+
326
+ // Compute the cumulative probabilities
327
+ float cum_sum = 0.0f;
328
+ size_t last_idx = indices.size();
329
+
330
+ for (size_t i = 0; i < indices.size(); ++i) {
331
+ size_t idx = indices[i];
332
+ cum_sum += candidates->data[idx].p;
333
+
334
+ // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
335
+ if (cum_sum > p && i >= min_keep - 1) {
336
+ last_idx = i + 1;
337
+ break;
338
+ }
339
+ }
340
+
341
+ // Resize the output vector to keep only the locally typical tokens
342
+ std::vector<llama_token_data> new_candidates;
343
+ for (size_t i = 0; i < last_idx; ++i) {
344
+ size_t idx = indices[i];
345
+ new_candidates.push_back(candidates->data[idx]);
346
+ }
347
+
348
+ // Replace the data in candidates with the new_candidates data
349
+ std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
350
+ candidates->size = new_candidates.size();
351
+ candidates->sorted = false;
352
+
353
+ if (smpl) {
354
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
355
+ }
356
+ }
357
+
358
+ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
359
+ const int64_t t_start_sample_us = ggml_time_us();
360
+
361
+ // no need to do anything if there is only one (or zero) candidates
362
+ if(candidates->size <= 1) {
363
+ return;
364
+ }
365
+
366
+ // Calculate maximum possible entropy
367
+ float max_entropy = -logf(1.0f / candidates->size);
368
+
369
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
370
+
371
+ // Calculate entropy of the softmax probabilities
372
+ float entropy = 0.0f;
373
+ for (size_t i = 0; i < candidates->size; ++i) {
374
+ float prob = candidates->data[i].p;
375
+ if (prob > 0.0f) { // Ensure no log(0)
376
+ entropy -= prob * logf(prob);
377
+ }
378
+ }
379
+
380
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
381
+ float normalized_entropy = entropy / max_entropy;
382
+
383
+ // Map the normalized entropy to the desired temperature range using the power function
384
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
385
+
386
+ #ifdef DEBUG
387
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
388
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
389
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
390
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
391
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
392
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
393
+ #endif
394
+
395
+ // Apply the dynamically calculated temperature scaling
396
+ for (size_t i = 0; i < candidates->size; ++i) {
397
+ candidates->data[i].logit /= dyn_temp;
398
+ }
399
+
400
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
401
+ double max_l_double = candidates->data[0].logit;
402
+ double cum_sum_double = 0.0;
403
+ for (size_t i = 0; i < candidates->size; ++i) {
404
+ double p = exp(candidates->data[i].logit - max_l_double);
405
+ candidates->data[i].p = p; // Store the scaled probability
406
+ cum_sum_double += p;
407
+ }
408
+ for (size_t i = 0; i < candidates->size; ++i) {
409
+ candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
410
+ }
411
+
412
+ #ifdef DEBUG
413
+ // Print the updated top 25 probabilities after temperature scaling
414
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
415
+ for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
416
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
417
+ }
418
+ #endif
419
+
420
+ if (smpl) {
421
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
422
+ }
423
+ }
424
+
425
+ void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
426
+ const int64_t t_start_sample_us = ggml_time_us();
427
+
428
+ for (size_t i = 0; i < candidates->size; ++i) {
429
+ candidates->data[i].logit /= temp;
430
+ }
431
+
432
+ if (smpl) {
433
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
434
+ }
435
+ }
436
+
437
+ void llama_sample_repetition_penalties_impl(
438
+ struct llama_sampling * smpl,
439
+ llama_token_data_array * candidates,
440
+ const llama_token * last_tokens,
441
+ size_t penalty_last_n,
442
+ float penalty_repeat,
443
+ float penalty_freq,
444
+ float penalty_present) {
445
+ if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
446
+ return;
447
+ }
448
+
449
+ const int64_t t_start_sample_us = ggml_time_us();
450
+
451
+ // Create a frequency map to count occurrences of each token in last_tokens
452
+ std::unordered_map<llama_token, int> token_count;
453
+ for (size_t i = 0; i < penalty_last_n; ++i) {
454
+ token_count[last_tokens[i]]++;
455
+ }
456
+
457
+ // Apply frequency and presence penalties to the candidates
458
+ for (size_t i = 0; i < candidates->size; ++i) {
459
+ const auto token_iter = token_count.find(candidates->data[i].id);
460
+ if (token_iter == token_count.end()) {
461
+ continue;
462
+ }
463
+
464
+ const int count = token_iter->second;
465
+
466
+ // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
467
+ // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
468
+ if (candidates->data[i].logit <= 0) {
469
+ candidates->data[i].logit *= penalty_repeat;
470
+ } else {
471
+ candidates->data[i].logit /= penalty_repeat;
472
+ }
473
+
474
+ candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
475
+ }
476
+
477
+ candidates->sorted = false;
478
+
479
+ if (smpl) {
480
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
481
+ }
482
+ }
483
+
484
+ void llama_sample_apply_guidance_impl(
485
+ struct llama_sampling * smpl,
486
+ float * logits,
487
+ float * logits_guidance,
488
+ float scale) {
489
+ GGML_ASSERT(smpl);
490
+
491
+ const auto t_start_sample_us = ggml_time_us();
492
+ const auto n_vocab = smpl->n_vocab;
493
+
494
+ llama_log_softmax(logits, n_vocab);
495
+ llama_log_softmax(logits_guidance, n_vocab);
496
+
497
+ for (int i = 0; i < n_vocab; ++i) {
498
+ auto & l = logits[i];
499
+ const auto & g = logits_guidance[i];
500
+
501
+ l = scale * (l - g) + g;
502
+ }
503
+
504
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
505
+ }
506
+
507
+ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
508
+ GGML_ASSERT(smpl);
509
+
510
+ const int32_t n_vocab = float(smpl->n_vocab);
511
+
512
+ int64_t t_start_sample_us = ggml_time_us();
513
+
514
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
515
+
516
+ // Estimate s_hat using the most probable m tokens
517
+ float s_hat = 0.0;
518
+ float sum_ti_bi = 0.0;
519
+ float sum_ti_sq = 0.0;
520
+ for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
521
+ float t_i = logf(float(i + 2) / float(i + 1));
522
+ float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
523
+ sum_ti_bi += t_i * b_i;
524
+ sum_ti_sq += t_i * t_i;
525
+ }
526
+ s_hat = sum_ti_bi / sum_ti_sq;
527
+
528
+ // Compute k from the estimated s_hat and target surprise value
529
+ float epsilon_hat = s_hat - 1;
530
+ float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
531
+
532
+ // Sample the next word X using top-k sampling
533
+ llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
534
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
535
+ llama_token X = llama_sample_token_impl(smpl, candidates);
536
+ t_start_sample_us = ggml_time_us();
537
+
538
+ // Compute error as the difference between observed surprise and target surprise value
539
+ size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
540
+ return candidate.id == X;
541
+ }));
542
+ float observed_surprise = -log2f(candidates->data[X_idx].p);
543
+ float e = observed_surprise - tau;
544
+
545
+ // Update mu using the learning rate and error
546
+ *mu = *mu - eta * e;
547
+
548
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
549
+ return X;
550
+ }
551
+
552
+ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
553
+ int64_t t_start_sample_us;
554
+ t_start_sample_us = ggml_time_us();
555
+
556
+ llama_sample_softmax_impl(smpl, candidates);
557
+
558
+ // Truncate the words with surprise values greater than mu
559
+ candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
560
+ return -log2f(candidate.p) > *mu;
561
+ }));
562
+
563
+ if (candidates->size == 0) {
564
+ candidates->size = 1;
565
+ }
566
+
567
+ if (smpl) {
568
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
569
+ }
570
+
571
+ // Normalize the probabilities of the remaining words
572
+ llama_sample_softmax_impl(smpl, candidates);
573
+
574
+ // Sample the next word X from the remaining words
575
+ llama_token X = llama_sample_token_impl(smpl, candidates);
576
+ t_start_sample_us = ggml_time_us();
577
+
578
+ // Compute error as the difference between observed surprise and target surprise value
579
+ size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
580
+ return candidate.id == X;
581
+ }));
582
+ float observed_surprise = -log2f(candidates->data[X_idx].p);
583
+ float e = observed_surprise - tau;
584
+
585
+ // Update mu using the learning rate and error
586
+ *mu = *mu - eta * e;
587
+
588
+ if (smpl) {
589
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
590
+ }
591
+ return X;
592
+ }
593
+
594
+ llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
595
+ const int64_t t_start_sample_us = ggml_time_us();
596
+
597
+ // Find max element
598
+ auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
599
+ return a.logit < b.logit;
600
+ });
601
+
602
+ llama_token result = max_iter->id;
603
+ if (smpl) {
604
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
605
+ smpl->n_sample++;
606
+ }
607
+ return result;
608
+ }
609
+
610
+ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
611
+ GGML_ASSERT(smpl);
612
+
613
+ const int64_t t_start_sample_us = ggml_time_us();
614
+ llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
615
+
616
+ std::vector<float> probs;
617
+ probs.reserve(candidates->size);
618
+ for (size_t i = 0; i < candidates->size; ++i) {
619
+ probs.push_back(candidates->data[i].p);
620
+ }
621
+
622
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
623
+ int idx = dist(rng);
624
+
625
+ llama_token result = candidates->data[idx].id;
626
+
627
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
628
+ smpl->n_sample++;
629
+
630
+ return result;
631
+ }
632
+
633
+ llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
634
+ return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
635
+ }
examples/talk-llama/llama-sampling.h ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-impl.h"
4
+
5
+ struct llama_sampling {
6
+ llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
7
+
8
+ std::mt19937 rng;
9
+
10
+ int32_t n_vocab = 0;
11
+
12
+ mutable int64_t t_sample_us = 0;
13
+ mutable int32_t n_sample = 0;
14
+
15
+ void reset_timings() const {
16
+ t_sample_us = 0;
17
+ n_sample = 0;
18
+ }
19
+ };
20
+
21
+ //
22
+ // internal API
23
+ //
24
+
25
+ void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
26
+
27
+ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
28
+ void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
29
+ void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
30
+ void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
31
+ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
32
+ void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
33
+ void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
34
+ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
35
+
36
+ void llama_sample_repetition_penalties_impl(
37
+ struct llama_sampling * smpl,
38
+ llama_token_data_array * candidates,
39
+ const llama_token * last_tokens,
40
+ size_t penalty_last_n,
41
+ float penalty_repeat,
42
+ float penalty_freq,
43
+ float penalty_present);
44
+
45
+ void llama_sample_apply_guidance_impl(
46
+ struct llama_sampling * smpl,
47
+ float * logits,
48
+ float * logits_guidance,
49
+ float scale);
50
+
51
+ llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
52
+ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
53
+ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
54
+ llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
55
+ llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
56
+
examples/talk-llama/llama-vocab.cpp ADDED
@@ -0,0 +1,1729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-vocab.h"
2
+
3
+ #include "unicode.h"
4
+
5
+ #include <algorithm>
6
+ #include <cassert>
7
+ #include <cfloat>
8
+ #include <climits>
9
+ #include <cstdarg>
10
+ #include <cstring>
11
+ #include <forward_list>
12
+ #include <queue>
13
+ #include <sstream>
14
+
15
+ //
16
+ // helpers
17
+ //
18
+
19
+ static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
20
+ std::string result;
21
+ for (size_t pos = 0; ; pos += search.length()) {
22
+ auto new_pos = s.find(search, pos);
23
+ if (new_pos == std::string::npos) {
24
+ result += s.substr(pos, s.size() - pos);
25
+ break;
26
+ }
27
+ result += s.substr(pos, new_pos - pos) + replace;
28
+ pos = new_pos;
29
+ }
30
+ s = std::move(result);
31
+ }
32
+
33
+ LLAMA_ATTRIBUTE_FORMAT(1, 2)
34
+ static std::string format(const char * fmt, ...) {
35
+ va_list ap;
36
+ va_list ap2;
37
+ va_start(ap, fmt);
38
+ va_copy(ap2, ap);
39
+ int size = vsnprintf(NULL, 0, fmt, ap);
40
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
41
+ std::vector<char> buf(size + 1);
42
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
43
+ GGML_ASSERT(size2 == size);
44
+ va_end(ap2);
45
+ va_end(ap);
46
+ return std::string(buf.data(), size);
47
+ }
48
+
49
+ struct naive_trie {
50
+ naive_trie() : has_value(false), value(0) {
51
+ }
52
+ void insert(const char * key, size_t len, int32_t value = 0) {
53
+ if (len == 0) {
54
+ this->has_value = true;
55
+ this->value = value;
56
+ return;
57
+ }
58
+ char c = key[0];
59
+ auto res = children.find(c);
60
+ if (res != children.end()) {
61
+ res->second.insert(key + 1, len - 1, value);
62
+ } else {
63
+ auto res = children.insert(std::make_pair(c, naive_trie()));
64
+ res.first->second.insert(key + 1, len - 1, value);
65
+ }
66
+ }
67
+ std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
68
+ if (len == 0 || offset == len) {
69
+ return std::make_pair(key, offset);
70
+ }
71
+ char c = key[offset];
72
+ auto res = children.find(c);
73
+ if (res != children.end()) {
74
+ return res->second.get_longest_prefix(key, len, offset + 1);
75
+ } else {
76
+ return std::make_pair(key, offset);
77
+ }
78
+ }
79
+ struct naive_trie * traverse(const char c) {
80
+ auto res = children.find(c);
81
+ if (res != children.end()) {
82
+ return &res->second;
83
+ } else {
84
+ return NULL;
85
+ }
86
+ }
87
+ std::map<char, struct naive_trie> children;
88
+ bool has_value;
89
+ llama_token value;
90
+ };
91
+
92
+ //
93
+ // impl
94
+ //
95
+
96
+ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
97
+ GGML_ASSERT(token_left.find(' ') == std::string::npos);
98
+ GGML_ASSERT(token_left.find('\n') == std::string::npos);
99
+ GGML_ASSERT(token_right.find(' ') == std::string::npos);
100
+ GGML_ASSERT(token_right.find('\n') == std::string::npos);
101
+
102
+ auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
103
+ if (it == bpe_ranks.end()) {
104
+ return -1;
105
+ }
106
+
107
+ return it->second;
108
+ }
109
+
110
+ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
111
+ return vocab.type;
112
+ }
113
+
114
+ static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
115
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
116
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
117
+ }
118
+
119
+ static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
120
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
121
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
122
+ }
123
+
124
+ static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
125
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
126
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
127
+ }
128
+
129
+ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
130
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
131
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
132
+ }
133
+
134
+ static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) {
135
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
136
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
137
+ }
138
+
139
+ static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) {
140
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
141
+ return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
142
+ }
143
+
144
+ static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
145
+ GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
146
+ GGML_ASSERT(llama_is_byte_token(vocab, id));
147
+ const auto & token_data = vocab.id_to_token.at(id);
148
+ switch (llama_vocab_get_type(vocab)) {
149
+ case LLAMA_VOCAB_TYPE_SPM:
150
+ case LLAMA_VOCAB_TYPE_UGM: {
151
+ auto buf = token_data.text.substr(3, 2);
152
+ return strtol(buf.c_str(), NULL, 16);
153
+ }
154
+ case LLAMA_VOCAB_TYPE_BPE: {
155
+ GGML_ABORT("fatal error");
156
+ //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
157
+ }
158
+ case LLAMA_VOCAB_TYPE_WPM: {
159
+ GGML_ABORT("fatal error");
160
+ }
161
+ default:
162
+ GGML_ABORT("fatal error");
163
+ }
164
+ }
165
+
166
+ static void llama_escape_whitespace(std::string & text) {
167
+ replace_all(text, " ", "\xe2\x96\x81");
168
+ }
169
+
170
+ static void llama_unescape_whitespace(std::string & word) {
171
+ replace_all(word, "\xe2\x96\x81", " ");
172
+ }
173
+
174
+ struct llm_symbol {
175
+ using index = int;
176
+ index prev;
177
+ index next;
178
+ const char * text;
179
+ size_t n;
180
+ };
181
+
182
+ static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
183
+
184
+ //
185
+ // SPM tokenizer
186
+ // original implementation:
187
+ // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
188
+ //
189
+
190
+ struct llm_bigram_spm {
191
+ struct comparator {
192
+ bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
193
+ return (l.score < r.score) || (l.score == r.score && l.left > r.left);
194
+ }
195
+ };
196
+ using queue_storage = std::vector<llm_bigram_spm>;
197
+ using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
198
+ llm_symbol::index left;
199
+ llm_symbol::index right;
200
+ float score;
201
+ size_t size;
202
+ };
203
+
204
+ struct llm_tokenizer_spm {
205
+ llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
206
+
207
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
208
+ // split string into utf8 chars
209
+ int index = 0;
210
+ size_t offs = 0;
211
+ while (offs < text.size()) {
212
+ llm_symbol sym;
213
+ size_t len = unicode_len_utf8(text[offs]);
214
+ sym.text = text.c_str() + offs;
215
+ sym.n = std::min(len, text.size() - offs);
216
+ offs += sym.n;
217
+ sym.prev = index - 1;
218
+ sym.next = offs == text.size() ? -1 : index + 1;
219
+ index++;
220
+ symbols.emplace_back(sym);
221
+ }
222
+
223
+ // seed the work queue with all possible 2-character tokens.
224
+ for (size_t i = 1; i < symbols.size(); ++i) {
225
+ try_add_bigram(i - 1, i);
226
+ }
227
+
228
+ // keep substituting the highest frequency pairs for as long as we can.
229
+ while (!work_queue.empty()) {
230
+ auto bigram = work_queue.top();
231
+ work_queue.pop();
232
+
233
+ auto & left_sym = symbols[bigram.left];
234
+ auto & right_sym = symbols[bigram.right];
235
+
236
+ // if one of the symbols already got merged, skip it.
237
+ if (left_sym.n == 0 || right_sym.n == 0 ||
238
+ left_sym.n + right_sym.n != bigram.size) {
239
+ continue;
240
+ }
241
+
242
+ // merge the right sym into the left one
243
+ left_sym.n += right_sym.n;
244
+ right_sym.n = 0;
245
+
246
+ //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
247
+
248
+ // remove the right sym from the chain
249
+ left_sym.next = right_sym.next;
250
+ if (right_sym.next >= 0) {
251
+ symbols[right_sym.next].prev = bigram.left;
252
+ }
253
+
254
+ // find more substitutions
255
+ try_add_bigram(left_sym.prev, bigram.left);
256
+ try_add_bigram(bigram.left, left_sym.next);
257
+ }
258
+
259
+ for (int i = 0; i != -1; i = symbols[i].next) {
260
+ auto & symbol = symbols[i];
261
+ resegment(symbol, output);
262
+ }
263
+ }
264
+
265
+ private:
266
+ void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) {
267
+ auto text = std::string(symbol.text, symbol.n);
268
+ auto token = vocab.token_to_id.find(text);
269
+
270
+ // Do we need to support is_unused?
271
+ if (token != vocab.token_to_id.end()) {
272
+ output.push_back((*token).second);
273
+ return;
274
+ }
275
+
276
+ const auto p = rev_merge.find(text);
277
+
278
+ if (p == rev_merge.end()) {
279
+ // output any symbols that did not form tokens as bytes.
280
+ output.reserve(output.size() + symbol.n);
281
+ for (int j = 0; j < (int)symbol.n; ++j) {
282
+ llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]);
283
+ output.push_back(token_id);
284
+ }
285
+ return;
286
+ }
287
+
288
+ resegment(symbols[p->second.first], output);
289
+ resegment(symbols[p->second.second], output);
290
+ }
291
+
292
+ void try_add_bigram(int left, int right) {
293
+ if (left == -1 || right == -1) {
294
+ return;
295
+ }
296
+
297
+ const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
298
+ auto token = vocab.token_to_id.find(text);
299
+
300
+ if (token == vocab.token_to_id.end()) {
301
+ return;
302
+ }
303
+
304
+ if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
305
+ return;
306
+ }
307
+
308
+ const auto & tok_data = vocab.id_to_token[(*token).second];
309
+
310
+ llm_bigram_spm bigram;
311
+ bigram.left = left;
312
+ bigram.right = right;
313
+ bigram.score = tok_data.score;
314
+ bigram.size = text.size();
315
+
316
+ work_queue.push(bigram);
317
+
318
+ // Do we need to support is_unused?
319
+ rev_merge[text] = std::make_pair(left, right);
320
+ }
321
+
322
+ const llama_vocab & vocab;
323
+
324
+ std::vector<llm_symbol> symbols;
325
+ llm_bigram_spm::queue work_queue;
326
+
327
+ std::map<std::string, std::pair<int, int>> rev_merge;
328
+ };
329
+
330
+ //
331
+ // BPE tokenizer
332
+ // adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License]
333
+ // tried to simplify unicode stuff, so most likely does not work 100% correctly!
334
+ //
335
+
336
+ // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
337
+
338
+ struct llm_bigram_bpe {
339
+ struct comparator {
340
+ bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
341
+ return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
342
+ }
343
+ };
344
+
345
+ using queue_storage = std::vector<llm_bigram_bpe>;
346
+ using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
347
+ llm_symbol::index left;
348
+ llm_symbol::index right;
349
+ std::string text;
350
+ int rank;
351
+ size_t size;
352
+ };
353
+
354
+ struct llm_tokenizer_bpe {
355
+ llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
356
+ GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
357
+ switch (vocab.type_pre) {
358
+ case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
359
+ regex_exprs = {
360
+ // original regex from tokenizer.json
361
+ //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
362
+
363
+ // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
364
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
365
+ };
366
+ break;
367
+ case LLAMA_VOCAB_PRE_TYPE_DBRX:
368
+ case LLAMA_VOCAB_PRE_TYPE_SMAUG:
369
+ regex_exprs = {
370
+ // same as llama3
371
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
372
+ };
373
+ break;
374
+ case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
375
+ regex_exprs = {
376
+ "[\r\n]",
377
+ "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
378
+ "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
379
+ "\\s+$",
380
+ "[一-龥ࠀ-一가-퟿]+",
381
+ "\\p{N}+",
382
+ };
383
+ break;
384
+ case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
385
+ regex_exprs = {
386
+ "[\r\n]",
387
+ "\\s?\\p{L}+",
388
+ "\\s?\\p{P}+",
389
+ "[一-龥ࠀ-一가-퟿]+",
390
+ "\\p{N}",
391
+ };
392
+ break;
393
+ case LLAMA_VOCAB_PRE_TYPE_FALCON:
394
+ regex_exprs = {
395
+ "[\\p{P}\\$\\+<=>\\^~\\|`]+",
396
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
397
+ "[0-9][0-9][0-9]",
398
+ };
399
+ break;
400
+ case LLAMA_VOCAB_PRE_TYPE_STARCODER:
401
+ case LLAMA_VOCAB_PRE_TYPE_REFACT:
402
+ case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
403
+ case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
404
+ case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
405
+ regex_exprs = {
406
+ "\\p{N}",
407
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
408
+ };
409
+ break;
410
+ case LLAMA_VOCAB_PRE_TYPE_GPT2:
411
+ case LLAMA_VOCAB_PRE_TYPE_MPT:
412
+ case LLAMA_VOCAB_PRE_TYPE_OLMO:
413
+ case LLAMA_VOCAB_PRE_TYPE_JAIS:
414
+ regex_exprs = {
415
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
416
+ };
417
+ break;
418
+ case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
419
+ case LLAMA_VOCAB_PRE_TYPE_QWEN2:
420
+ regex_exprs = {
421
+ // original regex from tokenizer.json
422
+ // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
423
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
424
+ };
425
+ break;
426
+ case LLAMA_VOCAB_PRE_TYPE_PORO:
427
+ regex_exprs = {
428
+ " ?[^(\\s|.,!?…。,、।۔،)]+",
429
+ };
430
+ break;
431
+ case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
432
+ regex_exprs = {
433
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
434
+ };
435
+ break;
436
+ case LLAMA_VOCAB_PRE_TYPE_VIKING:
437
+ regex_exprs = {
438
+ " ?[^(\\s|.,!?…。,、।۔،)]+",
439
+ "\\p{N}",
440
+ };
441
+ break;
442
+ case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
443
+ // original regex from tokenizer.json
444
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
445
+ regex_exprs = {
446
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
447
+ };
448
+ break;
449
+ default:
450
+ // default regex for BPE tokenization pre-processing
451
+ regex_exprs = {
452
+ "[\\p{P}\\$\\+<=>\\^~\\|]+",
453
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
454
+ "\\p{N}+",
455
+ "[0-9][0-9][0-9]",
456
+ };
457
+ break;
458
+ }
459
+ }
460
+
461
+ void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
462
+ output.push_back(token_id);
463
+ }
464
+
465
+ bool append_bos(std::vector<llama_vocab::id> & output) const {
466
+ if (vocab.tokenizer_add_bos) {
467
+ GGML_ASSERT(vocab.special_bos_id != -1);
468
+ output.push_back(vocab.special_bos_id);
469
+ return true;
470
+ }
471
+ return false;
472
+ }
473
+
474
+ bool append_eos(std::vector<llama_vocab::id> & output) const {
475
+ if (vocab.tokenizer_add_eos) {
476
+ GGML_ASSERT(vocab.special_eos_id != -1);
477
+ output.push_back(vocab.special_eos_id);
478
+ return true;
479
+ }
480
+ return false;
481
+ }
482
+
483
+ void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
484
+ if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
485
+ LLAMA_LOG_WARN(
486
+ "%s: Added a BOS token to the prompt as specified by the model but the prompt "
487
+ "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
488
+ "Are you sure this is what you want?\n", __FUNCTION__);
489
+ }
490
+ if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
491
+ LLAMA_LOG_WARN(
492
+ "%s: Added a EOS token to the prompt as specified by the model but the prompt "
493
+ "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
494
+ "Are you sure this is what you want?\n", __FUNCTION__);
495
+ }
496
+ }
497
+
498
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
499
+ int final_prev_index = -1;
500
+
501
+ const auto word_collection = unicode_regex_split(text, regex_exprs);
502
+
503
+ symbols_final.clear();
504
+
505
+ for (auto & word : word_collection) {
506
+ work_queue = llm_bigram_bpe::queue();
507
+ symbols.clear();
508
+
509
+ int index = 0;
510
+ size_t offset = 0;
511
+
512
+ if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
513
+ symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
514
+ offset = word.size();
515
+ }
516
+
517
+ while (offset < word.size()) {
518
+ llm_symbol sym;
519
+ size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset]));
520
+ sym.text = word.c_str() + offset;
521
+ sym.n = char_len;
522
+ offset += sym.n;
523
+ sym.prev = index - 1;
524
+ sym.next = offset == word.size() ? -1 : index + 1;
525
+ index++;
526
+ symbols.emplace_back(sym);
527
+ }
528
+ for (size_t i = 1; i < symbols.size(); ++i) {
529
+ add_new_bigram(i - 1, i);
530
+ }
531
+
532
+ // build token(s)
533
+ while (!work_queue.empty()) {
534
+ auto bigram = work_queue.top();
535
+ work_queue.pop();
536
+
537
+ auto & left_symbol = symbols[bigram.left];
538
+ auto & right_symbol = symbols[bigram.right];
539
+
540
+ if (left_symbol.n == 0 || right_symbol.n == 0) {
541
+ continue;
542
+ }
543
+ std::string left_token = std::string(left_symbol.text, left_symbol.n);
544
+ std::string right_token = std::string(right_symbol.text, right_symbol.n);
545
+ if (left_token + right_token != bigram.text) {
546
+ continue; // Skip this bigram if it's outdated
547
+ }
548
+
549
+ // merge the right sym into the left one
550
+ left_symbol.n += right_symbol.n;
551
+ right_symbol.n = 0;
552
+
553
+ // remove the right sym from the chain
554
+ left_symbol.next = right_symbol.next;
555
+ if (right_symbol.next >= 0) {
556
+ symbols[right_symbol.next].prev = bigram.left;
557
+ }
558
+
559
+ add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol
560
+ add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol
561
+ }
562
+
563
+ // add the finished tokens to the final list keeping correct order for next and prev
564
+ for (auto & sym : symbols) {
565
+ if (sym.n > 0) {
566
+ sym.prev = final_prev_index;
567
+ sym.next = -1;
568
+ if (final_prev_index != -1) {
569
+ symbols_final[final_prev_index].next = symbols_final.size();
570
+ }
571
+ symbols_final.emplace_back(sym);
572
+ final_prev_index = symbols_final.size() - 1;
573
+ }
574
+ }
575
+ }
576
+
577
+ symbols = symbols_final;
578
+
579
+ if (!symbols.empty()) {
580
+ for (int i = 0; i != -1; i = symbols[i].next) {
581
+ auto & symbol = symbols[i];
582
+ if (symbol.n == 0) {
583
+ continue;
584
+ }
585
+
586
+ const std::string str = std::string(symbol.text, symbol.n);
587
+ const auto token = vocab.token_to_id.find(str);
588
+
589
+ if (token == vocab.token_to_id.end()) {
590
+ for (auto j = str.begin(); j != str.end(); ++j) {
591
+ std::string byte_str(1, *j);
592
+ auto token_multibyte = vocab.token_to_id.find(byte_str);
593
+ if (token_multibyte != vocab.token_to_id.end()) {
594
+ output.push_back(token_multibyte->second);
595
+ }
596
+ }
597
+ } else {
598
+ output.push_back((*token).second);
599
+ }
600
+ }
601
+ }
602
+ }
603
+
604
+ private:
605
+ void add_new_bigram(int left, int right) {
606
+ if (left == -1 || right == -1) {
607
+ return;
608
+ }
609
+
610
+ std::string left_token = std::string(symbols[left].text, symbols[left].n);
611
+ std::string right_token = std::string(symbols[right].text, symbols[right].n);
612
+
613
+ int rank_found = -1;
614
+
615
+ rank_found = vocab.find_bpe_rank(left_token, right_token);
616
+
617
+ if (rank_found < 0) {
618
+ return;
619
+ }
620
+
621
+ llm_bigram_bpe bigram;
622
+
623
+ bigram.left = left;
624
+ bigram.right = right;
625
+ bigram.text = left_token + right_token;
626
+ bigram.size = left_token.size() + right_token.size();
627
+ bigram.rank = rank_found;
628
+
629
+ work_queue.push(bigram);
630
+ }
631
+
632
+ const llama_vocab & vocab;
633
+
634
+ std::vector<std::string> regex_exprs;
635
+
636
+ std::vector<llm_symbol> symbols;
637
+ std::vector<llm_symbol> symbols_final;
638
+
639
+ llm_bigram_bpe::queue work_queue;
640
+ };
641
+
642
+ //
643
+ // WPM tokenizer
644
+ //
645
+
646
+ struct llm_tokenizer_wpm {
647
+ llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
648
+
649
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
650
+ const auto & token_map = vocab.token_to_id;
651
+
652
+ // normalize and split by whitespace
653
+ std::vector<std::string> words = preprocess(text);
654
+
655
+ // bos token prepended already
656
+
657
+ // find the longest tokens that form the words
658
+ for (const std::string & word : words) {
659
+ // skip empty words
660
+ if (word.size() == 0) {
661
+ continue;
662
+ }
663
+
664
+ // prepend phantom space
665
+ const std::string word1 = "\xe2\x96\x81" + word;
666
+ const int n = word1.size();
667
+
668
+ const size_t current_tokens = output.size();
669
+
670
+ // we're at the start of a new word
671
+ // move through character position in word
672
+ for (int i = 0; i < n; ++i) {
673
+ // loop through possible match length
674
+ bool match = false;
675
+ for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
676
+ auto it = token_map.find(word1.substr(i, j - i));
677
+ if (it != token_map.end()) {
678
+ output.push_back(it->second);
679
+ match = true;
680
+ i = j - 1;
681
+ break;
682
+ }
683
+ }
684
+
685
+ if (!match) { // discard all
686
+ output.resize(current_tokens);
687
+ break; // and discard next tokens
688
+ }
689
+ }
690
+
691
+ // we didn't find any matches for this word
692
+ if (current_tokens == output.size()) {
693
+ output.push_back(vocab.special_unk_id);
694
+ }
695
+ }
696
+ }
697
+
698
+ // TODO: reduce string copies by using cpts_offs array
699
+ std::vector<std::string> preprocess(const std::string & text) const {
700
+ const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
701
+ std::vector<std::string> words(1, "");
702
+
703
+ for (const uint32_t cpt : cpts_nfd) {
704
+ const auto flags = unicode_cpt_flags(cpt);
705
+
706
+ if (flags.is_whitespace) {
707
+ if (words.back().size()) { // finish previous word if any
708
+ words.emplace_back();
709
+ }
710
+ continue;
711
+ }
712
+
713
+ assert (!flags.is_separator);
714
+ if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
715
+ continue;
716
+ }
717
+
718
+ const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
719
+ if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
720
+ if (words.back().size()) { // finish previous word if any
721
+ words.emplace_back();
722
+ }
723
+ words.back() = s; // single char word
724
+ words.emplace_back(); // start a new word
725
+ } else {
726
+ words.back() += s; // append char to word
727
+ }
728
+ }
729
+
730
+ if (!words.back().size()) {
731
+ words.pop_back();
732
+ }
733
+
734
+ return words;
735
+ }
736
+
737
+ static bool is_chinese_char(uint32_t cpt) {
738
+ return
739
+ (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
740
+ (cpt >= 0x03400 && cpt <= 0x04DBF) ||
741
+ (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
742
+ (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
743
+ (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
744
+ (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
745
+ (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
746
+ (cpt >= 0x2F800 && cpt <= 0x2FA1F);
747
+ //(cpt >= 0x3000 && cpt <= 0x303F) ||
748
+ //(cpt >= 0xFF00 && cpt <= 0xFFEF);
749
+ }
750
+
751
+ const llama_vocab & vocab;
752
+ };
753
+
754
+ //
755
+ // UGM tokenizer
756
+ //
757
+
758
+ struct llm_tokenizer_ugm {
759
+ llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
760
+ if (vocab.precompiled_charsmap.size() > 0) {
761
+ size_t charsmap_offset = 0;
762
+
763
+ // First four bytes of precompiled_charsmap contains length of binary
764
+ // blob containing XOR-compressed compact double array (XCDA) entries
765
+ uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
766
+ charsmap_offset += sizeof(xcda_blob_size);
767
+ if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
768
+ throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
769
+ }
770
+
771
+ // Next xcda_blob_size bytes contain entries of XOR-compressed compact
772
+ // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
773
+ xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
774
+ xcda_array_size = xcda_blob_size / sizeof(uint32_t);
775
+ charsmap_offset += xcda_blob_size;
776
+
777
+ // Remaining bytes of precompiled charsmap contain null-terminated
778
+ // replacement strings for prefixes matched by the XCDA.
779
+ prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
780
+ prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
781
+ }
782
+
783
+ for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
784
+ const auto &token_data = vocab.id_to_token[id];
785
+
786
+ if (llama_is_normal_token(vocab, id)) {
787
+ min_score = std::min<float>(min_score, token_data.score);
788
+ max_score = std::max<float>(max_score, token_data.score);
789
+ }
790
+
791
+ if (llama_is_normal_token(vocab, id) ||
792
+ llama_is_user_defined_token(vocab, id) ||
793
+ llama_is_unused_token(vocab, id)) {
794
+ token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
795
+ }
796
+
797
+ if (llama_is_user_defined_token(vocab, id)) {
798
+ user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
799
+ }
800
+ }
801
+
802
+ unknown_token_score = min_score - unknown_token_score_penalty;
803
+ }
804
+
805
+ /* This implementation is based on SentencePiece optimized Viterbi algorithm for
806
+ * unigram language models. The general idea is to:
807
+ * - move along the input sequence in steps of one UTF code point,
808
+ * - at each step find all possible tokenizations of the prefix by
809
+ * traversing the tokens trie,
810
+ * - for each tokenization store the best one so far (by higher score)
811
+ * - use the position in sequence after given token as an index to store
812
+ * results
813
+ * - if there was no valid tokenization of the current UTF code point
814
+ * then use unknown token with additional score penalty
815
+ * After processing the whole sequence we backtrack from the end to get
816
+ * the best tokenization.
817
+ */
818
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
819
+ // get current size of output (for reversal later)
820
+ size_t output_size = output.size();
821
+
822
+ // normalize the input first
823
+ std::string normalized;
824
+ normalize(text, &normalized);
825
+ size_t input_len = normalized.size();
826
+ if (input_len == 0) {
827
+ return;
828
+ }
829
+
830
+ // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
831
+ std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
832
+ // at the beginning tokenization score is zero
833
+ tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
834
+
835
+ for (size_t input_offset = 0; input_offset < input_len;) {
836
+ size_t prefix_offset = input_offset;
837
+ // calculate how many code units are in the currently processed UTF code point
838
+ size_t n_utf8_code_units = std::min<size_t>(unicode_len_utf8(normalized[input_offset]), input_len - input_offset);
839
+
840
+ // traverse the token matcher trie to find a matching token
841
+ bool single_codepoint_token_found = false;
842
+ const struct best_tokenization & current_best = tokenization_results[input_offset];
843
+ struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
844
+
845
+ while (prefix_offset <= input_len && node != NULL) {
846
+ // check if we found valid token in prefix
847
+ if (node->has_value) {
848
+ // check if it corresponds to the whole UTF code point
849
+ if (prefix_offset - input_offset == n_utf8_code_units) {
850
+ single_codepoint_token_found = true;
851
+ }
852
+ llama_token token_id = node->value;
853
+ const auto & token_data = vocab.id_to_token[token_id];
854
+
855
+ // we set the user-defined token scores to 0 to make them more likely to be selected
856
+ // (normal token scores are log probabilities, so they are negative)
857
+ // score type is double here to make tokenization results exactly
858
+ // the same as in the HF tokenizer using SentencePiece
859
+ const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
860
+ const double challenger_score = current_best.score_sum + token_score;
861
+ struct best_tokenization & current_champ = tokenization_results[prefix_offset];
862
+ if (challenger_score > current_champ.score_sum) {
863
+ struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
864
+ current_champ = challenger;
865
+ }
866
+ }
867
+ node = node->traverse(normalized[prefix_offset++]);
868
+ }
869
+
870
+ // if we didn't find a valid token corresponding to the whole UTF code point
871
+ // then use unknown token as the tokenization of this UTF code point
872
+ if (!single_codepoint_token_found) {
873
+ const double challenger_score = current_best.score_sum + unknown_token_score;
874
+ prefix_offset = input_offset + n_utf8_code_units;
875
+ struct best_tokenization & current_champ = tokenization_results[prefix_offset];
876
+ if (challenger_score > current_champ.score_sum) {
877
+ struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
878
+ current_champ = challenger;
879
+ }
880
+ }
881
+
882
+ // move to the next UTF code point
883
+ input_offset += n_utf8_code_units;
884
+ }
885
+
886
+ // now backtrack from the end to gather token ids of the best tokenization
887
+ // merge sequences of consecutive unknown tokens into single unknown tokens
888
+ bool is_prev_unknown = false;
889
+ for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
890
+ bool is_unknown = tokenization.token_id == vocab.special_unk_id;
891
+ if (!(is_prev_unknown && is_unknown)) {
892
+ output.push_back(tokenization.token_id);
893
+ }
894
+ if (tokenization.input_offset == 0) {
895
+ break;
896
+ }
897
+ is_prev_unknown = is_unknown;
898
+ }
899
+
900
+ // reverse the output since we added tokens starting from the end of the input
901
+ std::reverse(output.begin() + output_size, output.end());
902
+ }
903
+
904
+ private:
905
+ const llama_vocab & vocab;
906
+
907
+ // helper structure for returning normalization results
908
+ struct normalization_result {
909
+ const char * normalized;
910
+ size_t normalized_len;
911
+ size_t consumed_input;
912
+ };
913
+
914
+ void normalize(const std::string& input, std::string * normalized) {
915
+ normalized->clear();
916
+ normalized->reserve(input.size() * 3);
917
+
918
+ const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
919
+
920
+ bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
921
+ bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
922
+ bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
923
+
924
+ bool is_space_prepended = false;
925
+ bool processing_non_ws = false;
926
+
927
+ size_t input_len = input.size();
928
+
929
+ for (size_t input_offset = 0; input_offset < input_len; ) {
930
+ auto norm_res = normalize_prefix(input, input_offset);
931
+ for (size_t i = 0; i < norm_res.normalized_len; i++) {
932
+ char c = norm_res.normalized[i];
933
+ if (c != ' ') {
934
+ if (!processing_non_ws) {
935
+ processing_non_ws = true;
936
+ if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) {
937
+ normalized->append(space);
938
+ is_space_prepended = true;
939
+ }
940
+ }
941
+ normalized->push_back(c);
942
+ } else {
943
+ if (processing_non_ws) {
944
+ processing_non_ws = false;
945
+ }
946
+ if (!shall_merge_spaces) {
947
+ normalized->append(space);
948
+ }
949
+ }
950
+ }
951
+
952
+ input_offset += norm_res.consumed_input;
953
+ }
954
+
955
+ if (shall_append_space) {
956
+ normalized->append(space);
957
+ }
958
+ }
959
+
960
+ /*
961
+ * This structure is a view wrapper for XOR-compressed double array (XCDA)
962
+ * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
963
+ * Eeach bit-packed entry contains:
964
+ * - BASE array value in bits 10-30
965
+ * - LCHECK array value in bits 0-7
966
+ * - LEAF array value in bit 9
967
+ * Entries containing indexes of replacement sequences have set bit 31
968
+ */
969
+ struct xcda_array_view {
970
+ public:
971
+ xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) {
972
+ }
973
+ uint32_t get_base(size_t index) {
974
+ uint32_t packed_node = get_node(index);
975
+ return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6);
976
+ }
977
+ uint32_t get_lcheck(size_t index) {
978
+ uint32_t packed_node = get_node(index);
979
+ return packed_node & ((1U << 31) | 0xff);
980
+ }
981
+ bool get_leaf(size_t index) {
982
+ uint32_t packed_node = get_node(index);
983
+ return (packed_node >> 8) & 1;
984
+ }
985
+ uint32_t get_value(size_t index) {
986
+ uint32_t packed_node = get_node(index);
987
+ return packed_node & ((1U << 31) - 1);
988
+ }
989
+ private:
990
+ uint32_t get_node(size_t index) {
991
+ if (index > xcda_array_size) {
992
+ throw std::runtime_error("Index out of array bounds in XCDA array!");
993
+ }
994
+ return xcda_array[index];
995
+ }
996
+ const uint32_t * xcda_array;
997
+ size_t xcda_array_size;
998
+ };
999
+
1000
+ struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
1001
+ if (input_offset == input.size()) {
1002
+ return { &input[input_offset], 0, 0 };
1003
+ }
1004
+
1005
+ // if input prefix matches some user-defined token return this token as normalization result
1006
+ auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1007
+ if (user_defined_token_match.second > 0) {
1008
+ return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
1009
+ }
1010
+
1011
+ size_t longest_prefix_length = 0;
1012
+ size_t longest_prefix_offset = 0;
1013
+
1014
+ if (xcda_array_size > 0) {
1015
+ struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
1016
+
1017
+ // Find the longest normalized sequence matching the input prefix by walking
1018
+ // the XOR-compressed compact double array (XCDA) starting from the root node
1019
+ // We find the index of the next node by calculating BASE[s] ^ c where s is
1020
+ // the index of the previous node and c is a numerical character value
1021
+ uint32_t node_index = 0;
1022
+ // get BASE of the root node
1023
+ node_index = xcda_view.get_base(node_index);
1024
+ for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
1025
+ unsigned char c = input[prefix_offset];
1026
+ if (c == 0) {
1027
+ break;
1028
+ }
1029
+ node_index ^= c;
1030
+ // if value of LCHECK is not c it means that this is not a child of
1031
+ // the previous node, so we stop matching
1032
+ if (xcda_view.get_lcheck(node_index) != c) {
1033
+ break;
1034
+ }
1035
+ bool is_leaf = xcda_view.get_leaf(node_index);
1036
+ // get BASE of the current node
1037
+ node_index ^= xcda_view.get_base(node_index);
1038
+ // if LEAF of the current node is true, it means that its BASE points to the node
1039
+ // containing index of replacement sequence for currently matched input prefix
1040
+ if (is_leaf)
1041
+ {
1042
+ longest_prefix_length = prefix_offset - input_offset + 1;
1043
+ // get index of replacement sequence for currently matched input prefix
1044
+ longest_prefix_offset = xcda_view.get_value(node_index);
1045
+ }
1046
+ }
1047
+ }
1048
+
1049
+ if (longest_prefix_length > 0) {
1050
+ // we have a match, so return the replacement sequence
1051
+ if (longest_prefix_offset >= prefix_replacements_size) {
1052
+ throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
1053
+ }
1054
+ const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
1055
+ return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
1056
+ } else {
1057
+ // check if the input prefix contains a valid sequence of UTF-8 code units
1058
+ try {
1059
+ // if yes, return this sequence unmodified
1060
+ size_t prefix_offset = input_offset;
1061
+ unicode_cpt_from_utf8(input, prefix_offset);
1062
+ return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
1063
+ } catch (std::invalid_argument & /*ex*/) {
1064
+ // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
1065
+ return { "\xEF\xBF\xBD", 3, 1 };
1066
+ }
1067
+ }
1068
+ }
1069
+
1070
+ // escaped space symbol - U+2581 (Lower One Eighth Block)
1071
+ const std::string escaped_space = "\xE2\x96\x81";
1072
+
1073
+ const char * prefix_replacements = NULL;
1074
+ size_t prefix_replacements_size = 0;
1075
+
1076
+ const uint32_t * xcda_array = NULL;
1077
+ size_t xcda_array_size = 0;
1078
+
1079
+ struct naive_trie user_defined_token_matcher;
1080
+
1081
+ // this structure stores the best tokenization so far at input_offset
1082
+ struct best_tokenization {
1083
+ llama_token token_id;
1084
+ size_t input_offset;
1085
+ float score_sum;
1086
+ };
1087
+
1088
+ float min_score = FLT_MAX;
1089
+ float max_score = -FLT_MAX;
1090
+
1091
+ float unknown_token_score_penalty = 10.0;
1092
+ float unknown_token_score;
1093
+
1094
+ struct naive_trie token_matcher;
1095
+ };
1096
+
1097
+ //
1098
+ // (de-) tokenize
1099
+ //
1100
+
1101
+ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
1102
+ FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
1103
+ FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
1104
+ } FRAGMENT_BUFFER_VARIANT_TYPE;
1105
+
1106
+ struct fragment_buffer_variant {
1107
+ fragment_buffer_variant(llama_vocab::id _token)
1108
+ :
1109
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
1110
+ token(_token),
1111
+ raw_text(_dummy),
1112
+ offset(0),
1113
+ length(0) {}
1114
+
1115
+ fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
1116
+ :
1117
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
1118
+ token((llama_vocab::id) - 1),
1119
+ raw_text(_raw_text),
1120
+ offset(_offset),
1121
+ length(_length){
1122
+ GGML_ASSERT(_offset >= 0);
1123
+ GGML_ASSERT(_length >= 1);
1124
+ GGML_ASSERT(offset + length <= raw_text.length());
1125
+ }
1126
+
1127
+ const FRAGMENT_BUFFER_VARIANT_TYPE type;
1128
+ const llama_vocab::id token;
1129
+ const std::string _dummy;
1130
+ const std::string & raw_text;
1131
+ const uint64_t offset;
1132
+ const uint64_t length;
1133
+ };
1134
+
1135
+ // #define PRETOKENIZERDEBUG
1136
+
1137
+ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
1138
+ // for each special token
1139
+ for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
1140
+ const auto & data = vocab.id_to_token[special_id];
1141
+ const auto & special_token = data.text;
1142
+
1143
+ if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
1144
+ // Ignore control and unknown tokens when parse_special == false
1145
+ continue;
1146
+ // User-defined tokens are still pre-tokenized before everything else
1147
+ // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
1148
+ // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
1149
+ }
1150
+
1151
+ // for each text fragment
1152
+ std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
1153
+ while (it != buffer.end()) {
1154
+ auto & fragment = (*it);
1155
+
1156
+ // if a fragment is text ( not yet processed )
1157
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1158
+ auto & raw_text = fragment.raw_text;
1159
+
1160
+ auto raw_text_base_offset = fragment.offset;
1161
+ auto raw_text_base_length = fragment.length;
1162
+
1163
+ // loop over the text
1164
+ while (true) {
1165
+ // find the first occurrence of a given special token in this fragment
1166
+ // passing offset argument only limit the "search area" but match coordinates
1167
+ // are still relative to the source full raw_text
1168
+ auto match = raw_text.find(special_token, raw_text_base_offset);
1169
+
1170
+ // no occurrences found, stop processing this fragment for a given special token
1171
+ if (match == std::string::npos) break;
1172
+
1173
+ // check if match is within bounds of offset <-> length
1174
+ if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
1175
+
1176
+ #ifdef PRETOKENIZERDEBUG
1177
+ LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
1178
+ #endif
1179
+ auto source = std::distance(buffer.begin(), it);
1180
+
1181
+ // if match is further than base offset
1182
+ // then we have some text to the left of it
1183
+ if (match > raw_text_base_offset) {
1184
+ // left
1185
+ const int64_t left_reminder_offset = raw_text_base_offset + 0;
1186
+ int64_t left_reminder_length = match - raw_text_base_offset;
1187
+
1188
+ if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
1189
+ while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
1190
+ left_reminder_length--;
1191
+ }
1192
+ }
1193
+
1194
+ if (left_reminder_length > 0) {
1195
+ buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
1196
+ it++;
1197
+ }
1198
+
1199
+ #ifdef PRETOKENIZERDEBUG
1200
+ LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
1201
+ #endif
1202
+ }
1203
+
1204
+ // special token
1205
+ buffer.emplace_after(it, special_id);
1206
+ it++;
1207
+
1208
+ // right
1209
+ if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
1210
+ int64_t right_reminder_offset = match + special_token.length();
1211
+ int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
1212
+
1213
+ if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
1214
+ while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
1215
+ right_reminder_offset++;
1216
+ right_reminder_length--;
1217
+ }
1218
+ }
1219
+
1220
+ if (right_reminder_length > 0) {
1221
+ buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
1222
+ it++;
1223
+ }
1224
+
1225
+ #ifdef PRETOKENIZERDEBUG
1226
+ LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
1227
+ #endif
1228
+
1229
+ if (source == 0) {
1230
+ buffer.erase_after(buffer.before_begin());
1231
+ } else {
1232
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
1233
+ }
1234
+
1235
+ // repeat for the right side
1236
+ raw_text_base_offset = right_reminder_offset;
1237
+ raw_text_base_length = right_reminder_length;
1238
+
1239
+ #ifdef PRETOKENIZERDEBUG
1240
+ LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
1241
+ #endif
1242
+ } else {
1243
+ if (source == 0) {
1244
+ buffer.erase_after(buffer.before_begin());
1245
+ } else {
1246
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
1247
+ }
1248
+ break;
1249
+ }
1250
+ }
1251
+ }
1252
+ it++;
1253
+ }
1254
+ }
1255
+ }
1256
+
1257
+ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
1258
+ std::vector<llama_vocab::id> output;
1259
+ std::forward_list<fragment_buffer_variant> fragment_buffer;
1260
+
1261
+ if (!raw_text.empty()) {
1262
+ fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
1263
+ tokenizer_st_partition(vocab, fragment_buffer, parse_special);
1264
+ }
1265
+
1266
+ switch (vocab.type) {
1267
+ case LLAMA_VOCAB_TYPE_SPM:
1268
+ {
1269
+ // OG tokenizer behavior:
1270
+ //
1271
+ // tokenizer.encode('', add_special_tokens=True) returns [1]
1272
+ // tokenizer.encode('', add_special_tokens=False) returns []
1273
+
1274
+ bool is_prev_special = true; // prefix with space if first token
1275
+
1276
+ if (add_special && vocab.tokenizer_add_bos) {
1277
+ GGML_ASSERT(vocab.special_bos_id != -1);
1278
+ output.push_back(vocab.special_bos_id);
1279
+ is_prev_special = true;
1280
+ }
1281
+
1282
+ for (const auto & fragment : fragment_buffer) {
1283
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1284
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1285
+
1286
+ // prefix with space if previous is special
1287
+ if (vocab.tokenizer_add_space_prefix && is_prev_special) {
1288
+ raw_text = " " + raw_text;
1289
+ }
1290
+
1291
+ #ifdef PRETOKENIZERDEBUG
1292
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1293
+ #endif
1294
+ llm_tokenizer_spm tokenizer(vocab);
1295
+ llama_escape_whitespace(raw_text);
1296
+ tokenizer.tokenize(raw_text, output);
1297
+ is_prev_special = false;
1298
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1299
+ output.push_back(fragment.token);
1300
+ is_prev_special = true;
1301
+ }
1302
+ }
1303
+
1304
+ if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1305
+ LLAMA_LOG_WARN(
1306
+ "%s: Added a BOS token to the prompt as specified by the model but the prompt "
1307
+ "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
1308
+ "Are you sure this is what you want?\n", __FUNCTION__);
1309
+ }
1310
+
1311
+ if (add_special && vocab.tokenizer_add_eos) {
1312
+ GGML_ASSERT(vocab.special_eos_id != -1);
1313
+ output.push_back(vocab.special_eos_id);
1314
+ }
1315
+ } break;
1316
+ case LLAMA_VOCAB_TYPE_BPE:
1317
+ {
1318
+ llm_tokenizer_bpe tokenizer(vocab);
1319
+
1320
+ if (add_special) {
1321
+ tokenizer.append_bos(output);
1322
+ }
1323
+ for (const auto & fragment : fragment_buffer) {
1324
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1325
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1326
+
1327
+ #ifdef PRETOKENIZERDEBUG
1328
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1329
+ #endif
1330
+ tokenizer.tokenize(raw_text, output);
1331
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1332
+ tokenizer.append(fragment.token, output);
1333
+ }
1334
+ }
1335
+
1336
+ if (add_special) {
1337
+ tokenizer.append_eos(output);
1338
+ tokenizer.check_double_bos_eos(output);
1339
+ }
1340
+ } break;
1341
+ case LLAMA_VOCAB_TYPE_WPM:
1342
+ {
1343
+ if (add_special) {
1344
+ GGML_ASSERT(vocab.special_cls_id != -1);
1345
+ output.push_back(vocab.special_cls_id);
1346
+ }
1347
+
1348
+ llm_tokenizer_wpm tokenizer(vocab);
1349
+
1350
+ for (const auto & fragment : fragment_buffer) {
1351
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1352
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1353
+
1354
+ #ifdef PRETOKENIZERDEBUG
1355
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1356
+ #endif
1357
+ tokenizer.tokenize(raw_text, output);
1358
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1359
+ output.push_back(fragment.token);
1360
+ }
1361
+ }
1362
+
1363
+ if (add_special) {
1364
+ GGML_ASSERT(vocab.special_sep_id != -1);
1365
+ output.push_back(vocab.special_sep_id);
1366
+ }
1367
+ } break;
1368
+ case LLAMA_VOCAB_TYPE_UGM:
1369
+ {
1370
+ llm_tokenizer_ugm tokenizer(vocab);
1371
+
1372
+ if (add_special && vocab.tokenizer_add_bos != 0) {
1373
+ GGML_ASSERT(vocab.special_bos_id != -1);
1374
+ output.push_back(vocab.special_bos_id);
1375
+ }
1376
+
1377
+ for (const auto & fragment : fragment_buffer) {
1378
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1379
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1380
+ #ifdef PRETOKENIZERDEBUG
1381
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1382
+ #endif
1383
+ tokenizer.tokenize(raw_text, output);
1384
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1385
+ output.push_back(fragment.token);
1386
+ }
1387
+ }
1388
+
1389
+ if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
1390
+ LLAMA_LOG_WARN(
1391
+ "%s: Added a BOS token to the prompt as specified by the model but the prompt "
1392
+ "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
1393
+ "Are you sure this is what you want?\n", __FUNCTION__);
1394
+ }
1395
+
1396
+ if (add_special && vocab.tokenizer_add_eos == 1) {
1397
+ GGML_ASSERT(vocab.special_eos_id != -1);
1398
+ output.push_back(vocab.special_eos_id);
1399
+ }
1400
+ } break;
1401
+ case LLAMA_VOCAB_TYPE_NONE:
1402
+ GGML_ABORT("fatal error");
1403
+ }
1404
+
1405
+ return output;
1406
+ }
1407
+
1408
+ llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
1409
+ GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
1410
+ static const char * hex = "0123456789ABCDEF";
1411
+ switch (llama_vocab_get_type(vocab)) {
1412
+ case LLAMA_VOCAB_TYPE_SPM:
1413
+ case LLAMA_VOCAB_TYPE_UGM: {
1414
+ const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
1415
+ auto token = vocab.token_to_id.find(buf);
1416
+ if (token != vocab.token_to_id.end()) {
1417
+ return (*token).second;
1418
+ }
1419
+ // Try to fall back to just the byte as a string
1420
+ const char buf2[2] = { (char)ch, 0 };
1421
+ return vocab.token_to_id.at(buf2);
1422
+ }
1423
+ case LLAMA_VOCAB_TYPE_WPM:
1424
+ case LLAMA_VOCAB_TYPE_BPE: {
1425
+ return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
1426
+ }
1427
+ default:
1428
+ GGML_ABORT("fatal error");
1429
+ }
1430
+ }
1431
+
1432
+ const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
1433
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
1434
+ return vocab.id_to_token[token].text.c_str();
1435
+ }
1436
+
1437
+ float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
1438
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
1439
+ return vocab.id_to_token[token].score;
1440
+ }
1441
+
1442
+ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
1443
+ GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
1444
+ return vocab.id_to_token[token].attr;
1445
+ }
1446
+
1447
+ bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1448
+ return token != -1 && (
1449
+ token == llama_token_eos_impl(vocab) ||
1450
+ token == llama_token_eot_impl(vocab) ||
1451
+ token == llama_token_eom_impl(vocab)
1452
+ );
1453
+ }
1454
+
1455
+ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
1456
+ return llama_is_control_token(vocab, token);
1457
+ }
1458
+
1459
+ llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
1460
+ return vocab.special_bos_id;
1461
+ }
1462
+
1463
+ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
1464
+ return vocab.special_eos_id;
1465
+ }
1466
+
1467
+ llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
1468
+ return vocab.special_cls_id;
1469
+ }
1470
+
1471
+ llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
1472
+ return vocab.special_sep_id;
1473
+ }
1474
+
1475
+ llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
1476
+ return vocab.linefeed_id;
1477
+ }
1478
+
1479
+ llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
1480
+ return vocab.special_pad_id;
1481
+ }
1482
+
1483
+ int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
1484
+ return vocab.tokenizer_add_bos;
1485
+ }
1486
+
1487
+ int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
1488
+ return vocab.tokenizer_add_eos;
1489
+ }
1490
+
1491
+ llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1492
+ return vocab.special_prefix_id;
1493
+ }
1494
+
1495
+ llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1496
+ return vocab.special_middle_id;
1497
+ }
1498
+
1499
+ llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1500
+ return vocab.special_suffix_id;
1501
+ }
1502
+
1503
+ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1504
+ return vocab.special_eot_id;
1505
+ }
1506
+
1507
+ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1508
+ return vocab.special_eom_id;
1509
+ }
1510
+
1511
+ int32_t llama_tokenize_impl(
1512
+ const struct llama_vocab & vocab,
1513
+ const char * text,
1514
+ int32_t text_len,
1515
+ llama_token * tokens,
1516
+ int32_t n_tokens_max,
1517
+ bool add_special,
1518
+ bool parse_special) {
1519
+ auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
1520
+ if (n_tokens_max < (int) res.size()) {
1521
+ // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
1522
+ return -((int) res.size());
1523
+ }
1524
+
1525
+ for (size_t i = 0; i < res.size(); i++) {
1526
+ tokens[i] = res[i];
1527
+ }
1528
+
1529
+ return res.size();
1530
+ }
1531
+
1532
+ static std::string llama_decode_text(const std::string & text) {
1533
+ std::string decoded_text;
1534
+
1535
+ const auto cpts = unicode_cpts_from_utf8(text);
1536
+ for (const auto cpt : cpts) {
1537
+ const auto utf8 = unicode_cpt_to_utf8(cpt);
1538
+ try {
1539
+ decoded_text += unicode_utf8_to_byte(utf8);
1540
+ } catch (const std::out_of_range & /*e*/) {
1541
+ decoded_text += "[UNK_BYTE_0x";
1542
+ for (const auto c : utf8) {
1543
+ decoded_text += format("%02x", (uint8_t) c);
1544
+ }
1545
+ decoded_text += text + "]";
1546
+ }
1547
+ }
1548
+
1549
+ return decoded_text;
1550
+ }
1551
+
1552
+ // does not write null-terminator to buf
1553
+ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
1554
+ // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
1555
+ static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
1556
+ const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
1557
+ if (!special && (attr & attr_special)) {
1558
+ return 0;
1559
+ }
1560
+
1561
+ // copy piece chars to output text buffer
1562
+ // skip up to 'lstrip' leading spaces before copying
1563
+ auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
1564
+ for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
1565
+ token++;
1566
+ size--;
1567
+ }
1568
+ if (length < (int32_t)size) {
1569
+ return -(int32_t) size;
1570
+ }
1571
+ memcpy(buf, token, size);
1572
+ return (int32_t) size;
1573
+ };
1574
+
1575
+ // if we have a cache - use it
1576
+ {
1577
+ const auto & cache = vocab.cache_token_to_piece;
1578
+
1579
+ if (!cache.empty()) {
1580
+ const auto & result = cache.at(token);
1581
+ return _try_copy(result.data(), result.size());
1582
+ }
1583
+ }
1584
+
1585
+ if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
1586
+ const std::string & token_text = vocab.id_to_token[token].text;
1587
+ switch (llama_vocab_get_type(vocab)) {
1588
+ case LLAMA_VOCAB_TYPE_WPM:
1589
+ case LLAMA_VOCAB_TYPE_SPM:
1590
+ case LLAMA_VOCAB_TYPE_UGM: {
1591
+ // NOTE: we accept all unsupported token types,
1592
+ // suppressing them like CONTROL tokens.
1593
+ if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1594
+ return _try_copy(token_text.data(), token_text.size());
1595
+ } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1596
+ std::string result = token_text;
1597
+ llama_unescape_whitespace(result);
1598
+ return _try_copy(result.data(), result.size());
1599
+ } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
1600
+ char byte = (char) llama_token_to_byte(vocab, token);
1601
+ return _try_copy((char*) &byte, 1);
1602
+ }
1603
+ break;
1604
+ }
1605
+ case LLAMA_VOCAB_TYPE_BPE: {
1606
+ // NOTE: we accept all unsupported token types,
1607
+ // suppressing them like CONTROL tokens.
1608
+ if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
1609
+ return _try_copy(token_text.data(), token_text.size());
1610
+ } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
1611
+ std::string result = llama_decode_text(token_text);
1612
+ return _try_copy(result.data(), result.size());
1613
+ }
1614
+ break;
1615
+ }
1616
+ default:
1617
+ GGML_ABORT("fatal error");
1618
+ }
1619
+ }
1620
+
1621
+ return 0;
1622
+ }
1623
+
1624
+ int32_t llama_detokenize_impl(
1625
+ const struct llama_vocab & vocab,
1626
+ const llama_token * tokens,
1627
+ int32_t n_tokens,
1628
+ char * text,
1629
+ int32_t text_len_max,
1630
+ bool remove_special,
1631
+ bool unparse_special) {
1632
+ int32_t avail = text_len_max;
1633
+ int32_t total = 0;
1634
+
1635
+ // remove the leading space
1636
+ bool remove_space = vocab.tokenizer_add_space_prefix;
1637
+
1638
+ if (remove_special && vocab.tokenizer_add_bos) {
1639
+ if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
1640
+ remove_space = false;
1641
+ n_tokens--;
1642
+ tokens++;
1643
+ }
1644
+ }
1645
+
1646
+ if (remove_special && vocab.tokenizer_add_eos) {
1647
+ if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
1648
+ n_tokens--;
1649
+ }
1650
+ }
1651
+
1652
+ for (int32_t i = 0; i < n_tokens; ++i) {
1653
+ GGML_ASSERT(avail >= 0);
1654
+ int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
1655
+ remove_space = false;
1656
+ if (n_chars < 0) {
1657
+ avail = 0;
1658
+ total -= n_chars;
1659
+ } else if (n_chars > 0) {
1660
+ avail -= n_chars;
1661
+ text += n_chars;
1662
+ total += n_chars;
1663
+ }
1664
+ }
1665
+
1666
+ if (total > text_len_max) {
1667
+ return -total;
1668
+ }
1669
+
1670
+ if (vocab.tokenizer_clean_spaces) {
1671
+ text -= total; // restart text
1672
+
1673
+ // first pass: characters ?!., //TODO: where do these characters come from?
1674
+ const int32_t total1 = total;
1675
+ total = total ? 1 : 0;
1676
+ for (int32_t i = 1; i < total1; ++i) {
1677
+ const char x = text[i];
1678
+ if (text[i - 1] == ' ') {
1679
+ if (x == '?' || x == '!' || x == '.' || x == ',') { // " ?", " !", " .", " ,"
1680
+ total--; // remove space
1681
+ }
1682
+ }
1683
+ text[total++] = x;
1684
+ }
1685
+
1686
+ // second pass: strip single apostrophe between spaces
1687
+ const int32_t total2 = total;
1688
+ total = total ? 1 : 0;
1689
+ for (int32_t i = 1; i < total2; ++i) {
1690
+ const char x = text[i];
1691
+ if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') { // " ' "
1692
+ total--; // remove prev space
1693
+ text[++i] = '\0'; // remove next space
1694
+ }
1695
+ text[total++] = x;
1696
+ }
1697
+
1698
+ // third pass: apostrophe contractions //NOTE: this makes sense?
1699
+ const int32_t total3 = total;
1700
+ total = total ? 1 : 0;
1701
+ for (int32_t i = 1; i < total3; ++i) {
1702
+ const char x = text[i];
1703
+ if (text[i - 1] == ' ') {
1704
+ if (x == '\'' && i + 1 < total3) {
1705
+ const char x1 = text[i + 1];
1706
+ if (x1 == 't' || x1 == 'd') { // " 't", " 'd"
1707
+ //total--; // remove space
1708
+ } else if (x1 == 's' || x1 == 'm') { // " 's", " 'm"
1709
+ total--; // remove space
1710
+ } else if (i + 2 < total3) {
1711
+ const char x2 = text[i + 2];
1712
+ if ((x1 == 'l' && x2 == 'l')) { // " 'll"
1713
+ //total--; // remove space
1714
+ } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) { // " 're", " 've"
1715
+ total--; // remove space
1716
+ } else {
1717
+ //total--; // remove space
1718
+ }
1719
+ } else {
1720
+ //total--; // remove space
1721
+ }
1722
+ }
1723
+ }
1724
+ text[total++] = x;
1725
+ }
1726
+ }
1727
+
1728
+ return total <= text_len_max ? total : -total;
1729
+ }
examples/talk-llama/llama-vocab.h ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-impl.h"
4
+
5
+ #include <string>
6
+ #include <vector>
7
+ #include <unordered_map>
8
+ #include <map>
9
+
10
+ struct llama_vocab {
11
+ using id = llama_token;
12
+ using token = std::string;
13
+ using tattr = llama_token_attr;
14
+
15
+ struct token_data {
16
+ token text;
17
+ float score;
18
+ tattr attr;
19
+ };
20
+
21
+ enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
22
+ enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
23
+
24
+ int max_token_len = 0; // used for optimizing longest token search
25
+
26
+ std::unordered_map<token, id> token_to_id;
27
+ std::vector<token_data> id_to_token;
28
+
29
+ std::vector<id> cache_special_tokens;
30
+ std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
31
+
32
+ std::map<std::pair<std::string, std::string>, int> bpe_ranks;
33
+
34
+ // default LLaMA special tokens
35
+ id special_bos_id = 1;
36
+ id special_eos_id = 2;
37
+ id special_unk_id = 0;
38
+ id special_sep_id = -1;
39
+ id special_pad_id = -1;
40
+ id special_cls_id = -1;
41
+ id special_mask_id = -1;
42
+
43
+ id linefeed_id = 13;
44
+ id special_prefix_id = -1;
45
+ id special_suffix_id = -1;
46
+ id special_middle_id = -1;
47
+ id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
48
+ id special_eom_id = -1;
49
+
50
+ // tokenizer flags
51
+ bool tokenizer_add_space_prefix = false;
52
+ bool tokenizer_add_bos = false;
53
+ bool tokenizer_add_eos = false;
54
+ bool tokenizer_ignore_merges = false;
55
+ bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
56
+ bool tokenizer_remove_extra_whitespaces = false;
57
+ bool tokenizer_escape_whitespaces = true;
58
+ bool tokenizer_treat_whitespace_as_suffix = false;
59
+
60
+ std::vector<char> precompiled_charsmap;
61
+
62
+ int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
63
+ };
64
+
65
+ const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
66
+
67
+ //
68
+ // internal API
69
+ //
70
+
71
+ // TODO: rename to llama_tokenize_impl
72
+ // TODO: This should probably be in llama.h
73
+ std::vector<llama_vocab::id> llama_tokenize_internal(
74
+ const llama_vocab & vocab,
75
+ std::string raw_text,
76
+ bool add_special,
77
+ bool parse_special = false);
78
+
79
+ llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
80
+
81
+ const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
82
+
83
+ float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
84
+
85
+ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
86
+
87
+ bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
88
+
89
+ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
90
+
91
+ llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
92
+ llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
93
+ llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
94
+ llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
95
+ llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
96
+ llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
97
+
98
+ int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
99
+ int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
100
+
101
+ llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
102
+ llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
103
+ llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
104
+ llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
105
+ llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
106
+
107
+ int32_t llama_tokenize_impl(
108
+ const struct llama_vocab & vocab,
109
+ const char * text,
110
+ int32_t text_len,
111
+ llama_token * tokens,
112
+ int32_t n_tokens_max,
113
+ bool add_special,
114
+ bool parse_special);
115
+
116
+ // does not write null-terminator to buf
117
+ int32_t llama_token_to_piece_impl(
118
+ const struct llama_vocab & vocab,
119
+ llama_token token,
120
+ char * buf,
121
+ int32_t length,
122
+ int32_t lstrip,
123
+ bool special);
124
+
125
+ int32_t llama_detokenize_impl(
126
+ const struct llama_vocab & vocab,
127
+ const llama_token * tokens,
128
+ int32_t n_tokens,
129
+ char * text,
130
+ int32_t text_len_max,
131
+ bool remove_special,
132
+ bool unparse_special);
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -33,17 +33,15 @@
33
 
34
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
35
 
36
- #define LLAMA_MAX_RNG_STATE (64*1024)
37
-
38
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
39
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
40
  #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
41
 
42
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43
- #define LLAMA_SESSION_VERSION 6
44
 
45
  #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
46
- #define LLAMA_STATE_SEQ_VERSION 1
47
 
48
  #ifdef __cplusplus
49
  extern "C" {
@@ -92,6 +90,9 @@ extern "C" {
92
  LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
93
  LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
94
  LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
 
 
 
95
  };
96
 
97
  // note: these values should be synchronized with ggml_rope
@@ -133,7 +134,7 @@ extern "C" {
133
  LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
134
  LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
135
  LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
136
- LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
137
  // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
138
  // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
139
  LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
@@ -162,6 +163,9 @@ extern "C" {
162
  LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
163
  LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
164
  LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
 
 
 
165
 
166
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
167
  };
@@ -341,7 +345,7 @@ extern "C" {
341
  int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
342
  enum llama_ftype ftype; // quantize to this llama_ftype
343
  enum ggml_type output_tensor_type; // output tensor type
344
- enum ggml_type token_embedding_type; // itoken embeddings tensor type
345
  bool allow_requantize; // allow quantizing non-f32/f16 tensors
346
  bool quantize_output_tensor; // quantize output.weight
347
  bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
@@ -408,6 +412,9 @@ extern "C" {
408
  const char * content;
409
  } llama_chat_message;
410
 
 
 
 
411
  // Helpers for getting default parameters
412
  LLAMA_API struct llama_model_params llama_model_default_params(void);
413
  LLAMA_API struct llama_context_params llama_context_default_params(void);
@@ -507,18 +514,32 @@ extern "C" {
507
  const char * fname_out,
508
  const llama_model_quantize_params * params);
509
 
510
- // Apply a LoRA adapter to a loaded model
511
- // path_base_model is the path to a higher quality model to use as a base for
512
- // the layers modified by the adapter. Can be NULL to use the current loaded model.
513
- // The model needs to be reloaded before applying a new adapter, otherwise the adapter
514
- // will be applied on top of the previous one
515
- // Returns 0 on success
516
- LLAMA_API int32_t llama_model_apply_lora_from_file(
517
- const struct llama_model * model,
518
- const char * path_lora,
519
- float scale,
520
- const char * path_base_model,
521
- int32_t n_threads);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  // Apply a loaded control vector to a llama_context, or if data is NULL, clear
524
  // the currently loaded vector.
@@ -668,10 +689,11 @@ extern "C" {
668
  // State / sessions
669
  //
670
 
671
- // Returns the maximum size in bytes of the state (rng, logits, embedding
672
- // and kv_cache) - will often be smaller after compacting tokens
673
- LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
674
- LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
 
675
  "use llama_state_get_size instead");
676
 
677
  // Copies the state to the specified destination address.
@@ -679,7 +701,8 @@ extern "C" {
679
  // Returns the number of bytes copied
680
  LLAMA_API size_t llama_state_get_data(
681
  struct llama_context * ctx,
682
- uint8_t * dst);
 
683
  LLAMA_API DEPRECATED(size_t llama_copy_state_data(
684
  struct llama_context * ctx,
685
  uint8_t * dst),
@@ -689,7 +712,8 @@ extern "C" {
689
  // Returns the number of bytes read
690
  LLAMA_API size_t llama_state_set_data(
691
  struct llama_context * ctx,
692
- const uint8_t * src);
 
693
  LLAMA_API DEPRECATED(size_t llama_set_state_data(
694
  struct llama_context * ctx,
695
  const uint8_t * src),
@@ -731,6 +755,7 @@ extern "C" {
731
  LLAMA_API size_t llama_state_seq_get_data(
732
  struct llama_context * ctx,
733
  uint8_t * dst,
 
734
  llama_seq_id seq_id);
735
 
736
  // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -740,6 +765,7 @@ extern "C" {
740
  LLAMA_API size_t llama_state_seq_set_data(
741
  struct llama_context * ctx,
742
  const uint8_t * src,
 
743
  llama_seq_id dest_seq_id);
744
 
745
  LLAMA_API size_t llama_state_seq_save_file(
@@ -887,10 +913,10 @@ extern "C" {
887
  LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
888
 
889
  // Returns -1 if unknown, 1 for true or 0 for false.
890
- LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
891
 
892
  // Returns -1 if unknown, 1 for true or 0 for false.
893
- LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
894
 
895
  // Codellama infill tokens
896
  LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@@ -946,6 +972,10 @@ extern "C" {
946
  bool remove_special,
947
  bool unparse_special);
948
 
 
 
 
 
949
  /// Apply chat template. Inspired by hf apply_chat_template() on python.
950
  /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
951
  /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@@ -984,6 +1014,23 @@ extern "C" {
984
 
985
  LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987
  //
988
  // Sampling functions
989
  //
@@ -1065,12 +1112,6 @@ extern "C" {
1065
  llama_token_data_array * candidates,
1066
  float temp);
1067
 
1068
- /// @details Apply constraints from grammar
1069
- LLAMA_API void llama_sample_grammar(
1070
- struct llama_context * ctx,
1071
- llama_token_data_array * candidates,
1072
- const struct llama_grammar * grammar);
1073
-
1074
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1075
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1076
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1108,12 +1149,6 @@ extern "C" {
1108
  struct llama_context * ctx,
1109
  llama_token_data_array * candidates);
1110
 
1111
- /// @details Accepts the sampled token into the grammar
1112
- LLAMA_API void llama_grammar_accept_token(
1113
- struct llama_context * ctx,
1114
- struct llama_grammar * grammar,
1115
- llama_token token);
1116
-
1117
  //
1118
  // Model split
1119
  //
@@ -1156,38 +1191,45 @@ extern "C" {
1156
 
1157
  struct ggml_tensor;
1158
 
 
 
 
 
1159
  struct llama_partial_utf8 {
1160
  uint32_t value; // bit value so far (unshifted)
1161
  int n_remain; // num bytes remaining; -1 indicates invalid sequence
1162
  };
1163
 
1164
- struct llama_grammar {
1165
- const std::vector<std::vector<llama_grammar_element>> rules;
1166
- std::vector<std::vector<const llama_grammar_element *>> stacks;
1167
-
1168
- // buffer for partially generated UTF-8 sequence from accepted tokens
1169
- llama_partial_utf8 partial_utf8;
1170
- };
1171
-
1172
  struct llama_grammar_candidate {
1173
  size_t index;
1174
  const uint32_t * code_points;
1175
  llama_partial_utf8 partial_utf8;
1176
  };
1177
 
1178
- const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1179
- struct llama_context * ctx
1180
- );
 
 
 
 
 
 
1181
 
1182
  void llama_grammar_accept(
1183
- const std::vector<std::vector<llama_grammar_element>> & rules,
1184
- const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1185
- const uint32_t chr,
1186
- std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
 
 
 
 
 
1187
 
1188
  std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
1189
  const std::string & src,
1190
- llama_partial_utf8 partial_start);
1191
 
1192
  // Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1193
  // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
 
33
 
34
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
35
 
 
 
36
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
37
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
38
  #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
39
 
40
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
41
+ #define LLAMA_SESSION_VERSION 8
42
 
43
  #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
44
+ #define LLAMA_STATE_SEQ_VERSION 2
45
 
46
  #ifdef __cplusplus
47
  extern "C" {
 
90
  LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
91
  LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
92
  LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
93
+ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
94
+ LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
95
+ LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
96
  };
97
 
98
  // note: these values should be synchronized with ggml_rope
 
134
  LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
135
  LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
136
  LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
137
+ // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
138
  // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
139
  // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
140
  LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
 
163
  LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
164
  LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
165
  LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
166
+ LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
167
+ LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
168
+ LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
169
 
170
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
171
  };
 
345
  int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
346
  enum llama_ftype ftype; // quantize to this llama_ftype
347
  enum ggml_type output_tensor_type; // output tensor type
348
+ enum ggml_type token_embedding_type; // token embeddings tensor type
349
  bool allow_requantize; // allow quantizing non-f32/f16 tensors
350
  bool quantize_output_tensor; // quantize output.weight
351
  bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
 
412
  const char * content;
413
  } llama_chat_message;
414
 
415
+ // lora adapter
416
+ struct llama_lora_adapter;
417
+
418
  // Helpers for getting default parameters
419
  LLAMA_API struct llama_model_params llama_model_default_params(void);
420
  LLAMA_API struct llama_context_params llama_context_default_params(void);
 
514
  const char * fname_out,
515
  const llama_model_quantize_params * params);
516
 
517
+ // Load a LoRA adapter from file
518
+ // The loaded adapter will be associated to the given model, and will be free when the model is deleted
519
+ LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
520
+ struct llama_model * model,
521
+ const char * path_lora);
522
+
523
+ // Add a loaded LoRA adapter to given context
524
+ // This will not modify model's weight
525
+ LLAMA_API int32_t llama_lora_adapter_set(
526
+ struct llama_context * ctx,
527
+ struct llama_lora_adapter * adapter,
528
+ float scale);
529
+
530
+ // Remove a specific LoRA adapter from given context
531
+ // Return -1 if the adapter is not present in the context
532
+ LLAMA_API int32_t llama_lora_adapter_remove(
533
+ struct llama_context * ctx,
534
+ struct llama_lora_adapter * adapter);
535
+
536
+ // Remove all LoRA adapters from given context
537
+ LLAMA_API void llama_lora_adapter_clear(
538
+ struct llama_context * ctx);
539
+
540
+ // Manually free a LoRA adapter
541
+ // Note: loaded adapters will be free when the associated model is deleted
542
+ LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
543
 
544
  // Apply a loaded control vector to a llama_context, or if data is NULL, clear
545
  // the currently loaded vector.
 
689
  // State / sessions
690
  //
691
 
692
+ // Returns the *actual* size in bytes of the state
693
+ // (rng, logits, embedding and kv_cache)
694
+ // Only use when saving the state, not when restoring it, otherwise the size may be too small.
695
+ LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
696
+ LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
697
  "use llama_state_get_size instead");
698
 
699
  // Copies the state to the specified destination address.
 
701
  // Returns the number of bytes copied
702
  LLAMA_API size_t llama_state_get_data(
703
  struct llama_context * ctx,
704
+ uint8_t * dst,
705
+ size_t size);
706
  LLAMA_API DEPRECATED(size_t llama_copy_state_data(
707
  struct llama_context * ctx,
708
  uint8_t * dst),
 
712
  // Returns the number of bytes read
713
  LLAMA_API size_t llama_state_set_data(
714
  struct llama_context * ctx,
715
+ const uint8_t * src,
716
+ size_t size);
717
  LLAMA_API DEPRECATED(size_t llama_set_state_data(
718
  struct llama_context * ctx,
719
  const uint8_t * src),
 
755
  LLAMA_API size_t llama_state_seq_get_data(
756
  struct llama_context * ctx,
757
  uint8_t * dst,
758
+ size_t size,
759
  llama_seq_id seq_id);
760
 
761
  // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
 
765
  LLAMA_API size_t llama_state_seq_set_data(
766
  struct llama_context * ctx,
767
  const uint8_t * src,
768
+ size_t size,
769
  llama_seq_id dest_seq_id);
770
 
771
  LLAMA_API size_t llama_state_seq_save_file(
 
913
  LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
914
 
915
  // Returns -1 if unknown, 1 for true or 0 for false.
916
+ LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
917
 
918
  // Returns -1 if unknown, 1 for true or 0 for false.
919
+ LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
920
 
921
  // Codellama infill tokens
922
  LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
 
972
  bool remove_special,
973
  bool unparse_special);
974
 
975
+ //
976
+ // Chat templates
977
+ //
978
+
979
  /// Apply chat template. Inspired by hf apply_chat_template() on python.
980
  /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
981
  /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
 
1014
 
1015
  LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
1016
 
1017
+ /// @details Apply constraints from grammar
1018
+ LLAMA_API void llama_grammar_sample(
1019
+ const struct llama_grammar * grammar,
1020
+ const struct llama_context * ctx,
1021
+ llama_token_data_array * candidates);
1022
+ LLAMA_API DEPRECATED(void llama_sample_grammar(
1023
+ struct llama_context * ctx,
1024
+ llama_token_data_array * candidates,
1025
+ const struct llama_grammar * grammar),
1026
+ "use llama_grammar_sample instead");
1027
+
1028
+ /// @details Accepts the sampled token into the grammar
1029
+ LLAMA_API void llama_grammar_accept_token(
1030
+ struct llama_grammar * grammar,
1031
+ struct llama_context * ctx,
1032
+ llama_token token);
1033
+
1034
  //
1035
  // Sampling functions
1036
  //
 
1112
  llama_token_data_array * candidates,
1113
  float temp);
1114
 
 
 
 
 
 
 
1115
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1116
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1117
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
 
1149
  struct llama_context * ctx,
1150
  llama_token_data_array * candidates);
1151
 
 
 
 
 
 
 
1152
  //
1153
  // Model split
1154
  //
 
1191
 
1192
  struct ggml_tensor;
1193
 
1194
+ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1195
+ struct llama_context * ctx
1196
+ );
1197
+
1198
  struct llama_partial_utf8 {
1199
  uint32_t value; // bit value so far (unshifted)
1200
  int n_remain; // num bytes remaining; -1 indicates invalid sequence
1201
  };
1202
 
 
 
 
 
 
 
 
 
1203
  struct llama_grammar_candidate {
1204
  size_t index;
1205
  const uint32_t * code_points;
1206
  llama_partial_utf8 partial_utf8;
1207
  };
1208
 
1209
+ using llama_grammar_rule = std::vector< llama_grammar_element>;
1210
+ using llama_grammar_stack = std::vector<const llama_grammar_element *>;
1211
+
1212
+ using llama_grammar_rules = std::vector<llama_grammar_rule>;
1213
+ using llama_grammar_stacks = std::vector<llama_grammar_stack>;
1214
+ using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
1215
+
1216
+ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
1217
+ llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
1218
 
1219
  void llama_grammar_accept(
1220
+ const llama_grammar_rules & rules,
1221
+ const llama_grammar_stacks & stacks,
1222
+ const uint32_t chr,
1223
+ llama_grammar_stacks & new_stacks);
1224
+
1225
+ std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
1226
+ const llama_grammar_rules & rules,
1227
+ const llama_grammar_stack & stack,
1228
+ const llama_grammar_candidates & candidates);
1229
 
1230
  std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
1231
  const std::string & src,
1232
+ llama_partial_utf8 partial_start);
1233
 
1234
  // Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1235
  // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
examples/talk-llama/unicode.cpp CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  #include "unicode.h"
2
  #include "unicode-data.h"
3
 
@@ -15,6 +19,12 @@
15
  #include <locale>
16
  #include <codecvt>
17
 
 
 
 
 
 
 
18
  static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
19
  std::string result;
20
  for (size_t i = 0; i < cps.size(); ++i) {
 
1
+ #if defined(_MSC_VER)
2
+ #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
3
+ #endif
4
+
5
  #include "unicode.h"
6
  #include "unicode-data.h"
7
 
 
19
  #include <locale>
20
  #include <codecvt>
21
 
22
+ size_t unicode_len_utf8(char src) {
23
+ const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
24
+ uint8_t highbits = static_cast<uint8_t>(src) >> 4;
25
+ return lookup[highbits];
26
+ }
27
+
28
  static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
29
  std::string result;
30
  for (size_t i = 0; i < cps.size(); ++i) {
examples/talk-llama/unicode.h CHANGED
@@ -4,6 +4,8 @@
4
  #include <string>
5
  #include <vector>
6
 
 
 
7
  struct codepoint_flags {
8
  enum {
9
  UNDEFINED = 0x0001,
@@ -46,6 +48,7 @@ struct codepoint_flags {
46
  }
47
  };
48
 
 
49
 
50
  std::string unicode_cpt_to_utf8(uint32_t cp);
51
  uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
 
4
  #include <string>
5
  #include <vector>
6
 
7
+ // TODO: prefix all symbols with "llama_"
8
+
9
  struct codepoint_flags {
10
  enum {
11
  UNDEFINED = 0x0001,
 
48
  }
49
  };
50
 
51
+ size_t unicode_len_utf8(char src);
52
 
53
  std::string unicode_cpt_to_utf8(uint32_t cp);
54
  uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
scripts/sync-llama.sh CHANGED
@@ -2,7 +2,8 @@
2
 
3
  cp -rpv ../llama.cpp/include/llama.h ./examples/talk-llama/llama.h
4
 
5
- cp -rpv ../llama.cpp/src/llama.cpp ./examples/talk-llama/llama.cpp
 
6
  cp -rpv ../llama.cpp/src/unicode.h ./examples/talk-llama/unicode.h
7
  cp -rpv ../llama.cpp/src/unicode.cpp ./examples/talk-llama/unicode.cpp
8
  cp -rpv ../llama.cpp/src/unicode-data.h ./examples/talk-llama/unicode-data.h
 
2
 
3
  cp -rpv ../llama.cpp/include/llama.h ./examples/talk-llama/llama.h
4
 
5
+ cp -rpv ../llama.cpp/src/llama*.cpp ./examples/talk-llama/
6
+ cp -rpv ../llama.cpp/src/llama*.h ./examples/talk-llama/
7
  cp -rpv ../llama.cpp/src/unicode.h ./examples/talk-llama/unicode.h
8
  cp -rpv ../llama.cpp/src/unicode.cpp ./examples/talk-llama/unicode.cpp
9
  cp -rpv ../llama.cpp/src/unicode-data.h ./examples/talk-llama/unicode-data.h