ggerganov commited on
Commit
f91f98d
·
1 Parent(s): 14b5848

talk-llama : sync llama.cpp

Browse files
Makefile CHANGED
@@ -1080,10 +1080,12 @@ lsp: examples/lsp/lsp.cpp \
1080
  $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1081
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
1082
 
1083
- talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
1084
- $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
1085
- $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1086
- $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
 
 
1087
 
1088
  talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
1089
  $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
 
1080
  $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1081
  $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
1082
 
1083
+ # TODO: disabled until update
1084
+ # https://github.com/ggerganov/whisper.cpp/issues/1818
1085
+ #talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
1086
+ # $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
1087
+ # $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
1088
+ # $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
1089
 
1090
  talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
1091
  $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
examples/CMakeLists.txt CHANGED
@@ -127,8 +127,10 @@ endif (WHISPER_SDL2)
127
  add_subdirectory(quantize)
128
  set_target_properties(quantize PROPERTIES FOLDER "examples")
129
  if (WHISPER_SDL2)
130
- add_subdirectory(talk)
131
- set_target_properties(talk PROPERTIES FOLDER "examples")
 
 
132
  add_subdirectory(talk-llama)
133
  set_target_properties(talk-llama PROPERTIES FOLDER "examples")
134
  add_subdirectory(lsp)
 
127
  add_subdirectory(quantize)
128
  set_target_properties(quantize PROPERTIES FOLDER "examples")
129
  if (WHISPER_SDL2)
130
+ # TODO: disabled until update
131
+ # https://github.com/ggerganov/whisper.cpp/issues/1818
132
+ #add_subdirectory(talk)
133
+ #set_target_properties(talk PROPERTIES FOLDER "examples")
134
  add_subdirectory(talk-llama)
135
  set_target_properties(talk-llama PROPERTIES FOLDER "examples")
136
  add_subdirectory(lsp)
examples/talk-llama/llama-grammar.cpp CHANGED
@@ -3,11 +3,31 @@
3
  #include "llama-vocab.h"
4
  #include "llama-sampling.h"
5
 
 
6
  #include <algorithm>
 
7
 
8
- // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
9
- // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
10
- std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  const std::string & src,
12
  llama_partial_utf8 partial_start) {
13
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -40,7 +60,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
40
  while (*pos != 0) {
41
  uint8_t first_byte = static_cast<uint8_t>(*pos);
42
  uint8_t highbits = first_byte >> 4;
43
- n_remain = lookup[highbits] - 1;
44
 
45
  if (n_remain < 0) {
46
  // invalid sequence, abort
@@ -50,7 +70,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
50
  }
51
 
52
  uint8_t mask = (1 << (7 - n_remain)) - 1;
53
- value = first_byte & mask;
54
 
55
  ++pos;
56
  while (*pos != 0 && n_remain > 0) {
@@ -67,12 +87,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
67
  return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
68
  }
69
 
70
- const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
71
- return grammar->rules;
72
  }
73
 
74
- llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
75
- return grammar->stacks;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  }
77
 
78
  // returns true iff pos points to the end of one of the definitions of a rule
@@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
89
  static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
90
  const llama_grammar_element * pos,
91
  const uint32_t chr) {
92
-
93
  bool found = false;
94
  bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
95
 
@@ -225,16 +742,93 @@ static void llama_grammar_advance_stack(
225
  }
226
  }
227
 
228
- // takes a set of possible pushdown stacks on a grammar, which are required to
229
- // be positioned at a character range (see `llama_grammar_advance_stack`), and
230
- // produces the N possible stacks if the given char is accepted at those
231
- // positions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  void llama_grammar_accept(
233
  const llama_grammar_rules & rules,
234
  const llama_grammar_stacks & stacks,
235
  const uint32_t chr,
236
- llama_grammar_stacks & new_stacks) {
237
- new_stacks.clear();
 
238
 
239
  for (const auto & stack : stacks) {
240
  if (stack.empty()) {
@@ -250,29 +844,11 @@ void llama_grammar_accept(
250
  if (!llama_grammar_is_end_of_sequence(pos)) {
251
  new_stack.push_back(pos);
252
  }
253
- llama_grammar_advance_stack(rules, new_stack, new_stacks);
254
  }
255
  }
256
  }
257
 
258
- static llama_grammar_candidates llama_grammar_reject_candidates(
259
- const llama_grammar_rules & rules,
260
- const llama_grammar_stacks & stacks,
261
- const llama_grammar_candidates & candidates) {
262
- GGML_ASSERT(!stacks.empty()); // REVIEW
263
-
264
- if (candidates.empty()) {
265
- return {};
266
- }
267
-
268
- auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
269
-
270
- for (size_t i = 1, size = stacks.size(); i < size; ++i) {
271
- rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
272
- }
273
- return rejects;
274
- }
275
-
276
  llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
277
  const llama_grammar_rules & rules,
278
  const llama_grammar_stack & stack,
@@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
328
  return rejects;
329
  }
330
 
331
- static bool llama_grammar_detect_left_recursion(
332
- const llama_grammar_rules & rules,
333
- size_t rule_index,
334
- std::vector<bool> * rules_visited,
335
- std::vector<bool> * rules_in_progress,
336
- std::vector<bool> * rules_may_be_empty) {
337
- if ((*rules_in_progress)[rule_index]) {
338
- return true;
339
- }
340
 
341
- (*rules_in_progress)[rule_index] = true;
 
 
 
 
 
342
 
343
- const llama_grammar_rule & rule = rules[rule_index];
 
 
 
 
 
 
 
344
 
345
- // First check if the rule might produce the empty string. This could be done combined with the second
346
- // step but it's more readable as two steps.
347
- bool at_rule_start = true;
348
- for (size_t i = 0; i < rule.size(); i++) {
349
- if (llama_grammar_is_end_of_sequence(&rule[i])) {
350
- if (at_rule_start) {
351
- (*rules_may_be_empty)[rule_index] = true;
352
- break;
353
- }
354
- at_rule_start = true;
355
- } else {
356
- at_rule_start = false;
357
  }
358
  }
359
 
360
- // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
361
- // be empty)
362
- bool recurse_into_nonterminal = true;
363
- for (size_t i = 0; i < rule.size(); i++) {
364
- if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
365
- if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
366
- return true;
367
- }
368
- if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
369
- recurse_into_nonterminal = false;
370
- }
371
- } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
372
- recurse_into_nonterminal = true;
 
 
 
 
373
  } else {
374
- recurse_into_nonterminal = false;
375
  }
376
- }
377
 
378
- (*rules_in_progress)[rule_index] = false;
379
- (*rules_visited)[rule_index] = true;
380
- return false;
 
381
  }
382
 
383
- //
384
- // grammar - external
385
- //
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- struct llama_grammar * llama_grammar_init_impl(
388
- const llama_grammar_element ** rules,
389
- size_t n_rules,
390
- size_t start_rule_index) {
391
  const llama_grammar_element * pos;
392
 
393
  // copy rule definitions into vectors
394
  llama_grammar_rules vec_rules(n_rules);
395
  for (size_t i = 0; i < n_rules; i++) {
396
- for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
397
  vec_rules[i].push_back(*pos);
398
  }
399
  vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
@@ -438,22 +1039,26 @@ struct llama_grammar * llama_grammar_init_impl(
438
  // Important: vec_rules has to be moved here, not copied, because stacks contains
439
  // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
440
  // then the pointers would be invalidated when the local vec_rules goes out of scope.
441
- return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
442
  }
443
 
444
  void llama_grammar_free_impl(struct llama_grammar * grammar) {
 
 
 
 
445
  delete grammar;
446
  }
447
 
448
- struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
449
- llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
450
 
451
  // redirect elements in stacks to point to new rules
452
  for (size_t is = 0; is < result->stacks.size(); is++) {
453
  for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
454
- for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
455
- for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
456
- if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
457
  result->stacks[is][ie] = &result->rules[ir0][ir1];
458
  }
459
  }
@@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
464
  return result;
465
  }
466
 
467
- void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468
- GGML_ASSERT(grammar);
469
- GGML_ASSERT(vocab);
470
-
471
- int64_t t_start_sample_us = ggml_time_us();
472
 
473
  bool allow_eog = false;
474
- for (const auto & stack : grammar->stacks) {
475
  if (stack.empty()) {
476
  allow_eog = true;
477
  break;
@@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
479
  }
480
 
481
  std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
482
- candidates_decoded.reserve(candidates->size);
483
 
484
  llama_grammar_candidates candidates_grammar;
485
- candidates_grammar.reserve(candidates->size);
486
 
487
- for (size_t i = 0; i < candidates->size; ++i) {
488
- const llama_token id = candidates->data[i].id;
489
- const std::string & piece = vocab->cache_token_to_piece.at(id);
490
 
491
- if (llama_token_is_eog_impl(*vocab, id)) {
492
  if (!allow_eog) {
493
- candidates->data[i].logit = -INFINITY;
494
  }
495
  } else if (piece.empty() || piece[0] == 0) {
496
- candidates->data[i].logit = -INFINITY;
497
  } else {
498
- candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
499
  candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
500
  }
501
  }
502
 
503
- const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
504
  for (const auto & reject : rejects) {
505
- candidates->data[reject.index].logit = -INFINITY;
506
  }
507
-
508
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
509
  }
510
 
511
- void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
512
- const int64_t t_start_sample_us = ggml_time_us();
513
 
514
- if (llama_token_is_eog_impl(*vocab, token)) {
515
- for (const auto & stack : grammar->stacks) {
516
  if (stack.empty()) {
517
  return;
518
  }
@@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
520
  GGML_ABORT("fatal error");
521
  }
522
 
523
- const std::string & piece = vocab->cache_token_to_piece.at(token);
524
 
525
  // Note terminating 0 in decoded string
526
- const auto decoded = decode_utf8(piece, grammar->partial_utf8);
527
  const auto & code_points = decoded.first;
528
 
529
- llama_grammar_stacks tmp_new_stacks;
 
530
  for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
531
- llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
532
- grammar->stacks = tmp_new_stacks;
533
  }
534
 
535
- grammar->partial_utf8 = decoded.second;
536
- GGML_ASSERT(!grammar->stacks.empty());
537
-
538
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
539
  }
 
3
  #include "llama-vocab.h"
4
  #include "llama-sampling.h"
5
 
6
+ #include <cmath>
7
  #include <algorithm>
8
+ #include <stdexcept>
9
 
10
+ //
11
+ // helpers
12
+ //
13
+
14
+ // NOTE: assumes valid utf8 (but checks for overrun)
15
+ static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
16
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
17
+ uint8_t first_byte = static_cast<uint8_t>(*src);
18
+ uint8_t highbits = first_byte >> 4;
19
+ int len = lookup[highbits];
20
+ uint8_t mask = (1 << (8 - len)) - 1;
21
+ uint32_t value = first_byte & mask;
22
+ const char * end = src + len; // may overrun!
23
+ const char * pos = src + 1;
24
+ for ( ; pos < end && *pos; pos++) {
25
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
26
+ }
27
+ return std::make_pair(value, pos);
28
+ }
29
+
30
+ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
31
  const std::string & src,
32
  llama_partial_utf8 partial_start) {
33
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
 
60
  while (*pos != 0) {
61
  uint8_t first_byte = static_cast<uint8_t>(*pos);
62
  uint8_t highbits = first_byte >> 4;
63
+ n_remain = lookup[highbits] - 1;
64
 
65
  if (n_remain < 0) {
66
  // invalid sequence, abort
 
70
  }
71
 
72
  uint8_t mask = (1 << (7 - n_remain)) - 1;
73
+ value = first_byte & mask;
74
 
75
  ++pos;
76
  while (*pos != 0 && n_remain > 0) {
 
87
  return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
88
  }
89
 
90
+ static bool is_digit_char(char c) {
91
+ return '0' <= c && c <= '9';
92
  }
93
 
94
+ static bool is_word_char(char c) {
95
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
96
+ }
97
+
98
+ static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
99
+ const char * pos = src;
100
+ const char * end = src + size;
101
+ uint32_t value = 0;
102
+ for ( ; pos < end && *pos; pos++) {
103
+ value <<= 4;
104
+ char c = *pos;
105
+ if ('a' <= c && c <= 'f') {
106
+ value += c - 'a' + 10;
107
+ } else if ('A' <= c && c <= 'F') {
108
+ value += c - 'A' + 10;
109
+ } else if ('0' <= c && c <= '9') {
110
+ value += c - '0';
111
+ } else {
112
+ break;
113
+ }
114
+ }
115
+ if (pos != end) {
116
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
117
+ }
118
+ return std::make_pair(value, pos);
119
+ }
120
+
121
+ static const char * parse_space(const char * src, bool newline_ok) {
122
+ const char * pos = src;
123
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
124
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
125
+ if (*pos == '#') {
126
+ while (*pos && *pos != '\r' && *pos != '\n') {
127
+ pos++;
128
+ }
129
+ } else {
130
+ pos++;
131
+ }
132
+ }
133
+ return pos;
134
+ }
135
+
136
+ static const char * parse_name(const char * src) {
137
+ const char * pos = src;
138
+ while (is_word_char(*pos)) {
139
+ pos++;
140
+ }
141
+ if (pos == src) {
142
+ throw std::runtime_error(std::string("expecting name at ") + src);
143
+ }
144
+ return pos;
145
+ }
146
+
147
+ static const char * parse_int(const char * src) {
148
+ const char * pos = src;
149
+ while (is_digit_char(*pos)) {
150
+ pos++;
151
+ }
152
+ if (pos == src) {
153
+ throw std::runtime_error(std::string("expecting integer at ") + src);
154
+ }
155
+ return pos;
156
+ }
157
+
158
+ static std::pair<uint32_t, const char *> parse_char(const char * src) {
159
+ if (*src == '\\') {
160
+ switch (src[1]) {
161
+ case 'x': return parse_hex(src + 2, 2);
162
+ case 'u': return parse_hex(src + 2, 4);
163
+ case 'U': return parse_hex(src + 2, 8);
164
+ case 't': return std::make_pair('\t', src + 2);
165
+ case 'r': return std::make_pair('\r', src + 2);
166
+ case 'n': return std::make_pair('\n', src + 2);
167
+ case '\\':
168
+ case '"':
169
+ case '[':
170
+ case ']':
171
+ return std::make_pair(src[1], src + 2);
172
+ default:
173
+ throw std::runtime_error(std::string("unknown escape at ") + src);
174
+ }
175
+ } else if (*src) {
176
+ return decode_utf8(src);
177
+ }
178
+ throw std::runtime_error("unexpected end of input");
179
+ }
180
+
181
+ static void print_grammar_char(FILE * file, uint32_t c) {
182
+ if (0x20 <= c && c <= 0x7f) {
183
+ fprintf(file, "%c", static_cast<char>(c));
184
+ } else {
185
+ // cop out of encoding UTF-8
186
+ fprintf(file, "<U+%04X>", c);
187
+ }
188
+ }
189
+
190
+ static bool is_char_element(llama_grammar_element elem) {
191
+ switch (elem.type) {
192
+ case LLAMA_GRETYPE_CHAR: return true;
193
+ case LLAMA_GRETYPE_CHAR_NOT: return true;
194
+ case LLAMA_GRETYPE_CHAR_ALT: return true;
195
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
196
+ case LLAMA_GRETYPE_CHAR_ANY: return true;
197
+ default: return false;
198
+ }
199
+ }
200
+
201
+ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
202
+ for (auto elem : rule) {
203
+ switch (elem.type) {
204
+ case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
205
+ case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
206
+ case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
207
+ case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
208
+ case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
209
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
210
+ case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
211
+ case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
212
+ }
213
+ switch (elem.type) {
214
+ case LLAMA_GRETYPE_END:
215
+ case LLAMA_GRETYPE_ALT:
216
+ case LLAMA_GRETYPE_RULE_REF:
217
+ fprintf(file, "(%u) ", elem.value);
218
+ break;
219
+ case LLAMA_GRETYPE_CHAR:
220
+ case LLAMA_GRETYPE_CHAR_NOT:
221
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
222
+ case LLAMA_GRETYPE_CHAR_ALT:
223
+ case LLAMA_GRETYPE_CHAR_ANY:
224
+ fprintf(file, "(\"");
225
+ print_grammar_char(file, elem.value);
226
+ fprintf(file, "\") ");
227
+ break;
228
+ }
229
+ }
230
+ fprintf(file, "\n");
231
+ }
232
+
233
+ static void print_rule(
234
+ FILE * file,
235
+ uint32_t rule_id,
236
+ const llama_grammar_rule & rule,
237
+ const std::map<uint32_t, std::string> & symbol_id_names) {
238
+ if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
239
+ throw std::runtime_error(
240
+ "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
241
+ }
242
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
243
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
244
+ llama_grammar_element elem = rule[i];
245
+ switch (elem.type) {
246
+ case LLAMA_GRETYPE_END:
247
+ throw std::runtime_error(
248
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
249
+ std::to_string(i));
250
+ case LLAMA_GRETYPE_ALT:
251
+ fprintf(file, "| ");
252
+ break;
253
+ case LLAMA_GRETYPE_RULE_REF:
254
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
255
+ break;
256
+ case LLAMA_GRETYPE_CHAR:
257
+ fprintf(file, "[");
258
+ print_grammar_char(file, elem.value);
259
+ break;
260
+ case LLAMA_GRETYPE_CHAR_NOT:
261
+ fprintf(file, "[^");
262
+ print_grammar_char(file, elem.value);
263
+ break;
264
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
265
+ if (i == 0 || !is_char_element(rule[i - 1])) {
266
+ throw std::runtime_error(
267
+ "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
268
+ std::to_string(rule_id) + "," + std::to_string(i));
269
+ }
270
+ fprintf(file, "-");
271
+ print_grammar_char(file, elem.value);
272
+ break;
273
+ case LLAMA_GRETYPE_CHAR_ALT:
274
+ if (i == 0 || !is_char_element(rule[i - 1])) {
275
+ throw std::runtime_error(
276
+ "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
277
+ std::to_string(rule_id) + "," + std::to_string(i));
278
+ }
279
+ print_grammar_char(file, elem.value);
280
+ break;
281
+ case LLAMA_GRETYPE_CHAR_ANY:
282
+ fprintf(file, ".");
283
+ break;
284
+ }
285
+ if (is_char_element(elem)) {
286
+ switch (rule[i + 1].type) {
287
+ case LLAMA_GRETYPE_CHAR_ALT:
288
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
289
+ case LLAMA_GRETYPE_CHAR_ANY:
290
+ break;
291
+ default:
292
+ fprintf(file, "] ");
293
+ }
294
+ }
295
+ }
296
+ fprintf(file, "\n");
297
+ }
298
+
299
+ //
300
+ // implementation
301
+ //
302
+
303
+ uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
304
+ uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
305
+ auto result = symbol_ids.emplace(std::string(src, len), next_id);
306
+ return result.first->second;
307
+ }
308
+
309
+ uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
310
+ uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
311
+ symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
312
+ return next_id;
313
+ }
314
+
315
+ void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
316
+ if (rules.size() <= rule_id) {
317
+ rules.resize(rule_id + 1);
318
+ }
319
+ rules[rule_id] = rule;
320
+ }
321
+
322
+ const char * llama_grammar_parser::parse_alternates(
323
+ const char * src,
324
+ const std::string & rule_name,
325
+ uint32_t rule_id,
326
+ bool is_nested) {
327
+ llama_grammar_rule rule;
328
+ const char * pos = parse_sequence(src, rule_name, rule, is_nested);
329
+ while (*pos == '|') {
330
+ rule.push_back({LLAMA_GRETYPE_ALT, 0});
331
+ pos = parse_space(pos + 1, true);
332
+ pos = parse_sequence(pos, rule_name, rule, is_nested);
333
+ }
334
+ rule.push_back({LLAMA_GRETYPE_END, 0});
335
+ add_rule(rule_id, rule);
336
+ return pos;
337
+ }
338
+
339
+ const char * llama_grammar_parser::parse_sequence(
340
+ const char * src,
341
+ const std::string & rule_name,
342
+ llama_grammar_rule & rule,
343
+ bool is_nested) {
344
+ size_t last_sym_start = rule.size();
345
+ const char * pos = src;
346
+
347
+ auto handle_repetitions = [&](int min_times, int max_times) {
348
+
349
+ if (last_sym_start == rule.size()) {
350
+ throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
351
+ }
352
+
353
+ // apply transformation to previous symbol (last_sym_start to end) according to
354
+ // the following rewrite rules:
355
+ // S{m,n} --> S S S (m times) S'(n-m)
356
+ // S'(x) ::= S S'(x-1) |
357
+ // (... n-m definitions of these S' rules ...)
358
+ // S'(1) ::= S |
359
+ // S{m,} --> S S S (m times) S'
360
+ // S' ::= S S' |
361
+ // S* --> S{0,}
362
+ // --> S' ::= S S' |
363
+ // S+ --> S{1,}
364
+ // --> S S'
365
+ // S' ::= S S' |
366
+ // S? --> S{0,1}
367
+ // --> S'
368
+ // S' ::= S |
369
+
370
+ llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
371
+ if (min_times == 0) {
372
+ rule.resize(last_sym_start);
373
+ } else {
374
+ // Repeat the previous elements (min_times - 1) times
375
+ for (int i = 1; i < min_times; i++) {
376
+ rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
377
+ }
378
+ }
379
+
380
+ uint32_t last_rec_rule_id = 0;
381
+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
382
+
383
+ llama_grammar_rule rec_rule(prev_rule);
384
+ for (int i = 0; i < n_opt; i++) {
385
+ rec_rule.resize(prev_rule.size());
386
+ uint32_t rec_rule_id = generate_symbol_id( rule_name);
387
+ if (i > 0 || max_times < 0) {
388
+ rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
389
+ }
390
+ rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
391
+ rec_rule.push_back({LLAMA_GRETYPE_END, 0});
392
+ add_rule( rec_rule_id, rec_rule);
393
+ last_rec_rule_id = rec_rule_id;
394
+ }
395
+ if (n_opt > 0) {
396
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
397
+ }
398
+ };
399
+
400
+ while (*pos) {
401
+ if (*pos == '"') { // literal string
402
+ pos++;
403
+ last_sym_start = rule.size();
404
+ while (*pos != '"') {
405
+ if (!*pos) {
406
+ throw std::runtime_error("unexpected end of input");
407
+ }
408
+ auto char_pair = parse_char(pos);
409
+ pos = char_pair.second;
410
+ rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
411
+ }
412
+ pos = parse_space(pos + 1, is_nested);
413
+ } else if (*pos == '[') { // char range(s)
414
+ pos++;
415
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
416
+ if (*pos == '^') {
417
+ pos++;
418
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
419
+ }
420
+ last_sym_start = rule.size();
421
+ while (*pos != ']') {
422
+ if (!*pos) {
423
+ throw std::runtime_error("unexpected end of input");
424
+ }
425
+ auto char_pair = parse_char(pos);
426
+ pos = char_pair.second;
427
+ enum llama_gretype type = last_sym_start < rule.size()
428
+ ? LLAMA_GRETYPE_CHAR_ALT
429
+ : start_type;
430
+
431
+ rule.push_back({type, char_pair.first});
432
+ if (pos[0] == '-' && pos[1] != ']') {
433
+ if (!pos[1]) {
434
+ throw std::runtime_error("unexpected end of input");
435
+ }
436
+ auto endchar_pair = parse_char(pos + 1);
437
+ pos = endchar_pair.second;
438
+ rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
439
+ }
440
+ }
441
+ pos = parse_space(pos + 1, is_nested);
442
+ } else if (is_word_char(*pos)) { // rule reference
443
+ const char * name_end = parse_name(pos);
444
+ uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
445
+ pos = parse_space(name_end, is_nested);
446
+ last_sym_start = rule.size();
447
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
448
+ } else if (*pos == '(') { // grouping
449
+ // parse nested alternates into synthesized rule
450
+ pos = parse_space(pos + 1, true);
451
+ uint32_t sub_rule_id = generate_symbol_id(rule_name);
452
+ pos = parse_alternates(pos, rule_name, sub_rule_id, true);
453
+ last_sym_start = rule.size();
454
+ // output reference to synthesized rule
455
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
456
+ if (*pos != ')') {
457
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
458
+ }
459
+ pos = parse_space(pos + 1, is_nested);
460
+ } else if (*pos == '.') { // any char
461
+ last_sym_start = rule.size();
462
+ rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
463
+ pos = parse_space(pos + 1, is_nested);
464
+ } else if (*pos == '*') {
465
+ pos = parse_space(pos + 1, is_nested);
466
+ handle_repetitions(0, -1);
467
+ } else if (*pos == '+') {
468
+ pos = parse_space(pos + 1, is_nested);
469
+ handle_repetitions(1, -1);
470
+ } else if (*pos == '?') {
471
+ pos = parse_space(pos + 1, is_nested);
472
+ handle_repetitions(0, 1);
473
+ } else if (*pos == '{') {
474
+ pos = parse_space(pos + 1, is_nested);
475
+
476
+ if (!is_digit_char(*pos)) {
477
+ throw std::runtime_error(std::string("expecting an int at ") + pos);
478
+ }
479
+ const char * int_end = parse_int(pos);
480
+ int min_times = std::stoul(std::string(pos, int_end - pos));
481
+ pos = parse_space(int_end, is_nested);
482
+
483
+ int max_times = -1;
484
+
485
+ if (*pos == '}') {
486
+ max_times = min_times;
487
+ pos = parse_space(pos + 1, is_nested);
488
+ } else if (*pos == ',') {
489
+ pos = parse_space(pos + 1, is_nested);
490
+
491
+ if (is_digit_char(*pos)) {
492
+ const char * int_end = parse_int(pos);
493
+ max_times = std::stoul(std::string(pos, int_end - pos));
494
+ pos = parse_space(int_end, is_nested);
495
+ }
496
+
497
+ if (*pos != '}') {
498
+ throw std::runtime_error(std::string("expecting '}' at ") + pos);
499
+ }
500
+ pos = parse_space(pos + 1, is_nested);
501
+ } else {
502
+ throw std::runtime_error(std::string("expecting ',' at ") + pos);
503
+ }
504
+ handle_repetitions(min_times, max_times);
505
+ } else {
506
+ break;
507
+ }
508
+ }
509
+ return pos;
510
+ }
511
+
512
+ const char * llama_grammar_parser::parse_rule(const char * src) {
513
+ const char * name_end = parse_name(src);
514
+ const char * pos = parse_space(name_end, false);
515
+ size_t name_len = name_end - src;
516
+ uint32_t rule_id = get_symbol_id(src, name_len);
517
+ const std::string name(src, name_len);
518
+
519
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
520
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
521
+ }
522
+ pos = parse_space(pos + 3, true);
523
+
524
+ pos = parse_alternates(pos, name, rule_id, false);
525
+
526
+ if (*pos == '\r') {
527
+ pos += pos[1] == '\n' ? 2 : 1;
528
+ } else if (*pos == '\n') {
529
+ pos++;
530
+ } else if (*pos) {
531
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
532
+ }
533
+ return parse_space(pos, true);
534
+ }
535
+
536
+ bool llama_grammar_parser::parse(const char * src) {
537
+ try {
538
+ const char * pos = parse_space(src, true);
539
+ while (*pos) {
540
+ pos = parse_rule(pos);
541
+ }
542
+ // Validate the state to ensure that all rules are defined
543
+ for (const auto & rule : rules) {
544
+ if (rule.empty()) {
545
+ throw std::runtime_error("Undefined rule");
546
+ }
547
+ for (const auto & elem : rule) {
548
+ if (elem.type == LLAMA_GRETYPE_RULE_REF) {
549
+ // Ensure that the rule at that location exists
550
+ if (elem.value >= rules.size() || rules[elem.value].empty()) {
551
+ // Get the name of the rule that is missing
552
+ for (const auto & kv : symbol_ids) {
553
+ if (kv.second == elem.value) {
554
+ throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
555
+ }
556
+ }
557
+ }
558
+ }
559
+ }
560
+ }
561
+ } catch (const std::exception & err) {
562
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
563
+ rules.clear();
564
+ return false;
565
+ }
566
+
567
+ return true;
568
+ }
569
+
570
+ void llama_grammar_parser::print(FILE * file) {
571
+ try {
572
+ std::map<uint32_t, std::string> symbol_id_names;
573
+ for (const auto & kv : symbol_ids) {
574
+ symbol_id_names[kv.second] = kv.first;
575
+ }
576
+ for (size_t i = 0, end = rules.size(); i < end; i++) {
577
+ // fprintf(file, "%zu: ", i);
578
+ // print_rule_binary(file, rules[i]);
579
+ print_rule(file, uint32_t(i), rules[i], symbol_id_names);
580
+ // fprintf(file, "\n");
581
+ }
582
+ } catch (const std::exception & err) {
583
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
584
+ }
585
+ }
586
+
587
+ llama_grammar_stack llama_grammar_parser::c_rules() const {
588
+ llama_grammar_stack ret;
589
+ ret.reserve(rules.size());
590
+ for (const auto & rule : rules) {
591
+ ret.push_back(rule.data());
592
+ }
593
+ return ret;
594
  }
595
 
596
  // returns true iff pos points to the end of one of the definitions of a rule
 
607
  static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
608
  const llama_grammar_element * pos,
609
  const uint32_t chr) {
 
610
  bool found = false;
611
  bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
612
 
 
742
  }
743
  }
744
 
745
+ static llama_grammar_candidates llama_grammar_reject_candidates(
746
+ const llama_grammar_rules & rules,
747
+ const llama_grammar_stacks & stacks,
748
+ const llama_grammar_candidates & candidates) {
749
+ GGML_ASSERT(!stacks.empty()); // REVIEW
750
+
751
+ if (candidates.empty()) {
752
+ return {};
753
+ }
754
+
755
+ auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
756
+
757
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
758
+ rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
759
+ }
760
+
761
+ return rejects;
762
+ }
763
+
764
+ static bool llama_grammar_detect_left_recursion(
765
+ const llama_grammar_rules & rules,
766
+ size_t rule_index,
767
+ std::vector<bool> * rules_visited,
768
+ std::vector<bool> * rules_in_progress,
769
+ std::vector<bool> * rules_may_be_empty) {
770
+ if ((*rules_in_progress)[rule_index]) {
771
+ return true;
772
+ }
773
+
774
+ (*rules_in_progress)[rule_index] = true;
775
+
776
+ const llama_grammar_rule & rule = rules[rule_index];
777
+
778
+ // First check if the rule might produce the empty string. This could be done combined with the second
779
+ // step but it's more readable as two steps.
780
+ bool at_rule_start = true;
781
+ for (size_t i = 0; i < rule.size(); i++) {
782
+ if (llama_grammar_is_end_of_sequence(&rule[i])) {
783
+ if (at_rule_start) {
784
+ (*rules_may_be_empty)[rule_index] = true;
785
+ break;
786
+ }
787
+ at_rule_start = true;
788
+ } else {
789
+ at_rule_start = false;
790
+ }
791
+ }
792
+
793
+ // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
794
+ // be empty)
795
+ bool recurse_into_nonterminal = true;
796
+ for (size_t i = 0; i < rule.size(); i++) {
797
+ if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
798
+ if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
799
+ return true;
800
+ }
801
+ if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
802
+ recurse_into_nonterminal = false;
803
+ }
804
+ } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
805
+ recurse_into_nonterminal = true;
806
+ } else {
807
+ recurse_into_nonterminal = false;
808
+ }
809
+ }
810
+
811
+ (*rules_in_progress)[rule_index] = false;
812
+ (*rules_visited)[rule_index] = true;
813
+
814
+ return false;
815
+ }
816
+
817
+ const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
818
+ return grammar->rules;
819
+ }
820
+
821
+ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
822
+ return grammar->stacks;
823
+ }
824
+
825
  void llama_grammar_accept(
826
  const llama_grammar_rules & rules,
827
  const llama_grammar_stacks & stacks,
828
  const uint32_t chr,
829
+ llama_grammar_stacks & stacks_new) {
830
+ stacks_new.clear();
831
+ stacks_new.reserve(stacks.size());
832
 
833
  for (const auto & stack : stacks) {
834
  if (stack.empty()) {
 
844
  if (!llama_grammar_is_end_of_sequence(pos)) {
845
  new_stack.push_back(pos);
846
  }
847
+ llama_grammar_advance_stack(rules, new_stack, stacks_new);
848
  }
849
  }
850
  }
851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
853
  const llama_grammar_rules & rules,
854
  const llama_grammar_stack & stack,
 
904
  return rejects;
905
  }
906
 
907
+ ////////////////////
 
 
 
 
 
 
 
 
908
 
909
+ struct llama_grammar * llama_grammar_init_impl(
910
+ const struct llama_vocab * vocab,
911
+ const llama_grammar_element ** rules,
912
+ size_t n_rules,
913
+ size_t start_rule_index) {
914
+ const llama_grammar_element * pos;
915
 
916
+ // copy rule definitions into vectors
917
+ llama_grammar_rules vec_rules(n_rules);
918
+ for (size_t i = 0; i < n_rules; i++) {
919
+ for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
920
+ vec_rules[i].push_back(*pos);
921
+ }
922
+ vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
923
+ }
924
 
925
+ // Check for left recursion
926
+ std::vector<bool> rules_visited(n_rules);
927
+ std::vector<bool> rules_in_progress(n_rules);
928
+ std::vector<bool> rules_may_be_empty(n_rules);
929
+ for (size_t i = 0; i < n_rules; i++) {
930
+ if (rules_visited[i]) {
931
+ continue;
932
+ }
933
+ if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
934
+ LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
935
+ return nullptr;
 
936
  }
937
  }
938
 
939
+ // loop over alternates of start rule to build initial stacks
940
+ llama_grammar_stacks stacks;
941
+ pos = vec_rules[start_rule_index].data();
942
+ do {
943
+ llama_grammar_stack stack;
944
+ if (!llama_grammar_is_end_of_sequence(pos)) {
945
+ // if alternate is nonempty, add to stack
946
+ stack.push_back(pos);
947
+ }
948
+ llama_grammar_advance_stack(vec_rules, stack, stacks);
949
+ while (!llama_grammar_is_end_of_sequence(pos)) {
950
+ // scan to end of alternate def
951
+ pos++;
952
+ }
953
+ if (pos->type == LLAMA_GRETYPE_ALT) {
954
+ // there's another alternate def of this rule to process
955
+ pos++;
956
  } else {
957
+ break;
958
  }
959
+ } while (true);
960
 
961
+ // Important: vec_rules has to be moved here, not copied, because stacks contains
962
+ // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
963
+ // then the pointers would be invalidated when the local vec_rules goes out of scope.
964
+ return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
965
  }
966
 
967
+ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
968
+ llama_grammar_parser parser;
969
+
970
+ // if there is a grammar, parse it
971
+ if (!parser.parse(grammar_str)) {
972
+ return nullptr;
973
+ }
974
+
975
+ // will be empty (default) if there are parse errors
976
+ if (parser.rules.empty()) {
977
+ fprintf(stderr, "%s: failed to parse grammar\n", __func__);
978
+ return nullptr;
979
+ }
980
+
981
+ // Ensure that there is a "root" node.
982
+ if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
983
+ fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
984
+ return nullptr;
985
+ }
986
+
987
+ std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
988
+
989
+ const size_t n_rules = grammar_rules.size();
990
+ const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
991
 
 
 
 
 
992
  const llama_grammar_element * pos;
993
 
994
  // copy rule definitions into vectors
995
  llama_grammar_rules vec_rules(n_rules);
996
  for (size_t i = 0; i < n_rules; i++) {
997
+ for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
998
  vec_rules[i].push_back(*pos);
999
  }
1000
  vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
 
1039
  // Important: vec_rules has to be moved here, not copied, because stacks contains
1040
  // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1041
  // then the pointers would be invalidated when the local vec_rules goes out of scope.
1042
+ return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
1043
  }
1044
 
1045
  void llama_grammar_free_impl(struct llama_grammar * grammar) {
1046
+ if (grammar == nullptr) {
1047
+ return;
1048
+ }
1049
+
1050
  delete grammar;
1051
  }
1052
 
1053
+ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1054
+ llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
1055
 
1056
  // redirect elements in stacks to point to new rules
1057
  for (size_t is = 0; is < result->stacks.size(); is++) {
1058
  for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
1059
+ for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1060
+ for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1061
+ if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1062
  result->stacks[is][ie] = &result->rules[ir0][ir1];
1063
  }
1064
  }
 
1069
  return result;
1070
  }
1071
 
1072
+ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
1073
+ GGML_ASSERT(grammar.vocab != nullptr);
 
 
 
1074
 
1075
  bool allow_eog = false;
1076
+ for (const auto & stack : grammar.stacks) {
1077
  if (stack.empty()) {
1078
  allow_eog = true;
1079
  break;
 
1081
  }
1082
 
1083
  std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
1084
+ candidates_decoded.reserve(cur_p->size);
1085
 
1086
  llama_grammar_candidates candidates_grammar;
1087
+ candidates_grammar.reserve(cur_p->size);
1088
 
1089
+ for (size_t i = 0; i < cur_p->size; ++i) {
1090
+ const llama_token id = cur_p->data[i].id;
1091
+ const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
1092
 
1093
+ if (llama_token_is_eog_impl(*grammar.vocab, id)) {
1094
  if (!allow_eog) {
1095
+ cur_p->data[i].logit = -INFINITY;
1096
  }
1097
  } else if (piece.empty() || piece[0] == 0) {
1098
+ cur_p->data[i].logit = -INFINITY;
1099
  } else {
1100
+ candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
1101
  candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
1102
  }
1103
  }
1104
 
1105
+ const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
1106
  for (const auto & reject : rejects) {
1107
+ cur_p->data[reject.index].logit = -INFINITY;
1108
  }
 
 
1109
  }
1110
 
1111
+ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
1112
+ GGML_ASSERT(grammar.vocab != nullptr);
1113
 
1114
+ if (llama_token_is_eog_impl(*grammar.vocab, token)) {
1115
+ for (const auto & stack : grammar.stacks) {
1116
  if (stack.empty()) {
1117
  return;
1118
  }
 
1120
  GGML_ABORT("fatal error");
1121
  }
1122
 
1123
+ const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
1124
 
1125
  // Note terminating 0 in decoded string
1126
+ const auto decoded = decode_utf8(piece, grammar.partial_utf8);
1127
  const auto & code_points = decoded.first;
1128
 
1129
+ llama_grammar_stacks stacks_new;
1130
+
1131
  for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1132
+ llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
1133
+ grammar.stacks = std::move(stacks_new);
1134
  }
1135
 
1136
+ grammar.partial_utf8 = decoded.second;
1137
+ GGML_ASSERT(!grammar.stacks.empty());
 
 
1138
  }
examples/talk-llama/llama-grammar.h CHANGED
@@ -2,11 +2,115 @@
2
 
3
  #include "llama-impl.h"
4
 
 
 
5
  struct llama_vocab;
6
- struct llama_sampling;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  struct llama_grammar {
9
- const llama_grammar_rules rules;
 
 
 
10
  llama_grammar_stacks stacks;
11
 
12
  // buffer for partially generated UTF-8 sequence from accepted tokens
@@ -17,23 +121,24 @@ struct llama_grammar {
17
  // internal API
18
  //
19
 
 
20
  struct llama_grammar * llama_grammar_init_impl(
21
- const llama_grammar_element ** rules,
22
- size_t n_rules,
23
- size_t start_rule_index);
 
 
 
24
 
25
  void llama_grammar_free_impl(struct llama_grammar * grammar);
26
 
27
- struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
28
 
29
- void llama_grammar_sample_impl(
30
- const struct llama_grammar * grammar,
31
- const struct llama_vocab * vocab,
32
- const struct llama_sampling * smpl,
33
- llama_token_data_array * candidates);
34
 
35
- void llama_grammar_accept_token_impl(
36
- struct llama_grammar * grammar,
37
- const struct llama_vocab * vocab,
38
- const struct llama_sampling * smpl,
39
  llama_token token);
 
2
 
3
  #include "llama-impl.h"
4
 
5
+ #include <map>
6
+
7
  struct llama_vocab;
8
+
9
+ // grammar element type
10
+ enum llama_gretype {
11
+ // end of rule definition
12
+ LLAMA_GRETYPE_END = 0,
13
+
14
+ // start of alternate definition for rule
15
+ LLAMA_GRETYPE_ALT = 1,
16
+
17
+ // non-terminal element: reference to rule
18
+ LLAMA_GRETYPE_RULE_REF = 2,
19
+
20
+ // terminal element: character (code point)
21
+ LLAMA_GRETYPE_CHAR = 3,
22
+
23
+ // inverse char(s) ([^a], [^a-b] [^abc])
24
+ LLAMA_GRETYPE_CHAR_NOT = 4,
25
+
26
+ // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
27
+ // be an inclusive range ([a-z])
28
+ LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
29
+
30
+ // modifies a preceding LLAMA_GRETYPE_CHAR or
31
+ // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
32
+ LLAMA_GRETYPE_CHAR_ALT = 6,
33
+
34
+ // any character (.)
35
+ LLAMA_GRETYPE_CHAR_ANY = 7,
36
+ };
37
+
38
+ typedef struct llama_grammar_element {
39
+ enum llama_gretype type;
40
+ uint32_t value; // Unicode code point or rule ID
41
+ } llama_grammar_element;
42
+
43
+ struct llama_partial_utf8 {
44
+ uint32_t value; // bit value so far (unshifted)
45
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
46
+ };
47
+
48
+ struct llama_grammar_candidate {
49
+ size_t index;
50
+ const uint32_t * code_points;
51
+ llama_partial_utf8 partial_utf8;
52
+ };
53
+
54
+ using llama_grammar_rule = std::vector< llama_grammar_element>;
55
+ using llama_grammar_stack = std::vector<const llama_grammar_element *>;
56
+
57
+ using llama_grammar_rules = std::vector<llama_grammar_rule>;
58
+ using llama_grammar_stacks = std::vector<llama_grammar_stack>;
59
+ using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
60
+
61
+ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
62
+ llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
63
+
64
+ // takes a set of possible pushdown stacks on a grammar, which are required to
65
+ // be positioned at a character range (see `llama_grammar_advance_stack`), and
66
+ // produces the N possible stacks if the given char is accepted at those
67
+ // positions
68
+ void llama_grammar_accept(
69
+ const llama_grammar_rules & rules,
70
+ const llama_grammar_stacks & stacks,
71
+ uint32_t chr,
72
+ llama_grammar_stacks & stacks_new);
73
+
74
+ std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
75
+ const llama_grammar_rules & rules,
76
+ const llama_grammar_stack & stack,
77
+ const llama_grammar_candidates & candidates);
78
+
79
+ struct llama_grammar_parser {
80
+ std::map<std::string, uint32_t> symbol_ids;
81
+
82
+ llama_grammar_rules rules;
83
+
84
+ llama_grammar_stack c_rules() const;
85
+
86
+ uint32_t get_symbol_id(const char * src, size_t len);
87
+ uint32_t generate_symbol_id(const std::string & base_name);
88
+
89
+ void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
90
+
91
+ const char * parse_alternates(
92
+ const char * src,
93
+ const std::string & rule_name,
94
+ uint32_t rule_id,
95
+ bool is_nested);
96
+
97
+ const char * parse_sequence(
98
+ const char * src,
99
+ const std::string & rule_name,
100
+ llama_grammar_rule & rule,
101
+ bool is_nested);
102
+
103
+ const char * parse_rule(const char * src);
104
+
105
+ bool parse(const char * src);
106
+ void print(FILE * file);
107
+ };
108
 
109
  struct llama_grammar {
110
+ // note: allow null vocab for testing (not great)
111
+ const llama_vocab * vocab;
112
+
113
+ const llama_grammar_rules rules; // TODO: shared ptr
114
  llama_grammar_stacks stacks;
115
 
116
  // buffer for partially generated UTF-8 sequence from accepted tokens
 
121
  // internal API
122
  //
123
 
124
+ // note: needed for tests (not great)
125
  struct llama_grammar * llama_grammar_init_impl(
126
+ const struct llama_vocab * vocab,
127
+ const llama_grammar_element ** rules,
128
+ size_t n_rules,
129
+ size_t start_rule_index);
130
+
131
+ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
132
 
133
  void llama_grammar_free_impl(struct llama_grammar * grammar);
134
 
135
+ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
136
 
137
+ // TODO: move the API below as member functions of llama_grammar
138
+ void llama_grammar_apply_impl(
139
+ const struct llama_grammar & grammar,
140
+ llama_token_data_array * cur_p);
 
141
 
142
+ void llama_grammar_accept_impl(
143
+ struct llama_grammar & grammar,
 
 
144
  llama_token token);
examples/talk-llama/llama-impl.h CHANGED
@@ -1,8 +1,11 @@
1
  #pragma once
2
 
3
- #define LLAMA_API_INTERNAL
4
  #include "llama.h"
5
 
 
 
 
 
6
  #ifdef __GNUC__
7
  #ifdef __MINGW32__
8
  #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
@@ -21,14 +24,31 @@ LLAMA_ATTRIBUTE_FORMAT(2, 3)
21
  void llama_log_internal (ggml_log_level level, const char * format, ...);
22
  void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
23
 
 
24
  #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
25
  #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
26
  #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
 
 
27
 
28
  //
29
  // helpers
30
  //
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
33
  if (search.empty()) {
34
  return;
@@ -45,3 +65,117 @@ static void replace_all(std::string & s, const std::string & search, const std::
45
  builder.append(s, last_pos, std::string::npos);
46
  s = std::move(builder);
47
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #pragma once
2
 
 
3
  #include "llama.h"
4
 
5
+ #include <string>
6
+ #include <vector>
7
+ #include <stdexcept>
8
+
9
  #ifdef __GNUC__
10
  #ifdef __MINGW32__
11
  #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
 
24
  void llama_log_internal (ggml_log_level level, const char * format, ...);
25
  void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
26
 
27
+ #define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__)
28
  #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
29
  #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
30
  #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
31
+ #define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
32
+ #define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
33
 
34
  //
35
  // helpers
36
  //
37
 
38
+ struct time_meas {
39
+ time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
40
+
41
+ ~time_meas() {
42
+ if (t_start_us >= 0) {
43
+ t_acc += ggml_time_us() - t_start_us;
44
+ }
45
+ }
46
+
47
+ const int64_t t_start_us;
48
+
49
+ int64_t & t_acc;
50
+ };
51
+
52
  static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
53
  if (search.empty()) {
54
  return;
 
65
  builder.append(s, last_pos, std::string::npos);
66
  s = std::move(builder);
67
  }
68
+
69
+ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
70
+ struct llama_context * ctx
71
+ );
72
+
73
+ // the ring buffer works similarly to std::deque, but with a fixed capacity
74
+ template<typename T>
75
+ struct ring_buffer {
76
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
77
+
78
+ T & front() {
79
+ if (sz == 0) {
80
+ throw std::runtime_error("ring buffer is empty");
81
+ }
82
+ return data[first];
83
+ }
84
+
85
+ const T & front() const {
86
+ if (sz == 0) {
87
+ throw std::runtime_error("ring buffer is empty");
88
+ }
89
+ return data[first];
90
+ }
91
+
92
+ T & back() {
93
+ if (sz == 0) {
94
+ throw std::runtime_error("ring buffer is empty");
95
+ }
96
+ return data[pos];
97
+ }
98
+
99
+ const T & back() const {
100
+ if (sz == 0) {
101
+ throw std::runtime_error("ring buffer is empty");
102
+ }
103
+ return data[pos];
104
+ }
105
+
106
+ void push_back(const T & value) {
107
+ if (capacity == 0) {
108
+ throw std::runtime_error("ring buffer: capacity is zero");
109
+ }
110
+
111
+ if (sz == capacity) {
112
+ // advance the start when buffer is full
113
+ first = (first + 1) % capacity;
114
+ } else {
115
+ sz++;
116
+ }
117
+ data[pos] = value;
118
+ pos = (pos + 1) % capacity;
119
+ }
120
+
121
+ T pop_front() {
122
+ if (sz == 0) {
123
+ throw std::runtime_error("ring buffer is empty");
124
+ }
125
+ T value = data[first];
126
+ first = (first + 1) % capacity;
127
+ sz--;
128
+ return value;
129
+ }
130
+
131
+ //T & operator[](size_t i) {
132
+ // if (i >= sz) {
133
+ // throw std::runtime_error("ring buffer: index out of bounds");
134
+ // }
135
+ // return data[(first + i) % capacity];
136
+ //}
137
+
138
+ //const T & at(size_t i) const {
139
+ // if (i >= sz) {
140
+ // throw std::runtime_error("ring buffer: index out of bounds");
141
+ // }
142
+ // return data[(first + i) % capacity];
143
+ //}
144
+
145
+ const T & rat(size_t i) const {
146
+ if (i >= sz) {
147
+ throw std::runtime_error("ring buffer: index out of bounds");
148
+ }
149
+ return data[(first + sz - i - 1) % capacity];
150
+ }
151
+
152
+ std::vector<T> to_vector() const {
153
+ std::vector<T> result;
154
+ result.reserve(sz);
155
+ for (size_t i = 0; i < sz; i++) {
156
+ result.push_back(data[(first + i) % capacity]);
157
+ }
158
+ return result;
159
+ }
160
+
161
+ void clear() {
162
+ // here only reset the status of the buffer
163
+ sz = 0;
164
+ first = 0;
165
+ pos = 0;
166
+ }
167
+
168
+ bool empty() const {
169
+ return sz == 0;
170
+ }
171
+
172
+ size_t size() const {
173
+ return sz;
174
+ }
175
+
176
+ size_t capacity = 0;
177
+ size_t sz = 0;
178
+ size_t first = 0;
179
+ size_t pos = 0;
180
+ std::vector<T> data;
181
+ };
examples/talk-llama/llama-sampling.cpp CHANGED
@@ -1,12 +1,53 @@
1
  #include "llama-sampling.h"
2
 
 
 
 
3
  #include <algorithm>
 
 
 
 
 
4
  #include <cstring>
5
  #include <ctime>
6
- #include <cfloat>
7
  #include <numeric>
 
8
  #include <unordered_map>
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  static void llama_log_softmax(float * array, size_t size) {
11
  float max_l = *std::max_element(array, array + size);
12
  float sum = 0.f;
@@ -20,66 +61,52 @@ static void llama_log_softmax(float * array, size_t size) {
20
  array[i] = logf(array[i] / sum);
21
  }
22
  }
 
23
 
24
- void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
25
- if (seed == LLAMA_DEFAULT_SEED) {
26
- seed = time(NULL);
27
- }
28
-
29
- smpl->rng.seed(seed);
30
- }
31
-
32
- void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
33
- GGML_ASSERT(candidates->size > 0);
34
-
35
- const int64_t t_start_sample_us = ggml_time_us();
36
 
37
  // Sort the logits in descending order
38
- if (!candidates->sorted) {
39
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
40
  return a.logit > b.logit;
41
  });
42
- candidates->sorted = true;
43
  }
44
 
45
- float max_l = candidates->data[0].logit;
46
  float cum_sum = 0.0f;
47
- for (size_t i = 0; i < candidates->size; ++i) {
48
- float p = expf(candidates->data[i].logit - max_l);
49
- candidates->data[i].p = p;
 
50
  cum_sum += p;
51
  }
52
- for (size_t i = 0; i < candidates->size; ++i) {
53
- candidates->data[i].p /= cum_sum;
54
- }
55
 
56
- if (smpl) {
57
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
58
  }
59
  }
60
 
61
- void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
62
  // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
63
- // if (k >= (int32_t)candidates->size) {
64
  // return;
65
  // }
66
 
67
- const int64_t t_start_sample_us = ggml_time_us();
68
-
69
  if (k <= 0) {
70
- k = candidates->size;
71
  }
72
 
73
- k = std::max(k, (int) min_keep);
74
- k = std::min(k, (int) candidates->size);
75
 
76
  // Sort scores in descending order
77
- if (!candidates->sorted) {
78
  auto comp = [](const llama_token_data & a, const llama_token_data & b) {
79
  return a.logit > b.logit;
80
  };
81
  if (k <= 128) {
82
- std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
83
  } else {
84
  constexpr int nbuckets = 128;
85
  constexpr float bucket_low = -10.0f;
@@ -87,11 +114,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
87
  constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
88
  constexpr float bucket_inter = -bucket_low * bucket_scale;
89
 
90
- std::vector<int> bucket_idx(candidates->size);
91
  std::vector<int> histo(nbuckets, 0);
92
 
93
- for (int i = 0; i < (int)candidates->size; ++i) {
94
- const float val = candidates->data[i].logit;
95
  int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
96
  ib = std::max(0, std::min(nbuckets-1, ib));
97
  bucket_idx[i] = ib;
@@ -101,20 +128,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
101
  int ib = nbuckets - 1;
102
  for ( ; ib >= 0; --ib) {
103
  nhave += histo[ib];
104
- if (nhave >= k) break;
 
 
105
  }
106
  std::vector<llama_token_data> tmp_tokens(nhave);
107
- auto ptr = tmp_tokens.data();
108
  std::vector<llama_token_data*> bucket_ptrs;
109
  bucket_ptrs.reserve(nbuckets - ib);
110
  for (int j = nbuckets - 1; j >= ib; --j) {
111
  bucket_ptrs.push_back(ptr);
112
  ptr += histo[j];
113
  }
114
- for (int i = 0; i < (int)candidates->size; ++i) {
115
  int j = bucket_idx[i];
116
  if (j >= ib) {
117
- *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
118
  }
119
  }
120
 
@@ -127,125 +156,582 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
127
  }
128
  std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
129
 
130
- std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
131
 
132
  }
133
- candidates->sorted = true;
134
  }
135
- candidates->size = k;
 
136
 
137
- if (smpl) {
138
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 
 
 
 
 
 
 
139
  }
 
140
  }
141
 
142
- void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143
- if (p >= 1.0f) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return;
145
  }
146
 
147
- llama_sample_softmax_impl(smpl, candidates);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- const int64_t t_start_sample_us = ggml_time_us();
 
 
 
 
 
 
 
150
 
151
  // Compute the cumulative probabilities
152
  float cum_sum = 0.0f;
153
- size_t last_idx = candidates->size;
154
 
155
- for (size_t i = 0; i < candidates->size; ++i) {
156
- cum_sum += candidates->data[i].p;
157
 
158
  // Check if the running sum is at least p or if we have kept at least min_keep tokens
159
  // we set the last index to i+1 to indicate that the current iterate should be included in the set
160
- if (cum_sum >= p && i + 1 >= min_keep) {
161
  last_idx = i + 1;
162
  break;
163
  }
164
  }
165
 
166
  // Resize the output vector to keep only the top-p tokens
167
- candidates->size = last_idx;
 
168
 
169
- if (smpl) {
170
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
171
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  }
173
 
174
- void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
175
- if (p <= 0.0f || !candidates->size) {
 
 
176
  return;
177
  }
178
 
179
- const int64_t t_start_sample_us = ggml_time_us();
180
-
181
  bool min_p_applied = false;
182
 
183
- // if the candidates aren't sorted, try the unsorted implementation first
184
- if (!candidates->sorted) {
185
  std::vector<llama_token_data> filtered_tokens;
186
 
187
  float max_logit = -FLT_MAX;
188
- for (size_t i = 0; i < candidates->size; ++i) {
189
- max_logit = std::max(max_logit, candidates->data[i].logit);
190
  }
191
- const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
192
 
193
- for (size_t i = 0; i < candidates->size; ++i) {
194
- if (candidates->data[i].logit >= min_logit) {
195
- filtered_tokens.push_back(candidates->data[i]);
196
  }
197
  }
198
 
199
  // if we have enough values the operation was a success
200
- if (filtered_tokens.size() >= min_keep) {
201
- memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
202
- candidates->size = filtered_tokens.size();
203
  min_p_applied = true;
204
  }
205
  }
206
 
207
- // if the candidates are sorted or the unsorted implementation failed, use this implementation
208
  if (!min_p_applied) {
209
  // Sort the logits in descending order
210
- if (!candidates->sorted) {
211
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
212
  return a.logit > b.logit;
213
  });
214
- candidates->sorted = true;
215
  }
216
 
217
- const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
218
  size_t i = 1; // first token always matches
219
 
220
- for (; i < candidates->size; ++i) {
221
- if (candidates->data[i].logit < min_logit && i >= min_keep) {
222
  break; // prob too small
223
  }
224
  }
225
 
226
  // Resize the output vector to keep only the matching tokens
227
- candidates->size = i;
228
  }
 
229
 
230
- if (smpl) {
231
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
232
- }
233
  }
234
 
235
- void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
236
- if (z >= 1.0f || candidates->size <= 2) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  return;
238
  }
239
 
240
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
241
- const int64_t t_start_sample_us = ggml_time_us();
242
 
243
  // Compute the first and second derivatives
244
- std::vector<float> first_derivatives(candidates->size - 1);
245
- std::vector<float> second_derivatives(candidates->size - 2);
246
 
247
  for (size_t i = 0; i < first_derivatives.size(); ++i) {
248
- first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
249
  }
250
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
251
  second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
@@ -272,51 +758,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
272
  }
273
 
274
  float cum_sum = 0.0f;
275
- size_t last_idx = candidates->size;
276
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
277
  cum_sum += second_derivatives[i];
278
 
279
  // Check if the running sum is greater than z or if we have kept at least min_keep tokens
280
- if (cum_sum > z && i >= min_keep) {
281
  last_idx = i;
282
  break;
283
  }
284
  }
285
 
286
  // Resize the output vector to keep only the tokens above the tail location
287
- candidates->size = last_idx;
 
288
 
289
- if (smpl) {
290
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
291
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  }
293
 
294
- void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
 
 
295
  // Reference implementation:
296
  // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
297
- if (p >= 1.0f) {
298
  return;
299
  }
300
 
301
  // Compute the softmax of logits and calculate entropy
302
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
303
-
304
- const int64_t t_start_sample_us = ggml_time_us();
305
 
306
  float entropy = 0.0f;
307
- for (size_t i = 0; i < candidates->size; ++i) {
308
- entropy += -candidates->data[i].p * logf(candidates->data[i].p);
309
  }
310
 
311
  // Compute the absolute difference between negative log probability and entropy for each candidate
312
  std::vector<float> shifted_scores;
313
- for (size_t i = 0; i < candidates->size; ++i) {
314
- float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
315
  shifted_scores.push_back(shifted_score);
316
  }
317
 
318
  // Sort tokens based on the shifted_scores and their corresponding indices
319
- std::vector<size_t> indices(candidates->size);
320
  std::iota(indices.begin(), indices.end(), 0);
321
 
322
  std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
@@ -329,134 +850,618 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
329
 
330
  for (size_t i = 0; i < indices.size(); ++i) {
331
  size_t idx = indices[i];
332
- cum_sum += candidates->data[idx].p;
333
 
334
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
335
- if (cum_sum > p && i >= min_keep - 1) {
336
  last_idx = i + 1;
337
  break;
338
  }
339
  }
340
 
341
  // Resize the output vector to keep only the locally typical tokens
342
- std::vector<llama_token_data> new_candidates;
343
  for (size_t i = 0; i < last_idx; ++i) {
344
  size_t idx = indices[i];
345
- new_candidates.push_back(candidates->data[idx]);
346
  }
347
 
348
- // Replace the data in candidates with the new_candidates data
349
- std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
350
- candidates->size = new_candidates.size();
351
- candidates->sorted = false;
 
352
 
353
- if (smpl) {
354
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
355
- }
356
  }
357
 
358
- void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
359
- const int64_t t_start_sample_us = ggml_time_us();
 
360
 
361
- // no need to do anything if there is only one (or zero) candidates
362
- if(candidates->size <= 1) {
363
- return;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  }
 
365
 
366
- // Calculate maximum possible entropy
367
- float max_entropy = -logf(1.0f / candidates->size);
 
 
368
 
369
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
 
 
370
 
371
- // Calculate entropy of the softmax probabilities
372
- float entropy = 0.0f;
373
- for (size_t i = 0; i < candidates->size; ++i) {
374
- float prob = candidates->data[i].p;
375
- if (prob > 0.0f) { // Ensure no log(0)
376
- entropy -= prob * logf(prob);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  }
378
  }
 
379
 
380
- // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
381
- float normalized_entropy = entropy / max_entropy;
 
 
382
 
383
- // Map the normalized entropy to the desired temperature range using the power function
384
- float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
 
385
 
386
- #ifdef DEBUG
387
- LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
388
- LLAMA_LOG_INFO("Entropy: %f\n", entropy);
389
- LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
390
- LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
391
- LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
392
- LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
393
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- // Apply the dynamically calculated temperature scaling
396
- for (size_t i = 0; i < candidates->size; ++i) {
397
- candidates->data[i].logit /= dyn_temp;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
- // Re-compute softmax probabilities after scaling logits with dynamic temperature
401
- double max_l_double = candidates->data[0].logit;
402
- double cum_sum_double = 0.0;
403
- for (size_t i = 0; i < candidates->size; ++i) {
404
- double p = exp(candidates->data[i].logit - max_l_double);
405
- candidates->data[i].p = p; // Store the scaled probability
406
- cum_sum_double += p;
 
 
 
 
 
 
 
407
  }
408
- for (size_t i = 0; i < candidates->size; ++i) {
409
- candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  }
411
 
412
- #ifdef DEBUG
413
- // Print the updated top 25 probabilities after temperature scaling
414
- LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
415
- for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
416
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  }
418
- #endif
419
 
420
- if (smpl) {
421
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  }
 
 
 
 
 
423
  }
424
 
425
- void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
426
- const int64_t t_start_sample_us = ggml_time_us();
 
 
 
 
 
 
427
 
428
- for (size_t i = 0; i < candidates->size; ++i) {
429
- candidates->data[i].logit /= temp;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  }
431
 
432
- if (smpl) {
433
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  }
 
 
435
  }
436
 
437
- void llama_sample_repetition_penalties_impl(
438
- struct llama_sampling * smpl,
439
- llama_token_data_array * candidates,
440
- const llama_token * last_tokens,
441
- size_t penalty_last_n,
442
- float penalty_repeat,
443
- float penalty_freq,
444
- float penalty_present) {
445
- if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  return;
447
  }
448
 
449
- const int64_t t_start_sample_us = ggml_time_us();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  // Create a frequency map to count occurrences of each token in last_tokens
452
- std::unordered_map<llama_token, int> token_count;
453
- for (size_t i = 0; i < penalty_last_n; ++i) {
454
- token_count[last_tokens[i]]++;
 
 
 
455
  }
456
 
457
- // Apply frequency and presence penalties to the candidates
458
- for (size_t i = 0; i < candidates->size; ++i) {
459
- const auto token_iter = token_count.find(candidates->data[i].id);
460
  if (token_iter == token_count.end()) {
461
  continue;
462
  }
@@ -465,171 +1470,238 @@ void llama_sample_repetition_penalties_impl(
465
 
466
  // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
467
  // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
468
- if (candidates->data[i].logit <= 0) {
469
- candidates->data[i].logit *= penalty_repeat;
470
  } else {
471
- candidates->data[i].logit /= penalty_repeat;
472
  }
473
 
474
- candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
475
  }
476
 
477
- candidates->sorted = false;
478
 
479
- if (smpl) {
480
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 
481
  }
482
  }
483
 
484
- void llama_sample_apply_guidance_impl(
485
- struct llama_sampling * smpl,
486
- float * logits,
487
- float * logits_guidance,
488
- float scale) {
489
- GGML_ASSERT(smpl);
490
-
491
- const auto t_start_sample_us = ggml_time_us();
492
- const auto n_vocab = smpl->n_vocab;
493
-
494
- llama_log_softmax(logits, n_vocab);
495
- llama_log_softmax(logits_guidance, n_vocab);
496
 
497
- for (int i = 0; i < n_vocab; ++i) {
498
- auto & l = logits[i];
499
- const auto & g = logits_guidance[i];
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
- l = scale * (l - g) + g;
502
  }
503
 
504
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
505
  }
506
 
507
- llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
508
- GGML_ASSERT(smpl);
509
-
510
- const int32_t n_vocab = float(smpl->n_vocab);
511
-
512
- int64_t t_start_sample_us = ggml_time_us();
513
 
514
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
- // Estimate s_hat using the most probable m tokens
517
- float s_hat = 0.0;
518
- float sum_ti_bi = 0.0;
519
- float sum_ti_sq = 0.0;
520
- for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
521
- float t_i = logf(float(i + 2) / float(i + 1));
522
- float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
523
- sum_ti_bi += t_i * b_i;
524
- sum_ti_sq += t_i * t_i;
525
  }
526
- s_hat = sum_ti_bi / sum_ti_sq;
527
 
528
- // Compute k from the estimated s_hat and target surprise value
529
- float epsilon_hat = s_hat - 1;
530
- float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
- // Sample the next word X using top-k sampling
533
- llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
534
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
535
- llama_token X = llama_sample_token_impl(smpl, candidates);
536
- t_start_sample_us = ggml_time_us();
537
 
538
- // Compute error as the difference between observed surprise and target surprise value
539
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
540
- return candidate.id == X;
541
- }));
542
- float observed_surprise = -log2f(candidates->data[X_idx].p);
543
- float e = observed_surprise - tau;
544
 
545
- // Update mu using the learning rate and error
546
- *mu = *mu - eta * e;
547
 
548
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
549
- return X;
 
 
 
550
  }
551
 
552
- llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
553
- int64_t t_start_sample_us;
554
- t_start_sample_us = ggml_time_us();
 
 
 
555
 
556
- llama_sample_softmax_impl(smpl, candidates);
557
 
558
- // Truncate the words with surprise values greater than mu
559
- candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
560
- return -log2f(candidate.p) > *mu;
561
- }));
 
 
 
 
562
 
563
- if (candidates->size == 0) {
564
- candidates->size = 1;
565
  }
566
 
567
- if (smpl) {
568
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 
 
 
 
 
 
569
  }
 
570
 
571
- // Normalize the probabilities of the remaining words
572
- llama_sample_softmax_impl(smpl, candidates);
 
 
573
 
574
- // Sample the next word X from the remaining words
575
- llama_token X = llama_sample_token_impl(smpl, candidates);
576
- t_start_sample_us = ggml_time_us();
577
 
578
- // Compute error as the difference between observed surprise and target surprise value
579
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
580
- return candidate.id == X;
581
- }));
582
- float observed_surprise = -log2f(candidates->data[X_idx].p);
583
- float e = observed_surprise - tau;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
- // Update mu using the learning rate and error
586
- *mu = *mu - eta * e;
587
 
588
- if (smpl) {
589
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 
590
  }
591
- return X;
592
- }
593
 
594
- llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
595
- const int64_t t_start_sample_us = ggml_time_us();
 
596
 
597
- // Find max element
598
- auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
599
- return a.logit < b.logit;
600
- });
601
 
602
- llama_token result = max_iter->id;
603
- if (smpl) {
604
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
605
- smpl->n_sample++;
 
 
 
 
606
  }
607
- return result;
 
608
  }
609
 
610
- llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
611
- GGML_ASSERT(smpl);
612
 
613
- const int64_t t_start_sample_us = ggml_time_us();
614
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
615
 
616
- std::vector<float> probs;
617
- probs.reserve(candidates->size);
618
- for (size_t i = 0; i < candidates->size; ++i) {
619
- probs.push_back(candidates->data[i].p);
620
  }
621
 
622
- std::discrete_distribution<> dist(probs.begin(), probs.end());
623
- int idx = dist(rng);
624
 
625
- llama_token result = candidates->data[idx].id;
 
626
 
627
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
628
- smpl->n_sample++;
629
 
630
- return result;
 
 
 
 
631
  }
632
 
633
- llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
634
- return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
 
 
 
 
 
 
635
  }
 
1
  #include "llama-sampling.h"
2
 
3
+ #include "llama-vocab.h"
4
+ #include "llama-grammar.h"
5
+
6
  #include <algorithm>
7
+ #include <cassert>
8
+ #include <cfloat>
9
+ #include <chrono>
10
+ #include <cmath>
11
+ #include <cstdlib>
12
  #include <cstring>
13
  #include <ctime>
 
14
  #include <numeric>
15
+ #include <random>
16
  #include <unordered_map>
17
 
18
+ static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
19
+ // iterator for the probabilities
20
+ #ifdef __GNUC__
21
+ #pragma GCC diagnostic push
22
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
23
+ #endif
24
+
25
+ struct probs_iterator {
26
+ typedef std::input_iterator_tag iterator_category;
27
+ typedef float value_type;
28
+ typedef float * pointer;
29
+ typedef float & reference;
30
+ typedef ptrdiff_t difference_type;
31
+
32
+ const llama_token_data * data;
33
+
34
+ bool operator==(const probs_iterator & other) const { return data == other.data; }
35
+ bool operator!=(const probs_iterator & other) const { return data != other.data; }
36
+ const float & operator*() const { return data->p; }
37
+ probs_iterator & operator++() { ++data; return *this; }
38
+ probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
39
+ };
40
+
41
+ #ifdef __GNUC__
42
+ #pragma GCC diagnostic pop
43
+ #endif
44
+
45
+ std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
46
+
47
+ return dist(rng);
48
+ }
49
+
50
+ /*
51
  static void llama_log_softmax(float * array, size_t size) {
52
  float max_l = *std::max_element(array, array + size);
53
  float sum = 0.f;
 
61
  array[i] = logf(array[i] / sum);
62
  }
63
  }
64
+ */
65
 
66
+ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
67
+ GGML_ASSERT(cur_p->size > 0);
 
 
 
 
 
 
 
 
 
 
68
 
69
  // Sort the logits in descending order
70
+ if (!cur_p->sorted) {
71
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
72
  return a.logit > b.logit;
73
  });
74
+ cur_p->sorted = true;
75
  }
76
 
77
+ float max_l = cur_p->data[0].logit;
78
  float cum_sum = 0.0f;
79
+
80
+ for (size_t i = 0; i < cur_p->size; ++i) {
81
+ float p = expf(cur_p->data[i].logit - max_l);
82
+ cur_p->data[i].p = p;
83
  cum_sum += p;
84
  }
 
 
 
85
 
86
+ for (size_t i = 0; i < cur_p->size; ++i) {
87
+ cur_p->data[i].p /= cum_sum;
88
  }
89
  }
90
 
91
+ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
92
  // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
93
+ // if (k >= (int32_t)cur_p->size) {
94
  // return;
95
  // }
96
 
 
 
97
  if (k <= 0) {
98
+ k = cur_p->size;
99
  }
100
 
101
+ k = std::min(k, (int) cur_p->size);
 
102
 
103
  // Sort scores in descending order
104
+ if (!cur_p->sorted) {
105
  auto comp = [](const llama_token_data & a, const llama_token_data & b) {
106
  return a.logit > b.logit;
107
  };
108
  if (k <= 128) {
109
+ std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
110
  } else {
111
  constexpr int nbuckets = 128;
112
  constexpr float bucket_low = -10.0f;
 
114
  constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
115
  constexpr float bucket_inter = -bucket_low * bucket_scale;
116
 
117
+ std::vector<int> bucket_idx(cur_p->size);
118
  std::vector<int> histo(nbuckets, 0);
119
 
120
+ for (int i = 0; i < (int)cur_p->size; ++i) {
121
+ const float val = cur_p->data[i].logit;
122
  int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
123
  ib = std::max(0, std::min(nbuckets-1, ib));
124
  bucket_idx[i] = ib;
 
128
  int ib = nbuckets - 1;
129
  for ( ; ib >= 0; --ib) {
130
  nhave += histo[ib];
131
+ if (nhave >= k) {
132
+ break;
133
+ }
134
  }
135
  std::vector<llama_token_data> tmp_tokens(nhave);
136
+ auto * ptr = tmp_tokens.data();
137
  std::vector<llama_token_data*> bucket_ptrs;
138
  bucket_ptrs.reserve(nbuckets - ib);
139
  for (int j = nbuckets - 1; j >= ib; --j) {
140
  bucket_ptrs.push_back(ptr);
141
  ptr += histo[j];
142
  }
143
+ for (int i = 0; i < (int)cur_p->size; ++i) {
144
  int j = bucket_idx[i];
145
  if (j >= ib) {
146
+ *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
147
  }
148
  }
149
 
 
156
  }
157
  std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
158
 
159
+ std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
160
 
161
  }
162
+ cur_p->sorted = true;
163
  }
164
+ cur_p->size = k;
165
+ }
166
 
167
+ static uint32_t get_rng_seed(uint32_t seed) {
168
+ if (seed == LLAMA_DEFAULT_SEED) {
169
+ // use system clock if std::random_device is not a true RNG
170
+ static bool is_rd_prng = std::random_device().entropy() == 0;
171
+ if (is_rd_prng) {
172
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
173
+ }
174
+ std::random_device rd;
175
+ return rd();
176
  }
177
+ return seed;
178
  }
179
 
180
+ // llama_sampler API
181
+
182
+ const char * llama_sampler_name(const struct llama_sampler * smpl) {
183
+ if (!smpl->iface) {
184
+ return "(null)";
185
+ }
186
+
187
+ return smpl->iface->name(smpl);
188
+ }
189
+
190
+ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
191
+ if (smpl->iface->accept) {
192
+ smpl->iface->accept(smpl, token);
193
+ }
194
+ }
195
+
196
+ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
197
+ GGML_ASSERT(smpl->iface->apply);
198
+ smpl->iface->apply(smpl, cur_p);
199
+ }
200
+
201
+ void llama_sampler_reset(struct llama_sampler * smpl) {
202
+ if (smpl->iface->reset) {
203
+ smpl->iface->reset(smpl);
204
+ }
205
+ }
206
+
207
+ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
208
+ if (smpl->iface->clone) {
209
+ return smpl->iface->clone(smpl);
210
+ }
211
+
212
+ if (smpl->ctx == nullptr) {
213
+ return new llama_sampler {
214
+ /* .iface = */ smpl->iface,
215
+ /* .ctx = */ nullptr,
216
+ };
217
+ }
218
+
219
+ GGML_ABORT("the sampler does not support cloning");
220
+ }
221
+
222
+ void llama_sampler_free(struct llama_sampler * smpl) {
223
+ if (smpl == nullptr) {
224
  return;
225
  }
226
 
227
+ if (smpl->iface->free) {
228
+ smpl->iface->free(smpl);
229
+ }
230
+
231
+ delete smpl;
232
+ }
233
+
234
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
235
+ const auto * logits = llama_get_logits_ith(ctx, idx);
236
+
237
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
238
+
239
+ // TODO: do not allocate each time
240
+ std::vector<llama_token_data> cur;
241
+ cur.reserve(n_vocab);
242
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
243
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
244
+ }
245
+
246
+ llama_token_data_array cur_p = {
247
+ /* .data = */ cur.data(),
248
+ /* .size = */ cur.size(),
249
+ /* .selected = */ -1,
250
+ /* .sorted = */ false,
251
+ };
252
+
253
+ llama_sampler_apply(smpl, &cur_p);
254
+
255
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
256
+
257
+ auto token = cur_p.data[cur_p.selected].id;
258
+
259
+ llama_sampler_accept(smpl, token);
260
+
261
+ return token;
262
+ }
263
+
264
+ // sampler chain
265
+
266
+ static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
267
+ return "chain";
268
+ }
269
+
270
+ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
271
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
272
+
273
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
274
+
275
+ for (auto * smpl : chain->samplers) {
276
+ llama_sampler_accept(smpl, token);
277
+ }
278
+
279
+ chain->n_sample++;
280
+ }
281
+
282
+ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
283
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
284
+
285
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
286
+
287
+ for (auto * smpl : chain->samplers) {
288
+ llama_sampler_apply(smpl, cur_p);
289
+ }
290
+ }
291
+
292
+ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
293
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
294
+
295
+ for (auto * smpl : chain->samplers) {
296
+ llama_sampler_reset(smpl);
297
+ }
298
+
299
+ chain->t_sample_us = 0;
300
+ chain->n_sample = 0;
301
+ }
302
+
303
+ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
304
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
305
+
306
+ auto * result = llama_sampler_chain_init(chain_src->params);
307
+
308
+ for (auto * smpl : chain_src->samplers) {
309
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
310
+ }
311
+
312
+ return result;
313
+ }
314
+
315
+ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
316
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
317
+
318
+ for (auto * smpl : chain->samplers) {
319
+ llama_sampler_free(smpl);
320
+ }
321
+
322
+ delete chain;
323
+ }
324
+
325
+ static struct llama_sampler_i llama_sampler_chain_i = {
326
+ /* .name = */ llama_sampler_chain_name,
327
+ /* .accept = */ llama_sampler_chain_accept,
328
+ /* .apply = */ llama_sampler_chain_apply,
329
+ /* .reset = */ llama_sampler_chain_reset,
330
+ /* .clone = */ llama_sampler_chain_clone,
331
+ /* .free = */ llama_sampler_chain_free,
332
+ };
333
+
334
+ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
335
+ return new llama_sampler {
336
+ /* .iface = */ &llama_sampler_chain_i,
337
+ /* .ctx = */ new llama_sampler_chain {
338
+ /* .params = */ params,
339
+ /* .samplers = */ {},
340
+ /* .t_sample_us = */ 0,
341
+ /* .n_sample = */ 0,
342
+ },
343
+ };
344
+ }
345
+
346
+ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
347
+ auto * p = (llama_sampler_chain *) chain->ctx;
348
+ p->samplers.push_back(smpl);
349
+ }
350
+
351
+ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
352
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
353
+
354
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
355
+ return nullptr;
356
+ }
357
+
358
+ return p->samplers[i];
359
+ }
360
+
361
+ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
362
+ auto * p = (llama_sampler_chain *) chain->ctx;
363
+
364
+ if (i < 0 || (size_t) i >= p->samplers.size()) {
365
+ return nullptr;
366
+ }
367
+
368
+ auto * result = p->samplers[i];
369
+ p->samplers.erase(p->samplers.begin() + i);
370
+
371
+ return result;
372
+ }
373
+
374
+ int llama_sampler_chain_n(const struct llama_sampler * chain) {
375
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
376
+
377
+ return p->samplers.size();
378
+ }
379
+
380
+ //
381
+ // samplers
382
+ //
383
+
384
+ // greedy
385
+
386
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
387
+ return "greedy";
388
+ }
389
+
390
+ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
391
+ cur_p->selected = 0;
392
+ for (size_t i = 1; i < cur_p->size; ++i) {
393
+ if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
394
+ cur_p->selected = i;
395
+ }
396
+ }
397
+ }
398
+
399
+ static struct llama_sampler_i llama_sampler_greedy_i = {
400
+ /* .name = */ llama_sampler_greedy_name,
401
+ /* .accept = */ nullptr,
402
+ /* .apply = */ llama_sampler_greedy_apply,
403
+ /* .reset = */ nullptr,
404
+ /* .clone = */ nullptr,
405
+ /* .free = */ nullptr,
406
+ };
407
+
408
+ struct llama_sampler * llama_sampler_init_greedy() {
409
+ return new llama_sampler {
410
+ /* .iface = */ &llama_sampler_greedy_i,
411
+ /* .ctx = */ nullptr,
412
+ };
413
+ }
414
+
415
+ // dist
416
+
417
+ struct llama_sampler_dist {
418
+ const uint32_t seed;
419
+ uint32_t seed_cur;
420
+
421
+ std::mt19937 rng;
422
+ };
423
+
424
+ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
425
+ return "dist";
426
+ }
427
+
428
+ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
429
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
430
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
431
+ }
432
+
433
+ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
434
+ const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
435
+ auto * result = llama_sampler_init_dist(ctx->seed);
436
+
437
+ // copy the state
438
+ {
439
+ auto * result_ctx = (llama_sampler_dist *) result->ctx;
440
+
441
+ result_ctx->rng = ctx->rng;
442
+ }
443
+
444
+ return result;
445
+ }
446
+
447
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
448
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
449
+ ctx->seed_cur = get_rng_seed(ctx->seed);
450
+ ctx->rng.seed(ctx->seed_cur);
451
+ }
452
+
453
+ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
454
+ delete (llama_sampler_dist *) smpl->ctx;
455
+ }
456
+
457
+ static struct llama_sampler_i llama_sampler_dist_i = {
458
+ /* .name = */ llama_sampler_dist_name,
459
+ /* .accept = */ nullptr,
460
+ /* .apply = */ llama_sampler_dist_apply,
461
+ /* .reset = */ llama_sampler_dist_reset,
462
+ /* .clone = */ llama_sampler_dist_clone,
463
+ /* .free = */ llama_sampler_dist_free,
464
+ };
465
+
466
+ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
467
+ auto seed_cur = get_rng_seed(seed);
468
+ return new llama_sampler {
469
+ /* .iface = */ &llama_sampler_dist_i,
470
+ /* .ctx = */ new llama_sampler_dist {
471
+ /* .seed = */ seed,
472
+ /* .seed_cur = */ seed_cur,
473
+ /* .rng = */ std::mt19937(seed_cur),
474
+ },
475
+ };
476
+ }
477
+
478
+ // softmax
479
+
480
+ static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
481
+ return "softmax";
482
+ }
483
+
484
+ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
485
+ llama_sampler_softmax_impl(cur_p);
486
+ }
487
+
488
+ static struct llama_sampler_i llama_sampler_softmax_i = {
489
+ /* .name = */ llama_sampler_softmax_name,
490
+ /* .accept = */ nullptr,
491
+ /* .apply = */ llama_sampler_softmax_apply,
492
+ /* .reset = */ nullptr,
493
+ /* .clone = */ nullptr,
494
+ /* .free = */ nullptr,
495
+ };
496
+
497
+ struct llama_sampler * llama_sampler_init_softmax() {
498
+ return new llama_sampler {
499
+ /* .iface = */ &llama_sampler_softmax_i,
500
+ /* .ctx = */ nullptr,
501
+ };
502
+ }
503
+
504
+ // top-k
505
+
506
+ struct llama_sampler_top_k {
507
+ const int32_t k;
508
+ };
509
+
510
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
511
+ return "top-k";
512
+ }
513
+
514
+ static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
515
+ const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
516
+ llama_sampler_top_k_impl(cur_p, ctx->k);
517
+ }
518
+
519
+ static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
520
+ const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
521
+ return llama_sampler_init_top_k(ctx->k);
522
+ }
523
+
524
+ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
525
+ delete (llama_sampler_top_k *) smpl->ctx;
526
+ }
527
+
528
+ static struct llama_sampler_i llama_sampler_top_k_i = {
529
+ /* .name = */ llama_sampler_top_k_name,
530
+ /* .accept = */ nullptr,
531
+ /* .apply = */ llama_sampler_top_k_apply,
532
+ /* .reset = */ nullptr,
533
+ /* .clone = */ llama_sampler_top_k_clone,
534
+ /* .free = */ llama_sampler_top_k_free,
535
+ };
536
+
537
+ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
538
+ return new llama_sampler {
539
+ /* .iface = */ &llama_sampler_top_k_i,
540
+ /* .ctx = */ new llama_sampler_top_k {
541
+ /* .k = */ k,
542
+ },
543
+ };
544
+ }
545
+
546
+ // top-p
547
+
548
+ struct llama_sampler_top_p {
549
+ const float p;
550
+ const size_t min_keep;
551
+ };
552
+
553
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
554
+ return "top-p";
555
+ }
556
 
557
+ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
558
+ const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
559
+
560
+ if (ctx->p >= 1.0f) {
561
+ return;
562
+ }
563
+
564
+ llama_sampler_softmax_impl(cur_p);
565
 
566
  // Compute the cumulative probabilities
567
  float cum_sum = 0.0f;
568
+ size_t last_idx = cur_p->size;
569
 
570
+ for (size_t i = 0; i < cur_p->size; ++i) {
571
+ cum_sum += cur_p->data[i].p;
572
 
573
  // Check if the running sum is at least p or if we have kept at least min_keep tokens
574
  // we set the last index to i+1 to indicate that the current iterate should be included in the set
575
+ if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
576
  last_idx = i + 1;
577
  break;
578
  }
579
  }
580
 
581
  // Resize the output vector to keep only the top-p tokens
582
+ cur_p->size = last_idx;
583
+ }
584
 
585
+ static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
586
+ const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
587
+ return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
588
+ }
589
+
590
+ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
591
+ delete (llama_sampler_top_p *) smpl->ctx;
592
+ }
593
+
594
+ static struct llama_sampler_i llama_sampler_top_p_i = {
595
+ /* .name = */ llama_sampler_top_p_name,
596
+ /* .accept = */ nullptr,
597
+ /* .apply = */ llama_sampler_top_p_apply,
598
+ /* .reset = */ nullptr,
599
+ /* .clone = */ llama_sampler_top_p_clone,
600
+ /* .free = */ llama_sampler_top_p_free,
601
+ };
602
+
603
+ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
604
+ return new llama_sampler {
605
+ /* .iface = */ &llama_sampler_top_p_i,
606
+ /* .ctx = */ new llama_sampler_top_p {
607
+ /* .p = */ p,
608
+ /* .min_keep = */ min_keep,
609
+ },
610
+ };
611
+ }
612
+
613
+ // min-p
614
+
615
+ struct llama_sampler_min_p {
616
+ const float p;
617
+ const size_t min_keep;
618
+ };
619
+
620
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
621
+ return "min-p";
622
  }
623
 
624
+ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
625
+ const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
626
+
627
+ if (ctx->p <= 0.0f || !cur_p->size) {
628
  return;
629
  }
630
 
 
 
631
  bool min_p_applied = false;
632
 
633
+ // if the cur_p aren't sorted, try the unsorted implementation first
634
+ if (!cur_p->sorted) {
635
  std::vector<llama_token_data> filtered_tokens;
636
 
637
  float max_logit = -FLT_MAX;
638
+ for (size_t i = 0; i < cur_p->size; ++i) {
639
+ max_logit = std::max(max_logit, cur_p->data[i].logit);
640
  }
641
+ const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
642
 
643
+ for (size_t i = 0; i < cur_p->size; ++i) {
644
+ if (cur_p->data[i].logit >= min_logit) {
645
+ filtered_tokens.push_back(cur_p->data[i]);
646
  }
647
  }
648
 
649
  // if we have enough values the operation was a success
650
+ if (filtered_tokens.size() >= ctx->min_keep) {
651
+ memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
652
+ cur_p->size = filtered_tokens.size();
653
  min_p_applied = true;
654
  }
655
  }
656
 
657
+ // if the cur_p are sorted or the unsorted implementation failed, use this implementation
658
  if (!min_p_applied) {
659
  // Sort the logits in descending order
660
+ if (!cur_p->sorted) {
661
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
662
  return a.logit > b.logit;
663
  });
664
+ cur_p->sorted = true;
665
  }
666
 
667
+ const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
668
  size_t i = 1; // first token always matches
669
 
670
+ for (; i < cur_p->size; ++i) {
671
+ if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
672
  break; // prob too small
673
  }
674
  }
675
 
676
  // Resize the output vector to keep only the matching tokens
677
+ cur_p->size = i;
678
  }
679
+ }
680
 
681
+ static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
682
+ const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
683
+ return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
684
  }
685
 
686
+ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
687
+ delete (llama_sampler_min_p *) smpl->ctx;
688
+ }
689
+
690
+ static struct llama_sampler_i llama_sampler_min_p_i = {
691
+ /* .name = */ llama_sampler_min_p_name,
692
+ /* .accept = */ nullptr,
693
+ /* .apply = */ llama_sampler_min_p_apply,
694
+ /* .reset = */ nullptr,
695
+ /* .clone = */ llama_sampler_min_p_clone,
696
+ /* .free = */ llama_sampler_min_p_free,
697
+ };
698
+
699
+ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
700
+ return new llama_sampler {
701
+ /* .iface = */ &llama_sampler_min_p_i,
702
+ /* .ctx = */ new llama_sampler_min_p {
703
+ /* .p = */ p,
704
+ /* .min_keep = */ min_keep,
705
+ },
706
+ };
707
+ }
708
+
709
+ // tail-free
710
+
711
+ struct llama_sampler_tail_free {
712
+ const float z;
713
+ const size_t min_keep;
714
+ };
715
+
716
+ static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
717
+ return "tail-free";
718
+ }
719
+
720
+ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
721
+ const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
722
+
723
+ if (ctx->z >= 1.0f || cur_p->size <= 2) {
724
  return;
725
  }
726
 
727
+ llama_sampler_softmax_impl(cur_p);
 
728
 
729
  // Compute the first and second derivatives
730
+ std::vector<float> first_derivatives(cur_p->size - 1);
731
+ std::vector<float> second_derivatives(cur_p->size - 2);
732
 
733
  for (size_t i = 0; i < first_derivatives.size(); ++i) {
734
+ first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
735
  }
736
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
737
  second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
 
758
  }
759
 
760
  float cum_sum = 0.0f;
761
+ size_t last_idx = cur_p->size;
762
  for (size_t i = 0; i < second_derivatives.size(); ++i) {
763
  cum_sum += second_derivatives[i];
764
 
765
  // Check if the running sum is greater than z or if we have kept at least min_keep tokens
766
+ if (cum_sum > ctx->z && i >= ctx->min_keep) {
767
  last_idx = i;
768
  break;
769
  }
770
  }
771
 
772
  // Resize the output vector to keep only the tokens above the tail location
773
+ cur_p->size = last_idx;
774
+ }
775
 
776
+ static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
777
+ const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
778
+ return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
779
+ }
780
+
781
+ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
782
+ delete (llama_sampler_tail_free *) smpl->ctx;
783
+ }
784
+
785
+ static struct llama_sampler_i llama_sampler_tail_free_i = {
786
+ /* .name = */ llama_sampler_tail_free_name,
787
+ /* .accept = */ nullptr,
788
+ /* .apply = */ llama_sampler_tail_free_apply,
789
+ /* .reset = */ nullptr,
790
+ /* .clone = */ llama_sampler_tail_free_clone,
791
+ /* .free = */ llama_sampler_tail_free_free,
792
+ };
793
+
794
+ struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
795
+ return new llama_sampler {
796
+ /* .iface = */ &llama_sampler_tail_free_i,
797
+ /* .ctx = */ new llama_sampler_tail_free {
798
+ /* .z = */ z,
799
+ /*. min_keep = */ min_keep,
800
+ },
801
+ };
802
+ }
803
+
804
+ // typical
805
+
806
+ struct llama_sampler_typical {
807
+ const float p;
808
+ const size_t min_keep;
809
+ };
810
+
811
+ static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
812
+ return "typical";
813
  }
814
 
815
+ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
816
+ const auto * ctx = (llama_sampler_typical *) smpl->ctx;
817
+
818
  // Reference implementation:
819
  // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
820
+ if (ctx->p >= 1.0f) {
821
  return;
822
  }
823
 
824
  // Compute the softmax of logits and calculate entropy
825
+ llama_sampler_softmax_impl(cur_p);
 
 
826
 
827
  float entropy = 0.0f;
828
+ for (size_t i = 0; i < cur_p->size; ++i) {
829
+ entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
830
  }
831
 
832
  // Compute the absolute difference between negative log probability and entropy for each candidate
833
  std::vector<float> shifted_scores;
834
+ for (size_t i = 0; i < cur_p->size; ++i) {
835
+ float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
836
  shifted_scores.push_back(shifted_score);
837
  }
838
 
839
  // Sort tokens based on the shifted_scores and their corresponding indices
840
+ std::vector<size_t> indices(cur_p->size);
841
  std::iota(indices.begin(), indices.end(), 0);
842
 
843
  std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
 
850
 
851
  for (size_t i = 0; i < indices.size(); ++i) {
852
  size_t idx = indices[i];
853
+ cum_sum += cur_p->data[idx].p;
854
 
855
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
856
+ if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
857
  last_idx = i + 1;
858
  break;
859
  }
860
  }
861
 
862
  // Resize the output vector to keep only the locally typical tokens
863
+ std::vector<llama_token_data> cur_p_new;
864
  for (size_t i = 0; i < last_idx; ++i) {
865
  size_t idx = indices[i];
866
+ cur_p_new.push_back(cur_p->data[idx]);
867
  }
868
 
869
+ // Replace the data in cur_p with the cur_p_new data
870
+ std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
871
+ cur_p->size = cur_p_new.size();
872
+ cur_p->sorted = false;
873
+ }
874
 
875
+ static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
876
+ const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
877
+ return llama_sampler_init_typical(ctx->p, ctx->min_keep);
878
  }
879
 
880
+ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
881
+ delete (llama_sampler_typical *) smpl->ctx;
882
+ }
883
 
884
+ static struct llama_sampler_i llama_sampler_typical_i = {
885
+ /* .name = */ llama_sampler_typical_name,
886
+ /* .accept = */ nullptr,
887
+ /* .apply = */ llama_sampler_typical_apply,
888
+ /* .reset = */ nullptr,
889
+ /* .clone = */ llama_sampler_typical_clone,
890
+ /* .free = */ llama_sampler_typical_free,
891
+ };
892
+
893
+ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
894
+ return new llama_sampler {
895
+ /* .iface = */ &llama_sampler_typical_i,
896
+ /* .ctx = */ new llama_sampler_typical {
897
+ /* .p = */ p,
898
+ /* .min_keep = */ min_keep,
899
+ },
900
+ };
901
+ }
902
+
903
+ // temp
904
+
905
+ struct llama_sampler_temp {
906
+ const float temp;
907
+ };
908
+
909
+ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
910
+ return "temp";
911
+ }
912
+
913
+ static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
914
+ const auto * ctx = (llama_sampler_temp *) smpl->ctx;
915
+ for (size_t i = 0; i < cur_p->size; ++i) {
916
+ cur_p->data[i].logit /= ctx->temp;
917
  }
918
+ }
919
 
920
+ static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
921
+ const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
922
+ return llama_sampler_init_temp(ctx->temp);
923
+ }
924
 
925
+ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
926
+ delete (llama_sampler_temp *) smpl->ctx;
927
+ }
928
 
929
+ static struct llama_sampler_i llama_sampler_temp_i = {
930
+ /* .name = */ llama_sampler_temp_name,
931
+ /* .accept = */ nullptr,
932
+ /* .apply = */ llama_sampler_temp_apply,
933
+ /* .reset = */ nullptr,
934
+ /* .clone = */ llama_sampler_temp_clone,
935
+ /* .free = */ llama_sampler_temp_free,
936
+ };
937
+
938
+ struct llama_sampler * llama_sampler_init_temp(float temp) {
939
+ return new llama_sampler {
940
+ /* .iface = */ &llama_sampler_temp_i,
941
+ /* .ctx = */ new llama_sampler_temp {
942
+ /*.temp = */ temp,
943
+ },
944
+ };
945
+ }
946
+
947
+ // temp-ext
948
+
949
+ struct llama_sampler_temp_ext {
950
+ const float temp;
951
+ const float delta;
952
+ const float exponent;
953
+ };
954
+
955
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
956
+ return "temp-ext";
957
+ }
958
+
959
+ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
960
+ const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
961
+ if (ctx->delta > 0) {
962
+ const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
963
+ const float max_temp = ctx->temp + ctx->delta;
964
+ float exponent_val = ctx->exponent;
965
+
966
+ // no need to do anything if there is only one (or zero) candidates
967
+ if (cur_p->size <= 1) {
968
+ return;
969
+ }
970
+
971
+ // Calculate maximum possible entropy
972
+ float max_entropy = -logf(1.0f / cur_p->size);
973
+
974
+ llama_sampler_softmax_impl(cur_p);
975
+
976
+ // Calculate entropy of the softmax probabilities
977
+ float entropy = 0.0f;
978
+ for (size_t i = 0; i < cur_p->size; ++i) {
979
+ float prob = cur_p->data[i].p;
980
+ if (prob > 0.0f) { // Ensure no log(0)
981
+ entropy -= prob * logf(prob);
982
+ }
983
+ }
984
+
985
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
986
+ float normalized_entropy = entropy / max_entropy;
987
+
988
+ // Map the normalized entropy to the desired temperature range using the power function
989
+ float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
990
+
991
+ #ifdef DEBUG
992
+ LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
993
+ LLAMA_LOG_INFO("Entropy: %f\n", entropy);
994
+ LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
995
+ LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
996
+ LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
997
+ LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
998
+ #endif
999
+
1000
+ // Apply the dynamically calculated temperature scaling
1001
+ for (size_t i = 0; i < cur_p->size; ++i) {
1002
+ cur_p->data[i].logit /= dyn_temp;
1003
+ }
1004
+
1005
+ // Re-compute softmax probabilities after scaling logits with dynamic temperature
1006
+ const double max_l_double = cur_p->data[0].logit;
1007
+
1008
+ double cum_sum_double = 0.0;
1009
+ for (size_t i = 0; i < cur_p->size; ++i) {
1010
+ double p = exp(cur_p->data[i].logit - max_l_double);
1011
+ cur_p->data[i].p = p; // Store the scaled probability
1012
+ cum_sum_double += p;
1013
+ }
1014
+
1015
+ for (size_t i = 0; i < cur_p->size; ++i) {
1016
+ cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1017
+ }
1018
+
1019
+ #ifdef DEBUG
1020
+ // Print the updated top 25 probabilities after temperature scaling
1021
+ LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1022
+ for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1023
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1024
+ }
1025
+ #endif
1026
+ } else {
1027
+ for (size_t i = 0; i < cur_p->size; ++i) {
1028
+ cur_p->data[i].logit /= ctx->temp;
1029
  }
1030
  }
1031
+ }
1032
 
1033
+ static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1034
+ const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1035
+ return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1036
+ }
1037
 
1038
+ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1039
+ delete (llama_sampler_temp_ext *) smpl->ctx;
1040
+ }
1041
 
1042
+ static struct llama_sampler_i llama_sampler_temp_ext_i = {
1043
+ /* .name = */ llama_sampler_temp_ext_name,
1044
+ /* .accept = */ nullptr,
1045
+ /* .apply = */ llama_sampler_temp_ext_apply,
1046
+ /* .reset = */ nullptr,
1047
+ /* .clone = */ llama_sampler_temp_ext_clone,
1048
+ /* .free = */ llama_sampler_temp_ext_free,
1049
+ };
1050
+
1051
+ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1052
+ return new llama_sampler {
1053
+ /* .iface = */ &llama_sampler_temp_ext_i,
1054
+ /* .ctx = */ new llama_sampler_temp_ext {
1055
+ /* .temp = */ temp,
1056
+ /* .delta = */ delta,
1057
+ /* .exponent = */ exponent,
1058
+ },
1059
+ };
1060
+ }
1061
+
1062
+ // mirostat
1063
+
1064
+ struct llama_sampler_mirostat {
1065
+ const int32_t n_vocab;
1066
+
1067
+ const uint32_t seed;
1068
+ uint32_t seed_cur;
1069
+
1070
+ const float tau;
1071
+ const float eta;
1072
+
1073
+ const int32_t m;
1074
+
1075
+ float mu;
1076
 
1077
+ std::mt19937 rng;
1078
+ };
1079
+
1080
+ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1081
+ return "mirostat";
1082
+ }
1083
+
1084
+ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1085
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1086
+
1087
+ llama_sampler_softmax_impl(cur_p);
1088
+
1089
+ // Estimate s_hat using the most probable m tokens
1090
+ float s_hat = 0.0;
1091
+ float sum_ti_bi = 0.0;
1092
+ float sum_ti_sq = 0.0;
1093
+ for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
1094
+ float t_i = logf(float(i + 2) / float(i + 1));
1095
+ float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
1096
+ sum_ti_bi += t_i * b_i;
1097
+ sum_ti_sq += t_i * t_i;
1098
  }
1099
+ s_hat = sum_ti_bi / sum_ti_sq;
1100
+
1101
+ // Compute k from the estimated s_hat and target surprise value
1102
+ float epsilon_hat = s_hat - 1;
1103
+ float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
1104
+
1105
+ llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1106
+ llama_sampler_softmax_impl(cur_p);
1107
+
1108
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1109
+
1110
+ cur_p->selected = idx;
1111
+
1112
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1113
+ float e = observed_surprise - ctx->tau;
1114
 
1115
+ // Update mu using the learning rate and error
1116
+ ctx->mu = ctx->mu - ctx->eta * e;
1117
+ }
1118
+
1119
+ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1120
+ const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1121
+ auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
1122
+
1123
+ // copy the state
1124
+ {
1125
+ auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1126
+
1127
+ result_ctx->mu = ctx->mu;
1128
+ result_ctx->rng = ctx->rng;
1129
  }
1130
+
1131
+ return result;
1132
+ }
1133
+
1134
+ static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1135
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1136
+ ctx->mu = 2.0f*ctx->tau;
1137
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1138
+ ctx->rng.seed(ctx->seed_cur);
1139
+ }
1140
+
1141
+ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1142
+ delete (llama_sampler_mirostat *) smpl->ctx;
1143
+ }
1144
+
1145
+ static struct llama_sampler_i llama_sampler_mirostat_i = {
1146
+ /* .name = */ llama_sampler_mirostat_name,
1147
+ /* .accept = */ nullptr,
1148
+ /* .apply = */ llama_sampler_mirostat_apply,
1149
+ /* .reset = */ llama_sampler_mirostat_reset,
1150
+ /* .clone = */ llama_sampler_mirostat_clone,
1151
+ /* .free = */ llama_sampler_mirostat_free,
1152
+ };
1153
+
1154
+ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1155
+ auto seed_cur = get_rng_seed(seed);
1156
+ return new llama_sampler {
1157
+ /* .iface = */ &llama_sampler_mirostat_i,
1158
+ /* .ctx = */ new llama_sampler_mirostat {
1159
+ /* .n_vocab = */ n_vocab,
1160
+ /* .seed = */ seed,
1161
+ /* .seed_cur = */ seed_cur,
1162
+ /* .tau = */ tau,
1163
+ /* .eta = */ eta,
1164
+ /* .m = */ m,
1165
+ /* .mu = */ 2.0f*tau,
1166
+ /* .rng = */ std::mt19937(seed_cur),
1167
+ },
1168
+ };
1169
+ }
1170
+
1171
+ // mirostat v2
1172
+
1173
+ struct llama_sampler_mirostat_v2 {
1174
+ const uint32_t seed;
1175
+ uint32_t seed_cur;
1176
+
1177
+ const float tau;
1178
+ const float eta;
1179
+
1180
+ float mu;
1181
+
1182
+ std::mt19937 rng;
1183
+ };
1184
+
1185
+ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1186
+ return "mirostat-v2";
1187
+ }
1188
+
1189
+ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1190
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1191
+
1192
+ llama_sampler_softmax_impl(cur_p);
1193
+
1194
+ // Truncate the words with surprise values greater than mu
1195
+ cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
1196
+ return -log2f(candidate.p) > ctx->mu;
1197
+ }));
1198
+
1199
+ if (cur_p->size == 0) {
1200
+ cur_p->size = 1;
1201
  }
1202
 
1203
+ // Normalize the probabilities of the remaining words
1204
+ llama_sampler_softmax_impl(cur_p);
1205
+
1206
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
1207
+
1208
+ cur_p->selected = idx;
1209
+
1210
+ float observed_surprise = -log2f(cur_p->data[idx].p);
1211
+ float e = observed_surprise - ctx->tau;
1212
+
1213
+ // Update mu using the learning rate and error
1214
+ ctx->mu = ctx->mu - ctx->eta * e;
1215
+ }
1216
+
1217
+ static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1218
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1219
+ ctx->mu = 2.0f*ctx->tau;
1220
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1221
+ ctx->rng.seed(ctx->seed_cur);
1222
+ }
1223
+
1224
+ static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1225
+ const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1226
+
1227
+ auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
1228
+
1229
+ // copy the state
1230
+ {
1231
+ auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1232
+
1233
+ result_ctx->mu = ctx->mu;
1234
+ result_ctx->rng = ctx->rng;
1235
  }
 
1236
 
1237
+ return result;
1238
+ }
1239
+
1240
+ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1241
+ delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1242
+ }
1243
+
1244
+ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1245
+ /* .name = */ llama_sampler_mirostat_v2_name,
1246
+ /* .accept = */ nullptr,
1247
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
1248
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
1249
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
1250
+ /* .free = */ llama_sampler_mirostat_v2_free,
1251
+ };
1252
+
1253
+ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1254
+ auto seed_cur = get_rng_seed(seed);
1255
+ return new llama_sampler {
1256
+ /* .iface = */ &llama_sampler_mirostat_v2_i,
1257
+ /* .ctx = */ new llama_sampler_mirostat_v2 {
1258
+ /* .seed = */ seed,
1259
+ /* .seed_cur = */ seed_cur,
1260
+ /* .tau = */ tau,
1261
+ /* .eta = */ eta,
1262
+ /* .mu = */ 2.0f*tau,
1263
+ /* .rng = */ std::mt19937(seed_cur),
1264
+ },
1265
+ };
1266
+ }
1267
+
1268
+ // grammar
1269
+
1270
+ struct llama_sampler_grammar {
1271
+ const struct llama_vocab * vocab;
1272
+
1273
+ std::string grammar_str;
1274
+ std::string grammar_root;
1275
+
1276
+ struct llama_grammar * grammar;
1277
+ };
1278
+
1279
+ static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1280
+ return "grammar";
1281
+ }
1282
+
1283
+ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1284
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1285
+ if (ctx->grammar) {
1286
+ llama_grammar_accept_impl(*ctx->grammar, token);
1287
+ }
1288
+ }
1289
+
1290
+ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1291
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1292
+ if (ctx->grammar) {
1293
+ llama_grammar_apply_impl(*ctx->grammar, cur_p);
1294
+ }
1295
+ }
1296
+
1297
+ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1298
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1299
+ if (!ctx->grammar) {
1300
+ return;
1301
  }
1302
+
1303
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
1304
+
1305
+ llama_grammar_free_impl(ctx->grammar);
1306
+ ctx->grammar = grammar_new;
1307
  }
1308
 
1309
+ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1310
+ const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1311
+
1312
+ auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
1313
+
1314
+ // copy the state
1315
+ {
1316
+ auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1317
 
1318
+ if (ctx->grammar) {
1319
+ result_ctx->grammar_str = ctx->grammar_str;
1320
+ result_ctx->grammar_root = ctx->grammar_root;
1321
+
1322
+ result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
1323
+ }
1324
+ }
1325
+
1326
+ return result;
1327
+ }
1328
+
1329
+ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1330
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1331
+
1332
+ if (ctx->grammar) {
1333
+ llama_grammar_free_impl(ctx->grammar);
1334
+ }
1335
+
1336
+ delete ctx;
1337
+ }
1338
+
1339
+ static struct llama_sampler_i llama_sampler_grammar_i = {
1340
+ /* .name = */ llama_sampler_grammar_name,
1341
+ /* .accept = */ llama_sampler_grammar_accept_impl,
1342
+ /* .apply = */ llama_sampler_grammar_apply,
1343
+ /* .reset = */ llama_sampler_grammar_reset,
1344
+ /* .clone = */ llama_sampler_grammar_clone,
1345
+ /* .free = */ llama_sampler_grammar_free,
1346
+ };
1347
+
1348
+ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
1349
+ auto * ctx = new llama_sampler_grammar;
1350
+
1351
+ if (grammar_str != nullptr && grammar_str[0] != '\0') {
1352
+ *ctx = {
1353
+ /* .vocab = */ &vocab,
1354
+ /* .grammar_str = */ grammar_str,
1355
+ /* .grammar_root = */ grammar_root,
1356
+ /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
1357
+ };
1358
+ } else {
1359
+ *ctx = {
1360
+ /* .vocab = */ &vocab,
1361
+ /* .grammar_str = */ {},
1362
+ /* .grammar_root = */ {},
1363
+ /* .grammar = */ nullptr,
1364
+ };
1365
  }
1366
 
1367
+ return new llama_sampler {
1368
+ /* .iface = */ &llama_sampler_grammar_i,
1369
+ /* .ctx = */ ctx,
1370
+ };
1371
+ }
1372
+
1373
+ // penalties
1374
+
1375
+ struct llama_sampler_penalties {
1376
+ const int32_t n_vocab;
1377
+ const llama_token special_eos_id;
1378
+ const llama_token linefeed_id;
1379
+
1380
+ const int32_t penalty_last_n;
1381
+ const float penalty_repeat;
1382
+ const float penalty_freq;
1383
+ const float penalty_present;
1384
+
1385
+ const bool penalize_nl;
1386
+ const bool ignore_eos;
1387
+
1388
+ ring_buffer<llama_token> prev;
1389
+ };
1390
+
1391
+ static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1392
+ return "penalties";
1393
+ }
1394
+
1395
+ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1396
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1397
+ if (ctx->penalty_last_n == 0) {
1398
+ return;
1399
  }
1400
+
1401
+ ctx->prev.push_back(token);
1402
  }
1403
 
1404
+ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1405
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1406
+
1407
+ if (ctx->ignore_eos) {
1408
+ assert(ctx->special_eos_id >= 0);
1409
+
1410
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1411
+ if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
1412
+ cur_p->data[ctx->special_eos_id].logit = -INFINITY;
1413
+ } else {
1414
+ // else, search for the special EOS token
1415
+ for (size_t i = 0; i < cur_p->size; ++i) {
1416
+ if (cur_p->data[i].id == ctx->special_eos_id) {
1417
+ cur_p->data[i].logit = -INFINITY;
1418
+ break;
1419
+ }
1420
+ }
1421
+ }
1422
+ }
1423
+
1424
+ if ((ctx->penalty_last_n == 0) ||
1425
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1426
  return;
1427
  }
1428
 
1429
+ bool nl_found = false;
1430
+ size_t nl_idx = 0;
1431
+ float nl_logit = -INFINITY;
1432
+ if (!ctx->penalize_nl) {
1433
+ assert(ctx->linefeed_id >= 0);
1434
+
1435
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
1436
+ if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
1437
+ nl_found = true;
1438
+ nl_idx = ctx->linefeed_id;
1439
+ nl_logit = cur_p->data[ctx->linefeed_id].logit;
1440
+ } else {
1441
+ // else, search for the linefeed token
1442
+ for (size_t i = 0; i < cur_p->size; ++i) {
1443
+ if (cur_p->data[i].id == ctx->linefeed_id) {
1444
+ nl_found = true;
1445
+ nl_idx = i;
1446
+ nl_logit = cur_p->data[i].logit;
1447
+ break;
1448
+ }
1449
+ }
1450
+ }
1451
+ }
1452
 
1453
  // Create a frequency map to count occurrences of each token in last_tokens
1454
+ // TODO: optimize this by maintaining the token count in the sampler context
1455
+ using llama_token_cnt = std::unordered_map<llama_token, int>;
1456
+ llama_token_cnt token_count;
1457
+
1458
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1459
+ token_count[ctx->prev.rat(i)]++;
1460
  }
1461
 
1462
+ // Apply frequency and presence penalties to the cur_p
1463
+ for (size_t i = 0; i < cur_p->size; ++i) {
1464
+ const auto token_iter = token_count.find(cur_p->data[i].id);
1465
  if (token_iter == token_count.end()) {
1466
  continue;
1467
  }
 
1470
 
1471
  // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1472
  // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1473
+ if (cur_p->data[i].logit <= 0) {
1474
+ cur_p->data[i].logit *= ctx->penalty_repeat;
1475
  } else {
1476
+ cur_p->data[i].logit /= ctx->penalty_repeat;
1477
  }
1478
 
1479
+ cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
1480
  }
1481
 
1482
+ cur_p->sorted = false;
1483
 
1484
+ if (!ctx->penalize_nl && nl_found) {
1485
+ // restore the logit of the newline token if it was penalized
1486
+ cur_p->data[nl_idx].logit = nl_logit;
1487
  }
1488
  }
1489
 
1490
+ static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1491
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1492
+ ctx->prev.clear();
1493
+ }
 
 
 
 
 
 
 
 
1494
 
1495
+ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1496
+ const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1497
+ auto * result = llama_sampler_init_penalties(
1498
+ ctx->n_vocab,
1499
+ ctx->special_eos_id,
1500
+ ctx->linefeed_id,
1501
+ ctx->penalty_last_n,
1502
+ ctx->penalty_repeat,
1503
+ ctx->penalty_freq,
1504
+ ctx->penalty_present,
1505
+ ctx->penalize_nl,
1506
+ ctx->ignore_eos);
1507
+
1508
+ // copy the state
1509
+ {
1510
+ auto * result_ctx = (llama_sampler_penalties *) result->ctx;
1511
 
1512
+ result_ctx->prev = ctx->prev;
1513
  }
1514
 
1515
+ return result;
1516
  }
1517
 
1518
+ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1519
+ delete (llama_sampler_penalties *) smpl->ctx;
1520
+ }
 
 
 
1521
 
1522
+ static struct llama_sampler_i llama_sampler_penalties_i = {
1523
+ /* .name = */ llama_sampler_penalties_name,
1524
+ /* .accept = */ llama_sampler_penalties_accept,
1525
+ /* .apply = */ llama_sampler_penalties_apply,
1526
+ /* .reset = */ llama_sampler_penalties_reset,
1527
+ /* .clone = */ llama_sampler_penalties_clone,
1528
+ /* .free = */ llama_sampler_penalties_free,
1529
+ };
1530
+
1531
+ struct llama_sampler * llama_sampler_init_penalties(
1532
+ int32_t n_vocab,
1533
+ llama_token special_eos_id,
1534
+ llama_token linefeed_id,
1535
+ int32_t penalty_last_n,
1536
+ float penalty_repeat,
1537
+ float penalty_freq,
1538
+ float penalty_present,
1539
+ bool penalize_nl,
1540
+ bool ignore_eos) {
1541
+ if (linefeed_id == LLAMA_TOKEN_NULL) {
1542
+ penalize_nl = true;
1543
+ }
1544
 
1545
+ if (special_eos_id == LLAMA_TOKEN_NULL) {
1546
+ ignore_eos = false;
 
 
 
 
 
 
 
1547
  }
 
1548
 
1549
+ penalty_last_n = std::max(penalty_last_n, 0);
1550
+
1551
+ return new llama_sampler {
1552
+ /* .iface = */ &llama_sampler_penalties_i,
1553
+ /* .ctx = */ new llama_sampler_penalties {
1554
+ /* .n_vocab = */ n_vocab,
1555
+ /* .special_eos_id = */ special_eos_id,
1556
+ /* .linefeed_id = */ linefeed_id,
1557
+ /* .penalty_last_n = */ penalty_last_n,
1558
+ /* .penalty_repeat = */ penalty_repeat,
1559
+ /* .penalty_freq = */ penalty_freq,
1560
+ /* .penalty_present = */ penalty_present,
1561
+ /* .penalize_nl = */ penalize_nl,
1562
+ /* .ignore_eos = */ ignore_eos,
1563
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1564
+ },
1565
+ };
1566
+ }
1567
 
1568
+ // logit-bias
 
 
 
 
1569
 
1570
+ struct llama_sampler_logit_bias {
1571
+ const int32_t n_vocab;
 
 
 
 
1572
 
1573
+ const std::vector<llama_logit_bias> logit_bias;
 
1574
 
1575
+ std::vector<llama_logit_bias> to_search;
1576
+ };
1577
+
1578
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
1579
+ return "logit-bias";
1580
  }
1581
 
1582
+ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1583
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
1584
+
1585
+ if (ctx->logit_bias.empty()) {
1586
+ return;
1587
+ }
1588
 
1589
+ ctx->to_search.clear();
1590
 
1591
+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
1592
+ for (const auto & lb : ctx->logit_bias) {
1593
+ if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
1594
+ cur_p->data[lb.token].logit += lb.bias;
1595
+ } else {
1596
+ ctx->to_search.push_back(lb);
1597
+ }
1598
+ }
1599
 
1600
+ if (ctx->to_search.empty()) {
1601
+ return;
1602
  }
1603
 
1604
+ // search for the remaining candidates that were not found in the previous step
1605
+ for (size_t i = 0; i < cur_p->size; ++i) {
1606
+ for (const auto & lb : ctx->to_search) {
1607
+ if (cur_p->data[i].id == lb.token) {
1608
+ cur_p->data[i].logit += lb.bias;
1609
+ break;
1610
+ }
1611
+ }
1612
  }
1613
+ }
1614
 
1615
+ static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
1616
+ const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
1617
+ return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
1618
+ }
1619
 
1620
+ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
1621
+ delete (llama_sampler_logit_bias *) smpl->ctx;
1622
+ }
1623
 
1624
+ static struct llama_sampler_i llama_sampler_logit_bias_i = {
1625
+ /* .name = */ llama_sampler_logit_bias_name,
1626
+ /* .accept = */ nullptr,
1627
+ /* .apply = */ llama_sampler_logit_bias_apply,
1628
+ /* .reset = */ nullptr,
1629
+ /* .clone = */ llama_sampler_logit_bias_clone,
1630
+ /* .free = */ llama_sampler_logit_bias_free,
1631
+ };
1632
+
1633
+ struct llama_sampler * llama_sampler_init_logit_bias(
1634
+ int32_t n_vocab,
1635
+ int32_t n_logit_bias,
1636
+ const llama_logit_bias * logit_bias) {
1637
+ return new llama_sampler {
1638
+ /* .iface = */ &llama_sampler_logit_bias_i,
1639
+ /* .ctx = */ new llama_sampler_logit_bias {
1640
+ /* .n_vocab = */ n_vocab,
1641
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
1642
+ /* .to_search = */ {},
1643
+ },
1644
+ };
1645
+ }
1646
 
1647
+ // utils
 
1648
 
1649
+ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
1650
+ if (smpl->iface == &llama_sampler_dist_i) {
1651
+ return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
1652
  }
 
 
1653
 
1654
+ if (smpl->iface == &llama_sampler_mirostat_i) {
1655
+ return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
1656
+ }
1657
 
1658
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
1659
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
1660
+ }
 
1661
 
1662
+ if (smpl->iface == &llama_sampler_chain_i) {
1663
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
1664
+ for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
1665
+ const uint32_t seed = llama_sampler_get_seed(*it);
1666
+ if (seed != LLAMA_DEFAULT_SEED) {
1667
+ return seed;
1668
+ }
1669
+ }
1670
  }
1671
+
1672
+ return LLAMA_DEFAULT_SEED;
1673
  }
1674
 
1675
+ // perf
 
1676
 
1677
+ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
1678
+ struct llama_perf_sampler_data data = {};
1679
 
1680
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1681
+ GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
 
 
1682
  }
1683
 
1684
+ const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
 
1685
 
1686
+ data.t_sample_ms = 1e-3 * ctx->t_sample_us;
1687
+ data.n_sample = std::max(0, ctx->n_sample);
1688
 
1689
+ return data;
1690
+ }
1691
 
1692
+ void llama_perf_sampler_print(const struct llama_sampler * chain) {
1693
+ const auto data = llama_perf_sampler(chain);
1694
+
1695
+ LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
1696
+ __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
1697
  }
1698
 
1699
+ void llama_perf_sampler_reset(struct llama_sampler * chain) {
1700
+ if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
1701
+ GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
1702
+ }
1703
+
1704
+ auto * ctx = (struct llama_sampler_chain *) chain->ctx;
1705
+
1706
+ ctx->t_sample_us = ctx->n_sample = 0;
1707
  }
examples/talk-llama/llama-sampling.h CHANGED
@@ -1,56 +1,29 @@
1
  #pragma once
2
 
3
- #include "llama-impl.h"
4
 
5
- struct llama_sampling {
6
- llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
7
 
8
- std::mt19937 rng;
9
 
10
- int32_t n_vocab = 0;
 
11
 
12
- mutable int64_t t_sample_us = 0;
13
- mutable int32_t n_sample = 0;
14
 
15
- void reset_timings() const {
16
- t_sample_us = 0;
17
- n_sample = 0;
18
- }
19
- };
 
20
 
21
- //
22
- // internal API
23
- //
24
-
25
- void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
26
-
27
- void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
28
- void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
29
- void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
30
- void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
31
- void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
32
- void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
33
- void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
34
- void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
35
-
36
- void llama_sample_repetition_penalties_impl(
37
- struct llama_sampling * smpl,
38
- llama_token_data_array * candidates,
39
- const llama_token * last_tokens,
40
- size_t penalty_last_n,
41
- float penalty_repeat,
42
- float penalty_freq,
43
- float penalty_present);
44
-
45
- void llama_sample_apply_guidance_impl(
46
- struct llama_sampling * smpl,
47
- float * logits,
48
- float * logits_guidance,
49
- float scale);
50
-
51
- llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
52
- llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
53
- llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
54
- llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
55
- llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
56
 
 
 
 
 
 
1
  #pragma once
2
 
3
+ // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
4
 
5
+ #include "llama-grammar.h"
 
6
 
7
+ #include <unordered_map>
8
 
9
+ struct llama_vocab;
10
+ struct llama_grammar;
11
 
12
+ // sampler chain
 
13
 
14
+ struct llama_sampler_chain {
15
+ llama_sampler_chain_params params;
16
+
17
+ std::vector<struct llama_sampler *> samplers;
18
+
19
+ // timing
20
 
21
+ mutable int64_t t_sample_us;
22
+
23
+ mutable int32_t n_sample;
24
+ };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ struct llama_sampler * llama_sampler_init_grammar_impl(
27
+ const struct llama_vocab & vocab,
28
+ const char * grammar_str,
29
+ const char * grammar_root);
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -58,17 +58,17 @@ struct naive_trie {
58
  auto res = children.find(c);
59
  if (res != children.end()) {
60
  return res->second.get_longest_prefix(key, len, offset + 1);
61
- } else {
62
- return std::make_pair(key, offset);
63
  }
 
 
64
  }
65
- struct naive_trie * traverse(const char c) {
66
  auto res = children.find(c);
67
  if (res != children.end()) {
68
  return &res->second;
69
- } else {
70
- return NULL;
71
  }
 
 
72
  }
73
  std::map<char, struct naive_trie> children;
74
  bool has_value;
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
843
  // traverse the token matcher trie to find a matching token
844
  bool single_codepoint_token_found = false;
845
  const struct best_tokenization & current_best = tokenization_results[input_offset];
846
- struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
847
 
848
  while (prefix_offset <= input_len && node != NULL) {
849
  // check if we found valid token in prefix
@@ -963,7 +963,7 @@ private:
963
  /*
964
  * This structure is a view wrapper for XOR-compressed double array (XCDA)
965
  * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
966
- * Eeach bit-packed entry contains:
967
  * - BASE array value in bits 10-30
968
  * - LCHECK array value in bits 0-7
969
  * - LEAF array value in bit 9
@@ -1097,6 +1097,111 @@ private:
1097
  struct naive_trie token_matcher;
1098
  };
1099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  //
1101
  // (de-) tokenize
1102
  //
@@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1401
  output.push_back(vocab.special_eos_id);
1402
  }
1403
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1404
  case LLAMA_VOCAB_TYPE_NONE:
1405
  GGML_ABORT("fatal error");
1406
  }
@@ -1448,11 +1570,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
1448
  }
1449
 
1450
  bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1451
- return token != -1 && (
1452
- token == llama_token_eos_impl(vocab) ||
1453
- token == llama_token_eot_impl(vocab) ||
1454
- token == llama_token_eom_impl(vocab)
1455
- );
1456
  }
1457
 
1458
  bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
@@ -1616,6 +1734,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
1616
  }
1617
  break;
1618
  }
 
 
 
 
 
 
 
 
 
 
 
1619
  default:
1620
  GGML_ABORT("fatal error");
1621
  }
 
58
  auto res = children.find(c);
59
  if (res != children.end()) {
60
  return res->second.get_longest_prefix(key, len, offset + 1);
 
 
61
  }
62
+
63
+ return std::make_pair(key, offset);
64
  }
65
+ const struct naive_trie * traverse(const char c) const {
66
  auto res = children.find(c);
67
  if (res != children.end()) {
68
  return &res->second;
 
 
69
  }
70
+
71
+ return NULL;
72
  }
73
  std::map<char, struct naive_trie> children;
74
  bool has_value;
 
843
  // traverse the token matcher trie to find a matching token
844
  bool single_codepoint_token_found = false;
845
  const struct best_tokenization & current_best = tokenization_results[input_offset];
846
+ const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
847
 
848
  while (prefix_offset <= input_len && node != NULL) {
849
  // check if we found valid token in prefix
 
963
  /*
964
  * This structure is a view wrapper for XOR-compressed double array (XCDA)
965
  * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
966
+ * Each bit-packed entry contains:
967
  * - BASE array value in bits 10-30
968
  * - LCHECK array value in bits 0-7
969
  * - LEAF array value in bit 9
 
1097
  struct naive_trie token_matcher;
1098
  };
1099
 
1100
+ //
1101
+ // RWKV tokenizer
1102
+ //
1103
+
1104
+ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
1105
+ std::vector<uint8_t> output;
1106
+ output.reserve(escaped.size());
1107
+
1108
+ // Parser state
1109
+ bool escaping = false;
1110
+ uint8_t hex_remaining = 0;
1111
+ uint8_t hex_acc = 0;
1112
+
1113
+ // Step through characters, performing parsing
1114
+ for (const char & c : escaped) {
1115
+ // If we're parsing a hex code, interpret the next character
1116
+ if (hex_remaining != 0) {
1117
+ uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
1118
+ hex_acc = (hex_acc << 4) + value;
1119
+
1120
+ hex_remaining -= 1;
1121
+ if (hex_remaining == 0) {
1122
+ output.push_back(hex_acc);
1123
+ hex_acc = 0;
1124
+ }
1125
+
1126
+ continue;
1127
+ }
1128
+
1129
+ // If we got an escape character, interpret it
1130
+ if (escaping) {
1131
+ if (c == 't') {
1132
+ output.push_back('\t');
1133
+ } else if (c == 'n') {
1134
+ output.push_back('\n');
1135
+ } else if (c == 'r') {
1136
+ output.push_back('\r');
1137
+ } else if (c == 'x') {
1138
+ hex_remaining = 2;
1139
+ } else {
1140
+ output.push_back(c);
1141
+ }
1142
+
1143
+ escaping = false;
1144
+ continue;
1145
+ }
1146
+
1147
+ if (c == '\\') {
1148
+ escaping = true;
1149
+ continue;
1150
+ }
1151
+
1152
+ output.push_back(c);
1153
+ }
1154
+
1155
+ return output;
1156
+ }
1157
+
1158
+ struct llm_tokenizer_rwkv {
1159
+ llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
1160
+ // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
1161
+ // For now, we decode the vocab here into the lookup we'll use for tokenization.
1162
+
1163
+ // build trie
1164
+ for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
1165
+ const auto & token = vocab.id_to_token[id];
1166
+ const auto data = llama_unescape_rwkv_token(token.text);
1167
+ token_matcher.insert((const char *) data.data(), data.size(), id);
1168
+ }
1169
+ }
1170
+
1171
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
1172
+ uint32_t position = 0;
1173
+
1174
+ while (position < text.size()) {
1175
+ const struct naive_trie * node = token_matcher.traverse(text[position]);
1176
+ if (node == NULL) {
1177
+ // no matching token found, add unknown token
1178
+ output.push_back(vocab.special_unk_id);
1179
+ position += 1;
1180
+ continue;
1181
+ }
1182
+
1183
+ // traverse the trie to find the longest matching token
1184
+ uint32_t token_id = 0;
1185
+ uint32_t token_length = 0;
1186
+ while (node != NULL) {
1187
+ if (node->has_value) {
1188
+ token_id = node->value;
1189
+ token_length = position + 1;
1190
+ }
1191
+ node = node->traverse(text[++position]);
1192
+ }
1193
+
1194
+ // add the longest matching token
1195
+ output.push_back(token_id);
1196
+ position = token_length;
1197
+ }
1198
+ }
1199
+
1200
+ const llama_vocab & vocab;
1201
+
1202
+ struct naive_trie token_matcher;
1203
+ };
1204
+
1205
  //
1206
  // (de-) tokenize
1207
  //
 
1506
  output.push_back(vocab.special_eos_id);
1507
  }
1508
  } break;
1509
+ case LLAMA_VOCAB_TYPE_RWKV:
1510
+ {
1511
+ for (const auto & fragment : fragment_buffer) {
1512
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1513
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1514
+
1515
+ #ifdef PRETOKENIZERDEBUG
1516
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1517
+ #endif
1518
+
1519
+ llm_tokenizer_rwkv tokenizer(vocab);
1520
+ tokenizer.tokenize(raw_text, output);
1521
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1522
+ output.push_back(fragment.token);
1523
+ }
1524
+ }
1525
+ } break;
1526
  case LLAMA_VOCAB_TYPE_NONE:
1527
  GGML_ABORT("fatal error");
1528
  }
 
1570
  }
1571
 
1572
  bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1573
+ return token != -1 && vocab.special_eog_ids.count(token) > 0;
 
 
 
 
1574
  }
1575
 
1576
  bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
 
1734
  }
1735
  break;
1736
  }
1737
+ case LLAMA_VOCAB_TYPE_RWKV: {
1738
+ std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
1739
+
1740
+ // If we don't have enough space, return an error
1741
+ if (result.size() > (size_t)length) {
1742
+ return -(int)result.size();
1743
+ }
1744
+
1745
+ memcpy(buf, result.data(), result.size());
1746
+ return (int)result.size();
1747
+ }
1748
  default:
1749
  GGML_ABORT("fatal error");
1750
  }
examples/talk-llama/llama-vocab.h CHANGED
@@ -6,6 +6,7 @@
6
  #include <vector>
7
  #include <unordered_map>
8
  #include <map>
 
9
 
10
  struct llama_vocab {
11
  using id = llama_token;
@@ -18,6 +19,8 @@ struct llama_vocab {
18
  tattr attr;
19
  };
20
 
 
 
21
  enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
22
  enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
23
 
@@ -47,12 +50,15 @@ struct llama_vocab {
47
  id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
48
  id special_eom_id = -1;
49
 
 
 
 
50
  // tokenizer flags
51
- bool tokenizer_add_space_prefix = false;
52
- bool tokenizer_add_bos = false;
53
- bool tokenizer_add_eos = false;
54
- bool tokenizer_ignore_merges = false;
55
- bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
56
  bool tokenizer_remove_extra_whitespaces = false;
57
  bool tokenizer_escape_whitespaces = true;
58
  bool tokenizer_treat_whitespace_as_suffix = false;
@@ -62,8 +68,6 @@ struct llama_vocab {
62
  int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
63
  };
64
 
65
- const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
66
-
67
  //
68
  // internal API
69
  //
@@ -76,6 +80,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
76
  bool add_special,
77
  bool parse_special = false);
78
 
 
79
  llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
80
 
81
  const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
 
6
  #include <vector>
7
  #include <unordered_map>
8
  #include <map>
9
+ #include <set>
10
 
11
  struct llama_vocab {
12
  using id = llama_token;
 
19
  tattr attr;
20
  };
21
 
22
+ uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
23
+
24
  enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
25
  enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
26
 
 
50
  id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
51
  id special_eom_id = -1;
52
 
53
+ // set of all tokens that cause "end of generation"
54
+ std::set<id> special_eog_ids;
55
+
56
  // tokenizer flags
57
+ bool tokenizer_add_space_prefix = false;
58
+ bool tokenizer_add_bos = false;
59
+ bool tokenizer_add_eos = false;
60
+ bool tokenizer_ignore_merges = false;
61
+ bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
62
  bool tokenizer_remove_extra_whitespaces = false;
63
  bool tokenizer_escape_whitespaces = true;
64
  bool tokenizer_treat_whitespace_as_suffix = false;
 
68
  int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
69
  };
70
 
 
 
71
  //
72
  // internal API
73
  //
 
80
  bool add_special,
81
  bool parse_special = false);
82
 
83
+ // TODO: move the API below as member functions of llama_vocab
84
  llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
85
 
86
  const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -33,12 +33,15 @@
33
 
34
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
35
 
 
 
 
36
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
37
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
38
  #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
39
 
40
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
41
- #define LLAMA_SESSION_VERSION 8
42
 
43
  #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
44
  #define LLAMA_STATE_SEQ_VERSION 2
@@ -53,8 +56,10 @@ extern "C" {
53
  // TODO: show sample usage
54
  //
55
 
 
56
  struct llama_model;
57
  struct llama_context;
 
58
 
59
  typedef int32_t llama_pos;
60
  typedef int32_t llama_token;
@@ -66,6 +71,7 @@ extern "C" {
66
  LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
67
  LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
68
  LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
 
69
  };
70
 
71
  // pre-tokenization types
@@ -166,6 +172,8 @@ extern "C" {
166
  LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
167
  LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
168
  LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
 
 
169
 
170
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
171
  };
@@ -198,6 +206,7 @@ extern "C" {
198
  LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
199
  };
200
 
 
201
  typedef struct llama_token_data {
202
  llama_token id; // token id
203
  float logit; // log-odds of the token
@@ -205,8 +214,10 @@ extern "C" {
205
  } llama_token_data;
206
 
207
  typedef struct llama_token_data_array {
 
208
  llama_token_data * data;
209
  size_t size;
 
210
  bool sorted;
211
  } llama_token_data_array;
212
 
@@ -267,9 +278,9 @@ extern "C" {
267
  enum llama_split_mode split_mode; // how to split the model across multiple GPUs
268
 
269
  // main_gpu interpretation depends on split_mode:
270
- // LLAMA_SPLIT_NONE: the GPU that is used for the entire model
271
- // LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
272
- // LLAMA_SPLIT_LAYER: ignored
273
  int32_t main_gpu;
274
 
275
  // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
@@ -299,13 +310,12 @@ extern "C" {
299
  // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
300
  // https://github.com/ggerganov/llama.cpp/pull/7544
301
  struct llama_context_params {
302
- uint32_t seed; // RNG seed, -1 for random
303
  uint32_t n_ctx; // text context, 0 = from model
304
  uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
305
  uint32_t n_ubatch; // physical maximum batch size
306
  uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
307
- uint32_t n_threads; // number of threads to use for generation
308
- uint32_t n_threads_batch; // number of threads to use for batch processing
309
 
310
  enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
311
  enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
@@ -327,11 +337,13 @@ extern "C" {
327
  enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
328
  enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
329
 
330
- // Keep the booleans together to avoid misalignment during copy-by-value.
 
331
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
332
  bool embeddings; // if true, extract embeddings (together with logits)
333
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
334
  bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
 
335
 
336
  // Abort callback
337
  // if it returns true, execution of llama_decode() will be aborted
@@ -355,56 +367,14 @@ extern "C" {
355
  void * kv_overrides; // pointer to vector containing overrides
356
  } llama_model_quantize_params;
357
 
358
- // grammar types
359
- struct llama_grammar;
360
-
361
- // grammar element type
362
- enum llama_gretype {
363
- // end of rule definition
364
- LLAMA_GRETYPE_END = 0,
365
-
366
- // start of alternate definition for rule
367
- LLAMA_GRETYPE_ALT = 1,
368
-
369
- // non-terminal element: reference to rule
370
- LLAMA_GRETYPE_RULE_REF = 2,
371
-
372
- // terminal element: character (code point)
373
- LLAMA_GRETYPE_CHAR = 3,
374
-
375
- // inverse char(s) ([^a], [^a-b] [^abc])
376
- LLAMA_GRETYPE_CHAR_NOT = 4,
377
-
378
- // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
379
- // be an inclusive range ([a-z])
380
- LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
381
-
382
- // modifies a preceding LLAMA_GRETYPE_CHAR or
383
- // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
384
- LLAMA_GRETYPE_CHAR_ALT = 6,
385
-
386
- // any character (.)
387
- LLAMA_GRETYPE_CHAR_ANY = 7,
388
- };
389
-
390
- typedef struct llama_grammar_element {
391
- enum llama_gretype type;
392
- uint32_t value; // Unicode code point or rule ID
393
- } llama_grammar_element;
394
-
395
- // performance timing information
396
- struct llama_timings {
397
- double t_start_ms;
398
- double t_end_ms;
399
- double t_load_ms;
400
- double t_sample_ms;
401
- double t_p_eval_ms;
402
- double t_eval_ms;
403
 
404
- int32_t n_sample;
405
- int32_t n_p_eval;
406
- int32_t n_eval;
407
- };
408
 
409
  // used in chat template
410
  typedef struct llama_chat_message {
@@ -416,8 +386,10 @@ extern "C" {
416
  struct llama_lora_adapter;
417
 
418
  // Helpers for getting default parameters
419
- LLAMA_API struct llama_model_params llama_model_default_params(void);
420
- LLAMA_API struct llama_context_params llama_context_default_params(void);
 
 
421
  LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
422
 
423
  // Initialize the llama + ggml backend
@@ -428,15 +400,23 @@ extern "C" {
428
  //optional:
429
  LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
430
 
 
 
 
 
 
 
 
431
  // Call once at the end of the program - currently only used for MPI
432
  LLAMA_API void llama_backend_free(void);
433
 
434
  LLAMA_API struct llama_model * llama_load_model_from_file(
435
  const char * path_model,
436
- struct llama_model_params params);
437
 
438
  LLAMA_API void llama_free_model(struct llama_model * model);
439
 
 
440
  LLAMA_API struct llama_context * llama_new_context_with_model(
441
  struct llama_model * model,
442
  struct llama_context_params params);
@@ -452,22 +432,22 @@ extern "C" {
452
  LLAMA_API bool llama_supports_mlock (void);
453
  LLAMA_API bool llama_supports_gpu_offload(void);
454
 
455
- LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
456
-
457
  LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
458
  LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
459
  LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
460
  LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
461
 
462
- LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
463
-
464
- LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
465
- LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
466
-
467
  LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
468
  LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
469
  LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
470
  LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
 
 
 
 
 
 
 
471
 
472
  // Get the model's RoPE frequency scaling factor
473
  LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@@ -696,7 +676,7 @@ extern "C" {
696
  //
697
 
698
  // Returns the *actual* size in bytes of the state
699
- // (rng, logits, embedding and kv_cache)
700
  // Only use when saving the state, not when restoring it, otherwise the size may be too small.
701
  LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
702
  LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -837,13 +817,13 @@ extern "C" {
837
  // Set the number of threads used for decoding
838
  // n_threads is the number of threads used for generation (single token)
839
  // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
840
- LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
841
 
842
  // Get the number of threads used for generation of a single token.
843
- LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
844
 
845
  // Get the number of threads used for prompt and batch processing (multiple token).
846
- LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
847
 
848
  // Set whether the model is in embeddings mode or not
849
  // If true, embeddings will be returned but logits will not
@@ -999,121 +979,114 @@ extern "C" {
999
  int32_t length);
1000
 
1001
  //
1002
- // Grammar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1003
  //
1004
 
1005
- /// Initialize a llama_grammar.
1006
- ///
1007
- /// @param rules The rule elements of the grammar to initialize.
1008
- /// @param n_rules The number of rules.
1009
- /// @param start_rule_index The index of the root rule (the starting point of the grammar).
1010
- /// @return The initialized llama_grammar or nullptr if initialization failed.
1011
- LLAMA_API struct llama_grammar * llama_grammar_init(
1012
- const llama_grammar_element ** rules,
1013
- size_t n_rules,
1014
- size_t start_rule_index);
1015
-
1016
- LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
1017
-
1018
- LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
1019
-
1020
- /// @details Apply constraints from grammar
1021
- LLAMA_API void llama_grammar_sample(
1022
- const struct llama_grammar * grammar,
1023
- const struct llama_context * ctx,
1024
- llama_token_data_array * candidates);
1025
- LLAMA_API DEPRECATED(void llama_sample_grammar(
1026
- struct llama_context * ctx,
1027
- llama_token_data_array * candidates,
1028
- const struct llama_grammar * grammar),
1029
- "use llama_grammar_sample instead");
1030
 
1031
- /// @details Accepts the sampled token into the grammar
1032
- LLAMA_API void llama_grammar_accept_token(
1033
- struct llama_grammar * grammar,
1034
- struct llama_context * ctx,
1035
- llama_token token);
 
 
 
1036
 
1037
- //
1038
- // Sampling functions
1039
- //
 
 
 
 
 
1040
 
1041
- // Sets the current rng seed.
1042
- LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
 
 
 
 
 
 
1043
 
1044
- /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
1045
- /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
1046
- LLAMA_API void llama_sample_repetition_penalties(
1047
- struct llama_context * ctx,
1048
- llama_token_data_array * candidates,
1049
- const llama_token * last_tokens,
1050
- size_t penalty_last_n,
1051
- float penalty_repeat,
1052
- float penalty_freq,
1053
- float penalty_present);
1054
-
1055
- /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
1056
- /// @param logits Logits extracted from the original generation context.
1057
- /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
1058
- /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
1059
- LLAMA_API void llama_sample_apply_guidance(
1060
- struct llama_context * ctx,
1061
- float * logits,
1062
- float * logits_guidance,
1063
- float scale);
1064
 
1065
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1066
- LLAMA_API void llama_sample_softmax(
1067
- struct llama_context * ctx,
1068
- llama_token_data_array * candidates);
1069
 
1070
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1071
- LLAMA_API void llama_sample_top_k(
1072
- struct llama_context * ctx,
1073
- llama_token_data_array * candidates,
1074
- int32_t k,
1075
- size_t min_keep);
1076
 
1077
  /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1078
- LLAMA_API void llama_sample_top_p(
1079
- struct llama_context * ctx,
1080
- llama_token_data_array * candidates,
1081
- float p,
1082
- size_t min_keep);
1083
 
1084
  /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1085
- LLAMA_API void llama_sample_min_p(
1086
- struct llama_context * ctx,
1087
- llama_token_data_array * candidates,
1088
- float p,
1089
- size_t min_keep);
1090
 
1091
  /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
1092
- LLAMA_API void llama_sample_tail_free(
1093
- struct llama_context * ctx,
1094
- llama_token_data_array * candidates,
1095
- float z,
1096
- size_t min_keep);
1097
 
1098
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1099
- LLAMA_API void llama_sample_typical(
1100
- struct llama_context * ctx,
1101
- llama_token_data_array * candidates,
1102
- float p,
1103
- size_t min_keep);
1104
 
1105
- /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
1106
- LLAMA_API void llama_sample_entropy(
1107
- struct llama_context * ctx,
1108
- llama_token_data_array * candidates_p,
1109
- float min_temp,
1110
- float max_temp,
1111
- float exponent_val);
1112
-
1113
- LLAMA_API void llama_sample_temp(
1114
- struct llama_context * ctx,
1115
- llama_token_data_array * candidates,
1116
- float temp);
1117
 
1118
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1119
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@@ -1121,36 +1094,62 @@ extern "C" {
1121
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1122
  /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
1123
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
1124
- LLAMA_API llama_token llama_sample_token_mirostat(
1125
- struct llama_context * ctx,
1126
- llama_token_data_array * candidates,
1127
- float tau,
1128
- float eta,
1129
- int32_t m,
1130
- float * mu);
1131
 
1132
  /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1133
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1134
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
1135
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1136
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
1137
- LLAMA_API llama_token llama_sample_token_mirostat_v2(
1138
- struct llama_context * ctx,
1139
- llama_token_data_array * candidates,
1140
- float tau,
1141
- float eta,
1142
- float * mu);
1143
-
1144
- /// @details Selects the token with the highest probability.
1145
- /// Does not compute the token probabilities. Use llama_sample_softmax() instead.
1146
- LLAMA_API llama_token llama_sample_token_greedy(
1147
- struct llama_context * ctx,
1148
- llama_token_data_array * candidates);
1149
-
1150
- /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
1151
- LLAMA_API llama_token llama_sample_token(
1152
- struct llama_context * ctx,
1153
- llama_token_data_array * candidates);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1154
 
1155
  //
1156
  // Model split
@@ -1166,12 +1165,6 @@ extern "C" {
1166
  // Returns the split_prefix length.
1167
  LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
1168
 
1169
- // Performance information
1170
- LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
1171
-
1172
- LLAMA_API void llama_print_timings(struct llama_context * ctx);
1173
- LLAMA_API void llama_reset_timings(struct llama_context * ctx);
1174
-
1175
  // Print system information
1176
  LLAMA_API const char * llama_print_system_info(void);
1177
 
@@ -1179,65 +1172,41 @@ extern "C" {
1179
  // If this is not called, or NULL is supplied, everything is output on stderr.
1180
  LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
1181
 
1182
- LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
1183
-
1184
- #ifdef __cplusplus
1185
- }
1186
- #endif
1187
-
1188
- // Internal API to be implemented by llama.cpp and used by tests/benchmarks only
1189
- #ifdef LLAMA_API_INTERNAL
1190
-
1191
- #include <random>
1192
- #include <string>
1193
- #include <vector>
1194
-
1195
- struct ggml_tensor;
1196
-
1197
- const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1198
- struct llama_context * ctx
1199
- );
1200
-
1201
- struct llama_partial_utf8 {
1202
- uint32_t value; // bit value so far (unshifted)
1203
- int n_remain; // num bytes remaining; -1 indicates invalid sequence
1204
- };
1205
-
1206
- struct llama_grammar_candidate {
1207
- size_t index;
1208
- const uint32_t * code_points;
1209
- llama_partial_utf8 partial_utf8;
1210
- };
1211
 
1212
- using llama_grammar_rule = std::vector< llama_grammar_element>;
1213
- using llama_grammar_stack = std::vector<const llama_grammar_element *>;
 
 
 
1214
 
1215
- using llama_grammar_rules = std::vector<llama_grammar_rule>;
1216
- using llama_grammar_stacks = std::vector<llama_grammar_stack>;
1217
- using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
1218
 
1219
- const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
1220
- llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
1221
 
1222
- void llama_grammar_accept(
1223
- const llama_grammar_rules & rules,
1224
- const llama_grammar_stacks & stacks,
1225
- const uint32_t chr,
1226
- llama_grammar_stacks & new_stacks);
1227
 
1228
- std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
1229
- const llama_grammar_rules & rules,
1230
- const llama_grammar_stack & stack,
1231
- const llama_grammar_candidates & candidates);
1232
 
1233
- std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
1234
- const std::string & src,
1235
- llama_partial_utf8 partial_start);
 
1236
 
1237
- // Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1238
- // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
1239
- llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
1240
 
1241
- #endif // LLAMA_API_INTERNAL
 
 
1242
 
1243
  #endif // LLAMA_H
 
33
 
34
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
35
 
36
+ // TODO: use everywhere in the implementation
37
+ #define LLAMA_TOKEN_NULL -1
38
+
39
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
40
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
41
  #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
42
 
43
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
44
+ #define LLAMA_SESSION_VERSION 9
45
 
46
  #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
47
  #define LLAMA_STATE_SEQ_VERSION 2
 
56
  // TODO: show sample usage
57
  //
58
 
59
+ // struct llama_vocab; // TODO: add in the future
60
  struct llama_model;
61
  struct llama_context;
62
+ struct llama_sampler;
63
 
64
  typedef int32_t llama_pos;
65
  typedef int32_t llama_token;
 
71
  LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
72
  LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
73
  LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
74
+ LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
75
  };
76
 
77
  // pre-tokenization types
 
172
  LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
173
  LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
174
  LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
175
+ LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
176
+ LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
177
 
178
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
179
  };
 
206
  LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
207
  };
208
 
209
+ // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
210
  typedef struct llama_token_data {
211
  llama_token id; // token id
212
  float logit; // log-odds of the token
 
214
  } llama_token_data;
215
 
216
  typedef struct llama_token_data_array {
217
+ // TODO: consider SoA
218
  llama_token_data * data;
219
  size_t size;
220
+ int64_t selected; // this is the index in the data array (i.e. not the token id)
221
  bool sorted;
222
  } llama_token_data_array;
223
 
 
278
  enum llama_split_mode split_mode; // how to split the model across multiple GPUs
279
 
280
  // main_gpu interpretation depends on split_mode:
281
+ // LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model
282
+ // LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results
283
+ // LLAMA_SPLIT_MODE_LAYER: ignored
284
  int32_t main_gpu;
285
 
286
  // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
 
310
  // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
311
  // https://github.com/ggerganov/llama.cpp/pull/7544
312
  struct llama_context_params {
 
313
  uint32_t n_ctx; // text context, 0 = from model
314
  uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
315
  uint32_t n_ubatch; // physical maximum batch size
316
  uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
317
+ int32_t n_threads; // number of threads to use for generation
318
+ int32_t n_threads_batch; // number of threads to use for batch processing
319
 
320
  enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
321
  enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
 
337
  enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
338
  enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
339
 
340
+ // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
341
+ // TODO: move at the end of the struct
342
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
343
  bool embeddings; // if true, extract embeddings (together with logits)
344
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
345
  bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
346
+ bool no_perf; // whether to measure performance timings
347
 
348
  // Abort callback
349
  // if it returns true, execution of llama_decode() will be aborted
 
367
  void * kv_overrides; // pointer to vector containing overrides
368
  } llama_model_quantize_params;
369
 
370
+ typedef struct llama_logit_bias {
371
+ llama_token token;
372
+ float bias;
373
+ } llama_logit_bias;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ typedef struct llama_sampler_chain_params {
376
+ bool no_perf; // whether to measure performance timings
377
+ } llama_sampler_chain_params;
 
378
 
379
  // used in chat template
380
  typedef struct llama_chat_message {
 
386
  struct llama_lora_adapter;
387
 
388
  // Helpers for getting default parameters
389
+ // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
390
+ LLAMA_API struct llama_model_params llama_model_default_params(void);
391
+ LLAMA_API struct llama_context_params llama_context_default_params(void);
392
+ LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
393
  LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
394
 
395
  // Initialize the llama + ggml backend
 
400
  //optional:
401
  LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
402
 
403
+ // Optional: an auto threadpool gets created in ggml if not passed explicitly
404
+ LLAMA_API void llama_attach_threadpool(
405
+ struct llama_context * ctx,
406
+ ggml_threadpool_t threadpool,
407
+ ggml_threadpool_t threadpool_batch);
408
+ LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
409
+
410
  // Call once at the end of the program - currently only used for MPI
411
  LLAMA_API void llama_backend_free(void);
412
 
413
  LLAMA_API struct llama_model * llama_load_model_from_file(
414
  const char * path_model,
415
+ struct llama_model_params params);
416
 
417
  LLAMA_API void llama_free_model(struct llama_model * model);
418
 
419
+ // TODO: rename to llama_init_from_model
420
  LLAMA_API struct llama_context * llama_new_context_with_model(
421
  struct llama_model * model,
422
  struct llama_context_params params);
 
432
  LLAMA_API bool llama_supports_mlock (void);
433
  LLAMA_API bool llama_supports_gpu_offload(void);
434
 
 
 
435
  LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
436
  LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
437
  LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
438
  LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
439
 
 
 
 
 
 
440
  LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
441
  LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
442
  LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
443
  LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
444
+ LLAMA_API int32_t llama_n_head (const struct llama_model * model);
445
+
446
+ LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
447
+
448
+ LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
449
+ LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
450
+ LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
451
 
452
  // Get the model's RoPE frequency scaling factor
453
  LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
 
676
  //
677
 
678
  // Returns the *actual* size in bytes of the state
679
+ // (logits, embedding and kv_cache)
680
  // Only use when saving the state, not when restoring it, otherwise the size may be too small.
681
  LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
682
  LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
 
817
  // Set the number of threads used for decoding
818
  // n_threads is the number of threads used for generation (single token)
819
  // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
820
+ LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch);
821
 
822
  // Get the number of threads used for generation of a single token.
823
+ LLAMA_API int32_t llama_n_threads(struct llama_context * ctx);
824
 
825
  // Get the number of threads used for prompt and batch processing (multiple token).
826
+ LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
827
 
828
  // Set whether the model is in embeddings mode or not
829
  // If true, embeddings will be returned but logits will not
 
979
  int32_t length);
980
 
981
  //
982
+ // Sampling API
983
+ //
984
+ // Sample usage:
985
+ //
986
+ // // prepare the sampling chain at the start
987
+ // auto sparams = llama_sampler_chain_default_params();
988
+ //
989
+ // llama_sampler * smpl = llama_sampler_chain_init(sparams);
990
+ //
991
+ // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
992
+ // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
993
+ // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
994
+ //
995
+ // // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat"
996
+ // // this sampler will be responsible to select the actual token
997
+ // llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed));
998
+ //
999
+ // ...
1000
+ //
1001
+ // // decoding loop:
1002
+ // while (...) {
1003
+ // ...
1004
+ //
1005
+ // llama_decode(ctx, batch);
1006
+ //
1007
+ // // sample from the logits of the last token in the batch
1008
+ // const llama_token id = llama_sampler_sample(smpl, ctx, -1);
1009
+ //
1010
+ // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
1011
+ // llama_sampler_accept(smpl, id);
1012
+ // ...
1013
+ // }
1014
+ //
1015
+ // llama_sampler_free(smpl);
1016
+ //
1017
+ // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
1018
+ // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
1019
  //
1020
 
1021
+ typedef void * llama_sampler_context_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1022
 
1023
+ // user code can implement the interface below in order to create custom llama_sampler
1024
+ struct llama_sampler_i {
1025
+ const char * (*name) (const struct llama_sampler * smpl); // can be NULL
1026
+ void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL
1027
+ void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
1028
+ void (*reset) ( struct llama_sampler * smpl); // can be NULL
1029
+ struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
1030
+ void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
1031
 
1032
+ // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
1033
+ //void (*apply_ggml) (struct llama_sampler * smpl, ...);
1034
+ };
1035
+
1036
+ struct llama_sampler {
1037
+ struct llama_sampler_i * iface;
1038
+ llama_sampler_context_t ctx;
1039
+ };
1040
 
1041
+ // mirror of llama_sampler_i:
1042
+ LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
1043
+ LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
1044
+ LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
1045
+ LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
1046
+ LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
1047
+ // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
1048
+ LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
1049
 
1050
+ // llama_sampler_chain
1051
+ // a type of llama_sampler that can chain multiple samplers one after another
1052
+
1053
+ LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
1054
+
1055
+ // important: takes ownership of the sampler object and will free it when llama_sampler_free is called
1056
+ LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
1057
+ LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
1058
+ LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
1059
+
1060
+ // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
1061
+ LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
1062
+
1063
+ // available samplers:
1064
+
1065
+ LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
1066
+ LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
 
 
 
1067
 
1068
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1069
+ /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1070
+ LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
 
1071
 
1072
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1073
+ LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
 
 
 
 
1074
 
1075
  /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1076
+ LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
 
 
 
 
1077
 
1078
  /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1079
+ LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
 
 
 
 
1080
 
1081
  /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
1082
+ LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
 
 
 
 
1083
 
1084
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
1085
+ LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
1086
+ LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
 
 
 
1087
 
1088
+ /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
1089
+ LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
 
 
 
 
 
 
 
 
 
 
1090
 
1091
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1092
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
 
1094
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1095
  /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
1096
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
1097
+ LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
1098
+ int32_t n_vocab,
1099
+ uint32_t seed,
1100
+ float tau,
1101
+ float eta,
1102
+ int32_t m);
 
1103
 
1104
  /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1105
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1106
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
1107
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1108
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
1109
+ LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
1110
+ uint32_t seed,
1111
+ float tau,
1112
+ float eta);
1113
+
1114
+ LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
1115
+ const struct llama_model * model,
1116
+ const char * grammar_str,
1117
+ const char * grammar_root);
1118
+
1119
+ LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
1120
+ int32_t n_vocab, // llama_n_vocab()
1121
+ llama_token special_eos_id, // llama_token_eos()
1122
+ llama_token linefeed_id, // llama_token_nl()
1123
+ int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
1124
+ float penalty_repeat, // 1.0 = disabled
1125
+ float penalty_freq, // 0.0 = disabled
1126
+ float penalty_present, // 0.0 = disabled
1127
+ bool penalize_nl, // consider newlines as a repeatable token
1128
+ bool ignore_eos); // ignore the end-of-sequence token
1129
+
1130
+ LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
1131
+ int32_t n_vocab,
1132
+ int32_t n_logit_bias,
1133
+ const llama_logit_bias * logit_bias);
1134
+
1135
+
1136
+ // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
1137
+ LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
1138
+
1139
+ /// @details Sample and accept a token from the idx-th output of the last evaluation
1140
+ //
1141
+ // Shorthand for:
1142
+ // const auto * logits = llama_get_logits_ith(ctx, idx);
1143
+ // llama_token_data_array cur_p = { ... init from logits ... };
1144
+ // llama_sampler_apply(smpl, &cur_p);
1145
+ // auto token = cur_p.data[cur_p.selected].id;
1146
+ // llama_sampler_accept(smpl, token);
1147
+ // return token;
1148
+ // Returns the sampled token
1149
+ LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
1150
+
1151
+ // TODO: extend in the future
1152
+ //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
1153
 
1154
  //
1155
  // Model split
 
1165
  // Returns the split_prefix length.
1166
  LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
1167
 
 
 
 
 
 
 
1168
  // Print system information
1169
  LLAMA_API const char * llama_print_system_info(void);
1170
 
 
1172
  // If this is not called, or NULL is supplied, everything is output on stderr.
1173
  LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
1174
 
1175
+ //
1176
+ // Performance utils
1177
+ //
1178
+ // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
1179
+ //
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1180
 
1181
+ struct llama_perf_context_data {
1182
+ double t_start_ms;
1183
+ double t_load_ms;
1184
+ double t_p_eval_ms;
1185
+ double t_eval_ms;
1186
 
1187
+ int32_t n_p_eval;
1188
+ int32_t n_eval;
1189
+ };
1190
 
1191
+ struct llama_perf_sampler_data {
1192
+ double t_sample_ms;
1193
 
1194
+ int32_t n_sample;
1195
+ };
 
 
 
1196
 
1197
+ LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx);
1198
+ LLAMA_API void llama_perf_context_print(const struct llama_context * ctx);
1199
+ LLAMA_API void llama_perf_context_reset( struct llama_context * ctx);
 
1200
 
1201
+ // NOTE: the following work only with samplers constructed via llama_sampler_chain_init
1202
+ LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain);
1203
+ LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
1204
+ LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
1205
 
1206
+ LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
 
 
1207
 
1208
+ #ifdef __cplusplus
1209
+ }
1210
+ #endif
1211
 
1212
  #endif // LLAMA_H
examples/talk-llama/talk-llama.cpp CHANGED
@@ -314,7 +314,6 @@ int main(int argc, char ** argv) {
314
 
315
  // tune these to your liking
316
  lcparams.n_ctx = 2048;
317
- lcparams.seed = 1;
318
  lcparams.n_threads = params.n_threads;
319
  lcparams.flash_attn = params.flash_attn;
320
 
@@ -402,6 +401,26 @@ int main(int argc, char ** argv) {
402
 
403
  llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  // init session
406
  std::string path_session = params.path_session;
407
  std::vector<llama_token> session_tokens;
@@ -700,54 +719,13 @@ int main(int argc, char ** argv) {
700
 
701
  {
702
  // out of user input, sample next token
703
- const float top_k = 5;
704
- const float top_p = 0.80f;
705
- const float temp = 0.30f;
706
- const float repeat_penalty = 1.1764f;
707
-
708
- const int repeat_last_n = 256;
709
 
710
  if (!path_session.empty() && need_to_save_session) {
711
  need_to_save_session = false;
712
  llama_state_save_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
713
  }
714
 
715
- llama_token id = 0;
716
-
717
- {
718
- auto logits = llama_get_logits(ctx_llama);
719
- auto n_vocab = llama_n_vocab(model_llama);
720
-
721
- logits[llama_token_eos(model_llama)] = 0;
722
-
723
- std::vector<llama_token_data> candidates;
724
- candidates.reserve(n_vocab);
725
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
726
- candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
727
- }
728
-
729
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
730
-
731
- // apply repeat penalty
732
- const float nl_logit = logits[llama_token_nl(model_llama)];
733
-
734
- llama_sample_repetition_penalties(ctx_llama, &candidates_p,
735
- embd_inp.data() + std::max(0, n_past - repeat_last_n),
736
- repeat_last_n, repeat_penalty, 0.0, 0.0f);
737
-
738
- logits[llama_token_nl(model_llama)] = nl_logit;
739
-
740
- if (temp <= 0) {
741
- // Greedy sampling
742
- id = llama_sample_token_greedy(ctx_llama, &candidates_p);
743
- } else {
744
- // Temperature sampling
745
- llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
746
- llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
747
- llama_sample_temp (ctx_llama, &candidates_p, temp);
748
- id = llama_sample_token(ctx_llama, &candidates_p);
749
- }
750
- }
751
 
752
  if (id != llama_token_eos(model_llama)) {
753
  // add it to the context
@@ -797,8 +775,14 @@ int main(int argc, char ** argv) {
797
  whisper_print_timings(ctx_wsp);
798
  whisper_free(ctx_wsp);
799
 
800
- llama_print_timings(ctx_llama);
 
 
 
 
801
  llama_free(ctx_llama);
802
 
 
 
803
  return 0;
804
  }
 
314
 
315
  // tune these to your liking
316
  lcparams.n_ctx = 2048;
 
317
  lcparams.n_threads = params.n_threads;
318
  lcparams.flash_attn = params.flash_attn;
319
 
 
401
 
402
  llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
403
 
404
+ // init sampler
405
+ const float top_k = 5;
406
+ const float top_p = 0.80f;
407
+ const float temp = 0.30f;
408
+
409
+ const int seed = 0;
410
+
411
+ auto sparams = llama_sampler_chain_default_params();
412
+
413
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
414
+
415
+ if (temp > 0.0f) {
416
+ llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
417
+ llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
418
+ llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp));
419
+ llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
420
+ } else {
421
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
422
+ }
423
+
424
  // init session
425
  std::string path_session = params.path_session;
426
  std::vector<llama_token> session_tokens;
 
719
 
720
  {
721
  // out of user input, sample next token
 
 
 
 
 
 
722
 
723
  if (!path_session.empty() && need_to_save_session) {
724
  need_to_save_session = false;
725
  llama_state_save_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
726
  }
727
 
728
+ const llama_token id = llama_sampler_sample(smpl, ctx_llama, -1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
 
730
  if (id != llama_token_eos(model_llama)) {
731
  // add it to the context
 
775
  whisper_print_timings(ctx_wsp);
776
  whisper_free(ctx_wsp);
777
 
778
+ llama_perf_sampler_print(smpl);
779
+ llama_perf_context_print(ctx_llama);
780
+
781
+ llama_sampler_free(smpl);
782
+ llama_batch_free(batch);
783
  llama_free(ctx_llama);
784
 
785
+ llama_backend_free();
786
+
787
  return 0;
788
  }
examples/talk-llama/unicode.cpp CHANGED
@@ -5,6 +5,7 @@
5
  #include "unicode.h"
6
  #include "unicode-data.h"
7
 
 
8
  #include <cassert>
9
  #include <cstddef>
10
  #include <cstdint>
 
5
  #include "unicode.h"
6
  #include "unicode-data.h"
7
 
8
+ #include <algorithm>
9
  #include <cassert>
10
  #include <cstddef>
11
  #include <cstdint>
src/whisper.cpp CHANGED
@@ -177,7 +177,7 @@ static bool ggml_graph_compute_helper(
177
  int n_threads,
178
  ggml_abort_callback abort_callback,
179
  void * abort_callback_data) {
180
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
181
 
182
  plan.abort_callback = abort_callback;
183
  plan.abort_callback_data = abort_callback_data;
@@ -2894,7 +2894,7 @@ static bool whisper_decode_internal(
2894
  ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
2895
  }
2896
 
2897
- logits = gf->nodes[gf->n_nodes - 1];
2898
 
2899
  if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2900
  return false;
 
177
  int n_threads,
178
  ggml_abort_callback abort_callback,
179
  void * abort_callback_data) {
180
+ struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
181
 
182
  plan.abort_callback = abort_callback;
183
  plan.abort_callback_data = abort_callback_data;
 
2894
  ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
2895
  }
2896
 
2897
+ logits = ggml_graph_node(gf, -1);
2898
 
2899
  if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2900
  return false;