Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files- Makefile +6 -4
- examples/CMakeLists.txt +4 -2
- examples/talk-llama/llama-grammar.cpp +721 -122
- examples/talk-llama/llama-grammar.h +120 -15
- examples/talk-llama/llama-impl.h +135 -1
- examples/talk-llama/llama-sampling.cpp +1373 -301
- examples/talk-llama/llama-sampling.h +20 -47
- examples/talk-llama/llama-vocab.cpp +141 -12
- examples/talk-llama/llama-vocab.h +12 -7
- examples/talk-llama/llama.cpp +0 -0
- examples/talk-llama/llama.h +218 -249
- examples/talk-llama/talk-llama.cpp +28 -44
- examples/talk-llama/unicode.cpp +1 -0
- src/whisper.cpp +2 -2
Makefile
CHANGED
|
@@ -1080,10 +1080,12 @@ lsp: examples/lsp/lsp.cpp \
|
|
| 1080 |
$(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
|
| 1081 |
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
|
| 1082 |
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
$(
|
|
|
|
|
|
|
| 1087 |
|
| 1088 |
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
|
| 1089 |
$(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
|
|
|
|
| 1080 |
$(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
|
| 1081 |
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
|
| 1082 |
|
| 1083 |
+
# TODO: disabled until update
|
| 1084 |
+
# https://github.com/ggerganov/whisper.cpp/issues/1818
|
| 1085 |
+
#talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
|
| 1086 |
+
# $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
|
| 1087 |
+
# $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
|
| 1088 |
+
# $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
|
| 1089 |
|
| 1090 |
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
|
| 1091 |
$(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
|
examples/CMakeLists.txt
CHANGED
|
@@ -127,8 +127,10 @@ endif (WHISPER_SDL2)
|
|
| 127 |
add_subdirectory(quantize)
|
| 128 |
set_target_properties(quantize PROPERTIES FOLDER "examples")
|
| 129 |
if (WHISPER_SDL2)
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
add_subdirectory(talk-llama)
|
| 133 |
set_target_properties(talk-llama PROPERTIES FOLDER "examples")
|
| 134 |
add_subdirectory(lsp)
|
|
|
|
| 127 |
add_subdirectory(quantize)
|
| 128 |
set_target_properties(quantize PROPERTIES FOLDER "examples")
|
| 129 |
if (WHISPER_SDL2)
|
| 130 |
+
# TODO: disabled until update
|
| 131 |
+
# https://github.com/ggerganov/whisper.cpp/issues/1818
|
| 132 |
+
#add_subdirectory(talk)
|
| 133 |
+
#set_target_properties(talk PROPERTIES FOLDER "examples")
|
| 134 |
add_subdirectory(talk-llama)
|
| 135 |
set_target_properties(talk-llama PROPERTIES FOLDER "examples")
|
| 136 |
add_subdirectory(lsp)
|
examples/talk-llama/llama-grammar.cpp
CHANGED
|
@@ -3,11 +3,31 @@
|
|
| 3 |
#include "llama-vocab.h"
|
| 4 |
#include "llama-sampling.h"
|
| 5 |
|
|
|
|
| 6 |
#include <algorithm>
|
|
|
|
| 7 |
|
| 8 |
-
//
|
| 9 |
-
//
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
const std::string & src,
|
| 12 |
llama_partial_utf8 partial_start) {
|
| 13 |
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
|
@@ -40,7 +60,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
| 40 |
while (*pos != 0) {
|
| 41 |
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
| 42 |
uint8_t highbits = first_byte >> 4;
|
| 43 |
-
|
| 44 |
|
| 45 |
if (n_remain < 0) {
|
| 46 |
// invalid sequence, abort
|
|
@@ -50,7 +70,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
| 50 |
}
|
| 51 |
|
| 52 |
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
| 53 |
-
|
| 54 |
|
| 55 |
++pos;
|
| 56 |
while (*pos != 0 && n_remain > 0) {
|
|
@@ -67,12 +87,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
| 67 |
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
| 68 |
}
|
| 69 |
|
| 70 |
-
|
| 71 |
-
return
|
| 72 |
}
|
| 73 |
|
| 74 |
-
|
| 75 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
}
|
| 77 |
|
| 78 |
// returns true iff pos points to the end of one of the definitions of a rule
|
|
@@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
|
|
| 89 |
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
| 90 |
const llama_grammar_element * pos,
|
| 91 |
const uint32_t chr) {
|
| 92 |
-
|
| 93 |
bool found = false;
|
| 94 |
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
| 95 |
|
|
@@ -225,16 +742,93 @@ static void llama_grammar_advance_stack(
|
|
| 225 |
}
|
| 226 |
}
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
void llama_grammar_accept(
|
| 233 |
const llama_grammar_rules & rules,
|
| 234 |
const llama_grammar_stacks & stacks,
|
| 235 |
const uint32_t chr,
|
| 236 |
-
llama_grammar_stacks &
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
for (const auto & stack : stacks) {
|
| 240 |
if (stack.empty()) {
|
|
@@ -250,29 +844,11 @@ void llama_grammar_accept(
|
|
| 250 |
if (!llama_grammar_is_end_of_sequence(pos)) {
|
| 251 |
new_stack.push_back(pos);
|
| 252 |
}
|
| 253 |
-
llama_grammar_advance_stack(rules, new_stack,
|
| 254 |
}
|
| 255 |
}
|
| 256 |
}
|
| 257 |
|
| 258 |
-
static llama_grammar_candidates llama_grammar_reject_candidates(
|
| 259 |
-
const llama_grammar_rules & rules,
|
| 260 |
-
const llama_grammar_stacks & stacks,
|
| 261 |
-
const llama_grammar_candidates & candidates) {
|
| 262 |
-
GGML_ASSERT(!stacks.empty()); // REVIEW
|
| 263 |
-
|
| 264 |
-
if (candidates.empty()) {
|
| 265 |
-
return {};
|
| 266 |
-
}
|
| 267 |
-
|
| 268 |
-
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
| 269 |
-
|
| 270 |
-
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
| 271 |
-
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
| 272 |
-
}
|
| 273 |
-
return rejects;
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
| 277 |
const llama_grammar_rules & rules,
|
| 278 |
const llama_grammar_stack & stack,
|
|
@@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|
| 328 |
return rejects;
|
| 329 |
}
|
| 330 |
|
| 331 |
-
|
| 332 |
-
const llama_grammar_rules & rules,
|
| 333 |
-
size_t rule_index,
|
| 334 |
-
std::vector<bool> * rules_visited,
|
| 335 |
-
std::vector<bool> * rules_in_progress,
|
| 336 |
-
std::vector<bool> * rules_may_be_empty) {
|
| 337 |
-
if ((*rules_in_progress)[rule_index]) {
|
| 338 |
-
return true;
|
| 339 |
-
}
|
| 340 |
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
-
//
|
| 346 |
-
|
| 347 |
-
bool
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
at_rule_start = false;
|
| 357 |
}
|
| 358 |
}
|
| 359 |
|
| 360 |
-
//
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
} else {
|
| 374 |
-
|
| 375 |
}
|
| 376 |
-
}
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
| 381 |
}
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
-
struct llama_grammar * llama_grammar_init_impl(
|
| 388 |
-
const llama_grammar_element ** rules,
|
| 389 |
-
size_t n_rules,
|
| 390 |
-
size_t start_rule_index) {
|
| 391 |
const llama_grammar_element * pos;
|
| 392 |
|
| 393 |
// copy rule definitions into vectors
|
| 394 |
llama_grammar_rules vec_rules(n_rules);
|
| 395 |
for (size_t i = 0; i < n_rules; i++) {
|
| 396 |
-
for (pos =
|
| 397 |
vec_rules[i].push_back(*pos);
|
| 398 |
}
|
| 399 |
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
|
@@ -438,22 +1039,26 @@ struct llama_grammar * llama_grammar_init_impl(
|
|
| 438 |
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
| 439 |
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
| 440 |
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
| 441 |
-
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
| 442 |
}
|
| 443 |
|
| 444 |
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
delete grammar;
|
| 446 |
}
|
| 447 |
|
| 448 |
-
struct llama_grammar *
|
| 449 |
-
llama_grammar * result = new llama_grammar{ grammar
|
| 450 |
|
| 451 |
// redirect elements in stacks to point to new rules
|
| 452 |
for (size_t is = 0; is < result->stacks.size(); is++) {
|
| 453 |
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
| 454 |
-
for (size_t ir0 = 0; ir0 < grammar
|
| 455 |
-
for (size_t ir1 = 0; ir1 < grammar
|
| 456 |
-
if (grammar
|
| 457 |
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
| 458 |
}
|
| 459 |
}
|
|
@@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
|
|
| 464 |
return result;
|
| 465 |
}
|
| 466 |
|
| 467 |
-
void
|
| 468 |
-
GGML_ASSERT(grammar);
|
| 469 |
-
GGML_ASSERT(vocab);
|
| 470 |
-
|
| 471 |
-
int64_t t_start_sample_us = ggml_time_us();
|
| 472 |
|
| 473 |
bool allow_eog = false;
|
| 474 |
-
for (const auto & stack : grammar
|
| 475 |
if (stack.empty()) {
|
| 476 |
allow_eog = true;
|
| 477 |
break;
|
|
@@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
|
|
| 479 |
}
|
| 480 |
|
| 481 |
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
| 482 |
-
candidates_decoded.reserve(
|
| 483 |
|
| 484 |
llama_grammar_candidates candidates_grammar;
|
| 485 |
-
candidates_grammar.reserve(
|
| 486 |
|
| 487 |
-
for (size_t i = 0; i <
|
| 488 |
-
const llama_token id =
|
| 489 |
-
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
| 490 |
|
| 491 |
-
if (llama_token_is_eog_impl(*vocab, id)) {
|
| 492 |
if (!allow_eog) {
|
| 493 |
-
|
| 494 |
}
|
| 495 |
} else if (piece.empty() || piece[0] == 0) {
|
| 496 |
-
|
| 497 |
} else {
|
| 498 |
-
candidates_decoded.push_back(decode_utf8(piece, grammar
|
| 499 |
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
| 500 |
}
|
| 501 |
}
|
| 502 |
|
| 503 |
-
const auto rejects = llama_grammar_reject_candidates(grammar
|
| 504 |
for (const auto & reject : rejects) {
|
| 505 |
-
|
| 506 |
}
|
| 507 |
-
|
| 508 |
-
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 509 |
}
|
| 510 |
|
| 511 |
-
void
|
| 512 |
-
|
| 513 |
|
| 514 |
-
if (llama_token_is_eog_impl(*vocab, token)) {
|
| 515 |
-
for (const auto & stack : grammar
|
| 516 |
if (stack.empty()) {
|
| 517 |
return;
|
| 518 |
}
|
|
@@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|
| 520 |
GGML_ABORT("fatal error");
|
| 521 |
}
|
| 522 |
|
| 523 |
-
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
| 524 |
|
| 525 |
// Note terminating 0 in decoded string
|
| 526 |
-
const auto decoded = decode_utf8(piece, grammar
|
| 527 |
const auto & code_points = decoded.first;
|
| 528 |
|
| 529 |
-
llama_grammar_stacks
|
|
|
|
| 530 |
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
| 531 |
-
llama_grammar_accept(grammar
|
| 532 |
-
grammar
|
| 533 |
}
|
| 534 |
|
| 535 |
-
grammar
|
| 536 |
-
GGML_ASSERT(!grammar
|
| 537 |
-
|
| 538 |
-
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 539 |
}
|
|
|
|
| 3 |
#include "llama-vocab.h"
|
| 4 |
#include "llama-sampling.h"
|
| 5 |
|
| 6 |
+
#include <cmath>
|
| 7 |
#include <algorithm>
|
| 8 |
+
#include <stdexcept>
|
| 9 |
|
| 10 |
+
//
|
| 11 |
+
// helpers
|
| 12 |
+
//
|
| 13 |
+
|
| 14 |
+
// NOTE: assumes valid utf8 (but checks for overrun)
|
| 15 |
+
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
| 16 |
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
| 17 |
+
uint8_t first_byte = static_cast<uint8_t>(*src);
|
| 18 |
+
uint8_t highbits = first_byte >> 4;
|
| 19 |
+
int len = lookup[highbits];
|
| 20 |
+
uint8_t mask = (1 << (8 - len)) - 1;
|
| 21 |
+
uint32_t value = first_byte & mask;
|
| 22 |
+
const char * end = src + len; // may overrun!
|
| 23 |
+
const char * pos = src + 1;
|
| 24 |
+
for ( ; pos < end && *pos; pos++) {
|
| 25 |
+
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
| 26 |
+
}
|
| 27 |
+
return std::make_pair(value, pos);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
| 31 |
const std::string & src,
|
| 32 |
llama_partial_utf8 partial_start) {
|
| 33 |
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
|
|
|
| 60 |
while (*pos != 0) {
|
| 61 |
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
| 62 |
uint8_t highbits = first_byte >> 4;
|
| 63 |
+
n_remain = lookup[highbits] - 1;
|
| 64 |
|
| 65 |
if (n_remain < 0) {
|
| 66 |
// invalid sequence, abort
|
|
|
|
| 70 |
}
|
| 71 |
|
| 72 |
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
| 73 |
+
value = first_byte & mask;
|
| 74 |
|
| 75 |
++pos;
|
| 76 |
while (*pos != 0 && n_remain > 0) {
|
|
|
|
| 87 |
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
| 88 |
}
|
| 89 |
|
| 90 |
+
static bool is_digit_char(char c) {
|
| 91 |
+
return '0' <= c && c <= '9';
|
| 92 |
}
|
| 93 |
|
| 94 |
+
static bool is_word_char(char c) {
|
| 95 |
+
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
| 99 |
+
const char * pos = src;
|
| 100 |
+
const char * end = src + size;
|
| 101 |
+
uint32_t value = 0;
|
| 102 |
+
for ( ; pos < end && *pos; pos++) {
|
| 103 |
+
value <<= 4;
|
| 104 |
+
char c = *pos;
|
| 105 |
+
if ('a' <= c && c <= 'f') {
|
| 106 |
+
value += c - 'a' + 10;
|
| 107 |
+
} else if ('A' <= c && c <= 'F') {
|
| 108 |
+
value += c - 'A' + 10;
|
| 109 |
+
} else if ('0' <= c && c <= '9') {
|
| 110 |
+
value += c - '0';
|
| 111 |
+
} else {
|
| 112 |
+
break;
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
if (pos != end) {
|
| 116 |
+
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
| 117 |
+
}
|
| 118 |
+
return std::make_pair(value, pos);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
static const char * parse_space(const char * src, bool newline_ok) {
|
| 122 |
+
const char * pos = src;
|
| 123 |
+
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
| 124 |
+
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
| 125 |
+
if (*pos == '#') {
|
| 126 |
+
while (*pos && *pos != '\r' && *pos != '\n') {
|
| 127 |
+
pos++;
|
| 128 |
+
}
|
| 129 |
+
} else {
|
| 130 |
+
pos++;
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
return pos;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
static const char * parse_name(const char * src) {
|
| 137 |
+
const char * pos = src;
|
| 138 |
+
while (is_word_char(*pos)) {
|
| 139 |
+
pos++;
|
| 140 |
+
}
|
| 141 |
+
if (pos == src) {
|
| 142 |
+
throw std::runtime_error(std::string("expecting name at ") + src);
|
| 143 |
+
}
|
| 144 |
+
return pos;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
static const char * parse_int(const char * src) {
|
| 148 |
+
const char * pos = src;
|
| 149 |
+
while (is_digit_char(*pos)) {
|
| 150 |
+
pos++;
|
| 151 |
+
}
|
| 152 |
+
if (pos == src) {
|
| 153 |
+
throw std::runtime_error(std::string("expecting integer at ") + src);
|
| 154 |
+
}
|
| 155 |
+
return pos;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
| 159 |
+
if (*src == '\\') {
|
| 160 |
+
switch (src[1]) {
|
| 161 |
+
case 'x': return parse_hex(src + 2, 2);
|
| 162 |
+
case 'u': return parse_hex(src + 2, 4);
|
| 163 |
+
case 'U': return parse_hex(src + 2, 8);
|
| 164 |
+
case 't': return std::make_pair('\t', src + 2);
|
| 165 |
+
case 'r': return std::make_pair('\r', src + 2);
|
| 166 |
+
case 'n': return std::make_pair('\n', src + 2);
|
| 167 |
+
case '\\':
|
| 168 |
+
case '"':
|
| 169 |
+
case '[':
|
| 170 |
+
case ']':
|
| 171 |
+
return std::make_pair(src[1], src + 2);
|
| 172 |
+
default:
|
| 173 |
+
throw std::runtime_error(std::string("unknown escape at ") + src);
|
| 174 |
+
}
|
| 175 |
+
} else if (*src) {
|
| 176 |
+
return decode_utf8(src);
|
| 177 |
+
}
|
| 178 |
+
throw std::runtime_error("unexpected end of input");
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
static void print_grammar_char(FILE * file, uint32_t c) {
|
| 182 |
+
if (0x20 <= c && c <= 0x7f) {
|
| 183 |
+
fprintf(file, "%c", static_cast<char>(c));
|
| 184 |
+
} else {
|
| 185 |
+
// cop out of encoding UTF-8
|
| 186 |
+
fprintf(file, "<U+%04X>", c);
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
static bool is_char_element(llama_grammar_element elem) {
|
| 191 |
+
switch (elem.type) {
|
| 192 |
+
case LLAMA_GRETYPE_CHAR: return true;
|
| 193 |
+
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
| 194 |
+
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
| 195 |
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
| 196 |
+
case LLAMA_GRETYPE_CHAR_ANY: return true;
|
| 197 |
+
default: return false;
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
| 202 |
+
for (auto elem : rule) {
|
| 203 |
+
switch (elem.type) {
|
| 204 |
+
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
| 205 |
+
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
| 206 |
+
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
| 207 |
+
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
| 208 |
+
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
| 209 |
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
| 210 |
+
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
| 211 |
+
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
| 212 |
+
}
|
| 213 |
+
switch (elem.type) {
|
| 214 |
+
case LLAMA_GRETYPE_END:
|
| 215 |
+
case LLAMA_GRETYPE_ALT:
|
| 216 |
+
case LLAMA_GRETYPE_RULE_REF:
|
| 217 |
+
fprintf(file, "(%u) ", elem.value);
|
| 218 |
+
break;
|
| 219 |
+
case LLAMA_GRETYPE_CHAR:
|
| 220 |
+
case LLAMA_GRETYPE_CHAR_NOT:
|
| 221 |
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
| 222 |
+
case LLAMA_GRETYPE_CHAR_ALT:
|
| 223 |
+
case LLAMA_GRETYPE_CHAR_ANY:
|
| 224 |
+
fprintf(file, "(\"");
|
| 225 |
+
print_grammar_char(file, elem.value);
|
| 226 |
+
fprintf(file, "\") ");
|
| 227 |
+
break;
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
fprintf(file, "\n");
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
static void print_rule(
|
| 234 |
+
FILE * file,
|
| 235 |
+
uint32_t rule_id,
|
| 236 |
+
const llama_grammar_rule & rule,
|
| 237 |
+
const std::map<uint32_t, std::string> & symbol_id_names) {
|
| 238 |
+
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
| 239 |
+
throw std::runtime_error(
|
| 240 |
+
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
| 241 |
+
}
|
| 242 |
+
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
| 243 |
+
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
| 244 |
+
llama_grammar_element elem = rule[i];
|
| 245 |
+
switch (elem.type) {
|
| 246 |
+
case LLAMA_GRETYPE_END:
|
| 247 |
+
throw std::runtime_error(
|
| 248 |
+
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
| 249 |
+
std::to_string(i));
|
| 250 |
+
case LLAMA_GRETYPE_ALT:
|
| 251 |
+
fprintf(file, "| ");
|
| 252 |
+
break;
|
| 253 |
+
case LLAMA_GRETYPE_RULE_REF:
|
| 254 |
+
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
| 255 |
+
break;
|
| 256 |
+
case LLAMA_GRETYPE_CHAR:
|
| 257 |
+
fprintf(file, "[");
|
| 258 |
+
print_grammar_char(file, elem.value);
|
| 259 |
+
break;
|
| 260 |
+
case LLAMA_GRETYPE_CHAR_NOT:
|
| 261 |
+
fprintf(file, "[^");
|
| 262 |
+
print_grammar_char(file, elem.value);
|
| 263 |
+
break;
|
| 264 |
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
| 265 |
+
if (i == 0 || !is_char_element(rule[i - 1])) {
|
| 266 |
+
throw std::runtime_error(
|
| 267 |
+
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
| 268 |
+
std::to_string(rule_id) + "," + std::to_string(i));
|
| 269 |
+
}
|
| 270 |
+
fprintf(file, "-");
|
| 271 |
+
print_grammar_char(file, elem.value);
|
| 272 |
+
break;
|
| 273 |
+
case LLAMA_GRETYPE_CHAR_ALT:
|
| 274 |
+
if (i == 0 || !is_char_element(rule[i - 1])) {
|
| 275 |
+
throw std::runtime_error(
|
| 276 |
+
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
| 277 |
+
std::to_string(rule_id) + "," + std::to_string(i));
|
| 278 |
+
}
|
| 279 |
+
print_grammar_char(file, elem.value);
|
| 280 |
+
break;
|
| 281 |
+
case LLAMA_GRETYPE_CHAR_ANY:
|
| 282 |
+
fprintf(file, ".");
|
| 283 |
+
break;
|
| 284 |
+
}
|
| 285 |
+
if (is_char_element(elem)) {
|
| 286 |
+
switch (rule[i + 1].type) {
|
| 287 |
+
case LLAMA_GRETYPE_CHAR_ALT:
|
| 288 |
+
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
| 289 |
+
case LLAMA_GRETYPE_CHAR_ANY:
|
| 290 |
+
break;
|
| 291 |
+
default:
|
| 292 |
+
fprintf(file, "] ");
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
fprintf(file, "\n");
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
//
|
| 300 |
+
// implementation
|
| 301 |
+
//
|
| 302 |
+
|
| 303 |
+
uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
|
| 304 |
+
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
| 305 |
+
auto result = symbol_ids.emplace(std::string(src, len), next_id);
|
| 306 |
+
return result.first->second;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
|
| 310 |
+
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
| 311 |
+
symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
| 312 |
+
return next_id;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
|
| 316 |
+
if (rules.size() <= rule_id) {
|
| 317 |
+
rules.resize(rule_id + 1);
|
| 318 |
+
}
|
| 319 |
+
rules[rule_id] = rule;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
const char * llama_grammar_parser::parse_alternates(
|
| 323 |
+
const char * src,
|
| 324 |
+
const std::string & rule_name,
|
| 325 |
+
uint32_t rule_id,
|
| 326 |
+
bool is_nested) {
|
| 327 |
+
llama_grammar_rule rule;
|
| 328 |
+
const char * pos = parse_sequence(src, rule_name, rule, is_nested);
|
| 329 |
+
while (*pos == '|') {
|
| 330 |
+
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
| 331 |
+
pos = parse_space(pos + 1, true);
|
| 332 |
+
pos = parse_sequence(pos, rule_name, rule, is_nested);
|
| 333 |
+
}
|
| 334 |
+
rule.push_back({LLAMA_GRETYPE_END, 0});
|
| 335 |
+
add_rule(rule_id, rule);
|
| 336 |
+
return pos;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
const char * llama_grammar_parser::parse_sequence(
|
| 340 |
+
const char * src,
|
| 341 |
+
const std::string & rule_name,
|
| 342 |
+
llama_grammar_rule & rule,
|
| 343 |
+
bool is_nested) {
|
| 344 |
+
size_t last_sym_start = rule.size();
|
| 345 |
+
const char * pos = src;
|
| 346 |
+
|
| 347 |
+
auto handle_repetitions = [&](int min_times, int max_times) {
|
| 348 |
+
|
| 349 |
+
if (last_sym_start == rule.size()) {
|
| 350 |
+
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
// apply transformation to previous symbol (last_sym_start to end) according to
|
| 354 |
+
// the following rewrite rules:
|
| 355 |
+
// S{m,n} --> S S S (m times) S'(n-m)
|
| 356 |
+
// S'(x) ::= S S'(x-1) |
|
| 357 |
+
// (... n-m definitions of these S' rules ...)
|
| 358 |
+
// S'(1) ::= S |
|
| 359 |
+
// S{m,} --> S S S (m times) S'
|
| 360 |
+
// S' ::= S S' |
|
| 361 |
+
// S* --> S{0,}
|
| 362 |
+
// --> S' ::= S S' |
|
| 363 |
+
// S+ --> S{1,}
|
| 364 |
+
// --> S S'
|
| 365 |
+
// S' ::= S S' |
|
| 366 |
+
// S? --> S{0,1}
|
| 367 |
+
// --> S'
|
| 368 |
+
// S' ::= S |
|
| 369 |
+
|
| 370 |
+
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
| 371 |
+
if (min_times == 0) {
|
| 372 |
+
rule.resize(last_sym_start);
|
| 373 |
+
} else {
|
| 374 |
+
// Repeat the previous elements (min_times - 1) times
|
| 375 |
+
for (int i = 1; i < min_times; i++) {
|
| 376 |
+
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
uint32_t last_rec_rule_id = 0;
|
| 381 |
+
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
| 382 |
+
|
| 383 |
+
llama_grammar_rule rec_rule(prev_rule);
|
| 384 |
+
for (int i = 0; i < n_opt; i++) {
|
| 385 |
+
rec_rule.resize(prev_rule.size());
|
| 386 |
+
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
| 387 |
+
if (i > 0 || max_times < 0) {
|
| 388 |
+
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
| 389 |
+
}
|
| 390 |
+
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
| 391 |
+
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
| 392 |
+
add_rule( rec_rule_id, rec_rule);
|
| 393 |
+
last_rec_rule_id = rec_rule_id;
|
| 394 |
+
}
|
| 395 |
+
if (n_opt > 0) {
|
| 396 |
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
| 397 |
+
}
|
| 398 |
+
};
|
| 399 |
+
|
| 400 |
+
while (*pos) {
|
| 401 |
+
if (*pos == '"') { // literal string
|
| 402 |
+
pos++;
|
| 403 |
+
last_sym_start = rule.size();
|
| 404 |
+
while (*pos != '"') {
|
| 405 |
+
if (!*pos) {
|
| 406 |
+
throw std::runtime_error("unexpected end of input");
|
| 407 |
+
}
|
| 408 |
+
auto char_pair = parse_char(pos);
|
| 409 |
+
pos = char_pair.second;
|
| 410 |
+
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
| 411 |
+
}
|
| 412 |
+
pos = parse_space(pos + 1, is_nested);
|
| 413 |
+
} else if (*pos == '[') { // char range(s)
|
| 414 |
+
pos++;
|
| 415 |
+
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
| 416 |
+
if (*pos == '^') {
|
| 417 |
+
pos++;
|
| 418 |
+
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
| 419 |
+
}
|
| 420 |
+
last_sym_start = rule.size();
|
| 421 |
+
while (*pos != ']') {
|
| 422 |
+
if (!*pos) {
|
| 423 |
+
throw std::runtime_error("unexpected end of input");
|
| 424 |
+
}
|
| 425 |
+
auto char_pair = parse_char(pos);
|
| 426 |
+
pos = char_pair.second;
|
| 427 |
+
enum llama_gretype type = last_sym_start < rule.size()
|
| 428 |
+
? LLAMA_GRETYPE_CHAR_ALT
|
| 429 |
+
: start_type;
|
| 430 |
+
|
| 431 |
+
rule.push_back({type, char_pair.first});
|
| 432 |
+
if (pos[0] == '-' && pos[1] != ']') {
|
| 433 |
+
if (!pos[1]) {
|
| 434 |
+
throw std::runtime_error("unexpected end of input");
|
| 435 |
+
}
|
| 436 |
+
auto endchar_pair = parse_char(pos + 1);
|
| 437 |
+
pos = endchar_pair.second;
|
| 438 |
+
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
pos = parse_space(pos + 1, is_nested);
|
| 442 |
+
} else if (is_word_char(*pos)) { // rule reference
|
| 443 |
+
const char * name_end = parse_name(pos);
|
| 444 |
+
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
| 445 |
+
pos = parse_space(name_end, is_nested);
|
| 446 |
+
last_sym_start = rule.size();
|
| 447 |
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
| 448 |
+
} else if (*pos == '(') { // grouping
|
| 449 |
+
// parse nested alternates into synthesized rule
|
| 450 |
+
pos = parse_space(pos + 1, true);
|
| 451 |
+
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
| 452 |
+
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
| 453 |
+
last_sym_start = rule.size();
|
| 454 |
+
// output reference to synthesized rule
|
| 455 |
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
| 456 |
+
if (*pos != ')') {
|
| 457 |
+
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
| 458 |
+
}
|
| 459 |
+
pos = parse_space(pos + 1, is_nested);
|
| 460 |
+
} else if (*pos == '.') { // any char
|
| 461 |
+
last_sym_start = rule.size();
|
| 462 |
+
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
| 463 |
+
pos = parse_space(pos + 1, is_nested);
|
| 464 |
+
} else if (*pos == '*') {
|
| 465 |
+
pos = parse_space(pos + 1, is_nested);
|
| 466 |
+
handle_repetitions(0, -1);
|
| 467 |
+
} else if (*pos == '+') {
|
| 468 |
+
pos = parse_space(pos + 1, is_nested);
|
| 469 |
+
handle_repetitions(1, -1);
|
| 470 |
+
} else if (*pos == '?') {
|
| 471 |
+
pos = parse_space(pos + 1, is_nested);
|
| 472 |
+
handle_repetitions(0, 1);
|
| 473 |
+
} else if (*pos == '{') {
|
| 474 |
+
pos = parse_space(pos + 1, is_nested);
|
| 475 |
+
|
| 476 |
+
if (!is_digit_char(*pos)) {
|
| 477 |
+
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
| 478 |
+
}
|
| 479 |
+
const char * int_end = parse_int(pos);
|
| 480 |
+
int min_times = std::stoul(std::string(pos, int_end - pos));
|
| 481 |
+
pos = parse_space(int_end, is_nested);
|
| 482 |
+
|
| 483 |
+
int max_times = -1;
|
| 484 |
+
|
| 485 |
+
if (*pos == '}') {
|
| 486 |
+
max_times = min_times;
|
| 487 |
+
pos = parse_space(pos + 1, is_nested);
|
| 488 |
+
} else if (*pos == ',') {
|
| 489 |
+
pos = parse_space(pos + 1, is_nested);
|
| 490 |
+
|
| 491 |
+
if (is_digit_char(*pos)) {
|
| 492 |
+
const char * int_end = parse_int(pos);
|
| 493 |
+
max_times = std::stoul(std::string(pos, int_end - pos));
|
| 494 |
+
pos = parse_space(int_end, is_nested);
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
if (*pos != '}') {
|
| 498 |
+
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
| 499 |
+
}
|
| 500 |
+
pos = parse_space(pos + 1, is_nested);
|
| 501 |
+
} else {
|
| 502 |
+
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
| 503 |
+
}
|
| 504 |
+
handle_repetitions(min_times, max_times);
|
| 505 |
+
} else {
|
| 506 |
+
break;
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
return pos;
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
const char * llama_grammar_parser::parse_rule(const char * src) {
|
| 513 |
+
const char * name_end = parse_name(src);
|
| 514 |
+
const char * pos = parse_space(name_end, false);
|
| 515 |
+
size_t name_len = name_end - src;
|
| 516 |
+
uint32_t rule_id = get_symbol_id(src, name_len);
|
| 517 |
+
const std::string name(src, name_len);
|
| 518 |
+
|
| 519 |
+
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
| 520 |
+
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
| 521 |
+
}
|
| 522 |
+
pos = parse_space(pos + 3, true);
|
| 523 |
+
|
| 524 |
+
pos = parse_alternates(pos, name, rule_id, false);
|
| 525 |
+
|
| 526 |
+
if (*pos == '\r') {
|
| 527 |
+
pos += pos[1] == '\n' ? 2 : 1;
|
| 528 |
+
} else if (*pos == '\n') {
|
| 529 |
+
pos++;
|
| 530 |
+
} else if (*pos) {
|
| 531 |
+
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
| 532 |
+
}
|
| 533 |
+
return parse_space(pos, true);
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
bool llama_grammar_parser::parse(const char * src) {
|
| 537 |
+
try {
|
| 538 |
+
const char * pos = parse_space(src, true);
|
| 539 |
+
while (*pos) {
|
| 540 |
+
pos = parse_rule(pos);
|
| 541 |
+
}
|
| 542 |
+
// Validate the state to ensure that all rules are defined
|
| 543 |
+
for (const auto & rule : rules) {
|
| 544 |
+
if (rule.empty()) {
|
| 545 |
+
throw std::runtime_error("Undefined rule");
|
| 546 |
+
}
|
| 547 |
+
for (const auto & elem : rule) {
|
| 548 |
+
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
| 549 |
+
// Ensure that the rule at that location exists
|
| 550 |
+
if (elem.value >= rules.size() || rules[elem.value].empty()) {
|
| 551 |
+
// Get the name of the rule that is missing
|
| 552 |
+
for (const auto & kv : symbol_ids) {
|
| 553 |
+
if (kv.second == elem.value) {
|
| 554 |
+
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
}
|
| 558 |
+
}
|
| 559 |
+
}
|
| 560 |
+
}
|
| 561 |
+
} catch (const std::exception & err) {
|
| 562 |
+
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
| 563 |
+
rules.clear();
|
| 564 |
+
return false;
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
return true;
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
void llama_grammar_parser::print(FILE * file) {
|
| 571 |
+
try {
|
| 572 |
+
std::map<uint32_t, std::string> symbol_id_names;
|
| 573 |
+
for (const auto & kv : symbol_ids) {
|
| 574 |
+
symbol_id_names[kv.second] = kv.first;
|
| 575 |
+
}
|
| 576 |
+
for (size_t i = 0, end = rules.size(); i < end; i++) {
|
| 577 |
+
// fprintf(file, "%zu: ", i);
|
| 578 |
+
// print_rule_binary(file, rules[i]);
|
| 579 |
+
print_rule(file, uint32_t(i), rules[i], symbol_id_names);
|
| 580 |
+
// fprintf(file, "\n");
|
| 581 |
+
}
|
| 582 |
+
} catch (const std::exception & err) {
|
| 583 |
+
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
| 584 |
+
}
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
llama_grammar_stack llama_grammar_parser::c_rules() const {
|
| 588 |
+
llama_grammar_stack ret;
|
| 589 |
+
ret.reserve(rules.size());
|
| 590 |
+
for (const auto & rule : rules) {
|
| 591 |
+
ret.push_back(rule.data());
|
| 592 |
+
}
|
| 593 |
+
return ret;
|
| 594 |
}
|
| 595 |
|
| 596 |
// returns true iff pos points to the end of one of the definitions of a rule
|
|
|
|
| 607 |
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
| 608 |
const llama_grammar_element * pos,
|
| 609 |
const uint32_t chr) {
|
|
|
|
| 610 |
bool found = false;
|
| 611 |
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
| 612 |
|
|
|
|
| 742 |
}
|
| 743 |
}
|
| 744 |
|
| 745 |
+
static llama_grammar_candidates llama_grammar_reject_candidates(
|
| 746 |
+
const llama_grammar_rules & rules,
|
| 747 |
+
const llama_grammar_stacks & stacks,
|
| 748 |
+
const llama_grammar_candidates & candidates) {
|
| 749 |
+
GGML_ASSERT(!stacks.empty()); // REVIEW
|
| 750 |
+
|
| 751 |
+
if (candidates.empty()) {
|
| 752 |
+
return {};
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
| 756 |
+
|
| 757 |
+
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
| 758 |
+
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
return rejects;
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
static bool llama_grammar_detect_left_recursion(
|
| 765 |
+
const llama_grammar_rules & rules,
|
| 766 |
+
size_t rule_index,
|
| 767 |
+
std::vector<bool> * rules_visited,
|
| 768 |
+
std::vector<bool> * rules_in_progress,
|
| 769 |
+
std::vector<bool> * rules_may_be_empty) {
|
| 770 |
+
if ((*rules_in_progress)[rule_index]) {
|
| 771 |
+
return true;
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
(*rules_in_progress)[rule_index] = true;
|
| 775 |
+
|
| 776 |
+
const llama_grammar_rule & rule = rules[rule_index];
|
| 777 |
+
|
| 778 |
+
// First check if the rule might produce the empty string. This could be done combined with the second
|
| 779 |
+
// step but it's more readable as two steps.
|
| 780 |
+
bool at_rule_start = true;
|
| 781 |
+
for (size_t i = 0; i < rule.size(); i++) {
|
| 782 |
+
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
| 783 |
+
if (at_rule_start) {
|
| 784 |
+
(*rules_may_be_empty)[rule_index] = true;
|
| 785 |
+
break;
|
| 786 |
+
}
|
| 787 |
+
at_rule_start = true;
|
| 788 |
+
} else {
|
| 789 |
+
at_rule_start = false;
|
| 790 |
+
}
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
| 794 |
+
// be empty)
|
| 795 |
+
bool recurse_into_nonterminal = true;
|
| 796 |
+
for (size_t i = 0; i < rule.size(); i++) {
|
| 797 |
+
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
| 798 |
+
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
| 799 |
+
return true;
|
| 800 |
+
}
|
| 801 |
+
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
| 802 |
+
recurse_into_nonterminal = false;
|
| 803 |
+
}
|
| 804 |
+
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
| 805 |
+
recurse_into_nonterminal = true;
|
| 806 |
+
} else {
|
| 807 |
+
recurse_into_nonterminal = false;
|
| 808 |
+
}
|
| 809 |
+
}
|
| 810 |
+
|
| 811 |
+
(*rules_in_progress)[rule_index] = false;
|
| 812 |
+
(*rules_visited)[rule_index] = true;
|
| 813 |
+
|
| 814 |
+
return false;
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
| 818 |
+
return grammar->rules;
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
| 822 |
+
return grammar->stacks;
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
void llama_grammar_accept(
|
| 826 |
const llama_grammar_rules & rules,
|
| 827 |
const llama_grammar_stacks & stacks,
|
| 828 |
const uint32_t chr,
|
| 829 |
+
llama_grammar_stacks & stacks_new) {
|
| 830 |
+
stacks_new.clear();
|
| 831 |
+
stacks_new.reserve(stacks.size());
|
| 832 |
|
| 833 |
for (const auto & stack : stacks) {
|
| 834 |
if (stack.empty()) {
|
|
|
|
| 844 |
if (!llama_grammar_is_end_of_sequence(pos)) {
|
| 845 |
new_stack.push_back(pos);
|
| 846 |
}
|
| 847 |
+
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
| 848 |
}
|
| 849 |
}
|
| 850 |
}
|
| 851 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
| 853 |
const llama_grammar_rules & rules,
|
| 854 |
const llama_grammar_stack & stack,
|
|
|
|
| 904 |
return rejects;
|
| 905 |
}
|
| 906 |
|
| 907 |
+
////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
+
struct llama_grammar * llama_grammar_init_impl(
|
| 910 |
+
const struct llama_vocab * vocab,
|
| 911 |
+
const llama_grammar_element ** rules,
|
| 912 |
+
size_t n_rules,
|
| 913 |
+
size_t start_rule_index) {
|
| 914 |
+
const llama_grammar_element * pos;
|
| 915 |
|
| 916 |
+
// copy rule definitions into vectors
|
| 917 |
+
llama_grammar_rules vec_rules(n_rules);
|
| 918 |
+
for (size_t i = 0; i < n_rules; i++) {
|
| 919 |
+
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
| 920 |
+
vec_rules[i].push_back(*pos);
|
| 921 |
+
}
|
| 922 |
+
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
| 923 |
+
}
|
| 924 |
|
| 925 |
+
// Check for left recursion
|
| 926 |
+
std::vector<bool> rules_visited(n_rules);
|
| 927 |
+
std::vector<bool> rules_in_progress(n_rules);
|
| 928 |
+
std::vector<bool> rules_may_be_empty(n_rules);
|
| 929 |
+
for (size_t i = 0; i < n_rules; i++) {
|
| 930 |
+
if (rules_visited[i]) {
|
| 931 |
+
continue;
|
| 932 |
+
}
|
| 933 |
+
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
| 934 |
+
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
|
| 935 |
+
return nullptr;
|
|
|
|
| 936 |
}
|
| 937 |
}
|
| 938 |
|
| 939 |
+
// loop over alternates of start rule to build initial stacks
|
| 940 |
+
llama_grammar_stacks stacks;
|
| 941 |
+
pos = vec_rules[start_rule_index].data();
|
| 942 |
+
do {
|
| 943 |
+
llama_grammar_stack stack;
|
| 944 |
+
if (!llama_grammar_is_end_of_sequence(pos)) {
|
| 945 |
+
// if alternate is nonempty, add to stack
|
| 946 |
+
stack.push_back(pos);
|
| 947 |
+
}
|
| 948 |
+
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
| 949 |
+
while (!llama_grammar_is_end_of_sequence(pos)) {
|
| 950 |
+
// scan to end of alternate def
|
| 951 |
+
pos++;
|
| 952 |
+
}
|
| 953 |
+
if (pos->type == LLAMA_GRETYPE_ALT) {
|
| 954 |
+
// there's another alternate def of this rule to process
|
| 955 |
+
pos++;
|
| 956 |
} else {
|
| 957 |
+
break;
|
| 958 |
}
|
| 959 |
+
} while (true);
|
| 960 |
|
| 961 |
+
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
| 962 |
+
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
| 963 |
+
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
| 964 |
+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
| 965 |
}
|
| 966 |
|
| 967 |
+
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
| 968 |
+
llama_grammar_parser parser;
|
| 969 |
+
|
| 970 |
+
// if there is a grammar, parse it
|
| 971 |
+
if (!parser.parse(grammar_str)) {
|
| 972 |
+
return nullptr;
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
// will be empty (default) if there are parse errors
|
| 976 |
+
if (parser.rules.empty()) {
|
| 977 |
+
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
| 978 |
+
return nullptr;
|
| 979 |
+
}
|
| 980 |
+
|
| 981 |
+
// Ensure that there is a "root" node.
|
| 982 |
+
if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
|
| 983 |
+
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
| 984 |
+
return nullptr;
|
| 985 |
+
}
|
| 986 |
+
|
| 987 |
+
std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
|
| 988 |
+
|
| 989 |
+
const size_t n_rules = grammar_rules.size();
|
| 990 |
+
const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
|
| 991 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
const llama_grammar_element * pos;
|
| 993 |
|
| 994 |
// copy rule definitions into vectors
|
| 995 |
llama_grammar_rules vec_rules(n_rules);
|
| 996 |
for (size_t i = 0; i < n_rules; i++) {
|
| 997 |
+
for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
| 998 |
vec_rules[i].push_back(*pos);
|
| 999 |
}
|
| 1000 |
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
|
|
|
| 1039 |
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
| 1040 |
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
| 1041 |
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
| 1042 |
+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
| 1043 |
}
|
| 1044 |
|
| 1045 |
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
| 1046 |
+
if (grammar == nullptr) {
|
| 1047 |
+
return;
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
delete grammar;
|
| 1051 |
}
|
| 1052 |
|
| 1053 |
+
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
| 1054 |
+
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
|
| 1055 |
|
| 1056 |
// redirect elements in stacks to point to new rules
|
| 1057 |
for (size_t is = 0; is < result->stacks.size(); is++) {
|
| 1058 |
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
| 1059 |
+
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
| 1060 |
+
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
| 1061 |
+
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
| 1062 |
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
| 1063 |
}
|
| 1064 |
}
|
|
|
|
| 1069 |
return result;
|
| 1070 |
}
|
| 1071 |
|
| 1072 |
+
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
| 1073 |
+
GGML_ASSERT(grammar.vocab != nullptr);
|
|
|
|
|
|
|
|
|
|
| 1074 |
|
| 1075 |
bool allow_eog = false;
|
| 1076 |
+
for (const auto & stack : grammar.stacks) {
|
| 1077 |
if (stack.empty()) {
|
| 1078 |
allow_eog = true;
|
| 1079 |
break;
|
|
|
|
| 1081 |
}
|
| 1082 |
|
| 1083 |
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
| 1084 |
+
candidates_decoded.reserve(cur_p->size);
|
| 1085 |
|
| 1086 |
llama_grammar_candidates candidates_grammar;
|
| 1087 |
+
candidates_grammar.reserve(cur_p->size);
|
| 1088 |
|
| 1089 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1090 |
+
const llama_token id = cur_p->data[i].id;
|
| 1091 |
+
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
|
| 1092 |
|
| 1093 |
+
if (llama_token_is_eog_impl(*grammar.vocab, id)) {
|
| 1094 |
if (!allow_eog) {
|
| 1095 |
+
cur_p->data[i].logit = -INFINITY;
|
| 1096 |
}
|
| 1097 |
} else if (piece.empty() || piece[0] == 0) {
|
| 1098 |
+
cur_p->data[i].logit = -INFINITY;
|
| 1099 |
} else {
|
| 1100 |
+
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
| 1101 |
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
| 1102 |
}
|
| 1103 |
}
|
| 1104 |
|
| 1105 |
+
const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
| 1106 |
for (const auto & reject : rejects) {
|
| 1107 |
+
cur_p->data[reject.index].logit = -INFINITY;
|
| 1108 |
}
|
|
|
|
|
|
|
| 1109 |
}
|
| 1110 |
|
| 1111 |
+
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
| 1112 |
+
GGML_ASSERT(grammar.vocab != nullptr);
|
| 1113 |
|
| 1114 |
+
if (llama_token_is_eog_impl(*grammar.vocab, token)) {
|
| 1115 |
+
for (const auto & stack : grammar.stacks) {
|
| 1116 |
if (stack.empty()) {
|
| 1117 |
return;
|
| 1118 |
}
|
|
|
|
| 1120 |
GGML_ABORT("fatal error");
|
| 1121 |
}
|
| 1122 |
|
| 1123 |
+
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
| 1124 |
|
| 1125 |
// Note terminating 0 in decoded string
|
| 1126 |
+
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
| 1127 |
const auto & code_points = decoded.first;
|
| 1128 |
|
| 1129 |
+
llama_grammar_stacks stacks_new;
|
| 1130 |
+
|
| 1131 |
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
| 1132 |
+
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
|
| 1133 |
+
grammar.stacks = std::move(stacks_new);
|
| 1134 |
}
|
| 1135 |
|
| 1136 |
+
grammar.partial_utf8 = decoded.second;
|
| 1137 |
+
GGML_ASSERT(!grammar.stacks.empty());
|
|
|
|
|
|
|
| 1138 |
}
|
examples/talk-llama/llama-grammar.h
CHANGED
|
@@ -2,11 +2,115 @@
|
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
|
|
|
|
|
|
|
| 5 |
struct llama_vocab;
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
struct llama_grammar {
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
llama_grammar_stacks stacks;
|
| 11 |
|
| 12 |
// buffer for partially generated UTF-8 sequence from accepted tokens
|
|
@@ -17,23 +121,24 @@ struct llama_grammar {
|
|
| 17 |
// internal API
|
| 18 |
//
|
| 19 |
|
|
|
|
| 20 |
struct llama_grammar * llama_grammar_init_impl(
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
| 26 |
|
| 27 |
-
struct llama_grammar *
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
llama_token_data_array * candidates);
|
| 34 |
|
| 35 |
-
void
|
| 36 |
-
struct llama_grammar
|
| 37 |
-
const struct llama_vocab * vocab,
|
| 38 |
-
const struct llama_sampling * smpl,
|
| 39 |
llama_token token);
|
|
|
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
|
| 5 |
+
#include <map>
|
| 6 |
+
|
| 7 |
struct llama_vocab;
|
| 8 |
+
|
| 9 |
+
// grammar element type
|
| 10 |
+
enum llama_gretype {
|
| 11 |
+
// end of rule definition
|
| 12 |
+
LLAMA_GRETYPE_END = 0,
|
| 13 |
+
|
| 14 |
+
// start of alternate definition for rule
|
| 15 |
+
LLAMA_GRETYPE_ALT = 1,
|
| 16 |
+
|
| 17 |
+
// non-terminal element: reference to rule
|
| 18 |
+
LLAMA_GRETYPE_RULE_REF = 2,
|
| 19 |
+
|
| 20 |
+
// terminal element: character (code point)
|
| 21 |
+
LLAMA_GRETYPE_CHAR = 3,
|
| 22 |
+
|
| 23 |
+
// inverse char(s) ([^a], [^a-b] [^abc])
|
| 24 |
+
LLAMA_GRETYPE_CHAR_NOT = 4,
|
| 25 |
+
|
| 26 |
+
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
| 27 |
+
// be an inclusive range ([a-z])
|
| 28 |
+
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
| 29 |
+
|
| 30 |
+
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
| 31 |
+
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
| 32 |
+
LLAMA_GRETYPE_CHAR_ALT = 6,
|
| 33 |
+
|
| 34 |
+
// any character (.)
|
| 35 |
+
LLAMA_GRETYPE_CHAR_ANY = 7,
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
typedef struct llama_grammar_element {
|
| 39 |
+
enum llama_gretype type;
|
| 40 |
+
uint32_t value; // Unicode code point or rule ID
|
| 41 |
+
} llama_grammar_element;
|
| 42 |
+
|
| 43 |
+
struct llama_partial_utf8 {
|
| 44 |
+
uint32_t value; // bit value so far (unshifted)
|
| 45 |
+
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
struct llama_grammar_candidate {
|
| 49 |
+
size_t index;
|
| 50 |
+
const uint32_t * code_points;
|
| 51 |
+
llama_partial_utf8 partial_utf8;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
| 55 |
+
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
| 56 |
+
|
| 57 |
+
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
| 58 |
+
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
| 59 |
+
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
| 60 |
+
|
| 61 |
+
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
| 62 |
+
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
| 63 |
+
|
| 64 |
+
// takes a set of possible pushdown stacks on a grammar, which are required to
|
| 65 |
+
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
| 66 |
+
// produces the N possible stacks if the given char is accepted at those
|
| 67 |
+
// positions
|
| 68 |
+
void llama_grammar_accept(
|
| 69 |
+
const llama_grammar_rules & rules,
|
| 70 |
+
const llama_grammar_stacks & stacks,
|
| 71 |
+
uint32_t chr,
|
| 72 |
+
llama_grammar_stacks & stacks_new);
|
| 73 |
+
|
| 74 |
+
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
| 75 |
+
const llama_grammar_rules & rules,
|
| 76 |
+
const llama_grammar_stack & stack,
|
| 77 |
+
const llama_grammar_candidates & candidates);
|
| 78 |
+
|
| 79 |
+
struct llama_grammar_parser {
|
| 80 |
+
std::map<std::string, uint32_t> symbol_ids;
|
| 81 |
+
|
| 82 |
+
llama_grammar_rules rules;
|
| 83 |
+
|
| 84 |
+
llama_grammar_stack c_rules() const;
|
| 85 |
+
|
| 86 |
+
uint32_t get_symbol_id(const char * src, size_t len);
|
| 87 |
+
uint32_t generate_symbol_id(const std::string & base_name);
|
| 88 |
+
|
| 89 |
+
void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
|
| 90 |
+
|
| 91 |
+
const char * parse_alternates(
|
| 92 |
+
const char * src,
|
| 93 |
+
const std::string & rule_name,
|
| 94 |
+
uint32_t rule_id,
|
| 95 |
+
bool is_nested);
|
| 96 |
+
|
| 97 |
+
const char * parse_sequence(
|
| 98 |
+
const char * src,
|
| 99 |
+
const std::string & rule_name,
|
| 100 |
+
llama_grammar_rule & rule,
|
| 101 |
+
bool is_nested);
|
| 102 |
+
|
| 103 |
+
const char * parse_rule(const char * src);
|
| 104 |
+
|
| 105 |
+
bool parse(const char * src);
|
| 106 |
+
void print(FILE * file);
|
| 107 |
+
};
|
| 108 |
|
| 109 |
struct llama_grammar {
|
| 110 |
+
// note: allow null vocab for testing (not great)
|
| 111 |
+
const llama_vocab * vocab;
|
| 112 |
+
|
| 113 |
+
const llama_grammar_rules rules; // TODO: shared ptr
|
| 114 |
llama_grammar_stacks stacks;
|
| 115 |
|
| 116 |
// buffer for partially generated UTF-8 sequence from accepted tokens
|
|
|
|
| 121 |
// internal API
|
| 122 |
//
|
| 123 |
|
| 124 |
+
// note: needed for tests (not great)
|
| 125 |
struct llama_grammar * llama_grammar_init_impl(
|
| 126 |
+
const struct llama_vocab * vocab,
|
| 127 |
+
const llama_grammar_element ** rules,
|
| 128 |
+
size_t n_rules,
|
| 129 |
+
size_t start_rule_index);
|
| 130 |
+
|
| 131 |
+
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
|
| 132 |
|
| 133 |
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
| 134 |
|
| 135 |
+
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
|
| 136 |
|
| 137 |
+
// TODO: move the API below as member functions of llama_grammar
|
| 138 |
+
void llama_grammar_apply_impl(
|
| 139 |
+
const struct llama_grammar & grammar,
|
| 140 |
+
llama_token_data_array * cur_p);
|
|
|
|
| 141 |
|
| 142 |
+
void llama_grammar_accept_impl(
|
| 143 |
+
struct llama_grammar & grammar,
|
|
|
|
|
|
|
| 144 |
llama_token token);
|
examples/talk-llama/llama-impl.h
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
-
#define LLAMA_API_INTERNAL
|
| 4 |
#include "llama.h"
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
#ifdef __GNUC__
|
| 7 |
#ifdef __MINGW32__
|
| 8 |
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
@@ -21,14 +24,31 @@ LLAMA_ATTRIBUTE_FORMAT(2, 3)
|
|
| 21 |
void llama_log_internal (ggml_log_level level, const char * format, ...);
|
| 22 |
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
| 23 |
|
|
|
|
| 24 |
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
| 25 |
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
| 26 |
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
|
|
|
|
|
| 27 |
|
| 28 |
//
|
| 29 |
// helpers
|
| 30 |
//
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
| 33 |
if (search.empty()) {
|
| 34 |
return;
|
|
@@ -45,3 +65,117 @@ static void replace_all(std::string & s, const std::string & search, const std::
|
|
| 45 |
builder.append(s, last_pos, std::string::npos);
|
| 46 |
s = std::move(builder);
|
| 47 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
|
|
|
| 3 |
#include "llama.h"
|
| 4 |
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <vector>
|
| 7 |
+
#include <stdexcept>
|
| 8 |
+
|
| 9 |
#ifdef __GNUC__
|
| 10 |
#ifdef __MINGW32__
|
| 11 |
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
|
|
| 24 |
void llama_log_internal (ggml_log_level level, const char * format, ...);
|
| 25 |
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
| 26 |
|
| 27 |
+
#define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__)
|
| 28 |
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
| 29 |
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
| 30 |
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
| 31 |
+
#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
| 32 |
+
#define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
|
| 33 |
|
| 34 |
//
|
| 35 |
// helpers
|
| 36 |
//
|
| 37 |
|
| 38 |
+
struct time_meas {
|
| 39 |
+
time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
| 40 |
+
|
| 41 |
+
~time_meas() {
|
| 42 |
+
if (t_start_us >= 0) {
|
| 43 |
+
t_acc += ggml_time_us() - t_start_us;
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
const int64_t t_start_us;
|
| 48 |
+
|
| 49 |
+
int64_t & t_acc;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
| 53 |
if (search.empty()) {
|
| 54 |
return;
|
|
|
|
| 65 |
builder.append(s, last_pos, std::string::npos);
|
| 66 |
s = std::move(builder);
|
| 67 |
}
|
| 68 |
+
|
| 69 |
+
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
| 70 |
+
struct llama_context * ctx
|
| 71 |
+
);
|
| 72 |
+
|
| 73 |
+
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
| 74 |
+
template<typename T>
|
| 75 |
+
struct ring_buffer {
|
| 76 |
+
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
| 77 |
+
|
| 78 |
+
T & front() {
|
| 79 |
+
if (sz == 0) {
|
| 80 |
+
throw std::runtime_error("ring buffer is empty");
|
| 81 |
+
}
|
| 82 |
+
return data[first];
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
const T & front() const {
|
| 86 |
+
if (sz == 0) {
|
| 87 |
+
throw std::runtime_error("ring buffer is empty");
|
| 88 |
+
}
|
| 89 |
+
return data[first];
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
T & back() {
|
| 93 |
+
if (sz == 0) {
|
| 94 |
+
throw std::runtime_error("ring buffer is empty");
|
| 95 |
+
}
|
| 96 |
+
return data[pos];
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
const T & back() const {
|
| 100 |
+
if (sz == 0) {
|
| 101 |
+
throw std::runtime_error("ring buffer is empty");
|
| 102 |
+
}
|
| 103 |
+
return data[pos];
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
void push_back(const T & value) {
|
| 107 |
+
if (capacity == 0) {
|
| 108 |
+
throw std::runtime_error("ring buffer: capacity is zero");
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
if (sz == capacity) {
|
| 112 |
+
// advance the start when buffer is full
|
| 113 |
+
first = (first + 1) % capacity;
|
| 114 |
+
} else {
|
| 115 |
+
sz++;
|
| 116 |
+
}
|
| 117 |
+
data[pos] = value;
|
| 118 |
+
pos = (pos + 1) % capacity;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
T pop_front() {
|
| 122 |
+
if (sz == 0) {
|
| 123 |
+
throw std::runtime_error("ring buffer is empty");
|
| 124 |
+
}
|
| 125 |
+
T value = data[first];
|
| 126 |
+
first = (first + 1) % capacity;
|
| 127 |
+
sz--;
|
| 128 |
+
return value;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
//T & operator[](size_t i) {
|
| 132 |
+
// if (i >= sz) {
|
| 133 |
+
// throw std::runtime_error("ring buffer: index out of bounds");
|
| 134 |
+
// }
|
| 135 |
+
// return data[(first + i) % capacity];
|
| 136 |
+
//}
|
| 137 |
+
|
| 138 |
+
//const T & at(size_t i) const {
|
| 139 |
+
// if (i >= sz) {
|
| 140 |
+
// throw std::runtime_error("ring buffer: index out of bounds");
|
| 141 |
+
// }
|
| 142 |
+
// return data[(first + i) % capacity];
|
| 143 |
+
//}
|
| 144 |
+
|
| 145 |
+
const T & rat(size_t i) const {
|
| 146 |
+
if (i >= sz) {
|
| 147 |
+
throw std::runtime_error("ring buffer: index out of bounds");
|
| 148 |
+
}
|
| 149 |
+
return data[(first + sz - i - 1) % capacity];
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
std::vector<T> to_vector() const {
|
| 153 |
+
std::vector<T> result;
|
| 154 |
+
result.reserve(sz);
|
| 155 |
+
for (size_t i = 0; i < sz; i++) {
|
| 156 |
+
result.push_back(data[(first + i) % capacity]);
|
| 157 |
+
}
|
| 158 |
+
return result;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
void clear() {
|
| 162 |
+
// here only reset the status of the buffer
|
| 163 |
+
sz = 0;
|
| 164 |
+
first = 0;
|
| 165 |
+
pos = 0;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
bool empty() const {
|
| 169 |
+
return sz == 0;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
size_t size() const {
|
| 173 |
+
return sz;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
size_t capacity = 0;
|
| 177 |
+
size_t sz = 0;
|
| 178 |
+
size_t first = 0;
|
| 179 |
+
size_t pos = 0;
|
| 180 |
+
std::vector<T> data;
|
| 181 |
+
};
|
examples/talk-llama/llama-sampling.cpp
CHANGED
|
@@ -1,12 +1,53 @@
|
|
| 1 |
#include "llama-sampling.h"
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
#include <cstring>
|
| 5 |
#include <ctime>
|
| 6 |
-
#include <cfloat>
|
| 7 |
#include <numeric>
|
|
|
|
| 8 |
#include <unordered_map>
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
static void llama_log_softmax(float * array, size_t size) {
|
| 11 |
float max_l = *std::max_element(array, array + size);
|
| 12 |
float sum = 0.f;
|
|
@@ -20,66 +61,52 @@ static void llama_log_softmax(float * array, size_t size) {
|
|
| 20 |
array[i] = logf(array[i] / sum);
|
| 21 |
}
|
| 22 |
}
|
|
|
|
| 23 |
|
| 24 |
-
void
|
| 25 |
-
|
| 26 |
-
seed = time(NULL);
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
-
smpl->rng.seed(seed);
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
| 33 |
-
GGML_ASSERT(candidates->size > 0);
|
| 34 |
-
|
| 35 |
-
const int64_t t_start_sample_us = ggml_time_us();
|
| 36 |
|
| 37 |
// Sort the logits in descending order
|
| 38 |
-
if (!
|
| 39 |
-
std::sort(
|
| 40 |
return a.logit > b.logit;
|
| 41 |
});
|
| 42 |
-
|
| 43 |
}
|
| 44 |
|
| 45 |
-
float max_l =
|
| 46 |
float cum_sum = 0.0f;
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
cum_sum += p;
|
| 51 |
}
|
| 52 |
-
for (size_t i = 0; i < candidates->size; ++i) {
|
| 53 |
-
candidates->data[i].p /= cum_sum;
|
| 54 |
-
}
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
}
|
| 59 |
}
|
| 60 |
|
| 61 |
-
void
|
| 62 |
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
| 63 |
-
// if (k >= (int32_t)
|
| 64 |
// return;
|
| 65 |
// }
|
| 66 |
|
| 67 |
-
const int64_t t_start_sample_us = ggml_time_us();
|
| 68 |
-
|
| 69 |
if (k <= 0) {
|
| 70 |
-
k =
|
| 71 |
}
|
| 72 |
|
| 73 |
-
k = std::
|
| 74 |
-
k = std::min(k, (int) candidates->size);
|
| 75 |
|
| 76 |
// Sort scores in descending order
|
| 77 |
-
if (!
|
| 78 |
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
| 79 |
return a.logit > b.logit;
|
| 80 |
};
|
| 81 |
if (k <= 128) {
|
| 82 |
-
std::partial_sort(
|
| 83 |
} else {
|
| 84 |
constexpr int nbuckets = 128;
|
| 85 |
constexpr float bucket_low = -10.0f;
|
|
@@ -87,11 +114,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
| 87 |
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
| 88 |
constexpr float bucket_inter = -bucket_low * bucket_scale;
|
| 89 |
|
| 90 |
-
std::vector<int> bucket_idx(
|
| 91 |
std::vector<int> histo(nbuckets, 0);
|
| 92 |
|
| 93 |
-
for (int i = 0; i < (int)
|
| 94 |
-
const float val =
|
| 95 |
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
| 96 |
ib = std::max(0, std::min(nbuckets-1, ib));
|
| 97 |
bucket_idx[i] = ib;
|
|
@@ -101,20 +128,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
| 101 |
int ib = nbuckets - 1;
|
| 102 |
for ( ; ib >= 0; --ib) {
|
| 103 |
nhave += histo[ib];
|
| 104 |
-
if (nhave >= k)
|
|
|
|
|
|
|
| 105 |
}
|
| 106 |
std::vector<llama_token_data> tmp_tokens(nhave);
|
| 107 |
-
auto ptr = tmp_tokens.data();
|
| 108 |
std::vector<llama_token_data*> bucket_ptrs;
|
| 109 |
bucket_ptrs.reserve(nbuckets - ib);
|
| 110 |
for (int j = nbuckets - 1; j >= ib; --j) {
|
| 111 |
bucket_ptrs.push_back(ptr);
|
| 112 |
ptr += histo[j];
|
| 113 |
}
|
| 114 |
-
for (int i = 0; i < (int)
|
| 115 |
int j = bucket_idx[i];
|
| 116 |
if (j >= ib) {
|
| 117 |
-
*bucket_ptrs[nbuckets-1-j]++ =
|
| 118 |
}
|
| 119 |
}
|
| 120 |
|
|
@@ -127,125 +156,582 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|
| 127 |
}
|
| 128 |
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
| 129 |
|
| 130 |
-
std::memcpy(
|
| 131 |
|
| 132 |
}
|
| 133 |
-
|
| 134 |
}
|
| 135 |
-
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
}
|
|
|
|
| 140 |
}
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
return;
|
| 145 |
}
|
| 146 |
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
// Compute the cumulative probabilities
|
| 152 |
float cum_sum = 0.0f;
|
| 153 |
-
size_t last_idx =
|
| 154 |
|
| 155 |
-
for (size_t i = 0; i <
|
| 156 |
-
cum_sum +=
|
| 157 |
|
| 158 |
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
| 159 |
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
| 160 |
-
if (cum_sum >= p && i + 1 >= min_keep) {
|
| 161 |
last_idx = i + 1;
|
| 162 |
break;
|
| 163 |
}
|
| 164 |
}
|
| 165 |
|
| 166 |
// Resize the output vector to keep only the top-p tokens
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
}
|
| 173 |
|
| 174 |
-
void
|
| 175 |
-
|
|
|
|
|
|
|
| 176 |
return;
|
| 177 |
}
|
| 178 |
|
| 179 |
-
const int64_t t_start_sample_us = ggml_time_us();
|
| 180 |
-
|
| 181 |
bool min_p_applied = false;
|
| 182 |
|
| 183 |
-
// if the
|
| 184 |
-
if (!
|
| 185 |
std::vector<llama_token_data> filtered_tokens;
|
| 186 |
|
| 187 |
float max_logit = -FLT_MAX;
|
| 188 |
-
for (size_t i = 0; i <
|
| 189 |
-
max_logit = std::max(max_logit,
|
| 190 |
}
|
| 191 |
-
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
| 192 |
|
| 193 |
-
for (size_t i = 0; i <
|
| 194 |
-
if (
|
| 195 |
-
filtered_tokens.push_back(
|
| 196 |
}
|
| 197 |
}
|
| 198 |
|
| 199 |
// if we have enough values the operation was a success
|
| 200 |
-
if (filtered_tokens.size() >= min_keep) {
|
| 201 |
-
memcpy(
|
| 202 |
-
|
| 203 |
min_p_applied = true;
|
| 204 |
}
|
| 205 |
}
|
| 206 |
|
| 207 |
-
// if the
|
| 208 |
if (!min_p_applied) {
|
| 209 |
// Sort the logits in descending order
|
| 210 |
-
if (!
|
| 211 |
-
std::sort(
|
| 212 |
return a.logit > b.logit;
|
| 213 |
});
|
| 214 |
-
|
| 215 |
}
|
| 216 |
|
| 217 |
-
const float min_logit =
|
| 218 |
size_t i = 1; // first token always matches
|
| 219 |
|
| 220 |
-
for (; i <
|
| 221 |
-
if (
|
| 222 |
break; // prob too small
|
| 223 |
}
|
| 224 |
}
|
| 225 |
|
| 226 |
// Resize the output vector to keep only the matching tokens
|
| 227 |
-
|
| 228 |
}
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
}
|
| 234 |
|
| 235 |
-
void
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
return;
|
| 238 |
}
|
| 239 |
|
| 240 |
-
|
| 241 |
-
const int64_t t_start_sample_us = ggml_time_us();
|
| 242 |
|
| 243 |
// Compute the first and second derivatives
|
| 244 |
-
std::vector<float> first_derivatives(
|
| 245 |
-
std::vector<float> second_derivatives(
|
| 246 |
|
| 247 |
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
| 248 |
-
first_derivatives[i] =
|
| 249 |
}
|
| 250 |
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
| 251 |
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
|
@@ -272,51 +758,86 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
|
| 272 |
}
|
| 273 |
|
| 274 |
float cum_sum = 0.0f;
|
| 275 |
-
size_t last_idx =
|
| 276 |
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
| 277 |
cum_sum += second_derivatives[i];
|
| 278 |
|
| 279 |
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
| 280 |
-
if (cum_sum > z && i >= min_keep) {
|
| 281 |
last_idx = i;
|
| 282 |
break;
|
| 283 |
}
|
| 284 |
}
|
| 285 |
|
| 286 |
// Resize the output vector to keep only the tokens above the tail location
|
| 287 |
-
|
|
|
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
}
|
| 293 |
|
| 294 |
-
void
|
|
|
|
|
|
|
| 295 |
// Reference implementation:
|
| 296 |
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
| 297 |
-
if (p >= 1.0f) {
|
| 298 |
return;
|
| 299 |
}
|
| 300 |
|
| 301 |
// Compute the softmax of logits and calculate entropy
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
const int64_t t_start_sample_us = ggml_time_us();
|
| 305 |
|
| 306 |
float entropy = 0.0f;
|
| 307 |
-
for (size_t i = 0; i <
|
| 308 |
-
entropy += -
|
| 309 |
}
|
| 310 |
|
| 311 |
// Compute the absolute difference between negative log probability and entropy for each candidate
|
| 312 |
std::vector<float> shifted_scores;
|
| 313 |
-
for (size_t i = 0; i <
|
| 314 |
-
float shifted_score = fabsf(-logf(
|
| 315 |
shifted_scores.push_back(shifted_score);
|
| 316 |
}
|
| 317 |
|
| 318 |
// Sort tokens based on the shifted_scores and their corresponding indices
|
| 319 |
-
std::vector<size_t> indices(
|
| 320 |
std::iota(indices.begin(), indices.end(), 0);
|
| 321 |
|
| 322 |
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
|
@@ -329,134 +850,618 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|
| 329 |
|
| 330 |
for (size_t i = 0; i < indices.size(); ++i) {
|
| 331 |
size_t idx = indices[i];
|
| 332 |
-
cum_sum +=
|
| 333 |
|
| 334 |
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
| 335 |
-
if (cum_sum > p && i >= min_keep - 1) {
|
| 336 |
last_idx = i + 1;
|
| 337 |
break;
|
| 338 |
}
|
| 339 |
}
|
| 340 |
|
| 341 |
// Resize the output vector to keep only the locally typical tokens
|
| 342 |
-
std::vector<llama_token_data>
|
| 343 |
for (size_t i = 0; i < last_idx; ++i) {
|
| 344 |
size_t idx = indices[i];
|
| 345 |
-
|
| 346 |
}
|
| 347 |
|
| 348 |
-
// Replace the data in
|
| 349 |
-
std::copy(
|
| 350 |
-
|
| 351 |
-
|
|
|
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
}
|
| 357 |
|
| 358 |
-
void
|
| 359 |
-
|
|
|
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
}
|
|
|
|
| 365 |
|
| 366 |
-
|
| 367 |
-
|
|
|
|
|
|
|
| 368 |
|
| 369 |
-
|
|
|
|
|
|
|
| 370 |
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
}
|
| 378 |
}
|
|
|
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
|
|
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
-
//
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
}
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
}
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
}
|
| 418 |
-
#endif
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
}
|
| 424 |
|
| 425 |
-
|
| 426 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
}
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
}
|
|
|
|
|
|
|
| 435 |
}
|
| 436 |
|
| 437 |
-
void
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
return;
|
| 447 |
}
|
| 448 |
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
// Create a frequency map to count occurrences of each token in last_tokens
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
| 455 |
}
|
| 456 |
|
| 457 |
-
// Apply frequency and presence penalties to the
|
| 458 |
-
for (size_t i = 0; i <
|
| 459 |
-
const auto token_iter = token_count.find(
|
| 460 |
if (token_iter == token_count.end()) {
|
| 461 |
continue;
|
| 462 |
}
|
|
@@ -465,171 +1470,238 @@ void llama_sample_repetition_penalties_impl(
|
|
| 465 |
|
| 466 |
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
| 467 |
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
| 468 |
-
if (
|
| 469 |
-
|
| 470 |
} else {
|
| 471 |
-
|
| 472 |
}
|
| 473 |
|
| 474 |
-
|
| 475 |
}
|
| 476 |
|
| 477 |
-
|
| 478 |
|
| 479 |
-
if (
|
| 480 |
-
|
|
|
|
| 481 |
}
|
| 482 |
}
|
| 483 |
|
| 484 |
-
void
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
float scale) {
|
| 489 |
-
GGML_ASSERT(smpl);
|
| 490 |
-
|
| 491 |
-
const auto t_start_sample_us = ggml_time_us();
|
| 492 |
-
const auto n_vocab = smpl->n_vocab;
|
| 493 |
-
|
| 494 |
-
llama_log_softmax(logits, n_vocab);
|
| 495 |
-
llama_log_softmax(logits_guidance, n_vocab);
|
| 496 |
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
-
|
| 502 |
}
|
| 503 |
|
| 504 |
-
|
| 505 |
}
|
| 506 |
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
const int32_t n_vocab = float(smpl->n_vocab);
|
| 511 |
-
|
| 512 |
-
int64_t t_start_sample_us = ggml_time_us();
|
| 513 |
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
float sum_ti_bi = 0.0;
|
| 519 |
-
float sum_ti_sq = 0.0;
|
| 520 |
-
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
| 521 |
-
float t_i = logf(float(i + 2) / float(i + 1));
|
| 522 |
-
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
| 523 |
-
sum_ti_bi += t_i * b_i;
|
| 524 |
-
sum_ti_sq += t_i * t_i;
|
| 525 |
}
|
| 526 |
-
s_hat = sum_ti_bi / sum_ti_sq;
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
-
|
| 533 |
-
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
| 534 |
-
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 535 |
-
llama_token X = llama_sample_token_impl(smpl, candidates);
|
| 536 |
-
t_start_sample_us = ggml_time_us();
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
return candidate.id == X;
|
| 541 |
-
}));
|
| 542 |
-
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
| 543 |
-
float e = observed_surprise - tau;
|
| 544 |
|
| 545 |
-
|
| 546 |
-
*mu = *mu - eta * e;
|
| 547 |
|
| 548 |
-
|
| 549 |
-
|
|
|
|
|
|
|
|
|
|
| 550 |
}
|
| 551 |
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
|
|
|
|
|
|
|
|
|
| 555 |
|
| 556 |
-
|
| 557 |
|
| 558 |
-
//
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
-
if (
|
| 564 |
-
|
| 565 |
}
|
| 566 |
|
| 567 |
-
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
}
|
|
|
|
| 570 |
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
| 573 |
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
-
|
| 586 |
-
*mu = *mu - eta * e;
|
| 587 |
|
| 588 |
-
|
| 589 |
-
|
|
|
|
| 590 |
}
|
| 591 |
-
return X;
|
| 592 |
-
}
|
| 593 |
|
| 594 |
-
|
| 595 |
-
|
|
|
|
| 596 |
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
});
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
}
|
| 607 |
-
|
|
|
|
| 608 |
}
|
| 609 |
|
| 610 |
-
|
| 611 |
-
GGML_ASSERT(smpl);
|
| 612 |
|
| 613 |
-
|
| 614 |
-
|
| 615 |
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
for (size_t i = 0; i < candidates->size; ++i) {
|
| 619 |
-
probs.push_back(candidates->data[i].p);
|
| 620 |
}
|
| 621 |
|
| 622 |
-
|
| 623 |
-
int idx = dist(rng);
|
| 624 |
|
| 625 |
-
|
|
|
|
| 626 |
|
| 627 |
-
|
| 628 |
-
|
| 629 |
|
| 630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
}
|
| 632 |
|
| 633 |
-
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
}
|
|
|
|
| 1 |
#include "llama-sampling.h"
|
| 2 |
|
| 3 |
+
#include "llama-vocab.h"
|
| 4 |
+
#include "llama-grammar.h"
|
| 5 |
+
|
| 6 |
#include <algorithm>
|
| 7 |
+
#include <cassert>
|
| 8 |
+
#include <cfloat>
|
| 9 |
+
#include <chrono>
|
| 10 |
+
#include <cmath>
|
| 11 |
+
#include <cstdlib>
|
| 12 |
#include <cstring>
|
| 13 |
#include <ctime>
|
|
|
|
| 14 |
#include <numeric>
|
| 15 |
+
#include <random>
|
| 16 |
#include <unordered_map>
|
| 17 |
|
| 18 |
+
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
| 19 |
+
// iterator for the probabilities
|
| 20 |
+
#ifdef __GNUC__
|
| 21 |
+
#pragma GCC diagnostic push
|
| 22 |
+
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
struct probs_iterator {
|
| 26 |
+
typedef std::input_iterator_tag iterator_category;
|
| 27 |
+
typedef float value_type;
|
| 28 |
+
typedef float * pointer;
|
| 29 |
+
typedef float & reference;
|
| 30 |
+
typedef ptrdiff_t difference_type;
|
| 31 |
+
|
| 32 |
+
const llama_token_data * data;
|
| 33 |
+
|
| 34 |
+
bool operator==(const probs_iterator & other) const { return data == other.data; }
|
| 35 |
+
bool operator!=(const probs_iterator & other) const { return data != other.data; }
|
| 36 |
+
const float & operator*() const { return data->p; }
|
| 37 |
+
probs_iterator & operator++() { ++data; return *this; }
|
| 38 |
+
probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
#ifdef __GNUC__
|
| 42 |
+
#pragma GCC diagnostic pop
|
| 43 |
+
#endif
|
| 44 |
+
|
| 45 |
+
std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
|
| 46 |
+
|
| 47 |
+
return dist(rng);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
static void llama_log_softmax(float * array, size_t size) {
|
| 52 |
float max_l = *std::max_element(array, array + size);
|
| 53 |
float sum = 0.f;
|
|
|
|
| 61 |
array[i] = logf(array[i] / sum);
|
| 62 |
}
|
| 63 |
}
|
| 64 |
+
*/
|
| 65 |
|
| 66 |
+
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
| 67 |
+
GGML_ASSERT(cur_p->size > 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
// Sort the logits in descending order
|
| 70 |
+
if (!cur_p->sorted) {
|
| 71 |
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
| 72 |
return a.logit > b.logit;
|
| 73 |
});
|
| 74 |
+
cur_p->sorted = true;
|
| 75 |
}
|
| 76 |
|
| 77 |
+
float max_l = cur_p->data[0].logit;
|
| 78 |
float cum_sum = 0.0f;
|
| 79 |
+
|
| 80 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 81 |
+
float p = expf(cur_p->data[i].logit - max_l);
|
| 82 |
+
cur_p->data[i].p = p;
|
| 83 |
cum_sum += p;
|
| 84 |
}
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 87 |
+
cur_p->data[i].p /= cum_sum;
|
| 88 |
}
|
| 89 |
}
|
| 90 |
|
| 91 |
+
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
| 92 |
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
| 93 |
+
// if (k >= (int32_t)cur_p->size) {
|
| 94 |
// return;
|
| 95 |
// }
|
| 96 |
|
|
|
|
|
|
|
| 97 |
if (k <= 0) {
|
| 98 |
+
k = cur_p->size;
|
| 99 |
}
|
| 100 |
|
| 101 |
+
k = std::min(k, (int) cur_p->size);
|
|
|
|
| 102 |
|
| 103 |
// Sort scores in descending order
|
| 104 |
+
if (!cur_p->sorted) {
|
| 105 |
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
| 106 |
return a.logit > b.logit;
|
| 107 |
};
|
| 108 |
if (k <= 128) {
|
| 109 |
+
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
|
| 110 |
} else {
|
| 111 |
constexpr int nbuckets = 128;
|
| 112 |
constexpr float bucket_low = -10.0f;
|
|
|
|
| 114 |
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
| 115 |
constexpr float bucket_inter = -bucket_low * bucket_scale;
|
| 116 |
|
| 117 |
+
std::vector<int> bucket_idx(cur_p->size);
|
| 118 |
std::vector<int> histo(nbuckets, 0);
|
| 119 |
|
| 120 |
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
| 121 |
+
const float val = cur_p->data[i].logit;
|
| 122 |
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
| 123 |
ib = std::max(0, std::min(nbuckets-1, ib));
|
| 124 |
bucket_idx[i] = ib;
|
|
|
|
| 128 |
int ib = nbuckets - 1;
|
| 129 |
for ( ; ib >= 0; --ib) {
|
| 130 |
nhave += histo[ib];
|
| 131 |
+
if (nhave >= k) {
|
| 132 |
+
break;
|
| 133 |
+
}
|
| 134 |
}
|
| 135 |
std::vector<llama_token_data> tmp_tokens(nhave);
|
| 136 |
+
auto * ptr = tmp_tokens.data();
|
| 137 |
std::vector<llama_token_data*> bucket_ptrs;
|
| 138 |
bucket_ptrs.reserve(nbuckets - ib);
|
| 139 |
for (int j = nbuckets - 1; j >= ib; --j) {
|
| 140 |
bucket_ptrs.push_back(ptr);
|
| 141 |
ptr += histo[j];
|
| 142 |
}
|
| 143 |
+
for (int i = 0; i < (int)cur_p->size; ++i) {
|
| 144 |
int j = bucket_idx[i];
|
| 145 |
if (j >= ib) {
|
| 146 |
+
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
|
| 147 |
}
|
| 148 |
}
|
| 149 |
|
|
|
|
| 156 |
}
|
| 157 |
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
| 158 |
|
| 159 |
+
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
| 160 |
|
| 161 |
}
|
| 162 |
+
cur_p->sorted = true;
|
| 163 |
}
|
| 164 |
+
cur_p->size = k;
|
| 165 |
+
}
|
| 166 |
|
| 167 |
+
static uint32_t get_rng_seed(uint32_t seed) {
|
| 168 |
+
if (seed == LLAMA_DEFAULT_SEED) {
|
| 169 |
+
// use system clock if std::random_device is not a true RNG
|
| 170 |
+
static bool is_rd_prng = std::random_device().entropy() == 0;
|
| 171 |
+
if (is_rd_prng) {
|
| 172 |
+
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
|
| 173 |
+
}
|
| 174 |
+
std::random_device rd;
|
| 175 |
+
return rd();
|
| 176 |
}
|
| 177 |
+
return seed;
|
| 178 |
}
|
| 179 |
|
| 180 |
+
// llama_sampler API
|
| 181 |
+
|
| 182 |
+
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
| 183 |
+
if (!smpl->iface) {
|
| 184 |
+
return "(null)";
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
return smpl->iface->name(smpl);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
| 191 |
+
if (smpl->iface->accept) {
|
| 192 |
+
smpl->iface->accept(smpl, token);
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
| 197 |
+
GGML_ASSERT(smpl->iface->apply);
|
| 198 |
+
smpl->iface->apply(smpl, cur_p);
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
void llama_sampler_reset(struct llama_sampler * smpl) {
|
| 202 |
+
if (smpl->iface->reset) {
|
| 203 |
+
smpl->iface->reset(smpl);
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
| 208 |
+
if (smpl->iface->clone) {
|
| 209 |
+
return smpl->iface->clone(smpl);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
if (smpl->ctx == nullptr) {
|
| 213 |
+
return new llama_sampler {
|
| 214 |
+
/* .iface = */ smpl->iface,
|
| 215 |
+
/* .ctx = */ nullptr,
|
| 216 |
+
};
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
GGML_ABORT("the sampler does not support cloning");
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
void llama_sampler_free(struct llama_sampler * smpl) {
|
| 223 |
+
if (smpl == nullptr) {
|
| 224 |
return;
|
| 225 |
}
|
| 226 |
|
| 227 |
+
if (smpl->iface->free) {
|
| 228 |
+
smpl->iface->free(smpl);
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
delete smpl;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
| 235 |
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
| 236 |
+
|
| 237 |
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
| 238 |
+
|
| 239 |
+
// TODO: do not allocate each time
|
| 240 |
+
std::vector<llama_token_data> cur;
|
| 241 |
+
cur.reserve(n_vocab);
|
| 242 |
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
| 243 |
+
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
llama_token_data_array cur_p = {
|
| 247 |
+
/* .data = */ cur.data(),
|
| 248 |
+
/* .size = */ cur.size(),
|
| 249 |
+
/* .selected = */ -1,
|
| 250 |
+
/* .sorted = */ false,
|
| 251 |
+
};
|
| 252 |
+
|
| 253 |
+
llama_sampler_apply(smpl, &cur_p);
|
| 254 |
+
|
| 255 |
+
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
| 256 |
+
|
| 257 |
+
auto token = cur_p.data[cur_p.selected].id;
|
| 258 |
+
|
| 259 |
+
llama_sampler_accept(smpl, token);
|
| 260 |
+
|
| 261 |
+
return token;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
// sampler chain
|
| 265 |
+
|
| 266 |
+
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
| 267 |
+
return "chain";
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
|
| 271 |
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
| 272 |
+
|
| 273 |
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
| 274 |
+
|
| 275 |
+
for (auto * smpl : chain->samplers) {
|
| 276 |
+
llama_sampler_accept(smpl, token);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
chain->n_sample++;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 283 |
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
| 284 |
+
|
| 285 |
+
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
| 286 |
+
|
| 287 |
+
for (auto * smpl : chain->samplers) {
|
| 288 |
+
llama_sampler_apply(smpl, cur_p);
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
| 293 |
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
| 294 |
+
|
| 295 |
+
for (auto * smpl : chain->samplers) {
|
| 296 |
+
llama_sampler_reset(smpl);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
chain->t_sample_us = 0;
|
| 300 |
+
chain->n_sample = 0;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
| 304 |
+
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
| 305 |
+
|
| 306 |
+
auto * result = llama_sampler_chain_init(chain_src->params);
|
| 307 |
+
|
| 308 |
+
for (auto * smpl : chain_src->samplers) {
|
| 309 |
+
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
return result;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
| 316 |
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
| 317 |
+
|
| 318 |
+
for (auto * smpl : chain->samplers) {
|
| 319 |
+
llama_sampler_free(smpl);
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
delete chain;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
static struct llama_sampler_i llama_sampler_chain_i = {
|
| 326 |
+
/* .name = */ llama_sampler_chain_name,
|
| 327 |
+
/* .accept = */ llama_sampler_chain_accept,
|
| 328 |
+
/* .apply = */ llama_sampler_chain_apply,
|
| 329 |
+
/* .reset = */ llama_sampler_chain_reset,
|
| 330 |
+
/* .clone = */ llama_sampler_chain_clone,
|
| 331 |
+
/* .free = */ llama_sampler_chain_free,
|
| 332 |
+
};
|
| 333 |
+
|
| 334 |
+
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
| 335 |
+
return new llama_sampler {
|
| 336 |
+
/* .iface = */ &llama_sampler_chain_i,
|
| 337 |
+
/* .ctx = */ new llama_sampler_chain {
|
| 338 |
+
/* .params = */ params,
|
| 339 |
+
/* .samplers = */ {},
|
| 340 |
+
/* .t_sample_us = */ 0,
|
| 341 |
+
/* .n_sample = */ 0,
|
| 342 |
+
},
|
| 343 |
+
};
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
| 347 |
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
| 348 |
+
p->samplers.push_back(smpl);
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
| 352 |
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
| 353 |
+
|
| 354 |
+
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
| 355 |
+
return nullptr;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
return p->samplers[i];
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
| 362 |
+
auto * p = (llama_sampler_chain *) chain->ctx;
|
| 363 |
+
|
| 364 |
+
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
| 365 |
+
return nullptr;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
auto * result = p->samplers[i];
|
| 369 |
+
p->samplers.erase(p->samplers.begin() + i);
|
| 370 |
+
|
| 371 |
+
return result;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
| 375 |
+
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
| 376 |
+
|
| 377 |
+
return p->samplers.size();
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
//
|
| 381 |
+
// samplers
|
| 382 |
+
//
|
| 383 |
+
|
| 384 |
+
// greedy
|
| 385 |
+
|
| 386 |
+
static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
|
| 387 |
+
return "greedy";
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
| 391 |
+
cur_p->selected = 0;
|
| 392 |
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
| 393 |
+
if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
|
| 394 |
+
cur_p->selected = i;
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
static struct llama_sampler_i llama_sampler_greedy_i = {
|
| 400 |
+
/* .name = */ llama_sampler_greedy_name,
|
| 401 |
+
/* .accept = */ nullptr,
|
| 402 |
+
/* .apply = */ llama_sampler_greedy_apply,
|
| 403 |
+
/* .reset = */ nullptr,
|
| 404 |
+
/* .clone = */ nullptr,
|
| 405 |
+
/* .free = */ nullptr,
|
| 406 |
+
};
|
| 407 |
+
|
| 408 |
+
struct llama_sampler * llama_sampler_init_greedy() {
|
| 409 |
+
return new llama_sampler {
|
| 410 |
+
/* .iface = */ &llama_sampler_greedy_i,
|
| 411 |
+
/* .ctx = */ nullptr,
|
| 412 |
+
};
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
// dist
|
| 416 |
+
|
| 417 |
+
struct llama_sampler_dist {
|
| 418 |
+
const uint32_t seed;
|
| 419 |
+
uint32_t seed_cur;
|
| 420 |
+
|
| 421 |
+
std::mt19937 rng;
|
| 422 |
+
};
|
| 423 |
+
|
| 424 |
+
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
|
| 425 |
+
return "dist";
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 429 |
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
| 430 |
+
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
| 434 |
+
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
| 435 |
+
auto * result = llama_sampler_init_dist(ctx->seed);
|
| 436 |
+
|
| 437 |
+
// copy the state
|
| 438 |
+
{
|
| 439 |
+
auto * result_ctx = (llama_sampler_dist *) result->ctx;
|
| 440 |
+
|
| 441 |
+
result_ctx->rng = ctx->rng;
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
return result;
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
| 448 |
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
| 449 |
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
| 450 |
+
ctx->rng.seed(ctx->seed_cur);
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
| 454 |
+
delete (llama_sampler_dist *) smpl->ctx;
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
static struct llama_sampler_i llama_sampler_dist_i = {
|
| 458 |
+
/* .name = */ llama_sampler_dist_name,
|
| 459 |
+
/* .accept = */ nullptr,
|
| 460 |
+
/* .apply = */ llama_sampler_dist_apply,
|
| 461 |
+
/* .reset = */ llama_sampler_dist_reset,
|
| 462 |
+
/* .clone = */ llama_sampler_dist_clone,
|
| 463 |
+
/* .free = */ llama_sampler_dist_free,
|
| 464 |
+
};
|
| 465 |
+
|
| 466 |
+
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
| 467 |
+
auto seed_cur = get_rng_seed(seed);
|
| 468 |
+
return new llama_sampler {
|
| 469 |
+
/* .iface = */ &llama_sampler_dist_i,
|
| 470 |
+
/* .ctx = */ new llama_sampler_dist {
|
| 471 |
+
/* .seed = */ seed,
|
| 472 |
+
/* .seed_cur = */ seed_cur,
|
| 473 |
+
/* .rng = */ std::mt19937(seed_cur),
|
| 474 |
+
},
|
| 475 |
+
};
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
// softmax
|
| 479 |
+
|
| 480 |
+
static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
|
| 481 |
+
return "softmax";
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
| 485 |
+
llama_sampler_softmax_impl(cur_p);
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
static struct llama_sampler_i llama_sampler_softmax_i = {
|
| 489 |
+
/* .name = */ llama_sampler_softmax_name,
|
| 490 |
+
/* .accept = */ nullptr,
|
| 491 |
+
/* .apply = */ llama_sampler_softmax_apply,
|
| 492 |
+
/* .reset = */ nullptr,
|
| 493 |
+
/* .clone = */ nullptr,
|
| 494 |
+
/* .free = */ nullptr,
|
| 495 |
+
};
|
| 496 |
+
|
| 497 |
+
struct llama_sampler * llama_sampler_init_softmax() {
|
| 498 |
+
return new llama_sampler {
|
| 499 |
+
/* .iface = */ &llama_sampler_softmax_i,
|
| 500 |
+
/* .ctx = */ nullptr,
|
| 501 |
+
};
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
// top-k
|
| 505 |
+
|
| 506 |
+
struct llama_sampler_top_k {
|
| 507 |
+
const int32_t k;
|
| 508 |
+
};
|
| 509 |
+
|
| 510 |
+
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
|
| 511 |
+
return "top-k";
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 515 |
+
const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
|
| 516 |
+
llama_sampler_top_k_impl(cur_p, ctx->k);
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
|
| 520 |
+
const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
|
| 521 |
+
return llama_sampler_init_top_k(ctx->k);
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
| 525 |
+
delete (llama_sampler_top_k *) smpl->ctx;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
static struct llama_sampler_i llama_sampler_top_k_i = {
|
| 529 |
+
/* .name = */ llama_sampler_top_k_name,
|
| 530 |
+
/* .accept = */ nullptr,
|
| 531 |
+
/* .apply = */ llama_sampler_top_k_apply,
|
| 532 |
+
/* .reset = */ nullptr,
|
| 533 |
+
/* .clone = */ llama_sampler_top_k_clone,
|
| 534 |
+
/* .free = */ llama_sampler_top_k_free,
|
| 535 |
+
};
|
| 536 |
+
|
| 537 |
+
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
| 538 |
+
return new llama_sampler {
|
| 539 |
+
/* .iface = */ &llama_sampler_top_k_i,
|
| 540 |
+
/* .ctx = */ new llama_sampler_top_k {
|
| 541 |
+
/* .k = */ k,
|
| 542 |
+
},
|
| 543 |
+
};
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
// top-p
|
| 547 |
+
|
| 548 |
+
struct llama_sampler_top_p {
|
| 549 |
+
const float p;
|
| 550 |
+
const size_t min_keep;
|
| 551 |
+
};
|
| 552 |
+
|
| 553 |
+
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
|
| 554 |
+
return "top-p";
|
| 555 |
+
}
|
| 556 |
|
| 557 |
+
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 558 |
+
const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
|
| 559 |
+
|
| 560 |
+
if (ctx->p >= 1.0f) {
|
| 561 |
+
return;
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
llama_sampler_softmax_impl(cur_p);
|
| 565 |
|
| 566 |
// Compute the cumulative probabilities
|
| 567 |
float cum_sum = 0.0f;
|
| 568 |
+
size_t last_idx = cur_p->size;
|
| 569 |
|
| 570 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 571 |
+
cum_sum += cur_p->data[i].p;
|
| 572 |
|
| 573 |
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
| 574 |
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
| 575 |
+
if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
|
| 576 |
last_idx = i + 1;
|
| 577 |
break;
|
| 578 |
}
|
| 579 |
}
|
| 580 |
|
| 581 |
// Resize the output vector to keep only the top-p tokens
|
| 582 |
+
cur_p->size = last_idx;
|
| 583 |
+
}
|
| 584 |
|
| 585 |
+
static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
|
| 586 |
+
const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
|
| 587 |
+
return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
| 591 |
+
delete (llama_sampler_top_p *) smpl->ctx;
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
static struct llama_sampler_i llama_sampler_top_p_i = {
|
| 595 |
+
/* .name = */ llama_sampler_top_p_name,
|
| 596 |
+
/* .accept = */ nullptr,
|
| 597 |
+
/* .apply = */ llama_sampler_top_p_apply,
|
| 598 |
+
/* .reset = */ nullptr,
|
| 599 |
+
/* .clone = */ llama_sampler_top_p_clone,
|
| 600 |
+
/* .free = */ llama_sampler_top_p_free,
|
| 601 |
+
};
|
| 602 |
+
|
| 603 |
+
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
| 604 |
+
return new llama_sampler {
|
| 605 |
+
/* .iface = */ &llama_sampler_top_p_i,
|
| 606 |
+
/* .ctx = */ new llama_sampler_top_p {
|
| 607 |
+
/* .p = */ p,
|
| 608 |
+
/* .min_keep = */ min_keep,
|
| 609 |
+
},
|
| 610 |
+
};
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
// min-p
|
| 614 |
+
|
| 615 |
+
struct llama_sampler_min_p {
|
| 616 |
+
const float p;
|
| 617 |
+
const size_t min_keep;
|
| 618 |
+
};
|
| 619 |
+
|
| 620 |
+
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
|
| 621 |
+
return "min-p";
|
| 622 |
}
|
| 623 |
|
| 624 |
+
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 625 |
+
const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
|
| 626 |
+
|
| 627 |
+
if (ctx->p <= 0.0f || !cur_p->size) {
|
| 628 |
return;
|
| 629 |
}
|
| 630 |
|
|
|
|
|
|
|
| 631 |
bool min_p_applied = false;
|
| 632 |
|
| 633 |
+
// if the cur_p aren't sorted, try the unsorted implementation first
|
| 634 |
+
if (!cur_p->sorted) {
|
| 635 |
std::vector<llama_token_data> filtered_tokens;
|
| 636 |
|
| 637 |
float max_logit = -FLT_MAX;
|
| 638 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 639 |
+
max_logit = std::max(max_logit, cur_p->data[i].logit);
|
| 640 |
}
|
| 641 |
+
const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
| 642 |
|
| 643 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 644 |
+
if (cur_p->data[i].logit >= min_logit) {
|
| 645 |
+
filtered_tokens.push_back(cur_p->data[i]);
|
| 646 |
}
|
| 647 |
}
|
| 648 |
|
| 649 |
// if we have enough values the operation was a success
|
| 650 |
+
if (filtered_tokens.size() >= ctx->min_keep) {
|
| 651 |
+
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
| 652 |
+
cur_p->size = filtered_tokens.size();
|
| 653 |
min_p_applied = true;
|
| 654 |
}
|
| 655 |
}
|
| 656 |
|
| 657 |
+
// if the cur_p are sorted or the unsorted implementation failed, use this implementation
|
| 658 |
if (!min_p_applied) {
|
| 659 |
// Sort the logits in descending order
|
| 660 |
+
if (!cur_p->sorted) {
|
| 661 |
+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
| 662 |
return a.logit > b.logit;
|
| 663 |
});
|
| 664 |
+
cur_p->sorted = true;
|
| 665 |
}
|
| 666 |
|
| 667 |
+
const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
| 668 |
size_t i = 1; // first token always matches
|
| 669 |
|
| 670 |
+
for (; i < cur_p->size; ++i) {
|
| 671 |
+
if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
|
| 672 |
break; // prob too small
|
| 673 |
}
|
| 674 |
}
|
| 675 |
|
| 676 |
// Resize the output vector to keep only the matching tokens
|
| 677 |
+
cur_p->size = i;
|
| 678 |
}
|
| 679 |
+
}
|
| 680 |
|
| 681 |
+
static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
|
| 682 |
+
const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
|
| 683 |
+
return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
|
| 684 |
}
|
| 685 |
|
| 686 |
+
static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
| 687 |
+
delete (llama_sampler_min_p *) smpl->ctx;
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
static struct llama_sampler_i llama_sampler_min_p_i = {
|
| 691 |
+
/* .name = */ llama_sampler_min_p_name,
|
| 692 |
+
/* .accept = */ nullptr,
|
| 693 |
+
/* .apply = */ llama_sampler_min_p_apply,
|
| 694 |
+
/* .reset = */ nullptr,
|
| 695 |
+
/* .clone = */ llama_sampler_min_p_clone,
|
| 696 |
+
/* .free = */ llama_sampler_min_p_free,
|
| 697 |
+
};
|
| 698 |
+
|
| 699 |
+
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
| 700 |
+
return new llama_sampler {
|
| 701 |
+
/* .iface = */ &llama_sampler_min_p_i,
|
| 702 |
+
/* .ctx = */ new llama_sampler_min_p {
|
| 703 |
+
/* .p = */ p,
|
| 704 |
+
/* .min_keep = */ min_keep,
|
| 705 |
+
},
|
| 706 |
+
};
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
// tail-free
|
| 710 |
+
|
| 711 |
+
struct llama_sampler_tail_free {
|
| 712 |
+
const float z;
|
| 713 |
+
const size_t min_keep;
|
| 714 |
+
};
|
| 715 |
+
|
| 716 |
+
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
|
| 717 |
+
return "tail-free";
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 721 |
+
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
|
| 722 |
+
|
| 723 |
+
if (ctx->z >= 1.0f || cur_p->size <= 2) {
|
| 724 |
return;
|
| 725 |
}
|
| 726 |
|
| 727 |
+
llama_sampler_softmax_impl(cur_p);
|
|
|
|
| 728 |
|
| 729 |
// Compute the first and second derivatives
|
| 730 |
+
std::vector<float> first_derivatives(cur_p->size - 1);
|
| 731 |
+
std::vector<float> second_derivatives(cur_p->size - 2);
|
| 732 |
|
| 733 |
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
| 734 |
+
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
|
| 735 |
}
|
| 736 |
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
| 737 |
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
|
|
|
| 758 |
}
|
| 759 |
|
| 760 |
float cum_sum = 0.0f;
|
| 761 |
+
size_t last_idx = cur_p->size;
|
| 762 |
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
| 763 |
cum_sum += second_derivatives[i];
|
| 764 |
|
| 765 |
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
| 766 |
+
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
| 767 |
last_idx = i;
|
| 768 |
break;
|
| 769 |
}
|
| 770 |
}
|
| 771 |
|
| 772 |
// Resize the output vector to keep only the tokens above the tail location
|
| 773 |
+
cur_p->size = last_idx;
|
| 774 |
+
}
|
| 775 |
|
| 776 |
+
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
| 777 |
+
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
|
| 778 |
+
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
| 782 |
+
delete (llama_sampler_tail_free *) smpl->ctx;
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
| 786 |
+
/* .name = */ llama_sampler_tail_free_name,
|
| 787 |
+
/* .accept = */ nullptr,
|
| 788 |
+
/* .apply = */ llama_sampler_tail_free_apply,
|
| 789 |
+
/* .reset = */ nullptr,
|
| 790 |
+
/* .clone = */ llama_sampler_tail_free_clone,
|
| 791 |
+
/* .free = */ llama_sampler_tail_free_free,
|
| 792 |
+
};
|
| 793 |
+
|
| 794 |
+
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
| 795 |
+
return new llama_sampler {
|
| 796 |
+
/* .iface = */ &llama_sampler_tail_free_i,
|
| 797 |
+
/* .ctx = */ new llama_sampler_tail_free {
|
| 798 |
+
/* .z = */ z,
|
| 799 |
+
/*. min_keep = */ min_keep,
|
| 800 |
+
},
|
| 801 |
+
};
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
// typical
|
| 805 |
+
|
| 806 |
+
struct llama_sampler_typical {
|
| 807 |
+
const float p;
|
| 808 |
+
const size_t min_keep;
|
| 809 |
+
};
|
| 810 |
+
|
| 811 |
+
static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
|
| 812 |
+
return "typical";
|
| 813 |
}
|
| 814 |
|
| 815 |
+
static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 816 |
+
const auto * ctx = (llama_sampler_typical *) smpl->ctx;
|
| 817 |
+
|
| 818 |
// Reference implementation:
|
| 819 |
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
| 820 |
+
if (ctx->p >= 1.0f) {
|
| 821 |
return;
|
| 822 |
}
|
| 823 |
|
| 824 |
// Compute the softmax of logits and calculate entropy
|
| 825 |
+
llama_sampler_softmax_impl(cur_p);
|
|
|
|
|
|
|
| 826 |
|
| 827 |
float entropy = 0.0f;
|
| 828 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 829 |
+
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
|
| 830 |
}
|
| 831 |
|
| 832 |
// Compute the absolute difference between negative log probability and entropy for each candidate
|
| 833 |
std::vector<float> shifted_scores;
|
| 834 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 835 |
+
float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
|
| 836 |
shifted_scores.push_back(shifted_score);
|
| 837 |
}
|
| 838 |
|
| 839 |
// Sort tokens based on the shifted_scores and their corresponding indices
|
| 840 |
+
std::vector<size_t> indices(cur_p->size);
|
| 841 |
std::iota(indices.begin(), indices.end(), 0);
|
| 842 |
|
| 843 |
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
|
|
|
| 850 |
|
| 851 |
for (size_t i = 0; i < indices.size(); ++i) {
|
| 852 |
size_t idx = indices[i];
|
| 853 |
+
cum_sum += cur_p->data[idx].p;
|
| 854 |
|
| 855 |
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
| 856 |
+
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
|
| 857 |
last_idx = i + 1;
|
| 858 |
break;
|
| 859 |
}
|
| 860 |
}
|
| 861 |
|
| 862 |
// Resize the output vector to keep only the locally typical tokens
|
| 863 |
+
std::vector<llama_token_data> cur_p_new;
|
| 864 |
for (size_t i = 0; i < last_idx; ++i) {
|
| 865 |
size_t idx = indices[i];
|
| 866 |
+
cur_p_new.push_back(cur_p->data[idx]);
|
| 867 |
}
|
| 868 |
|
| 869 |
+
// Replace the data in cur_p with the cur_p_new data
|
| 870 |
+
std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
|
| 871 |
+
cur_p->size = cur_p_new.size();
|
| 872 |
+
cur_p->sorted = false;
|
| 873 |
+
}
|
| 874 |
|
| 875 |
+
static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
|
| 876 |
+
const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
|
| 877 |
+
return llama_sampler_init_typical(ctx->p, ctx->min_keep);
|
| 878 |
}
|
| 879 |
|
| 880 |
+
static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
| 881 |
+
delete (llama_sampler_typical *) smpl->ctx;
|
| 882 |
+
}
|
| 883 |
|
| 884 |
+
static struct llama_sampler_i llama_sampler_typical_i = {
|
| 885 |
+
/* .name = */ llama_sampler_typical_name,
|
| 886 |
+
/* .accept = */ nullptr,
|
| 887 |
+
/* .apply = */ llama_sampler_typical_apply,
|
| 888 |
+
/* .reset = */ nullptr,
|
| 889 |
+
/* .clone = */ llama_sampler_typical_clone,
|
| 890 |
+
/* .free = */ llama_sampler_typical_free,
|
| 891 |
+
};
|
| 892 |
+
|
| 893 |
+
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
| 894 |
+
return new llama_sampler {
|
| 895 |
+
/* .iface = */ &llama_sampler_typical_i,
|
| 896 |
+
/* .ctx = */ new llama_sampler_typical {
|
| 897 |
+
/* .p = */ p,
|
| 898 |
+
/* .min_keep = */ min_keep,
|
| 899 |
+
},
|
| 900 |
+
};
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
// temp
|
| 904 |
+
|
| 905 |
+
struct llama_sampler_temp {
|
| 906 |
+
const float temp;
|
| 907 |
+
};
|
| 908 |
+
|
| 909 |
+
static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
|
| 910 |
+
return "temp";
|
| 911 |
+
}
|
| 912 |
+
|
| 913 |
+
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 914 |
+
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
| 915 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 916 |
+
cur_p->data[i].logit /= ctx->temp;
|
| 917 |
}
|
| 918 |
+
}
|
| 919 |
|
| 920 |
+
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
| 921 |
+
const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
|
| 922 |
+
return llama_sampler_init_temp(ctx->temp);
|
| 923 |
+
}
|
| 924 |
|
| 925 |
+
static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
| 926 |
+
delete (llama_sampler_temp *) smpl->ctx;
|
| 927 |
+
}
|
| 928 |
|
| 929 |
+
static struct llama_sampler_i llama_sampler_temp_i = {
|
| 930 |
+
/* .name = */ llama_sampler_temp_name,
|
| 931 |
+
/* .accept = */ nullptr,
|
| 932 |
+
/* .apply = */ llama_sampler_temp_apply,
|
| 933 |
+
/* .reset = */ nullptr,
|
| 934 |
+
/* .clone = */ llama_sampler_temp_clone,
|
| 935 |
+
/* .free = */ llama_sampler_temp_free,
|
| 936 |
+
};
|
| 937 |
+
|
| 938 |
+
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
| 939 |
+
return new llama_sampler {
|
| 940 |
+
/* .iface = */ &llama_sampler_temp_i,
|
| 941 |
+
/* .ctx = */ new llama_sampler_temp {
|
| 942 |
+
/*.temp = */ temp,
|
| 943 |
+
},
|
| 944 |
+
};
|
| 945 |
+
}
|
| 946 |
+
|
| 947 |
+
// temp-ext
|
| 948 |
+
|
| 949 |
+
struct llama_sampler_temp_ext {
|
| 950 |
+
const float temp;
|
| 951 |
+
const float delta;
|
| 952 |
+
const float exponent;
|
| 953 |
+
};
|
| 954 |
+
|
| 955 |
+
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
|
| 956 |
+
return "temp-ext";
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 960 |
+
const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
|
| 961 |
+
if (ctx->delta > 0) {
|
| 962 |
+
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
| 963 |
+
const float max_temp = ctx->temp + ctx->delta;
|
| 964 |
+
float exponent_val = ctx->exponent;
|
| 965 |
+
|
| 966 |
+
// no need to do anything if there is only one (or zero) candidates
|
| 967 |
+
if (cur_p->size <= 1) {
|
| 968 |
+
return;
|
| 969 |
+
}
|
| 970 |
+
|
| 971 |
+
// Calculate maximum possible entropy
|
| 972 |
+
float max_entropy = -logf(1.0f / cur_p->size);
|
| 973 |
+
|
| 974 |
+
llama_sampler_softmax_impl(cur_p);
|
| 975 |
+
|
| 976 |
+
// Calculate entropy of the softmax probabilities
|
| 977 |
+
float entropy = 0.0f;
|
| 978 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 979 |
+
float prob = cur_p->data[i].p;
|
| 980 |
+
if (prob > 0.0f) { // Ensure no log(0)
|
| 981 |
+
entropy -= prob * logf(prob);
|
| 982 |
+
}
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
// Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
|
| 986 |
+
float normalized_entropy = entropy / max_entropy;
|
| 987 |
+
|
| 988 |
+
// Map the normalized entropy to the desired temperature range using the power function
|
| 989 |
+
float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
|
| 990 |
+
|
| 991 |
+
#ifdef DEBUG
|
| 992 |
+
LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
|
| 993 |
+
LLAMA_LOG_INFO("Entropy: %f\n", entropy);
|
| 994 |
+
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
| 995 |
+
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
| 996 |
+
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
| 997 |
+
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
| 998 |
+
#endif
|
| 999 |
+
|
| 1000 |
+
// Apply the dynamically calculated temperature scaling
|
| 1001 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1002 |
+
cur_p->data[i].logit /= dyn_temp;
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
| 1006 |
+
const double max_l_double = cur_p->data[0].logit;
|
| 1007 |
+
|
| 1008 |
+
double cum_sum_double = 0.0;
|
| 1009 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1010 |
+
double p = exp(cur_p->data[i].logit - max_l_double);
|
| 1011 |
+
cur_p->data[i].p = p; // Store the scaled probability
|
| 1012 |
+
cum_sum_double += p;
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1016 |
+
cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
| 1017 |
+
}
|
| 1018 |
+
|
| 1019 |
+
#ifdef DEBUG
|
| 1020 |
+
// Print the updated top 25 probabilities after temperature scaling
|
| 1021 |
+
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
|
| 1022 |
+
for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
|
| 1023 |
+
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
|
| 1024 |
+
}
|
| 1025 |
+
#endif
|
| 1026 |
+
} else {
|
| 1027 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1028 |
+
cur_p->data[i].logit /= ctx->temp;
|
| 1029 |
}
|
| 1030 |
}
|
| 1031 |
+
}
|
| 1032 |
|
| 1033 |
+
static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
|
| 1034 |
+
const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
|
| 1035 |
+
return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
|
| 1036 |
+
}
|
| 1037 |
|
| 1038 |
+
static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
| 1039 |
+
delete (llama_sampler_temp_ext *) smpl->ctx;
|
| 1040 |
+
}
|
| 1041 |
|
| 1042 |
+
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
| 1043 |
+
/* .name = */ llama_sampler_temp_ext_name,
|
| 1044 |
+
/* .accept = */ nullptr,
|
| 1045 |
+
/* .apply = */ llama_sampler_temp_ext_apply,
|
| 1046 |
+
/* .reset = */ nullptr,
|
| 1047 |
+
/* .clone = */ llama_sampler_temp_ext_clone,
|
| 1048 |
+
/* .free = */ llama_sampler_temp_ext_free,
|
| 1049 |
+
};
|
| 1050 |
+
|
| 1051 |
+
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
| 1052 |
+
return new llama_sampler {
|
| 1053 |
+
/* .iface = */ &llama_sampler_temp_ext_i,
|
| 1054 |
+
/* .ctx = */ new llama_sampler_temp_ext {
|
| 1055 |
+
/* .temp = */ temp,
|
| 1056 |
+
/* .delta = */ delta,
|
| 1057 |
+
/* .exponent = */ exponent,
|
| 1058 |
+
},
|
| 1059 |
+
};
|
| 1060 |
+
}
|
| 1061 |
+
|
| 1062 |
+
// mirostat
|
| 1063 |
+
|
| 1064 |
+
struct llama_sampler_mirostat {
|
| 1065 |
+
const int32_t n_vocab;
|
| 1066 |
+
|
| 1067 |
+
const uint32_t seed;
|
| 1068 |
+
uint32_t seed_cur;
|
| 1069 |
+
|
| 1070 |
+
const float tau;
|
| 1071 |
+
const float eta;
|
| 1072 |
+
|
| 1073 |
+
const int32_t m;
|
| 1074 |
+
|
| 1075 |
+
float mu;
|
| 1076 |
|
| 1077 |
+
std::mt19937 rng;
|
| 1078 |
+
};
|
| 1079 |
+
|
| 1080 |
+
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
|
| 1081 |
+
return "mirostat";
|
| 1082 |
+
}
|
| 1083 |
+
|
| 1084 |
+
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1085 |
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
| 1086 |
+
|
| 1087 |
+
llama_sampler_softmax_impl(cur_p);
|
| 1088 |
+
|
| 1089 |
+
// Estimate s_hat using the most probable m tokens
|
| 1090 |
+
float s_hat = 0.0;
|
| 1091 |
+
float sum_ti_bi = 0.0;
|
| 1092 |
+
float sum_ti_sq = 0.0;
|
| 1093 |
+
for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
|
| 1094 |
+
float t_i = logf(float(i + 2) / float(i + 1));
|
| 1095 |
+
float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
|
| 1096 |
+
sum_ti_bi += t_i * b_i;
|
| 1097 |
+
sum_ti_sq += t_i * t_i;
|
| 1098 |
}
|
| 1099 |
+
s_hat = sum_ti_bi / sum_ti_sq;
|
| 1100 |
+
|
| 1101 |
+
// Compute k from the estimated s_hat and target surprise value
|
| 1102 |
+
float epsilon_hat = s_hat - 1;
|
| 1103 |
+
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
|
| 1104 |
+
|
| 1105 |
+
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
| 1106 |
+
llama_sampler_softmax_impl(cur_p);
|
| 1107 |
+
|
| 1108 |
+
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
| 1109 |
+
|
| 1110 |
+
cur_p->selected = idx;
|
| 1111 |
+
|
| 1112 |
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
| 1113 |
+
float e = observed_surprise - ctx->tau;
|
| 1114 |
|
| 1115 |
+
// Update mu using the learning rate and error
|
| 1116 |
+
ctx->mu = ctx->mu - ctx->eta * e;
|
| 1117 |
+
}
|
| 1118 |
+
|
| 1119 |
+
static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
|
| 1120 |
+
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
|
| 1121 |
+
auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
|
| 1122 |
+
|
| 1123 |
+
// copy the state
|
| 1124 |
+
{
|
| 1125 |
+
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
|
| 1126 |
+
|
| 1127 |
+
result_ctx->mu = ctx->mu;
|
| 1128 |
+
result_ctx->rng = ctx->rng;
|
| 1129 |
}
|
| 1130 |
+
|
| 1131 |
+
return result;
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
|
| 1135 |
+
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
| 1136 |
+
ctx->mu = 2.0f*ctx->tau;
|
| 1137 |
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
| 1138 |
+
ctx->rng.seed(ctx->seed_cur);
|
| 1139 |
+
}
|
| 1140 |
+
|
| 1141 |
+
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
| 1142 |
+
delete (llama_sampler_mirostat *) smpl->ctx;
|
| 1143 |
+
}
|
| 1144 |
+
|
| 1145 |
+
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
| 1146 |
+
/* .name = */ llama_sampler_mirostat_name,
|
| 1147 |
+
/* .accept = */ nullptr,
|
| 1148 |
+
/* .apply = */ llama_sampler_mirostat_apply,
|
| 1149 |
+
/* .reset = */ llama_sampler_mirostat_reset,
|
| 1150 |
+
/* .clone = */ llama_sampler_mirostat_clone,
|
| 1151 |
+
/* .free = */ llama_sampler_mirostat_free,
|
| 1152 |
+
};
|
| 1153 |
+
|
| 1154 |
+
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
| 1155 |
+
auto seed_cur = get_rng_seed(seed);
|
| 1156 |
+
return new llama_sampler {
|
| 1157 |
+
/* .iface = */ &llama_sampler_mirostat_i,
|
| 1158 |
+
/* .ctx = */ new llama_sampler_mirostat {
|
| 1159 |
+
/* .n_vocab = */ n_vocab,
|
| 1160 |
+
/* .seed = */ seed,
|
| 1161 |
+
/* .seed_cur = */ seed_cur,
|
| 1162 |
+
/* .tau = */ tau,
|
| 1163 |
+
/* .eta = */ eta,
|
| 1164 |
+
/* .m = */ m,
|
| 1165 |
+
/* .mu = */ 2.0f*tau,
|
| 1166 |
+
/* .rng = */ std::mt19937(seed_cur),
|
| 1167 |
+
},
|
| 1168 |
+
};
|
| 1169 |
+
}
|
| 1170 |
+
|
| 1171 |
+
// mirostat v2
|
| 1172 |
+
|
| 1173 |
+
struct llama_sampler_mirostat_v2 {
|
| 1174 |
+
const uint32_t seed;
|
| 1175 |
+
uint32_t seed_cur;
|
| 1176 |
+
|
| 1177 |
+
const float tau;
|
| 1178 |
+
const float eta;
|
| 1179 |
+
|
| 1180 |
+
float mu;
|
| 1181 |
+
|
| 1182 |
+
std::mt19937 rng;
|
| 1183 |
+
};
|
| 1184 |
+
|
| 1185 |
+
static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
|
| 1186 |
+
return "mirostat-v2";
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1190 |
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
| 1191 |
+
|
| 1192 |
+
llama_sampler_softmax_impl(cur_p);
|
| 1193 |
+
|
| 1194 |
+
// Truncate the words with surprise values greater than mu
|
| 1195 |
+
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
| 1196 |
+
return -log2f(candidate.p) > ctx->mu;
|
| 1197 |
+
}));
|
| 1198 |
+
|
| 1199 |
+
if (cur_p->size == 0) {
|
| 1200 |
+
cur_p->size = 1;
|
| 1201 |
}
|
| 1202 |
|
| 1203 |
+
// Normalize the probabilities of the remaining words
|
| 1204 |
+
llama_sampler_softmax_impl(cur_p);
|
| 1205 |
+
|
| 1206 |
+
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
| 1207 |
+
|
| 1208 |
+
cur_p->selected = idx;
|
| 1209 |
+
|
| 1210 |
+
float observed_surprise = -log2f(cur_p->data[idx].p);
|
| 1211 |
+
float e = observed_surprise - ctx->tau;
|
| 1212 |
+
|
| 1213 |
+
// Update mu using the learning rate and error
|
| 1214 |
+
ctx->mu = ctx->mu - ctx->eta * e;
|
| 1215 |
+
}
|
| 1216 |
+
|
| 1217 |
+
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
|
| 1218 |
+
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
| 1219 |
+
ctx->mu = 2.0f*ctx->tau;
|
| 1220 |
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
| 1221 |
+
ctx->rng.seed(ctx->seed_cur);
|
| 1222 |
+
}
|
| 1223 |
+
|
| 1224 |
+
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
|
| 1225 |
+
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
|
| 1226 |
+
|
| 1227 |
+
auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
|
| 1228 |
+
|
| 1229 |
+
// copy the state
|
| 1230 |
+
{
|
| 1231 |
+
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
|
| 1232 |
+
|
| 1233 |
+
result_ctx->mu = ctx->mu;
|
| 1234 |
+
result_ctx->rng = ctx->rng;
|
| 1235 |
}
|
|
|
|
| 1236 |
|
| 1237 |
+
return result;
|
| 1238 |
+
}
|
| 1239 |
+
|
| 1240 |
+
static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
| 1241 |
+
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
|
| 1242 |
+
}
|
| 1243 |
+
|
| 1244 |
+
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
| 1245 |
+
/* .name = */ llama_sampler_mirostat_v2_name,
|
| 1246 |
+
/* .accept = */ nullptr,
|
| 1247 |
+
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
| 1248 |
+
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
| 1249 |
+
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
| 1250 |
+
/* .free = */ llama_sampler_mirostat_v2_free,
|
| 1251 |
+
};
|
| 1252 |
+
|
| 1253 |
+
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
| 1254 |
+
auto seed_cur = get_rng_seed(seed);
|
| 1255 |
+
return new llama_sampler {
|
| 1256 |
+
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
| 1257 |
+
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
| 1258 |
+
/* .seed = */ seed,
|
| 1259 |
+
/* .seed_cur = */ seed_cur,
|
| 1260 |
+
/* .tau = */ tau,
|
| 1261 |
+
/* .eta = */ eta,
|
| 1262 |
+
/* .mu = */ 2.0f*tau,
|
| 1263 |
+
/* .rng = */ std::mt19937(seed_cur),
|
| 1264 |
+
},
|
| 1265 |
+
};
|
| 1266 |
+
}
|
| 1267 |
+
|
| 1268 |
+
// grammar
|
| 1269 |
+
|
| 1270 |
+
struct llama_sampler_grammar {
|
| 1271 |
+
const struct llama_vocab * vocab;
|
| 1272 |
+
|
| 1273 |
+
std::string grammar_str;
|
| 1274 |
+
std::string grammar_root;
|
| 1275 |
+
|
| 1276 |
+
struct llama_grammar * grammar;
|
| 1277 |
+
};
|
| 1278 |
+
|
| 1279 |
+
static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
|
| 1280 |
+
return "grammar";
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
|
| 1284 |
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
| 1285 |
+
if (ctx->grammar) {
|
| 1286 |
+
llama_grammar_accept_impl(*ctx->grammar, token);
|
| 1287 |
+
}
|
| 1288 |
+
}
|
| 1289 |
+
|
| 1290 |
+
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1291 |
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
| 1292 |
+
if (ctx->grammar) {
|
| 1293 |
+
llama_grammar_apply_impl(*ctx->grammar, cur_p);
|
| 1294 |
+
}
|
| 1295 |
+
}
|
| 1296 |
+
|
| 1297 |
+
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
| 1298 |
+
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
| 1299 |
+
if (!ctx->grammar) {
|
| 1300 |
+
return;
|
| 1301 |
}
|
| 1302 |
+
|
| 1303 |
+
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
|
| 1304 |
+
|
| 1305 |
+
llama_grammar_free_impl(ctx->grammar);
|
| 1306 |
+
ctx->grammar = grammar_new;
|
| 1307 |
}
|
| 1308 |
|
| 1309 |
+
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
| 1310 |
+
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
| 1311 |
+
|
| 1312 |
+
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
|
| 1313 |
+
|
| 1314 |
+
// copy the state
|
| 1315 |
+
{
|
| 1316 |
+
auto * result_ctx = (llama_sampler_grammar *) result->ctx;
|
| 1317 |
|
| 1318 |
+
if (ctx->grammar) {
|
| 1319 |
+
result_ctx->grammar_str = ctx->grammar_str;
|
| 1320 |
+
result_ctx->grammar_root = ctx->grammar_root;
|
| 1321 |
+
|
| 1322 |
+
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
|
| 1323 |
+
}
|
| 1324 |
+
}
|
| 1325 |
+
|
| 1326 |
+
return result;
|
| 1327 |
+
}
|
| 1328 |
+
|
| 1329 |
+
static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
| 1330 |
+
const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
| 1331 |
+
|
| 1332 |
+
if (ctx->grammar) {
|
| 1333 |
+
llama_grammar_free_impl(ctx->grammar);
|
| 1334 |
+
}
|
| 1335 |
+
|
| 1336 |
+
delete ctx;
|
| 1337 |
+
}
|
| 1338 |
+
|
| 1339 |
+
static struct llama_sampler_i llama_sampler_grammar_i = {
|
| 1340 |
+
/* .name = */ llama_sampler_grammar_name,
|
| 1341 |
+
/* .accept = */ llama_sampler_grammar_accept_impl,
|
| 1342 |
+
/* .apply = */ llama_sampler_grammar_apply,
|
| 1343 |
+
/* .reset = */ llama_sampler_grammar_reset,
|
| 1344 |
+
/* .clone = */ llama_sampler_grammar_clone,
|
| 1345 |
+
/* .free = */ llama_sampler_grammar_free,
|
| 1346 |
+
};
|
| 1347 |
+
|
| 1348 |
+
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
| 1349 |
+
auto * ctx = new llama_sampler_grammar;
|
| 1350 |
+
|
| 1351 |
+
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
| 1352 |
+
*ctx = {
|
| 1353 |
+
/* .vocab = */ &vocab,
|
| 1354 |
+
/* .grammar_str = */ grammar_str,
|
| 1355 |
+
/* .grammar_root = */ grammar_root,
|
| 1356 |
+
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
| 1357 |
+
};
|
| 1358 |
+
} else {
|
| 1359 |
+
*ctx = {
|
| 1360 |
+
/* .vocab = */ &vocab,
|
| 1361 |
+
/* .grammar_str = */ {},
|
| 1362 |
+
/* .grammar_root = */ {},
|
| 1363 |
+
/* .grammar = */ nullptr,
|
| 1364 |
+
};
|
| 1365 |
}
|
| 1366 |
|
| 1367 |
+
return new llama_sampler {
|
| 1368 |
+
/* .iface = */ &llama_sampler_grammar_i,
|
| 1369 |
+
/* .ctx = */ ctx,
|
| 1370 |
+
};
|
| 1371 |
+
}
|
| 1372 |
+
|
| 1373 |
+
// penalties
|
| 1374 |
+
|
| 1375 |
+
struct llama_sampler_penalties {
|
| 1376 |
+
const int32_t n_vocab;
|
| 1377 |
+
const llama_token special_eos_id;
|
| 1378 |
+
const llama_token linefeed_id;
|
| 1379 |
+
|
| 1380 |
+
const int32_t penalty_last_n;
|
| 1381 |
+
const float penalty_repeat;
|
| 1382 |
+
const float penalty_freq;
|
| 1383 |
+
const float penalty_present;
|
| 1384 |
+
|
| 1385 |
+
const bool penalize_nl;
|
| 1386 |
+
const bool ignore_eos;
|
| 1387 |
+
|
| 1388 |
+
ring_buffer<llama_token> prev;
|
| 1389 |
+
};
|
| 1390 |
+
|
| 1391 |
+
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
| 1392 |
+
return "penalties";
|
| 1393 |
+
}
|
| 1394 |
+
|
| 1395 |
+
static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
|
| 1396 |
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
| 1397 |
+
if (ctx->penalty_last_n == 0) {
|
| 1398 |
+
return;
|
| 1399 |
}
|
| 1400 |
+
|
| 1401 |
+
ctx->prev.push_back(token);
|
| 1402 |
}
|
| 1403 |
|
| 1404 |
+
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1405 |
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
| 1406 |
+
|
| 1407 |
+
if (ctx->ignore_eos) {
|
| 1408 |
+
assert(ctx->special_eos_id >= 0);
|
| 1409 |
+
|
| 1410 |
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
| 1411 |
+
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
| 1412 |
+
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
| 1413 |
+
} else {
|
| 1414 |
+
// else, search for the special EOS token
|
| 1415 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1416 |
+
if (cur_p->data[i].id == ctx->special_eos_id) {
|
| 1417 |
+
cur_p->data[i].logit = -INFINITY;
|
| 1418 |
+
break;
|
| 1419 |
+
}
|
| 1420 |
+
}
|
| 1421 |
+
}
|
| 1422 |
+
}
|
| 1423 |
+
|
| 1424 |
+
if ((ctx->penalty_last_n == 0) ||
|
| 1425 |
+
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
| 1426 |
return;
|
| 1427 |
}
|
| 1428 |
|
| 1429 |
+
bool nl_found = false;
|
| 1430 |
+
size_t nl_idx = 0;
|
| 1431 |
+
float nl_logit = -INFINITY;
|
| 1432 |
+
if (!ctx->penalize_nl) {
|
| 1433 |
+
assert(ctx->linefeed_id >= 0);
|
| 1434 |
+
|
| 1435 |
+
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
| 1436 |
+
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
| 1437 |
+
nl_found = true;
|
| 1438 |
+
nl_idx = ctx->linefeed_id;
|
| 1439 |
+
nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
| 1440 |
+
} else {
|
| 1441 |
+
// else, search for the linefeed token
|
| 1442 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1443 |
+
if (cur_p->data[i].id == ctx->linefeed_id) {
|
| 1444 |
+
nl_found = true;
|
| 1445 |
+
nl_idx = i;
|
| 1446 |
+
nl_logit = cur_p->data[i].logit;
|
| 1447 |
+
break;
|
| 1448 |
+
}
|
| 1449 |
+
}
|
| 1450 |
+
}
|
| 1451 |
+
}
|
| 1452 |
|
| 1453 |
// Create a frequency map to count occurrences of each token in last_tokens
|
| 1454 |
+
// TODO: optimize this by maintaining the token count in the sampler context
|
| 1455 |
+
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
| 1456 |
+
llama_token_cnt token_count;
|
| 1457 |
+
|
| 1458 |
+
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
| 1459 |
+
token_count[ctx->prev.rat(i)]++;
|
| 1460 |
}
|
| 1461 |
|
| 1462 |
+
// Apply frequency and presence penalties to the cur_p
|
| 1463 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1464 |
+
const auto token_iter = token_count.find(cur_p->data[i].id);
|
| 1465 |
if (token_iter == token_count.end()) {
|
| 1466 |
continue;
|
| 1467 |
}
|
|
|
|
| 1470 |
|
| 1471 |
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
| 1472 |
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
| 1473 |
+
if (cur_p->data[i].logit <= 0) {
|
| 1474 |
+
cur_p->data[i].logit *= ctx->penalty_repeat;
|
| 1475 |
} else {
|
| 1476 |
+
cur_p->data[i].logit /= ctx->penalty_repeat;
|
| 1477 |
}
|
| 1478 |
|
| 1479 |
+
cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
|
| 1480 |
}
|
| 1481 |
|
| 1482 |
+
cur_p->sorted = false;
|
| 1483 |
|
| 1484 |
+
if (!ctx->penalize_nl && nl_found) {
|
| 1485 |
+
// restore the logit of the newline token if it was penalized
|
| 1486 |
+
cur_p->data[nl_idx].logit = nl_logit;
|
| 1487 |
}
|
| 1488 |
}
|
| 1489 |
|
| 1490 |
+
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
| 1491 |
+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
| 1492 |
+
ctx->prev.clear();
|
| 1493 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1494 |
|
| 1495 |
+
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
| 1496 |
+
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
| 1497 |
+
auto * result = llama_sampler_init_penalties(
|
| 1498 |
+
ctx->n_vocab,
|
| 1499 |
+
ctx->special_eos_id,
|
| 1500 |
+
ctx->linefeed_id,
|
| 1501 |
+
ctx->penalty_last_n,
|
| 1502 |
+
ctx->penalty_repeat,
|
| 1503 |
+
ctx->penalty_freq,
|
| 1504 |
+
ctx->penalty_present,
|
| 1505 |
+
ctx->penalize_nl,
|
| 1506 |
+
ctx->ignore_eos);
|
| 1507 |
+
|
| 1508 |
+
// copy the state
|
| 1509 |
+
{
|
| 1510 |
+
auto * result_ctx = (llama_sampler_penalties *) result->ctx;
|
| 1511 |
|
| 1512 |
+
result_ctx->prev = ctx->prev;
|
| 1513 |
}
|
| 1514 |
|
| 1515 |
+
return result;
|
| 1516 |
}
|
| 1517 |
|
| 1518 |
+
static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
| 1519 |
+
delete (llama_sampler_penalties *) smpl->ctx;
|
| 1520 |
+
}
|
|
|
|
|
|
|
|
|
|
| 1521 |
|
| 1522 |
+
static struct llama_sampler_i llama_sampler_penalties_i = {
|
| 1523 |
+
/* .name = */ llama_sampler_penalties_name,
|
| 1524 |
+
/* .accept = */ llama_sampler_penalties_accept,
|
| 1525 |
+
/* .apply = */ llama_sampler_penalties_apply,
|
| 1526 |
+
/* .reset = */ llama_sampler_penalties_reset,
|
| 1527 |
+
/* .clone = */ llama_sampler_penalties_clone,
|
| 1528 |
+
/* .free = */ llama_sampler_penalties_free,
|
| 1529 |
+
};
|
| 1530 |
+
|
| 1531 |
+
struct llama_sampler * llama_sampler_init_penalties(
|
| 1532 |
+
int32_t n_vocab,
|
| 1533 |
+
llama_token special_eos_id,
|
| 1534 |
+
llama_token linefeed_id,
|
| 1535 |
+
int32_t penalty_last_n,
|
| 1536 |
+
float penalty_repeat,
|
| 1537 |
+
float penalty_freq,
|
| 1538 |
+
float penalty_present,
|
| 1539 |
+
bool penalize_nl,
|
| 1540 |
+
bool ignore_eos) {
|
| 1541 |
+
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
| 1542 |
+
penalize_nl = true;
|
| 1543 |
+
}
|
| 1544 |
|
| 1545 |
+
if (special_eos_id == LLAMA_TOKEN_NULL) {
|
| 1546 |
+
ignore_eos = false;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1547 |
}
|
|
|
|
| 1548 |
|
| 1549 |
+
penalty_last_n = std::max(penalty_last_n, 0);
|
| 1550 |
+
|
| 1551 |
+
return new llama_sampler {
|
| 1552 |
+
/* .iface = */ &llama_sampler_penalties_i,
|
| 1553 |
+
/* .ctx = */ new llama_sampler_penalties {
|
| 1554 |
+
/* .n_vocab = */ n_vocab,
|
| 1555 |
+
/* .special_eos_id = */ special_eos_id,
|
| 1556 |
+
/* .linefeed_id = */ linefeed_id,
|
| 1557 |
+
/* .penalty_last_n = */ penalty_last_n,
|
| 1558 |
+
/* .penalty_repeat = */ penalty_repeat,
|
| 1559 |
+
/* .penalty_freq = */ penalty_freq,
|
| 1560 |
+
/* .penalty_present = */ penalty_present,
|
| 1561 |
+
/* .penalize_nl = */ penalize_nl,
|
| 1562 |
+
/* .ignore_eos = */ ignore_eos,
|
| 1563 |
+
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
| 1564 |
+
},
|
| 1565 |
+
};
|
| 1566 |
+
}
|
| 1567 |
|
| 1568 |
+
// logit-bias
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1569 |
|
| 1570 |
+
struct llama_sampler_logit_bias {
|
| 1571 |
+
const int32_t n_vocab;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1572 |
|
| 1573 |
+
const std::vector<llama_logit_bias> logit_bias;
|
|
|
|
| 1574 |
|
| 1575 |
+
std::vector<llama_logit_bias> to_search;
|
| 1576 |
+
};
|
| 1577 |
+
|
| 1578 |
+
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
|
| 1579 |
+
return "logit-bias";
|
| 1580 |
}
|
| 1581 |
|
| 1582 |
+
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1583 |
+
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
| 1584 |
+
|
| 1585 |
+
if (ctx->logit_bias.empty()) {
|
| 1586 |
+
return;
|
| 1587 |
+
}
|
| 1588 |
|
| 1589 |
+
ctx->to_search.clear();
|
| 1590 |
|
| 1591 |
+
// update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
|
| 1592 |
+
for (const auto & lb : ctx->logit_bias) {
|
| 1593 |
+
if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
|
| 1594 |
+
cur_p->data[lb.token].logit += lb.bias;
|
| 1595 |
+
} else {
|
| 1596 |
+
ctx->to_search.push_back(lb);
|
| 1597 |
+
}
|
| 1598 |
+
}
|
| 1599 |
|
| 1600 |
+
if (ctx->to_search.empty()) {
|
| 1601 |
+
return;
|
| 1602 |
}
|
| 1603 |
|
| 1604 |
+
// search for the remaining candidates that were not found in the previous step
|
| 1605 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1606 |
+
for (const auto & lb : ctx->to_search) {
|
| 1607 |
+
if (cur_p->data[i].id == lb.token) {
|
| 1608 |
+
cur_p->data[i].logit += lb.bias;
|
| 1609 |
+
break;
|
| 1610 |
+
}
|
| 1611 |
+
}
|
| 1612 |
}
|
| 1613 |
+
}
|
| 1614 |
|
| 1615 |
+
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
|
| 1616 |
+
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
|
| 1617 |
+
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
|
| 1618 |
+
}
|
| 1619 |
|
| 1620 |
+
static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
| 1621 |
+
delete (llama_sampler_logit_bias *) smpl->ctx;
|
| 1622 |
+
}
|
| 1623 |
|
| 1624 |
+
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
| 1625 |
+
/* .name = */ llama_sampler_logit_bias_name,
|
| 1626 |
+
/* .accept = */ nullptr,
|
| 1627 |
+
/* .apply = */ llama_sampler_logit_bias_apply,
|
| 1628 |
+
/* .reset = */ nullptr,
|
| 1629 |
+
/* .clone = */ llama_sampler_logit_bias_clone,
|
| 1630 |
+
/* .free = */ llama_sampler_logit_bias_free,
|
| 1631 |
+
};
|
| 1632 |
+
|
| 1633 |
+
struct llama_sampler * llama_sampler_init_logit_bias(
|
| 1634 |
+
int32_t n_vocab,
|
| 1635 |
+
int32_t n_logit_bias,
|
| 1636 |
+
const llama_logit_bias * logit_bias) {
|
| 1637 |
+
return new llama_sampler {
|
| 1638 |
+
/* .iface = */ &llama_sampler_logit_bias_i,
|
| 1639 |
+
/* .ctx = */ new llama_sampler_logit_bias {
|
| 1640 |
+
/* .n_vocab = */ n_vocab,
|
| 1641 |
+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
| 1642 |
+
/* .to_search = */ {},
|
| 1643 |
+
},
|
| 1644 |
+
};
|
| 1645 |
+
}
|
| 1646 |
|
| 1647 |
+
// utils
|
|
|
|
| 1648 |
|
| 1649 |
+
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
| 1650 |
+
if (smpl->iface == &llama_sampler_dist_i) {
|
| 1651 |
+
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
|
| 1652 |
}
|
|
|
|
|
|
|
| 1653 |
|
| 1654 |
+
if (smpl->iface == &llama_sampler_mirostat_i) {
|
| 1655 |
+
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
|
| 1656 |
+
}
|
| 1657 |
|
| 1658 |
+
if (smpl->iface == &llama_sampler_mirostat_v2_i) {
|
| 1659 |
+
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
|
| 1660 |
+
}
|
|
|
|
| 1661 |
|
| 1662 |
+
if (smpl->iface == &llama_sampler_chain_i) {
|
| 1663 |
+
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
| 1664 |
+
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
| 1665 |
+
const uint32_t seed = llama_sampler_get_seed(*it);
|
| 1666 |
+
if (seed != LLAMA_DEFAULT_SEED) {
|
| 1667 |
+
return seed;
|
| 1668 |
+
}
|
| 1669 |
+
}
|
| 1670 |
}
|
| 1671 |
+
|
| 1672 |
+
return LLAMA_DEFAULT_SEED;
|
| 1673 |
}
|
| 1674 |
|
| 1675 |
+
// perf
|
|
|
|
| 1676 |
|
| 1677 |
+
struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
|
| 1678 |
+
struct llama_perf_sampler_data data = {};
|
| 1679 |
|
| 1680 |
+
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
| 1681 |
+
GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
|
|
|
|
|
|
| 1682 |
}
|
| 1683 |
|
| 1684 |
+
const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
|
|
|
|
| 1685 |
|
| 1686 |
+
data.t_sample_ms = 1e-3 * ctx->t_sample_us;
|
| 1687 |
+
data.n_sample = std::max(0, ctx->n_sample);
|
| 1688 |
|
| 1689 |
+
return data;
|
| 1690 |
+
}
|
| 1691 |
|
| 1692 |
+
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
| 1693 |
+
const auto data = llama_perf_sampler(chain);
|
| 1694 |
+
|
| 1695 |
+
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
| 1696 |
+
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
| 1697 |
}
|
| 1698 |
|
| 1699 |
+
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
| 1700 |
+
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
| 1701 |
+
GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
| 1702 |
+
}
|
| 1703 |
+
|
| 1704 |
+
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
| 1705 |
+
|
| 1706 |
+
ctx->t_sample_us = ctx->n_sample = 0;
|
| 1707 |
}
|
examples/talk-llama/llama-sampling.h
CHANGED
|
@@ -1,56 +1,29 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
| 6 |
-
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
mutable int32_t n_sample = 0;
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
|
| 26 |
-
|
| 27 |
-
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
| 28 |
-
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
| 29 |
-
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
| 30 |
-
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
| 31 |
-
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
| 32 |
-
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
| 33 |
-
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
| 34 |
-
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
|
| 35 |
-
|
| 36 |
-
void llama_sample_repetition_penalties_impl(
|
| 37 |
-
struct llama_sampling * smpl,
|
| 38 |
-
llama_token_data_array * candidates,
|
| 39 |
-
const llama_token * last_tokens,
|
| 40 |
-
size_t penalty_last_n,
|
| 41 |
-
float penalty_repeat,
|
| 42 |
-
float penalty_freq,
|
| 43 |
-
float penalty_present);
|
| 44 |
-
|
| 45 |
-
void llama_sample_apply_guidance_impl(
|
| 46 |
-
struct llama_sampling * smpl,
|
| 47 |
-
float * logits,
|
| 48 |
-
float * logits_guidance,
|
| 49 |
-
float scale);
|
| 50 |
-
|
| 51 |
-
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
| 52 |
-
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
| 53 |
-
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
| 54 |
-
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
| 55 |
-
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
+
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
|
| 4 |
|
| 5 |
+
#include "llama-grammar.h"
|
|
|
|
| 6 |
|
| 7 |
+
#include <unordered_map>
|
| 8 |
|
| 9 |
+
struct llama_vocab;
|
| 10 |
+
struct llama_grammar;
|
| 11 |
|
| 12 |
+
// sampler chain
|
|
|
|
| 13 |
|
| 14 |
+
struct llama_sampler_chain {
|
| 15 |
+
llama_sampler_chain_params params;
|
| 16 |
+
|
| 17 |
+
std::vector<struct llama_sampler *> samplers;
|
| 18 |
+
|
| 19 |
+
// timing
|
| 20 |
|
| 21 |
+
mutable int64_t t_sample_us;
|
| 22 |
+
|
| 23 |
+
mutable int32_t n_sample;
|
| 24 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
struct llama_sampler * llama_sampler_init_grammar_impl(
|
| 27 |
+
const struct llama_vocab & vocab,
|
| 28 |
+
const char * grammar_str,
|
| 29 |
+
const char * grammar_root);
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -58,17 +58,17 @@ struct naive_trie {
|
|
| 58 |
auto res = children.find(c);
|
| 59 |
if (res != children.end()) {
|
| 60 |
return res->second.get_longest_prefix(key, len, offset + 1);
|
| 61 |
-
} else {
|
| 62 |
-
return std::make_pair(key, offset);
|
| 63 |
}
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
-
struct naive_trie * traverse(const char c) {
|
| 66 |
auto res = children.find(c);
|
| 67 |
if (res != children.end()) {
|
| 68 |
return &res->second;
|
| 69 |
-
} else {
|
| 70 |
-
return NULL;
|
| 71 |
}
|
|
|
|
|
|
|
| 72 |
}
|
| 73 |
std::map<char, struct naive_trie> children;
|
| 74 |
bool has_value;
|
|
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
|
|
| 843 |
// traverse the token matcher trie to find a matching token
|
| 844 |
bool single_codepoint_token_found = false;
|
| 845 |
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
| 846 |
-
struct naive_trie * node
|
| 847 |
|
| 848 |
while (prefix_offset <= input_len && node != NULL) {
|
| 849 |
// check if we found valid token in prefix
|
|
@@ -963,7 +963,7 @@ private:
|
|
| 963 |
/*
|
| 964 |
* This structure is a view wrapper for XOR-compressed double array (XCDA)
|
| 965 |
* See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
|
| 966 |
-
*
|
| 967 |
* - BASE array value in bits 10-30
|
| 968 |
* - LCHECK array value in bits 0-7
|
| 969 |
* - LEAF array value in bit 9
|
|
@@ -1097,6 +1097,111 @@ private:
|
|
| 1097 |
struct naive_trie token_matcher;
|
| 1098 |
};
|
| 1099 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
//
|
| 1101 |
// (de-) tokenize
|
| 1102 |
//
|
|
@@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|
| 1401 |
output.push_back(vocab.special_eos_id);
|
| 1402 |
}
|
| 1403 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1404 |
case LLAMA_VOCAB_TYPE_NONE:
|
| 1405 |
GGML_ABORT("fatal error");
|
| 1406 |
}
|
|
@@ -1448,11 +1570,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
|
|
| 1448 |
}
|
| 1449 |
|
| 1450 |
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
| 1451 |
-
return token != -1 && (
|
| 1452 |
-
token == llama_token_eos_impl(vocab) ||
|
| 1453 |
-
token == llama_token_eot_impl(vocab) ||
|
| 1454 |
-
token == llama_token_eom_impl(vocab)
|
| 1455 |
-
);
|
| 1456 |
}
|
| 1457 |
|
| 1458 |
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
|
@@ -1616,6 +1734,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|
| 1616 |
}
|
| 1617 |
break;
|
| 1618 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1619 |
default:
|
| 1620 |
GGML_ABORT("fatal error");
|
| 1621 |
}
|
|
|
|
| 58 |
auto res = children.find(c);
|
| 59 |
if (res != children.end()) {
|
| 60 |
return res->second.get_longest_prefix(key, len, offset + 1);
|
|
|
|
|
|
|
| 61 |
}
|
| 62 |
+
|
| 63 |
+
return std::make_pair(key, offset);
|
| 64 |
}
|
| 65 |
+
const struct naive_trie * traverse(const char c) const {
|
| 66 |
auto res = children.find(c);
|
| 67 |
if (res != children.end()) {
|
| 68 |
return &res->second;
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
+
|
| 71 |
+
return NULL;
|
| 72 |
}
|
| 73 |
std::map<char, struct naive_trie> children;
|
| 74 |
bool has_value;
|
|
|
|
| 843 |
// traverse the token matcher trie to find a matching token
|
| 844 |
bool single_codepoint_token_found = false;
|
| 845 |
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
| 846 |
+
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
| 847 |
|
| 848 |
while (prefix_offset <= input_len && node != NULL) {
|
| 849 |
// check if we found valid token in prefix
|
|
|
|
| 963 |
/*
|
| 964 |
* This structure is a view wrapper for XOR-compressed double array (XCDA)
|
| 965 |
* See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
|
| 966 |
+
* Each bit-packed entry contains:
|
| 967 |
* - BASE array value in bits 10-30
|
| 968 |
* - LCHECK array value in bits 0-7
|
| 969 |
* - LEAF array value in bit 9
|
|
|
|
| 1097 |
struct naive_trie token_matcher;
|
| 1098 |
};
|
| 1099 |
|
| 1100 |
+
//
|
| 1101 |
+
// RWKV tokenizer
|
| 1102 |
+
//
|
| 1103 |
+
|
| 1104 |
+
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
|
| 1105 |
+
std::vector<uint8_t> output;
|
| 1106 |
+
output.reserve(escaped.size());
|
| 1107 |
+
|
| 1108 |
+
// Parser state
|
| 1109 |
+
bool escaping = false;
|
| 1110 |
+
uint8_t hex_remaining = 0;
|
| 1111 |
+
uint8_t hex_acc = 0;
|
| 1112 |
+
|
| 1113 |
+
// Step through characters, performing parsing
|
| 1114 |
+
for (const char & c : escaped) {
|
| 1115 |
+
// If we're parsing a hex code, interpret the next character
|
| 1116 |
+
if (hex_remaining != 0) {
|
| 1117 |
+
uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
|
| 1118 |
+
hex_acc = (hex_acc << 4) + value;
|
| 1119 |
+
|
| 1120 |
+
hex_remaining -= 1;
|
| 1121 |
+
if (hex_remaining == 0) {
|
| 1122 |
+
output.push_back(hex_acc);
|
| 1123 |
+
hex_acc = 0;
|
| 1124 |
+
}
|
| 1125 |
+
|
| 1126 |
+
continue;
|
| 1127 |
+
}
|
| 1128 |
+
|
| 1129 |
+
// If we got an escape character, interpret it
|
| 1130 |
+
if (escaping) {
|
| 1131 |
+
if (c == 't') {
|
| 1132 |
+
output.push_back('\t');
|
| 1133 |
+
} else if (c == 'n') {
|
| 1134 |
+
output.push_back('\n');
|
| 1135 |
+
} else if (c == 'r') {
|
| 1136 |
+
output.push_back('\r');
|
| 1137 |
+
} else if (c == 'x') {
|
| 1138 |
+
hex_remaining = 2;
|
| 1139 |
+
} else {
|
| 1140 |
+
output.push_back(c);
|
| 1141 |
+
}
|
| 1142 |
+
|
| 1143 |
+
escaping = false;
|
| 1144 |
+
continue;
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
if (c == '\\') {
|
| 1148 |
+
escaping = true;
|
| 1149 |
+
continue;
|
| 1150 |
+
}
|
| 1151 |
+
|
| 1152 |
+
output.push_back(c);
|
| 1153 |
+
}
|
| 1154 |
+
|
| 1155 |
+
return output;
|
| 1156 |
+
}
|
| 1157 |
+
|
| 1158 |
+
struct llm_tokenizer_rwkv {
|
| 1159 |
+
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
| 1160 |
+
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
| 1161 |
+
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
| 1162 |
+
|
| 1163 |
+
// build trie
|
| 1164 |
+
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
|
| 1165 |
+
const auto & token = vocab.id_to_token[id];
|
| 1166 |
+
const auto data = llama_unescape_rwkv_token(token.text);
|
| 1167 |
+
token_matcher.insert((const char *) data.data(), data.size(), id);
|
| 1168 |
+
}
|
| 1169 |
+
}
|
| 1170 |
+
|
| 1171 |
+
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
| 1172 |
+
uint32_t position = 0;
|
| 1173 |
+
|
| 1174 |
+
while (position < text.size()) {
|
| 1175 |
+
const struct naive_trie * node = token_matcher.traverse(text[position]);
|
| 1176 |
+
if (node == NULL) {
|
| 1177 |
+
// no matching token found, add unknown token
|
| 1178 |
+
output.push_back(vocab.special_unk_id);
|
| 1179 |
+
position += 1;
|
| 1180 |
+
continue;
|
| 1181 |
+
}
|
| 1182 |
+
|
| 1183 |
+
// traverse the trie to find the longest matching token
|
| 1184 |
+
uint32_t token_id = 0;
|
| 1185 |
+
uint32_t token_length = 0;
|
| 1186 |
+
while (node != NULL) {
|
| 1187 |
+
if (node->has_value) {
|
| 1188 |
+
token_id = node->value;
|
| 1189 |
+
token_length = position + 1;
|
| 1190 |
+
}
|
| 1191 |
+
node = node->traverse(text[++position]);
|
| 1192 |
+
}
|
| 1193 |
+
|
| 1194 |
+
// add the longest matching token
|
| 1195 |
+
output.push_back(token_id);
|
| 1196 |
+
position = token_length;
|
| 1197 |
+
}
|
| 1198 |
+
}
|
| 1199 |
+
|
| 1200 |
+
const llama_vocab & vocab;
|
| 1201 |
+
|
| 1202 |
+
struct naive_trie token_matcher;
|
| 1203 |
+
};
|
| 1204 |
+
|
| 1205 |
//
|
| 1206 |
// (de-) tokenize
|
| 1207 |
//
|
|
|
|
| 1506 |
output.push_back(vocab.special_eos_id);
|
| 1507 |
}
|
| 1508 |
} break;
|
| 1509 |
+
case LLAMA_VOCAB_TYPE_RWKV:
|
| 1510 |
+
{
|
| 1511 |
+
for (const auto & fragment : fragment_buffer) {
|
| 1512 |
+
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
| 1513 |
+
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
| 1514 |
+
|
| 1515 |
+
#ifdef PRETOKENIZERDEBUG
|
| 1516 |
+
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
| 1517 |
+
#endif
|
| 1518 |
+
|
| 1519 |
+
llm_tokenizer_rwkv tokenizer(vocab);
|
| 1520 |
+
tokenizer.tokenize(raw_text, output);
|
| 1521 |
+
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
| 1522 |
+
output.push_back(fragment.token);
|
| 1523 |
+
}
|
| 1524 |
+
}
|
| 1525 |
+
} break;
|
| 1526 |
case LLAMA_VOCAB_TYPE_NONE:
|
| 1527 |
GGML_ABORT("fatal error");
|
| 1528 |
}
|
|
|
|
| 1570 |
}
|
| 1571 |
|
| 1572 |
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
| 1573 |
+
return token != -1 && vocab.special_eog_ids.count(token) > 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1574 |
}
|
| 1575 |
|
| 1576 |
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
|
|
|
| 1734 |
}
|
| 1735 |
break;
|
| 1736 |
}
|
| 1737 |
+
case LLAMA_VOCAB_TYPE_RWKV: {
|
| 1738 |
+
std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
|
| 1739 |
+
|
| 1740 |
+
// If we don't have enough space, return an error
|
| 1741 |
+
if (result.size() > (size_t)length) {
|
| 1742 |
+
return -(int)result.size();
|
| 1743 |
+
}
|
| 1744 |
+
|
| 1745 |
+
memcpy(buf, result.data(), result.size());
|
| 1746 |
+
return (int)result.size();
|
| 1747 |
+
}
|
| 1748 |
default:
|
| 1749 |
GGML_ABORT("fatal error");
|
| 1750 |
}
|
examples/talk-llama/llama-vocab.h
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
#include <vector>
|
| 7 |
#include <unordered_map>
|
| 8 |
#include <map>
|
|
|
|
| 9 |
|
| 10 |
struct llama_vocab {
|
| 11 |
using id = llama_token;
|
|
@@ -18,6 +19,8 @@ struct llama_vocab {
|
|
| 18 |
tattr attr;
|
| 19 |
};
|
| 20 |
|
|
|
|
|
|
|
| 21 |
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
| 22 |
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
| 23 |
|
|
@@ -47,12 +50,15 @@ struct llama_vocab {
|
|
| 47 |
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
| 48 |
id special_eom_id = -1;
|
| 49 |
|
|
|
|
|
|
|
|
|
|
| 50 |
// tokenizer flags
|
| 51 |
-
bool tokenizer_add_space_prefix
|
| 52 |
-
bool tokenizer_add_bos
|
| 53 |
-
bool tokenizer_add_eos
|
| 54 |
-
bool tokenizer_ignore_merges
|
| 55 |
-
bool tokenizer_clean_spaces
|
| 56 |
bool tokenizer_remove_extra_whitespaces = false;
|
| 57 |
bool tokenizer_escape_whitespaces = true;
|
| 58 |
bool tokenizer_treat_whitespace_as_suffix = false;
|
|
@@ -62,8 +68,6 @@ struct llama_vocab {
|
|
| 62 |
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
| 63 |
};
|
| 64 |
|
| 65 |
-
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
|
| 66 |
-
|
| 67 |
//
|
| 68 |
// internal API
|
| 69 |
//
|
|
@@ -76,6 +80,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|
| 76 |
bool add_special,
|
| 77 |
bool parse_special = false);
|
| 78 |
|
|
|
|
| 79 |
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
| 80 |
|
| 81 |
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
|
|
|
| 6 |
#include <vector>
|
| 7 |
#include <unordered_map>
|
| 8 |
#include <map>
|
| 9 |
+
#include <set>
|
| 10 |
|
| 11 |
struct llama_vocab {
|
| 12 |
using id = llama_token;
|
|
|
|
| 19 |
tattr attr;
|
| 20 |
};
|
| 21 |
|
| 22 |
+
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
|
| 23 |
+
|
| 24 |
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
| 25 |
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
| 26 |
|
|
|
|
| 50 |
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
| 51 |
id special_eom_id = -1;
|
| 52 |
|
| 53 |
+
// set of all tokens that cause "end of generation"
|
| 54 |
+
std::set<id> special_eog_ids;
|
| 55 |
+
|
| 56 |
// tokenizer flags
|
| 57 |
+
bool tokenizer_add_space_prefix = false;
|
| 58 |
+
bool tokenizer_add_bos = false;
|
| 59 |
+
bool tokenizer_add_eos = false;
|
| 60 |
+
bool tokenizer_ignore_merges = false;
|
| 61 |
+
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
|
| 62 |
bool tokenizer_remove_extra_whitespaces = false;
|
| 63 |
bool tokenizer_escape_whitespaces = true;
|
| 64 |
bool tokenizer_treat_whitespace_as_suffix = false;
|
|
|
|
| 68 |
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
| 69 |
};
|
| 70 |
|
|
|
|
|
|
|
| 71 |
//
|
| 72 |
// internal API
|
| 73 |
//
|
|
|
|
| 80 |
bool add_special,
|
| 81 |
bool parse_special = false);
|
| 82 |
|
| 83 |
+
// TODO: move the API below as member functions of llama_vocab
|
| 84 |
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
| 85 |
|
| 86 |
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
examples/talk-llama/llama.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama.h
CHANGED
|
@@ -33,12 +33,15 @@
|
|
| 33 |
|
| 34 |
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
| 37 |
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
| 38 |
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
| 39 |
|
| 40 |
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
| 41 |
-
#define LLAMA_SESSION_VERSION
|
| 42 |
|
| 43 |
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
| 44 |
#define LLAMA_STATE_SEQ_VERSION 2
|
|
@@ -53,8 +56,10 @@ extern "C" {
|
|
| 53 |
// TODO: show sample usage
|
| 54 |
//
|
| 55 |
|
|
|
|
| 56 |
struct llama_model;
|
| 57 |
struct llama_context;
|
|
|
|
| 58 |
|
| 59 |
typedef int32_t llama_pos;
|
| 60 |
typedef int32_t llama_token;
|
|
@@ -66,6 +71,7 @@ extern "C" {
|
|
| 66 |
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
| 67 |
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
| 68 |
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
|
|
|
| 69 |
};
|
| 70 |
|
| 71 |
// pre-tokenization types
|
|
@@ -166,6 +172,8 @@ extern "C" {
|
|
| 166 |
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
|
| 167 |
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
|
| 168 |
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
|
|
|
|
|
|
|
| 169 |
|
| 170 |
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
| 171 |
};
|
|
@@ -198,6 +206,7 @@ extern "C" {
|
|
| 198 |
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
| 199 |
};
|
| 200 |
|
|
|
|
| 201 |
typedef struct llama_token_data {
|
| 202 |
llama_token id; // token id
|
| 203 |
float logit; // log-odds of the token
|
|
@@ -205,8 +214,10 @@ extern "C" {
|
|
| 205 |
} llama_token_data;
|
| 206 |
|
| 207 |
typedef struct llama_token_data_array {
|
|
|
|
| 208 |
llama_token_data * data;
|
| 209 |
size_t size;
|
|
|
|
| 210 |
bool sorted;
|
| 211 |
} llama_token_data_array;
|
| 212 |
|
|
@@ -267,9 +278,9 @@ extern "C" {
|
|
| 267 |
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
| 268 |
|
| 269 |
// main_gpu interpretation depends on split_mode:
|
| 270 |
-
//
|
| 271 |
-
//
|
| 272 |
-
//
|
| 273 |
int32_t main_gpu;
|
| 274 |
|
| 275 |
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
|
@@ -299,13 +310,12 @@ extern "C" {
|
|
| 299 |
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
| 300 |
// https://github.com/ggerganov/llama.cpp/pull/7544
|
| 301 |
struct llama_context_params {
|
| 302 |
-
uint32_t seed; // RNG seed, -1 for random
|
| 303 |
uint32_t n_ctx; // text context, 0 = from model
|
| 304 |
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
| 305 |
uint32_t n_ubatch; // physical maximum batch size
|
| 306 |
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
|
| 307 |
-
|
| 308 |
-
|
| 309 |
|
| 310 |
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
| 311 |
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
|
@@ -327,11 +337,13 @@ extern "C" {
|
|
| 327 |
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
| 328 |
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
| 329 |
|
| 330 |
-
// Keep the booleans together to avoid misalignment during copy-by-value.
|
|
|
|
| 331 |
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
| 332 |
bool embeddings; // if true, extract embeddings (together with logits)
|
| 333 |
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
| 334 |
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
|
|
|
| 335 |
|
| 336 |
// Abort callback
|
| 337 |
// if it returns true, execution of llama_decode() will be aborted
|
|
@@ -355,56 +367,14 @@ extern "C" {
|
|
| 355 |
void * kv_overrides; // pointer to vector containing overrides
|
| 356 |
} llama_model_quantize_params;
|
| 357 |
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
enum llama_gretype {
|
| 363 |
-
// end of rule definition
|
| 364 |
-
LLAMA_GRETYPE_END = 0,
|
| 365 |
-
|
| 366 |
-
// start of alternate definition for rule
|
| 367 |
-
LLAMA_GRETYPE_ALT = 1,
|
| 368 |
-
|
| 369 |
-
// non-terminal element: reference to rule
|
| 370 |
-
LLAMA_GRETYPE_RULE_REF = 2,
|
| 371 |
-
|
| 372 |
-
// terminal element: character (code point)
|
| 373 |
-
LLAMA_GRETYPE_CHAR = 3,
|
| 374 |
-
|
| 375 |
-
// inverse char(s) ([^a], [^a-b] [^abc])
|
| 376 |
-
LLAMA_GRETYPE_CHAR_NOT = 4,
|
| 377 |
-
|
| 378 |
-
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
| 379 |
-
// be an inclusive range ([a-z])
|
| 380 |
-
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
| 381 |
-
|
| 382 |
-
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
| 383 |
-
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
| 384 |
-
LLAMA_GRETYPE_CHAR_ALT = 6,
|
| 385 |
-
|
| 386 |
-
// any character (.)
|
| 387 |
-
LLAMA_GRETYPE_CHAR_ANY = 7,
|
| 388 |
-
};
|
| 389 |
-
|
| 390 |
-
typedef struct llama_grammar_element {
|
| 391 |
-
enum llama_gretype type;
|
| 392 |
-
uint32_t value; // Unicode code point or rule ID
|
| 393 |
-
} llama_grammar_element;
|
| 394 |
-
|
| 395 |
-
// performance timing information
|
| 396 |
-
struct llama_timings {
|
| 397 |
-
double t_start_ms;
|
| 398 |
-
double t_end_ms;
|
| 399 |
-
double t_load_ms;
|
| 400 |
-
double t_sample_ms;
|
| 401 |
-
double t_p_eval_ms;
|
| 402 |
-
double t_eval_ms;
|
| 403 |
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
};
|
| 408 |
|
| 409 |
// used in chat template
|
| 410 |
typedef struct llama_chat_message {
|
|
@@ -416,8 +386,10 @@ extern "C" {
|
|
| 416 |
struct llama_lora_adapter;
|
| 417 |
|
| 418 |
// Helpers for getting default parameters
|
| 419 |
-
|
| 420 |
-
LLAMA_API struct
|
|
|
|
|
|
|
| 421 |
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
| 422 |
|
| 423 |
// Initialize the llama + ggml backend
|
|
@@ -428,15 +400,23 @@ extern "C" {
|
|
| 428 |
//optional:
|
| 429 |
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
// Call once at the end of the program - currently only used for MPI
|
| 432 |
LLAMA_API void llama_backend_free(void);
|
| 433 |
|
| 434 |
LLAMA_API struct llama_model * llama_load_model_from_file(
|
| 435 |
const char * path_model,
|
| 436 |
-
|
| 437 |
|
| 438 |
LLAMA_API void llama_free_model(struct llama_model * model);
|
| 439 |
|
|
|
|
| 440 |
LLAMA_API struct llama_context * llama_new_context_with_model(
|
| 441 |
struct llama_model * model,
|
| 442 |
struct llama_context_params params);
|
|
@@ -452,22 +432,22 @@ extern "C" {
|
|
| 452 |
LLAMA_API bool llama_supports_mlock (void);
|
| 453 |
LLAMA_API bool llama_supports_gpu_offload(void);
|
| 454 |
|
| 455 |
-
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
| 456 |
-
|
| 457 |
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
| 458 |
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
| 459 |
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
| 460 |
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
| 461 |
|
| 462 |
-
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
| 463 |
-
|
| 464 |
-
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
| 465 |
-
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
| 466 |
-
|
| 467 |
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
| 468 |
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
| 469 |
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
| 470 |
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
// Get the model's RoPE frequency scaling factor
|
| 473 |
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
|
@@ -696,7 +676,7 @@ extern "C" {
|
|
| 696 |
//
|
| 697 |
|
| 698 |
// Returns the *actual* size in bytes of the state
|
| 699 |
-
// (
|
| 700 |
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
| 701 |
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
| 702 |
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
|
@@ -837,13 +817,13 @@ extern "C" {
|
|
| 837 |
// Set the number of threads used for decoding
|
| 838 |
// n_threads is the number of threads used for generation (single token)
|
| 839 |
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
| 840 |
-
LLAMA_API void llama_set_n_threads(struct llama_context * ctx,
|
| 841 |
|
| 842 |
// Get the number of threads used for generation of a single token.
|
| 843 |
-
LLAMA_API
|
| 844 |
|
| 845 |
// Get the number of threads used for prompt and batch processing (multiple token).
|
| 846 |
-
LLAMA_API
|
| 847 |
|
| 848 |
// Set whether the model is in embeddings mode or not
|
| 849 |
// If true, embeddings will be returned but logits will not
|
|
@@ -999,121 +979,114 @@ extern "C" {
|
|
| 999 |
int32_t length);
|
| 1000 |
|
| 1001 |
//
|
| 1002 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1003 |
//
|
| 1004 |
|
| 1005 |
-
|
| 1006 |
-
///
|
| 1007 |
-
/// @param rules The rule elements of the grammar to initialize.
|
| 1008 |
-
/// @param n_rules The number of rules.
|
| 1009 |
-
/// @param start_rule_index The index of the root rule (the starting point of the grammar).
|
| 1010 |
-
/// @return The initialized llama_grammar or nullptr if initialization failed.
|
| 1011 |
-
LLAMA_API struct llama_grammar * llama_grammar_init(
|
| 1012 |
-
const llama_grammar_element ** rules,
|
| 1013 |
-
size_t n_rules,
|
| 1014 |
-
size_t start_rule_index);
|
| 1015 |
-
|
| 1016 |
-
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
| 1017 |
-
|
| 1018 |
-
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
| 1019 |
-
|
| 1020 |
-
/// @details Apply constraints from grammar
|
| 1021 |
-
LLAMA_API void llama_grammar_sample(
|
| 1022 |
-
const struct llama_grammar * grammar,
|
| 1023 |
-
const struct llama_context * ctx,
|
| 1024 |
-
llama_token_data_array * candidates);
|
| 1025 |
-
LLAMA_API DEPRECATED(void llama_sample_grammar(
|
| 1026 |
-
struct llama_context * ctx,
|
| 1027 |
-
llama_token_data_array * candidates,
|
| 1028 |
-
const struct llama_grammar * grammar),
|
| 1029 |
-
"use llama_grammar_sample instead");
|
| 1030 |
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
|
|
|
|
|
|
|
|
|
|
| 1036 |
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
|
| 1041 |
-
//
|
| 1042 |
-
LLAMA_API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1043 |
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
LLAMA_API
|
| 1060 |
-
|
| 1061 |
-
float * logits,
|
| 1062 |
-
float * logits_guidance,
|
| 1063 |
-
float scale);
|
| 1064 |
|
| 1065 |
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
llama_token_data_array * candidates);
|
| 1069 |
|
| 1070 |
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1071 |
-
LLAMA_API
|
| 1072 |
-
struct llama_context * ctx,
|
| 1073 |
-
llama_token_data_array * candidates,
|
| 1074 |
-
int32_t k,
|
| 1075 |
-
size_t min_keep);
|
| 1076 |
|
| 1077 |
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1078 |
-
LLAMA_API
|
| 1079 |
-
struct llama_context * ctx,
|
| 1080 |
-
llama_token_data_array * candidates,
|
| 1081 |
-
float p,
|
| 1082 |
-
size_t min_keep);
|
| 1083 |
|
| 1084 |
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
| 1085 |
-
LLAMA_API
|
| 1086 |
-
struct llama_context * ctx,
|
| 1087 |
-
llama_token_data_array * candidates,
|
| 1088 |
-
float p,
|
| 1089 |
-
size_t min_keep);
|
| 1090 |
|
| 1091 |
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
| 1092 |
-
LLAMA_API
|
| 1093 |
-
struct llama_context * ctx,
|
| 1094 |
-
llama_token_data_array * candidates,
|
| 1095 |
-
float z,
|
| 1096 |
-
size_t min_keep);
|
| 1097 |
|
| 1098 |
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
| 1099 |
-
LLAMA_API
|
| 1100 |
-
|
| 1101 |
-
llama_token_data_array * candidates,
|
| 1102 |
-
float p,
|
| 1103 |
-
size_t min_keep);
|
| 1104 |
|
| 1105 |
-
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
|
| 1106 |
-
LLAMA_API
|
| 1107 |
-
struct llama_context * ctx,
|
| 1108 |
-
llama_token_data_array * candidates_p,
|
| 1109 |
-
float min_temp,
|
| 1110 |
-
float max_temp,
|
| 1111 |
-
float exponent_val);
|
| 1112 |
-
|
| 1113 |
-
LLAMA_API void llama_sample_temp(
|
| 1114 |
-
struct llama_context * ctx,
|
| 1115 |
-
llama_token_data_array * candidates,
|
| 1116 |
-
float temp);
|
| 1117 |
|
| 1118 |
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1119 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
|
@@ -1121,36 +1094,62 @@ extern "C" {
|
|
| 1121 |
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
| 1122 |
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
| 1123 |
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
| 1124 |
-
LLAMA_API
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
float * mu);
|
| 1131 |
|
| 1132 |
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1133 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
| 1134 |
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
| 1135 |
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
| 1136 |
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
| 1137 |
-
LLAMA_API
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1154 |
|
| 1155 |
//
|
| 1156 |
// Model split
|
|
@@ -1166,12 +1165,6 @@ extern "C" {
|
|
| 1166 |
// Returns the split_prefix length.
|
| 1167 |
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
|
| 1168 |
|
| 1169 |
-
// Performance information
|
| 1170 |
-
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
| 1171 |
-
|
| 1172 |
-
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
| 1173 |
-
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
| 1174 |
-
|
| 1175 |
// Print system information
|
| 1176 |
LLAMA_API const char * llama_print_system_info(void);
|
| 1177 |
|
|
@@ -1179,65 +1172,41 @@ extern "C" {
|
|
| 1179 |
// If this is not called, or NULL is supplied, everything is output on stderr.
|
| 1180 |
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
| 1181 |
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
| 1189 |
-
#ifdef LLAMA_API_INTERNAL
|
| 1190 |
-
|
| 1191 |
-
#include <random>
|
| 1192 |
-
#include <string>
|
| 1193 |
-
#include <vector>
|
| 1194 |
-
|
| 1195 |
-
struct ggml_tensor;
|
| 1196 |
-
|
| 1197 |
-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
| 1198 |
-
struct llama_context * ctx
|
| 1199 |
-
);
|
| 1200 |
-
|
| 1201 |
-
struct llama_partial_utf8 {
|
| 1202 |
-
uint32_t value; // bit value so far (unshifted)
|
| 1203 |
-
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
| 1204 |
-
};
|
| 1205 |
-
|
| 1206 |
-
struct llama_grammar_candidate {
|
| 1207 |
-
size_t index;
|
| 1208 |
-
const uint32_t * code_points;
|
| 1209 |
-
llama_partial_utf8 partial_utf8;
|
| 1210 |
-
};
|
| 1211 |
|
| 1212 |
-
|
| 1213 |
-
|
|
|
|
|
|
|
|
|
|
| 1214 |
|
| 1215 |
-
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
|
| 1219 |
-
|
| 1220 |
-
|
| 1221 |
|
| 1222 |
-
|
| 1223 |
-
|
| 1224 |
-
const llama_grammar_stacks & stacks,
|
| 1225 |
-
const uint32_t chr,
|
| 1226 |
-
llama_grammar_stacks & new_stacks);
|
| 1227 |
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
-
const llama_grammar_candidates & candidates);
|
| 1232 |
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
-
|
|
|
|
| 1236 |
|
| 1237 |
-
|
| 1238 |
-
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
| 1239 |
-
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
| 1240 |
|
| 1241 |
-
#
|
|
|
|
|
|
|
| 1242 |
|
| 1243 |
#endif // LLAMA_H
|
|
|
|
| 33 |
|
| 34 |
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
| 35 |
|
| 36 |
+
// TODO: use everywhere in the implementation
|
| 37 |
+
#define LLAMA_TOKEN_NULL -1
|
| 38 |
+
|
| 39 |
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
| 40 |
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
| 41 |
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
| 42 |
|
| 43 |
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
| 44 |
+
#define LLAMA_SESSION_VERSION 9
|
| 45 |
|
| 46 |
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
| 47 |
#define LLAMA_STATE_SEQ_VERSION 2
|
|
|
|
| 56 |
// TODO: show sample usage
|
| 57 |
//
|
| 58 |
|
| 59 |
+
// struct llama_vocab; // TODO: add in the future
|
| 60 |
struct llama_model;
|
| 61 |
struct llama_context;
|
| 62 |
+
struct llama_sampler;
|
| 63 |
|
| 64 |
typedef int32_t llama_pos;
|
| 65 |
typedef int32_t llama_token;
|
|
|
|
| 71 |
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
| 72 |
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
| 73 |
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
| 74 |
+
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
| 75 |
};
|
| 76 |
|
| 77 |
// pre-tokenization types
|
|
|
|
| 172 |
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
|
| 173 |
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
|
| 174 |
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
|
| 175 |
+
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
|
| 176 |
+
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
|
| 177 |
|
| 178 |
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
| 179 |
};
|
|
|
|
| 206 |
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
| 207 |
};
|
| 208 |
|
| 209 |
+
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
| 210 |
typedef struct llama_token_data {
|
| 211 |
llama_token id; // token id
|
| 212 |
float logit; // log-odds of the token
|
|
|
|
| 214 |
} llama_token_data;
|
| 215 |
|
| 216 |
typedef struct llama_token_data_array {
|
| 217 |
+
// TODO: consider SoA
|
| 218 |
llama_token_data * data;
|
| 219 |
size_t size;
|
| 220 |
+
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
| 221 |
bool sorted;
|
| 222 |
} llama_token_data_array;
|
| 223 |
|
|
|
|
| 278 |
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
| 279 |
|
| 280 |
// main_gpu interpretation depends on split_mode:
|
| 281 |
+
// LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model
|
| 282 |
+
// LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results
|
| 283 |
+
// LLAMA_SPLIT_MODE_LAYER: ignored
|
| 284 |
int32_t main_gpu;
|
| 285 |
|
| 286 |
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
|
|
|
| 310 |
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
| 311 |
// https://github.com/ggerganov/llama.cpp/pull/7544
|
| 312 |
struct llama_context_params {
|
|
|
|
| 313 |
uint32_t n_ctx; // text context, 0 = from model
|
| 314 |
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
| 315 |
uint32_t n_ubatch; // physical maximum batch size
|
| 316 |
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
|
| 317 |
+
int32_t n_threads; // number of threads to use for generation
|
| 318 |
+
int32_t n_threads_batch; // number of threads to use for batch processing
|
| 319 |
|
| 320 |
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
| 321 |
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
|
|
|
| 337 |
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
| 338 |
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
| 339 |
|
| 340 |
+
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
| 341 |
+
// TODO: move at the end of the struct
|
| 342 |
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
| 343 |
bool embeddings; // if true, extract embeddings (together with logits)
|
| 344 |
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
| 345 |
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
| 346 |
+
bool no_perf; // whether to measure performance timings
|
| 347 |
|
| 348 |
// Abort callback
|
| 349 |
// if it returns true, execution of llama_decode() will be aborted
|
|
|
|
| 367 |
void * kv_overrides; // pointer to vector containing overrides
|
| 368 |
} llama_model_quantize_params;
|
| 369 |
|
| 370 |
+
typedef struct llama_logit_bias {
|
| 371 |
+
llama_token token;
|
| 372 |
+
float bias;
|
| 373 |
+
} llama_logit_bias;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
+
typedef struct llama_sampler_chain_params {
|
| 376 |
+
bool no_perf; // whether to measure performance timings
|
| 377 |
+
} llama_sampler_chain_params;
|
|
|
|
| 378 |
|
| 379 |
// used in chat template
|
| 380 |
typedef struct llama_chat_message {
|
|
|
|
| 386 |
struct llama_lora_adapter;
|
| 387 |
|
| 388 |
// Helpers for getting default parameters
|
| 389 |
+
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
|
| 390 |
+
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
| 391 |
+
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
| 392 |
+
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
|
| 393 |
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
| 394 |
|
| 395 |
// Initialize the llama + ggml backend
|
|
|
|
| 400 |
//optional:
|
| 401 |
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
|
| 402 |
|
| 403 |
+
// Optional: an auto threadpool gets created in ggml if not passed explicitly
|
| 404 |
+
LLAMA_API void llama_attach_threadpool(
|
| 405 |
+
struct llama_context * ctx,
|
| 406 |
+
ggml_threadpool_t threadpool,
|
| 407 |
+
ggml_threadpool_t threadpool_batch);
|
| 408 |
+
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
| 409 |
+
|
| 410 |
// Call once at the end of the program - currently only used for MPI
|
| 411 |
LLAMA_API void llama_backend_free(void);
|
| 412 |
|
| 413 |
LLAMA_API struct llama_model * llama_load_model_from_file(
|
| 414 |
const char * path_model,
|
| 415 |
+
struct llama_model_params params);
|
| 416 |
|
| 417 |
LLAMA_API void llama_free_model(struct llama_model * model);
|
| 418 |
|
| 419 |
+
// TODO: rename to llama_init_from_model
|
| 420 |
LLAMA_API struct llama_context * llama_new_context_with_model(
|
| 421 |
struct llama_model * model,
|
| 422 |
struct llama_context_params params);
|
|
|
|
| 432 |
LLAMA_API bool llama_supports_mlock (void);
|
| 433 |
LLAMA_API bool llama_supports_gpu_offload(void);
|
| 434 |
|
|
|
|
|
|
|
| 435 |
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
| 436 |
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
| 437 |
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
| 438 |
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
| 441 |
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
| 442 |
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
| 443 |
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
| 444 |
+
LLAMA_API int32_t llama_n_head (const struct llama_model * model);
|
| 445 |
+
|
| 446 |
+
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
| 447 |
+
|
| 448 |
+
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
| 449 |
+
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
| 450 |
+
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
| 451 |
|
| 452 |
// Get the model's RoPE frequency scaling factor
|
| 453 |
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
|
|
|
| 676 |
//
|
| 677 |
|
| 678 |
// Returns the *actual* size in bytes of the state
|
| 679 |
+
// (logits, embedding and kv_cache)
|
| 680 |
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
| 681 |
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
| 682 |
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
|
|
|
| 817 |
// Set the number of threads used for decoding
|
| 818 |
// n_threads is the number of threads used for generation (single token)
|
| 819 |
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
| 820 |
+
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch);
|
| 821 |
|
| 822 |
// Get the number of threads used for generation of a single token.
|
| 823 |
+
LLAMA_API int32_t llama_n_threads(struct llama_context * ctx);
|
| 824 |
|
| 825 |
// Get the number of threads used for prompt and batch processing (multiple token).
|
| 826 |
+
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
|
| 827 |
|
| 828 |
// Set whether the model is in embeddings mode or not
|
| 829 |
// If true, embeddings will be returned but logits will not
|
|
|
|
| 979 |
int32_t length);
|
| 980 |
|
| 981 |
//
|
| 982 |
+
// Sampling API
|
| 983 |
+
//
|
| 984 |
+
// Sample usage:
|
| 985 |
+
//
|
| 986 |
+
// // prepare the sampling chain at the start
|
| 987 |
+
// auto sparams = llama_sampler_chain_default_params();
|
| 988 |
+
//
|
| 989 |
+
// llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
| 990 |
+
//
|
| 991 |
+
// llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
|
| 992 |
+
// llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
|
| 993 |
+
// llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
|
| 994 |
+
//
|
| 995 |
+
// // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat"
|
| 996 |
+
// // this sampler will be responsible to select the actual token
|
| 997 |
+
// llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed));
|
| 998 |
+
//
|
| 999 |
+
// ...
|
| 1000 |
+
//
|
| 1001 |
+
// // decoding loop:
|
| 1002 |
+
// while (...) {
|
| 1003 |
+
// ...
|
| 1004 |
+
//
|
| 1005 |
+
// llama_decode(ctx, batch);
|
| 1006 |
+
//
|
| 1007 |
+
// // sample from the logits of the last token in the batch
|
| 1008 |
+
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
| 1009 |
+
//
|
| 1010 |
+
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
|
| 1011 |
+
// llama_sampler_accept(smpl, id);
|
| 1012 |
+
// ...
|
| 1013 |
+
// }
|
| 1014 |
+
//
|
| 1015 |
+
// llama_sampler_free(smpl);
|
| 1016 |
+
//
|
| 1017 |
+
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
| 1018 |
+
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
|
| 1019 |
//
|
| 1020 |
|
| 1021 |
+
typedef void * llama_sampler_context_t;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
|
| 1023 |
+
// user code can implement the interface below in order to create custom llama_sampler
|
| 1024 |
+
struct llama_sampler_i {
|
| 1025 |
+
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
|
| 1026 |
+
void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL
|
| 1027 |
+
void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
|
| 1028 |
+
void (*reset) ( struct llama_sampler * smpl); // can be NULL
|
| 1029 |
+
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
| 1030 |
+
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
| 1031 |
|
| 1032 |
+
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
| 1033 |
+
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
|
| 1034 |
+
};
|
| 1035 |
+
|
| 1036 |
+
struct llama_sampler {
|
| 1037 |
+
struct llama_sampler_i * iface;
|
| 1038 |
+
llama_sampler_context_t ctx;
|
| 1039 |
+
};
|
| 1040 |
|
| 1041 |
+
// mirror of llama_sampler_i:
|
| 1042 |
+
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
| 1043 |
+
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
| 1044 |
+
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
| 1045 |
+
LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
|
| 1046 |
+
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
|
| 1047 |
+
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
|
| 1048 |
+
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
| 1049 |
|
| 1050 |
+
// llama_sampler_chain
|
| 1051 |
+
// a type of llama_sampler that can chain multiple samplers one after another
|
| 1052 |
+
|
| 1053 |
+
LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
|
| 1054 |
+
|
| 1055 |
+
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
| 1056 |
+
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
| 1057 |
+
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
| 1058 |
+
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
| 1059 |
+
|
| 1060 |
+
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
| 1061 |
+
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
|
| 1062 |
+
|
| 1063 |
+
// available samplers:
|
| 1064 |
+
|
| 1065 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
|
| 1066 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
|
|
|
|
|
|
|
|
|
| 1067 |
|
| 1068 |
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
| 1069 |
+
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
| 1070 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
|
|
|
| 1071 |
|
| 1072 |
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1073 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1074 |
|
| 1075 |
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1076 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
|
| 1078 |
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
| 1079 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1080 |
|
| 1081 |
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
| 1082 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1083 |
|
| 1084 |
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
| 1085 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
| 1086 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
|
|
|
|
|
|
|
|
|
| 1087 |
|
| 1088 |
+
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
| 1089 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1090 |
|
| 1091 |
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1092 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
|
|
|
| 1094 |
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
| 1095 |
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
| 1096 |
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
| 1097 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
|
| 1098 |
+
int32_t n_vocab,
|
| 1099 |
+
uint32_t seed,
|
| 1100 |
+
float tau,
|
| 1101 |
+
float eta,
|
| 1102 |
+
int32_t m);
|
|
|
|
| 1103 |
|
| 1104 |
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1105 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
| 1106 |
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
| 1107 |
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
| 1108 |
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
| 1109 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
|
| 1110 |
+
uint32_t seed,
|
| 1111 |
+
float tau,
|
| 1112 |
+
float eta);
|
| 1113 |
+
|
| 1114 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
| 1115 |
+
const struct llama_model * model,
|
| 1116 |
+
const char * grammar_str,
|
| 1117 |
+
const char * grammar_root);
|
| 1118 |
+
|
| 1119 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
| 1120 |
+
int32_t n_vocab, // llama_n_vocab()
|
| 1121 |
+
llama_token special_eos_id, // llama_token_eos()
|
| 1122 |
+
llama_token linefeed_id, // llama_token_nl()
|
| 1123 |
+
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
| 1124 |
+
float penalty_repeat, // 1.0 = disabled
|
| 1125 |
+
float penalty_freq, // 0.0 = disabled
|
| 1126 |
+
float penalty_present, // 0.0 = disabled
|
| 1127 |
+
bool penalize_nl, // consider newlines as a repeatable token
|
| 1128 |
+
bool ignore_eos); // ignore the end-of-sequence token
|
| 1129 |
+
|
| 1130 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
| 1131 |
+
int32_t n_vocab,
|
| 1132 |
+
int32_t n_logit_bias,
|
| 1133 |
+
const llama_logit_bias * logit_bias);
|
| 1134 |
+
|
| 1135 |
+
|
| 1136 |
+
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
| 1137 |
+
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
| 1138 |
+
|
| 1139 |
+
/// @details Sample and accept a token from the idx-th output of the last evaluation
|
| 1140 |
+
//
|
| 1141 |
+
// Shorthand for:
|
| 1142 |
+
// const auto * logits = llama_get_logits_ith(ctx, idx);
|
| 1143 |
+
// llama_token_data_array cur_p = { ... init from logits ... };
|
| 1144 |
+
// llama_sampler_apply(smpl, &cur_p);
|
| 1145 |
+
// auto token = cur_p.data[cur_p.selected].id;
|
| 1146 |
+
// llama_sampler_accept(smpl, token);
|
| 1147 |
+
// return token;
|
| 1148 |
+
// Returns the sampled token
|
| 1149 |
+
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
|
| 1150 |
+
|
| 1151 |
+
// TODO: extend in the future
|
| 1152 |
+
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
|
| 1153 |
|
| 1154 |
//
|
| 1155 |
// Model split
|
|
|
|
| 1165 |
// Returns the split_prefix length.
|
| 1166 |
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
|
| 1167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1168 |
// Print system information
|
| 1169 |
LLAMA_API const char * llama_print_system_info(void);
|
| 1170 |
|
|
|
|
| 1172 |
// If this is not called, or NULL is supplied, everything is output on stderr.
|
| 1173 |
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
| 1174 |
|
| 1175 |
+
//
|
| 1176 |
+
// Performance utils
|
| 1177 |
+
//
|
| 1178 |
+
// NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
|
| 1179 |
+
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1180 |
|
| 1181 |
+
struct llama_perf_context_data {
|
| 1182 |
+
double t_start_ms;
|
| 1183 |
+
double t_load_ms;
|
| 1184 |
+
double t_p_eval_ms;
|
| 1185 |
+
double t_eval_ms;
|
| 1186 |
|
| 1187 |
+
int32_t n_p_eval;
|
| 1188 |
+
int32_t n_eval;
|
| 1189 |
+
};
|
| 1190 |
|
| 1191 |
+
struct llama_perf_sampler_data {
|
| 1192 |
+
double t_sample_ms;
|
| 1193 |
|
| 1194 |
+
int32_t n_sample;
|
| 1195 |
+
};
|
|
|
|
|
|
|
|
|
|
| 1196 |
|
| 1197 |
+
LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx);
|
| 1198 |
+
LLAMA_API void llama_perf_context_print(const struct llama_context * ctx);
|
| 1199 |
+
LLAMA_API void llama_perf_context_reset( struct llama_context * ctx);
|
|
|
|
| 1200 |
|
| 1201 |
+
// NOTE: the following work only with samplers constructed via llama_sampler_chain_init
|
| 1202 |
+
LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain);
|
| 1203 |
+
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
|
| 1204 |
+
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
|
| 1205 |
|
| 1206 |
+
LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
|
|
|
|
|
|
|
| 1207 |
|
| 1208 |
+
#ifdef __cplusplus
|
| 1209 |
+
}
|
| 1210 |
+
#endif
|
| 1211 |
|
| 1212 |
#endif // LLAMA_H
|
examples/talk-llama/talk-llama.cpp
CHANGED
|
@@ -314,7 +314,6 @@ int main(int argc, char ** argv) {
|
|
| 314 |
|
| 315 |
// tune these to your liking
|
| 316 |
lcparams.n_ctx = 2048;
|
| 317 |
-
lcparams.seed = 1;
|
| 318 |
lcparams.n_threads = params.n_threads;
|
| 319 |
lcparams.flash_attn = params.flash_attn;
|
| 320 |
|
|
@@ -402,6 +401,26 @@ int main(int argc, char ** argv) {
|
|
| 402 |
|
| 403 |
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
|
| 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
// init session
|
| 406 |
std::string path_session = params.path_session;
|
| 407 |
std::vector<llama_token> session_tokens;
|
|
@@ -700,54 +719,13 @@ int main(int argc, char ** argv) {
|
|
| 700 |
|
| 701 |
{
|
| 702 |
// out of user input, sample next token
|
| 703 |
-
const float top_k = 5;
|
| 704 |
-
const float top_p = 0.80f;
|
| 705 |
-
const float temp = 0.30f;
|
| 706 |
-
const float repeat_penalty = 1.1764f;
|
| 707 |
-
|
| 708 |
-
const int repeat_last_n = 256;
|
| 709 |
|
| 710 |
if (!path_session.empty() && need_to_save_session) {
|
| 711 |
need_to_save_session = false;
|
| 712 |
llama_state_save_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
| 713 |
}
|
| 714 |
|
| 715 |
-
llama_token id =
|
| 716 |
-
|
| 717 |
-
{
|
| 718 |
-
auto logits = llama_get_logits(ctx_llama);
|
| 719 |
-
auto n_vocab = llama_n_vocab(model_llama);
|
| 720 |
-
|
| 721 |
-
logits[llama_token_eos(model_llama)] = 0;
|
| 722 |
-
|
| 723 |
-
std::vector<llama_token_data> candidates;
|
| 724 |
-
candidates.reserve(n_vocab);
|
| 725 |
-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
| 726 |
-
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
| 727 |
-
}
|
| 728 |
-
|
| 729 |
-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
| 730 |
-
|
| 731 |
-
// apply repeat penalty
|
| 732 |
-
const float nl_logit = logits[llama_token_nl(model_llama)];
|
| 733 |
-
|
| 734 |
-
llama_sample_repetition_penalties(ctx_llama, &candidates_p,
|
| 735 |
-
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
| 736 |
-
repeat_last_n, repeat_penalty, 0.0, 0.0f);
|
| 737 |
-
|
| 738 |
-
logits[llama_token_nl(model_llama)] = nl_logit;
|
| 739 |
-
|
| 740 |
-
if (temp <= 0) {
|
| 741 |
-
// Greedy sampling
|
| 742 |
-
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
|
| 743 |
-
} else {
|
| 744 |
-
// Temperature sampling
|
| 745 |
-
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
|
| 746 |
-
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
|
| 747 |
-
llama_sample_temp (ctx_llama, &candidates_p, temp);
|
| 748 |
-
id = llama_sample_token(ctx_llama, &candidates_p);
|
| 749 |
-
}
|
| 750 |
-
}
|
| 751 |
|
| 752 |
if (id != llama_token_eos(model_llama)) {
|
| 753 |
// add it to the context
|
|
@@ -797,8 +775,14 @@ int main(int argc, char ** argv) {
|
|
| 797 |
whisper_print_timings(ctx_wsp);
|
| 798 |
whisper_free(ctx_wsp);
|
| 799 |
|
| 800 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
llama_free(ctx_llama);
|
| 802 |
|
|
|
|
|
|
|
| 803 |
return 0;
|
| 804 |
}
|
|
|
|
| 314 |
|
| 315 |
// tune these to your liking
|
| 316 |
lcparams.n_ctx = 2048;
|
|
|
|
| 317 |
lcparams.n_threads = params.n_threads;
|
| 318 |
lcparams.flash_attn = params.flash_attn;
|
| 319 |
|
|
|
|
| 401 |
|
| 402 |
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
|
| 403 |
|
| 404 |
+
// init sampler
|
| 405 |
+
const float top_k = 5;
|
| 406 |
+
const float top_p = 0.80f;
|
| 407 |
+
const float temp = 0.30f;
|
| 408 |
+
|
| 409 |
+
const int seed = 0;
|
| 410 |
+
|
| 411 |
+
auto sparams = llama_sampler_chain_default_params();
|
| 412 |
+
|
| 413 |
+
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
| 414 |
+
|
| 415 |
+
if (temp > 0.0f) {
|
| 416 |
+
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
|
| 417 |
+
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
|
| 418 |
+
llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp));
|
| 419 |
+
llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
|
| 420 |
+
} else {
|
| 421 |
+
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
// init session
|
| 425 |
std::string path_session = params.path_session;
|
| 426 |
std::vector<llama_token> session_tokens;
|
|
|
|
| 719 |
|
| 720 |
{
|
| 721 |
// out of user input, sample next token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
|
| 723 |
if (!path_session.empty() && need_to_save_session) {
|
| 724 |
need_to_save_session = false;
|
| 725 |
llama_state_save_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
| 726 |
}
|
| 727 |
|
| 728 |
+
const llama_token id = llama_sampler_sample(smpl, ctx_llama, -1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
if (id != llama_token_eos(model_llama)) {
|
| 731 |
// add it to the context
|
|
|
|
| 775 |
whisper_print_timings(ctx_wsp);
|
| 776 |
whisper_free(ctx_wsp);
|
| 777 |
|
| 778 |
+
llama_perf_sampler_print(smpl);
|
| 779 |
+
llama_perf_context_print(ctx_llama);
|
| 780 |
+
|
| 781 |
+
llama_sampler_free(smpl);
|
| 782 |
+
llama_batch_free(batch);
|
| 783 |
llama_free(ctx_llama);
|
| 784 |
|
| 785 |
+
llama_backend_free();
|
| 786 |
+
|
| 787 |
return 0;
|
| 788 |
}
|
examples/talk-llama/unicode.cpp
CHANGED
|
@@ -5,6 +5,7 @@
|
|
| 5 |
#include "unicode.h"
|
| 6 |
#include "unicode-data.h"
|
| 7 |
|
|
|
|
| 8 |
#include <cassert>
|
| 9 |
#include <cstddef>
|
| 10 |
#include <cstdint>
|
|
|
|
| 5 |
#include "unicode.h"
|
| 6 |
#include "unicode-data.h"
|
| 7 |
|
| 8 |
+
#include <algorithm>
|
| 9 |
#include <cassert>
|
| 10 |
#include <cstddef>
|
| 11 |
#include <cstdint>
|
src/whisper.cpp
CHANGED
|
@@ -177,7 +177,7 @@ static bool ggml_graph_compute_helper(
|
|
| 177 |
int n_threads,
|
| 178 |
ggml_abort_callback abort_callback,
|
| 179 |
void * abort_callback_data) {
|
| 180 |
-
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
| 181 |
|
| 182 |
plan.abort_callback = abort_callback;
|
| 183 |
plan.abort_callback_data = abort_callback_data;
|
|
@@ -2894,7 +2894,7 @@ static bool whisper_decode_internal(
|
|
| 2894 |
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
| 2895 |
}
|
| 2896 |
|
| 2897 |
-
logits = gf
|
| 2898 |
|
| 2899 |
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
| 2900 |
return false;
|
|
|
|
| 177 |
int n_threads,
|
| 178 |
ggml_abort_callback abort_callback,
|
| 179 |
void * abort_callback_data) {
|
| 180 |
+
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
|
| 181 |
|
| 182 |
plan.abort_callback = abort_callback;
|
| 183 |
plan.abort_callback_data = abort_callback_data;
|
|
|
|
| 2894 |
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
| 2895 |
}
|
| 2896 |
|
| 2897 |
+
logits = ggml_graph_node(gf, -1);
|
| 2898 |
|
| 2899 |
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
| 2900 |
return false;
|