Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files
examples/talk-llama/llama-sampling.cpp
CHANGED
|
@@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
|
|
| 63 |
}
|
| 64 |
*/
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
| 67 |
GGML_ASSERT(cur_p->size > 0);
|
| 68 |
|
|
@@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
|
| 427 |
|
| 428 |
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 429 |
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
|
|
|
|
|
|
|
|
| 430 |
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
| 431 |
}
|
| 432 |
|
|
@@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
|
|
| 912 |
|
| 913 |
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 914 |
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
}
|
| 918 |
}
|
| 919 |
|
| 920 |
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
|
@@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|
| 961 |
if (ctx->delta > 0) {
|
| 962 |
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
| 963 |
const float max_temp = ctx->temp + ctx->delta;
|
|
|
|
| 964 |
float exponent_val = ctx->exponent;
|
| 965 |
|
| 966 |
// no need to do anything if there is only one (or zero) candidates
|
|
@@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|
| 998 |
#endif
|
| 999 |
|
| 1000 |
// Apply the dynamically calculated temperature scaling
|
| 1001 |
-
|
| 1002 |
-
cur_p->data[i].logit /= dyn_temp;
|
| 1003 |
-
}
|
| 1004 |
|
| 1005 |
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
| 1006 |
const double max_l_double = cur_p->data[0].logit;
|
|
@@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|
| 1024 |
}
|
| 1025 |
#endif
|
| 1026 |
} else {
|
| 1027 |
-
|
| 1028 |
-
cur_p->data[i].logit /= ctx->temp;
|
| 1029 |
-
}
|
| 1030 |
}
|
| 1031 |
}
|
| 1032 |
|
|
@@ -1059,6 +1082,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
|
| 1059 |
};
|
| 1060 |
}
|
| 1061 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1062 |
// mirostat
|
| 1063 |
|
| 1064 |
struct llama_sampler_mirostat {
|
|
@@ -1565,6 +1683,397 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
| 1565 |
};
|
| 1566 |
}
|
| 1567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1568 |
// logit-bias
|
| 1569 |
|
| 1570 |
struct llama_sampler_logit_bias {
|
|
@@ -1644,6 +2153,229 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|
| 1644 |
};
|
| 1645 |
}
|
| 1646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1647 |
// utils
|
| 1648 |
|
| 1649 |
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|
|
|
| 63 |
}
|
| 64 |
*/
|
| 65 |
|
| 66 |
+
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
|
| 67 |
+
if (temp <= 0.0f) {
|
| 68 |
+
// find the token with the highest logit and set the rest to -inf
|
| 69 |
+
size_t max_i = 0;
|
| 70 |
+
float max_l = cur_p->data[0].logit;
|
| 71 |
+
|
| 72 |
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
| 73 |
+
if (cur_p->data[i ].logit > max_l) {
|
| 74 |
+
cur_p->data[max_i].logit = -INFINITY;
|
| 75 |
+
max_i = i;
|
| 76 |
+
max_l = cur_p->data[i].logit;
|
| 77 |
+
} else {
|
| 78 |
+
cur_p->data[i].logit = -INFINITY;
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
return;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 86 |
+
cur_p->data[i].logit /= temp;
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
| 91 |
GGML_ASSERT(cur_p->size > 0);
|
| 92 |
|
|
|
|
| 451 |
|
| 452 |
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 453 |
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
| 454 |
+
|
| 455 |
+
llama_sampler_softmax_impl(cur_p);
|
| 456 |
+
|
| 457 |
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
| 458 |
}
|
| 459 |
|
|
|
|
| 939 |
|
| 940 |
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 941 |
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
| 942 |
+
|
| 943 |
+
llama_sampler_temp_impl(cur_p, ctx->temp);
|
|
|
|
| 944 |
}
|
| 945 |
|
| 946 |
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
|
|
|
| 987 |
if (ctx->delta > 0) {
|
| 988 |
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
| 989 |
const float max_temp = ctx->temp + ctx->delta;
|
| 990 |
+
|
| 991 |
float exponent_val = ctx->exponent;
|
| 992 |
|
| 993 |
// no need to do anything if there is only one (or zero) candidates
|
|
|
|
| 1025 |
#endif
|
| 1026 |
|
| 1027 |
// Apply the dynamically calculated temperature scaling
|
| 1028 |
+
llama_sampler_temp_impl(cur_p, dyn_temp);
|
|
|
|
|
|
|
| 1029 |
|
| 1030 |
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
| 1031 |
const double max_l_double = cur_p->data[0].logit;
|
|
|
|
| 1049 |
}
|
| 1050 |
#endif
|
| 1051 |
} else {
|
| 1052 |
+
llama_sampler_temp_impl(cur_p, ctx->temp);
|
|
|
|
|
|
|
| 1053 |
}
|
| 1054 |
}
|
| 1055 |
|
|
|
|
| 1082 |
};
|
| 1083 |
}
|
| 1084 |
|
| 1085 |
+
// xtc
|
| 1086 |
+
|
| 1087 |
+
struct llama_sampler_xtc {
|
| 1088 |
+
const float probability;
|
| 1089 |
+
const float threshold;
|
| 1090 |
+
const size_t min_keep;
|
| 1091 |
+
|
| 1092 |
+
const uint32_t seed;
|
| 1093 |
+
uint32_t seed_cur;
|
| 1094 |
+
|
| 1095 |
+
std::mt19937 rng;
|
| 1096 |
+
};
|
| 1097 |
+
|
| 1098 |
+
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
| 1099 |
+
return "xtc";
|
| 1100 |
+
}
|
| 1101 |
+
|
| 1102 |
+
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1103 |
+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
| 1104 |
+
|
| 1105 |
+
if (ctx->probability <= 0.0f
|
| 1106 |
+
|| ctx->threshold > 0.5f
|
| 1107 |
+
|| cur_p->size < 2) {
|
| 1108 |
+
return;
|
| 1109 |
+
}
|
| 1110 |
+
|
| 1111 |
+
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
|
| 1112 |
+
float chance = distribution(ctx->rng);
|
| 1113 |
+
if (chance > ctx->probability) return;
|
| 1114 |
+
|
| 1115 |
+
// in case it's not sorted/recalculated yet
|
| 1116 |
+
llama_sampler_softmax_impl(cur_p);
|
| 1117 |
+
|
| 1118 |
+
int pos_last = 0;
|
| 1119 |
+
|
| 1120 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1121 |
+
if (cur_p->data[i].p >= ctx->threshold) {
|
| 1122 |
+
pos_last = i;
|
| 1123 |
+
} else break;
|
| 1124 |
+
}
|
| 1125 |
+
|
| 1126 |
+
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
|
| 1127 |
+
cur_p->data += pos_last;
|
| 1128 |
+
cur_p->size -= pos_last;
|
| 1129 |
+
}
|
| 1130 |
+
}
|
| 1131 |
+
|
| 1132 |
+
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
| 1133 |
+
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
| 1134 |
+
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
|
| 1135 |
+
|
| 1136 |
+
// copy the state
|
| 1137 |
+
{
|
| 1138 |
+
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
|
| 1139 |
+
|
| 1140 |
+
result_ctx->rng = ctx->rng;
|
| 1141 |
+
}
|
| 1142 |
+
|
| 1143 |
+
return result;
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
| 1147 |
+
delete (llama_sampler_xtc *) smpl->ctx;
|
| 1148 |
+
}
|
| 1149 |
+
|
| 1150 |
+
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
| 1151 |
+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
| 1152 |
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
| 1153 |
+
ctx->rng.seed(ctx->seed_cur);
|
| 1154 |
+
}
|
| 1155 |
+
|
| 1156 |
+
static struct llama_sampler_i llama_sampler_xtc_i = {
|
| 1157 |
+
/* .name = */ llama_sampler_xtc_name,
|
| 1158 |
+
/* .accept = */ nullptr,
|
| 1159 |
+
/* .apply = */ llama_sample_xtc_apply,
|
| 1160 |
+
/* .reset = */ llama_sampler_xtc_reset,
|
| 1161 |
+
/* .clone = */ llama_sampler_xtc_clone,
|
| 1162 |
+
/* .free = */ llama_sampler_xtc_free,
|
| 1163 |
+
};
|
| 1164 |
+
|
| 1165 |
+
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
| 1166 |
+
auto seed_cur = get_rng_seed(seed);
|
| 1167 |
+
return new llama_sampler {
|
| 1168 |
+
/* .iface = */ &llama_sampler_xtc_i,
|
| 1169 |
+
/* .ctx = */ new llama_sampler_xtc {
|
| 1170 |
+
/* .probability = */ p,
|
| 1171 |
+
/* .threshold = */ t,
|
| 1172 |
+
/* .min_keep = */ min_keep,
|
| 1173 |
+
/* .seed = */ seed,
|
| 1174 |
+
/* .seed_cur = */ seed_cur,
|
| 1175 |
+
/* .rng = */ std::mt19937(seed_cur),
|
| 1176 |
+
},
|
| 1177 |
+
};
|
| 1178 |
+
}
|
| 1179 |
+
|
| 1180 |
// mirostat
|
| 1181 |
|
| 1182 |
struct llama_sampler_mirostat {
|
|
|
|
| 1683 |
};
|
| 1684 |
}
|
| 1685 |
|
| 1686 |
+
// DRY
|
| 1687 |
+
|
| 1688 |
+
struct llama_sampler_dry {
|
| 1689 |
+
int32_t total_context_size;
|
| 1690 |
+
|
| 1691 |
+
const float dry_multiplier;
|
| 1692 |
+
const float dry_base;
|
| 1693 |
+
const int32_t dry_allowed_length;
|
| 1694 |
+
const int32_t dry_penalty_last_n;
|
| 1695 |
+
|
| 1696 |
+
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
|
| 1697 |
+
std::vector<int> dry_repeat_count;
|
| 1698 |
+
std::unordered_map<llama_token, int> dry_max_token_repeat;
|
| 1699 |
+
ring_buffer<llama_token> last_tokens;
|
| 1700 |
+
};
|
| 1701 |
+
|
| 1702 |
+
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
| 1703 |
+
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
| 1704 |
+
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
|
| 1705 |
+
std::string word = llama_detokenize(vocab, {token_id}, true);
|
| 1706 |
+
if (word.find(str) != std::string::npos) {
|
| 1707 |
+
token_sequences.emplace(token_id, std::vector<llama_token>());
|
| 1708 |
+
} else {
|
| 1709 |
+
size_t word_len = word.size(), str_len = str.size();
|
| 1710 |
+
size_t pos = -1;
|
| 1711 |
+
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
| 1712 |
+
bool match = true;
|
| 1713 |
+
size_t i;
|
| 1714 |
+
for (i = 1; i < str_len && i + pos < word_len; ++i) {
|
| 1715 |
+
if (word[pos + i] != str[i]) {
|
| 1716 |
+
match = false;
|
| 1717 |
+
break;
|
| 1718 |
+
}
|
| 1719 |
+
}
|
| 1720 |
+
if (match) {
|
| 1721 |
+
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
| 1722 |
+
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
| 1723 |
+
tokenization.resize(max_tail_len);
|
| 1724 |
+
}
|
| 1725 |
+
|
| 1726 |
+
// Ensure we don't already have a duplicate matching tokenization
|
| 1727 |
+
auto its = token_sequences.equal_range(token_id);
|
| 1728 |
+
bool found = false;
|
| 1729 |
+
for (auto it = its.first; it != its.second; ++it) {
|
| 1730 |
+
if (tokenization == it->second) {
|
| 1731 |
+
found = true;
|
| 1732 |
+
break;
|
| 1733 |
+
}
|
| 1734 |
+
}
|
| 1735 |
+
if (!found) {
|
| 1736 |
+
token_sequences.emplace(token_id, tokenization);
|
| 1737 |
+
}
|
| 1738 |
+
}
|
| 1739 |
+
}
|
| 1740 |
+
}
|
| 1741 |
+
}
|
| 1742 |
+
}
|
| 1743 |
+
|
| 1744 |
+
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
|
| 1745 |
+
return "dry";
|
| 1746 |
+
}
|
| 1747 |
+
|
| 1748 |
+
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
| 1749 |
+
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
| 1750 |
+
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
| 1751 |
+
return;
|
| 1752 |
+
}
|
| 1753 |
+
|
| 1754 |
+
ctx->last_tokens.push_back(token);
|
| 1755 |
+
}
|
| 1756 |
+
|
| 1757 |
+
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
| 1758 |
+
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1759 |
+
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
| 1760 |
+
|
| 1761 |
+
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
| 1762 |
+
return;
|
| 1763 |
+
}
|
| 1764 |
+
|
| 1765 |
+
int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
|
| 1766 |
+
int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
|
| 1767 |
+
|
| 1768 |
+
if (last_n_repeat <= ctx->dry_allowed_length) {
|
| 1769 |
+
return;
|
| 1770 |
+
}
|
| 1771 |
+
|
| 1772 |
+
ctx->dry_repeat_count.assign(last_n_repeat, 0);
|
| 1773 |
+
ctx->dry_max_token_repeat.clear();
|
| 1774 |
+
|
| 1775 |
+
// Step 1: Look for restart sequences to limit the maximum repetition length.
|
| 1776 |
+
// Work backwards through the context looking for any token that begins a restart sequence.
|
| 1777 |
+
//
|
| 1778 |
+
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
|
| 1779 |
+
// sequences that together comprise a restart sequence. This allows us to quickly check
|
| 1780 |
+
// whether each token is the head of a complete sequence. Most restart sequences are actually
|
| 1781 |
+
// a single token, and for these the "tail" is an empty vector.
|
| 1782 |
+
//
|
| 1783 |
+
// If the token is a "head", test all restart sequences that begin with this token
|
| 1784 |
+
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
|
| 1785 |
+
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
|
| 1786 |
+
// longest matching sequence (if any) is used to limit the maximum repetition length.
|
| 1787 |
+
//
|
| 1788 |
+
// Note that in the case case of a short sequence contained in a longer one, this might fail to
|
| 1789 |
+
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
|
| 1790 |
+
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
|
| 1791 |
+
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
|
| 1792 |
+
//
|
| 1793 |
+
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
|
| 1794 |
+
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
|
| 1795 |
+
// With clamping, this scan is O(N) in the context length.
|
| 1796 |
+
|
| 1797 |
+
int rep_limit = last_n_repeat;
|
| 1798 |
+
for (int i = 0; i < last_n_repeat; ++i) {
|
| 1799 |
+
llama_token token = ctx->last_tokens.rat(i);
|
| 1800 |
+
auto its = ctx->dry_processed_breakers.equal_range(token);
|
| 1801 |
+
if (its.first == ctx->dry_processed_breakers.end()) {
|
| 1802 |
+
continue;
|
| 1803 |
+
}
|
| 1804 |
+
int longest_match = -1;
|
| 1805 |
+
for (auto it = its.first; it != its.second; ++it) {
|
| 1806 |
+
// Note that (*it) does not contain the head character, so seq_len will be
|
| 1807 |
+
// the restart sequence length minus 1.
|
| 1808 |
+
// In the common case of a single-token restart sequence, (*it) will be empty
|
| 1809 |
+
// and we will trivially match.
|
| 1810 |
+
int seq_len = (int)it->second.size();
|
| 1811 |
+
if (seq_len > longest_match && seq_len <= (int)i) {
|
| 1812 |
+
bool match = true;
|
| 1813 |
+
for (int offset = 0; offset < seq_len; ++offset) {
|
| 1814 |
+
// The -1 when indexing `last_tokens` is because we already matched the head.
|
| 1815 |
+
if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
|
| 1816 |
+
match = false;
|
| 1817 |
+
break;
|
| 1818 |
+
}
|
| 1819 |
+
}
|
| 1820 |
+
if (match) {
|
| 1821 |
+
longest_match = seq_len;
|
| 1822 |
+
}
|
| 1823 |
+
}
|
| 1824 |
+
}
|
| 1825 |
+
if (longest_match >= 0) {
|
| 1826 |
+
// We found a restart sequence starting `i` tokens from the end and continuing for
|
| 1827 |
+
// `longest_match` tokens.
|
| 1828 |
+
rep_limit = i - longest_match;
|
| 1829 |
+
break;
|
| 1830 |
+
}
|
| 1831 |
+
}
|
| 1832 |
+
if (rep_limit < ctx->dry_allowed_length) {
|
| 1833 |
+
return;
|
| 1834 |
+
}
|
| 1835 |
+
|
| 1836 |
+
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
|
| 1837 |
+
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
|
| 1838 |
+
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
|
| 1839 |
+
//
|
| 1840 |
+
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
|
| 1841 |
+
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
|
| 1842 |
+
//
|
| 1843 |
+
// The code below is adapted from the public domain implementation by the same author here:
|
| 1844 |
+
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
|
| 1845 |
+
//
|
| 1846 |
+
// Example:
|
| 1847 |
+
// Last N tokens: a b c c b c y a b c
|
| 1848 |
+
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
| 1849 |
+
// ^
|
| 1850 |
+
// This `3` means that the last three tokens of the context (a b c) also appear here.
|
| 1851 |
+
//
|
| 1852 |
+
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
|
| 1853 |
+
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
|
| 1854 |
+
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
|
| 1855 |
+
// ensure that the inner while loops only examine each token in the context once as the outer
|
| 1856 |
+
// for loop iterates over the context.
|
| 1857 |
+
|
| 1858 |
+
{
|
| 1859 |
+
const int last = last_n_repeat - 1;
|
| 1860 |
+
int rt = 0, lt = 0;
|
| 1861 |
+
|
| 1862 |
+
for (int k = 1; k < last_n_repeat; ++k) {
|
| 1863 |
+
if (k > rt) {
|
| 1864 |
+
// If k is outside the current Z-box, do naive computation.
|
| 1865 |
+
int n = 0;
|
| 1866 |
+
while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
|
| 1867 |
+
++n;
|
| 1868 |
+
}
|
| 1869 |
+
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
| 1870 |
+
if (n > 0) {
|
| 1871 |
+
lt = k;
|
| 1872 |
+
rt = k+n-1;
|
| 1873 |
+
}
|
| 1874 |
+
} else {
|
| 1875 |
+
// If k is inside the current Z-box, consider two cases.
|
| 1876 |
+
|
| 1877 |
+
int p = k - lt; // Pair index.
|
| 1878 |
+
int right_part_len = rt - k + 1;
|
| 1879 |
+
|
| 1880 |
+
if (ctx->dry_repeat_count[last - p] < right_part_len) {
|
| 1881 |
+
int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
|
| 1882 |
+
ctx->dry_repeat_count[last - k] = n;
|
| 1883 |
+
} else {
|
| 1884 |
+
int i = rt + 1;
|
| 1885 |
+
while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
|
| 1886 |
+
i += 1;
|
| 1887 |
+
}
|
| 1888 |
+
|
| 1889 |
+
int n = std::min(i - k, rep_limit);
|
| 1890 |
+
ctx->dry_repeat_count[last - k] = n;
|
| 1891 |
+
lt = k;
|
| 1892 |
+
rt = i - 1;
|
| 1893 |
+
}
|
| 1894 |
+
}
|
| 1895 |
+
}
|
| 1896 |
+
}
|
| 1897 |
+
|
| 1898 |
+
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
|
| 1899 |
+
// that would be generated by emitting each new token that would extend a sequence.
|
| 1900 |
+
//
|
| 1901 |
+
// Following the same example as above:
|
| 1902 |
+
// Last N tokens: a b c c b c y a b c
|
| 1903 |
+
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
| 1904 |
+
//
|
| 1905 |
+
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
|
| 1906 |
+
// c: 3 -> 4 (from `a b c` to `a b c c`)
|
| 1907 |
+
// b: 1 -> 2 (from `c` to `c b`)
|
| 1908 |
+
// y: 2 -> 3 (from `b c` to `b c y`)
|
| 1909 |
+
|
| 1910 |
+
for (int i = 0; i < last_n_repeat - 1; ++i) {
|
| 1911 |
+
int repeat_len = ctx->dry_repeat_count[i];
|
| 1912 |
+
if (repeat_len >= ctx->dry_allowed_length) {
|
| 1913 |
+
// This token ends a repeat, so the next token would continue one.
|
| 1914 |
+
// By convention, the value of `repeat_len` only includes the tokens currently
|
| 1915 |
+
// in the context, not the new token that would be added.
|
| 1916 |
+
llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
|
| 1917 |
+
// Track the maximum sequence ending in this token.
|
| 1918 |
+
const auto& it = ctx->dry_max_token_repeat.find(token);
|
| 1919 |
+
if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
|
| 1920 |
+
ctx->dry_max_token_repeat[token] = repeat_len;
|
| 1921 |
+
}
|
| 1922 |
+
}
|
| 1923 |
+
}
|
| 1924 |
+
|
| 1925 |
+
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
|
| 1926 |
+
|
| 1927 |
+
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
|
| 1928 |
+
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
|
| 1929 |
+
const float FLOAT_MAX_LOG = 88.7228391f;
|
| 1930 |
+
int max_exponent = 0;
|
| 1931 |
+
if (ctx->dry_base > 1.000001f) {
|
| 1932 |
+
max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
|
| 1933 |
+
}
|
| 1934 |
+
|
| 1935 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1936 |
+
const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
|
| 1937 |
+
if (af_kvp != ctx->dry_max_token_repeat.end()) {
|
| 1938 |
+
// Check all sequence breakers starting with this token
|
| 1939 |
+
auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
|
| 1940 |
+
bool is_single_token_breaker = false;
|
| 1941 |
+
|
| 1942 |
+
for (auto it = range.first; it != range.second; ++it) {
|
| 1943 |
+
if (it->second.empty()) {
|
| 1944 |
+
is_single_token_breaker = true;
|
| 1945 |
+
break;
|
| 1946 |
+
}
|
| 1947 |
+
}
|
| 1948 |
+
|
| 1949 |
+
// Apply penalty only if it's not a single-token sequence breaker
|
| 1950 |
+
if (!is_single_token_breaker) {
|
| 1951 |
+
int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
|
| 1952 |
+
if (max_exponent > 0 && repeat_exp > max_exponent) {
|
| 1953 |
+
repeat_exp = max_exponent;
|
| 1954 |
+
}
|
| 1955 |
+
float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
|
| 1956 |
+
cur_p->data[i].logit -= penalty;
|
| 1957 |
+
}
|
| 1958 |
+
}
|
| 1959 |
+
}
|
| 1960 |
+
|
| 1961 |
+
cur_p->sorted = false;
|
| 1962 |
+
}
|
| 1963 |
+
|
| 1964 |
+
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
|
| 1965 |
+
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
| 1966 |
+
ctx->last_tokens.clear();
|
| 1967 |
+
ctx->dry_repeat_count.clear();
|
| 1968 |
+
ctx->dry_max_token_repeat.clear();
|
| 1969 |
+
}
|
| 1970 |
+
|
| 1971 |
+
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
| 1972 |
+
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
| 1973 |
+
|
| 1974 |
+
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
| 1975 |
+
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
| 1976 |
+
// Copy the state, including the processed breakers
|
| 1977 |
+
{
|
| 1978 |
+
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
| 1979 |
+
result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
|
| 1980 |
+
result_ctx->dry_repeat_count = ctx->dry_repeat_count;
|
| 1981 |
+
result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
|
| 1982 |
+
result_ctx->last_tokens = ctx->last_tokens;
|
| 1983 |
+
}
|
| 1984 |
+
|
| 1985 |
+
return result;
|
| 1986 |
+
}
|
| 1987 |
+
|
| 1988 |
+
static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
| 1989 |
+
delete (llama_sampler_dry *) smpl->ctx;
|
| 1990 |
+
}
|
| 1991 |
+
|
| 1992 |
+
static struct llama_sampler_i llama_sampler_dry_i = {
|
| 1993 |
+
/* .name = */ llama_sampler_dry_name,
|
| 1994 |
+
/* .accept = */ llama_sampler_dry_accept,
|
| 1995 |
+
/* .apply = */ llama_sampler_dry_apply,
|
| 1996 |
+
/* .reset = */ llama_sampler_dry_reset,
|
| 1997 |
+
/* .clone = */ llama_sampler_dry_clone,
|
| 1998 |
+
/* .free = */ llama_sampler_dry_free,
|
| 1999 |
+
};
|
| 2000 |
+
|
| 2001 |
+
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
| 2002 |
+
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
| 2003 |
+
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
| 2004 |
+
const int MAX_CHAR_LEN = 40;
|
| 2005 |
+
const int MAX_SEQ_LEN = 20;
|
| 2006 |
+
|
| 2007 |
+
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
| 2008 |
+
|
| 2009 |
+
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
| 2010 |
+
// Process sequence breakers
|
| 2011 |
+
for (size_t i = 0; i < num_breakers; ++i) {
|
| 2012 |
+
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
| 2013 |
+
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
| 2014 |
+
continue;
|
| 2015 |
+
}
|
| 2016 |
+
|
| 2017 |
+
std::string sequence_break(seq_breakers[i]);
|
| 2018 |
+
if (sequence_break.empty()) {
|
| 2019 |
+
LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
|
| 2020 |
+
continue;
|
| 2021 |
+
}
|
| 2022 |
+
|
| 2023 |
+
if (sequence_break.size() > MAX_CHAR_LEN) {
|
| 2024 |
+
LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
|
| 2025 |
+
sequence_break.resize(MAX_CHAR_LEN);
|
| 2026 |
+
}
|
| 2027 |
+
|
| 2028 |
+
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
| 2029 |
+
}
|
| 2030 |
+
}
|
| 2031 |
+
|
| 2032 |
+
return new llama_sampler {
|
| 2033 |
+
/* .iface = */ &llama_sampler_dry_i,
|
| 2034 |
+
/* .ctx = */ new llama_sampler_dry {
|
| 2035 |
+
/* .total_context_size = */ context_size,
|
| 2036 |
+
/* .dry_multiplier = */ dry_multiplier,
|
| 2037 |
+
/* .dry_base = */ dry_base,
|
| 2038 |
+
/* .dry_allowed_length = */ dry_allowed_length,
|
| 2039 |
+
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
| 2040 |
+
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
| 2041 |
+
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
| 2042 |
+
/* .dry_max_token_repeat = */ {},
|
| 2043 |
+
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
| 2044 |
+
},
|
| 2045 |
+
};
|
| 2046 |
+
}
|
| 2047 |
+
|
| 2048 |
+
// wrapper for test-sampling.cpp
|
| 2049 |
+
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
| 2050 |
+
llama_vocab dummy_vocab;
|
| 2051 |
+
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
| 2052 |
+
auto * ctx = (llama_sampler_dry *) result->ctx;
|
| 2053 |
+
|
| 2054 |
+
// Process the token-based sequence breakers
|
| 2055 |
+
ctx->dry_processed_breakers.clear();
|
| 2056 |
+
if (seq_breakers.empty()) {
|
| 2057 |
+
LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
|
| 2058 |
+
} else {
|
| 2059 |
+
for (const auto& breaker : seq_breakers) {
|
| 2060 |
+
if (breaker.empty()) {
|
| 2061 |
+
LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
|
| 2062 |
+
continue;
|
| 2063 |
+
}
|
| 2064 |
+
llama_token head_token = breaker[0];
|
| 2065 |
+
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
|
| 2066 |
+
ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
|
| 2067 |
+
}
|
| 2068 |
+
|
| 2069 |
+
if (ctx->dry_processed_breakers.empty()) {
|
| 2070 |
+
LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
|
| 2071 |
+
}
|
| 2072 |
+
}
|
| 2073 |
+
|
| 2074 |
+
return result;
|
| 2075 |
+
}
|
| 2076 |
+
|
| 2077 |
// logit-bias
|
| 2078 |
|
| 2079 |
struct llama_sampler_logit_bias {
|
|
|
|
| 2153 |
};
|
| 2154 |
}
|
| 2155 |
|
| 2156 |
+
// infill
|
| 2157 |
+
|
| 2158 |
+
//#define GGML_DEBUG_SAMPLER_INFILL
|
| 2159 |
+
|
| 2160 |
+
struct llama_sampler_infill {
|
| 2161 |
+
const struct llama_vocab * vocab;
|
| 2162 |
+
|
| 2163 |
+
std::vector<char> buf0;
|
| 2164 |
+
std::vector<char> buf1;
|
| 2165 |
+
};
|
| 2166 |
+
|
| 2167 |
+
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
| 2168 |
+
return "infill";
|
| 2169 |
+
}
|
| 2170 |
+
|
| 2171 |
+
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 2172 |
+
auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
| 2173 |
+
|
| 2174 |
+
llama_sampler_softmax_impl(cur_p);
|
| 2175 |
+
|
| 2176 |
+
#if defined(GGML_DEBUG_SAMPLER_INFILL)
|
| 2177 |
+
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
| 2178 |
+
#else
|
| 2179 |
+
#define LOG_DBG_CUR(...)
|
| 2180 |
+
#endif
|
| 2181 |
+
|
| 2182 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 2183 |
+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
| 2184 |
+
}
|
| 2185 |
+
|
| 2186 |
+
float p_txt_sum = 0.0f;
|
| 2187 |
+
float p_eog_sum = 0.0f;
|
| 2188 |
+
|
| 2189 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 2190 |
+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
| 2191 |
+
p_eog_sum += cur_p->data[i].p;
|
| 2192 |
+
} else {
|
| 2193 |
+
p_txt_sum += cur_p->data[i].p;
|
| 2194 |
+
}
|
| 2195 |
+
}
|
| 2196 |
+
|
| 2197 |
+
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
| 2198 |
+
|
| 2199 |
+
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
| 2200 |
+
|
| 2201 |
+
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
| 2202 |
+
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
| 2203 |
+
|
| 2204 |
+
// keep just the EOG tokens
|
| 2205 |
+
const auto size_org = cur_p->size;
|
| 2206 |
+
|
| 2207 |
+
cur_p->size = 0;
|
| 2208 |
+
|
| 2209 |
+
float p_sum = 0.0f;
|
| 2210 |
+
|
| 2211 |
+
for (size_t i = 0; i < size_org; ++i) {
|
| 2212 |
+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
| 2213 |
+
p_sum += cur_p->data[i].p;
|
| 2214 |
+
|
| 2215 |
+
cur_p->data[cur_p->size++] = cur_p->data[i];
|
| 2216 |
+
}
|
| 2217 |
+
}
|
| 2218 |
+
|
| 2219 |
+
// normalize probs
|
| 2220 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 2221 |
+
cur_p->data[i].p /= p_sum;
|
| 2222 |
+
}
|
| 2223 |
+
|
| 2224 |
+
return;
|
| 2225 |
+
}
|
| 2226 |
+
|
| 2227 |
+
size_t n_combined = 0; GGML_UNUSED(n_combined);
|
| 2228 |
+
|
| 2229 |
+
// combine tokens with common prefix
|
| 2230 |
+
for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
|
| 2231 |
+
for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
|
| 2232 |
+
if (cur_p->data[i0].logit == -INFINITY) {
|
| 2233 |
+
break;
|
| 2234 |
+
}
|
| 2235 |
+
|
| 2236 |
+
if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
|
| 2237 |
+
continue;
|
| 2238 |
+
}
|
| 2239 |
+
|
| 2240 |
+
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
| 2241 |
+
if (len0 < 0) {
|
| 2242 |
+
ctx->buf0.resize(len0);
|
| 2243 |
+
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
| 2244 |
+
assert(len0 > 0);
|
| 2245 |
+
}
|
| 2246 |
+
|
| 2247 |
+
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
| 2248 |
+
if (len1 < 0) {
|
| 2249 |
+
ctx->buf1.resize(len1);
|
| 2250 |
+
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
| 2251 |
+
assert(len1 > 0);
|
| 2252 |
+
}
|
| 2253 |
+
|
| 2254 |
+
// token i0 is a prefix of token i1
|
| 2255 |
+
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
|
| 2256 |
+
int dst = i0;
|
| 2257 |
+
int src = i1;
|
| 2258 |
+
|
| 2259 |
+
// merge into the token with higher probability
|
| 2260 |
+
if (cur_p->data[i1].p > cur_p->data[i0].p) {
|
| 2261 |
+
std::swap(dst, src);
|
| 2262 |
+
}
|
| 2263 |
+
|
| 2264 |
+
cur_p->data[dst].p += cur_p->data[src].p;
|
| 2265 |
+
cur_p->data[src].logit = -INFINITY;
|
| 2266 |
+
cur_p->data[src].p = 0.0f;
|
| 2267 |
+
|
| 2268 |
+
n_combined++;
|
| 2269 |
+
}
|
| 2270 |
+
}
|
| 2271 |
+
}
|
| 2272 |
+
|
| 2273 |
+
size_t n_non_eog = 0;
|
| 2274 |
+
|
| 2275 |
+
size_t size_org = cur_p->size;
|
| 2276 |
+
|
| 2277 |
+
float p_sum = 0.0f;
|
| 2278 |
+
float thold = 0.2f;
|
| 2279 |
+
|
| 2280 |
+
cur_p->size = 0;
|
| 2281 |
+
|
| 2282 |
+
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
| 2283 |
+
|
| 2284 |
+
for (size_t i = 0; i < size_org; ++i) {
|
| 2285 |
+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
| 2286 |
+
|
| 2287 |
+
if (cur_p->data[i].p < thold && !is_eog) {
|
| 2288 |
+
continue;
|
| 2289 |
+
}
|
| 2290 |
+
|
| 2291 |
+
if (!is_eog) {
|
| 2292 |
+
++n_non_eog;
|
| 2293 |
+
}
|
| 2294 |
+
|
| 2295 |
+
p_sum += cur_p->data[i].p;
|
| 2296 |
+
|
| 2297 |
+
// keep this token
|
| 2298 |
+
cur_p->data[cur_p->size++] = cur_p->data[i];
|
| 2299 |
+
}
|
| 2300 |
+
|
| 2301 |
+
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
|
| 2302 |
+
|
| 2303 |
+
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
| 2304 |
+
if (n_non_eog == 0) {
|
| 2305 |
+
cur_p->size = 1;
|
| 2306 |
+
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
| 2307 |
+
cur_p->data[0].logit = 1.0f;
|
| 2308 |
+
|
| 2309 |
+
return;
|
| 2310 |
+
}
|
| 2311 |
+
|
| 2312 |
+
// normalize probs
|
| 2313 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 2314 |
+
cur_p->data[i].p /= p_sum;
|
| 2315 |
+
|
| 2316 |
+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
| 2317 |
+
}
|
| 2318 |
+
|
| 2319 |
+
size_org = cur_p->size;
|
| 2320 |
+
p_sum = 0.0f;
|
| 2321 |
+
thold = 1.0/(n_non_eog + 1);
|
| 2322 |
+
|
| 2323 |
+
cur_p->size = 0;
|
| 2324 |
+
|
| 2325 |
+
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
| 2326 |
+
|
| 2327 |
+
for (size_t i = 0; i < size_org; ++i) {
|
| 2328 |
+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
| 2329 |
+
|
| 2330 |
+
if (cur_p->data[i].p < thold && !is_eog) {
|
| 2331 |
+
continue;
|
| 2332 |
+
}
|
| 2333 |
+
|
| 2334 |
+
p_sum += cur_p->data[i].p;
|
| 2335 |
+
|
| 2336 |
+
cur_p->data[cur_p->size++] = cur_p->data[i];
|
| 2337 |
+
}
|
| 2338 |
+
|
| 2339 |
+
// normalize probs
|
| 2340 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 2341 |
+
cur_p->data[i].p /= p_sum;
|
| 2342 |
+
|
| 2343 |
+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
| 2344 |
+
}
|
| 2345 |
+
|
| 2346 |
+
#undef LOG_DBG_CUR
|
| 2347 |
+
}
|
| 2348 |
+
|
| 2349 |
+
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
| 2350 |
+
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
| 2351 |
+
return llama_sampler_init_infill_impl(*ctx->vocab);
|
| 2352 |
+
}
|
| 2353 |
+
|
| 2354 |
+
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
| 2355 |
+
delete (llama_sampler_infill *) smpl->ctx;
|
| 2356 |
+
}
|
| 2357 |
+
|
| 2358 |
+
static struct llama_sampler_i llama_sampler_infill_i = {
|
| 2359 |
+
/* .name = */ llama_sampler_infill_name,
|
| 2360 |
+
/* .accept = */ nullptr,
|
| 2361 |
+
/* .apply = */ llama_sampler_infill_apply,
|
| 2362 |
+
/* .reset = */ nullptr,
|
| 2363 |
+
/* .clone = */ llama_sampler_infill_clone,
|
| 2364 |
+
/* .free = */ llama_sampler_infill_free,
|
| 2365 |
+
};
|
| 2366 |
+
|
| 2367 |
+
struct llama_sampler * llama_sampler_init_infill_impl(
|
| 2368 |
+
const struct llama_vocab & vocab) {
|
| 2369 |
+
return new llama_sampler {
|
| 2370 |
+
/* .iface = */ &llama_sampler_infill_i,
|
| 2371 |
+
/* .ctx = */ new llama_sampler_infill {
|
| 2372 |
+
/* .vocab = */ &vocab,
|
| 2373 |
+
/* .buf0 = */ std::vector<char>(512),
|
| 2374 |
+
/* .buf1 = */ std::vector<char>(512),
|
| 2375 |
+
},
|
| 2376 |
+
};
|
| 2377 |
+
}
|
| 2378 |
+
|
| 2379 |
// utils
|
| 2380 |
|
| 2381 |
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
examples/talk-llama/llama-sampling.h
CHANGED
|
@@ -4,8 +4,6 @@
|
|
| 4 |
|
| 5 |
#include "llama-grammar.h"
|
| 6 |
|
| 7 |
-
#include <unordered_map>
|
| 8 |
-
|
| 9 |
struct llama_vocab;
|
| 10 |
struct llama_grammar;
|
| 11 |
|
|
@@ -27,3 +25,24 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
| 27 |
const struct llama_vocab & vocab,
|
| 28 |
const char * grammar_str,
|
| 29 |
const char * grammar_root);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
#include "llama-grammar.h"
|
| 6 |
|
|
|
|
|
|
|
| 7 |
struct llama_vocab;
|
| 8 |
struct llama_grammar;
|
| 9 |
|
|
|
|
| 25 |
const struct llama_vocab & vocab,
|
| 26 |
const char * grammar_str,
|
| 27 |
const char * grammar_root);
|
| 28 |
+
|
| 29 |
+
struct llama_sampler * llama_sampler_init_infill_impl(
|
| 30 |
+
const struct llama_vocab & vocab);
|
| 31 |
+
|
| 32 |
+
struct llama_sampler * llama_sampler_init_dry_impl(
|
| 33 |
+
const struct llama_vocab & vocab,
|
| 34 |
+
int32_t context_size,
|
| 35 |
+
float dry_multiplier,
|
| 36 |
+
float dry_base,
|
| 37 |
+
int32_t dry_allowed_length,
|
| 38 |
+
int32_t dry_penalty_last_n,
|
| 39 |
+
const char ** seq_breakers,
|
| 40 |
+
size_t num_breakers);
|
| 41 |
+
|
| 42 |
+
struct llama_sampler * llama_sampler_init_dry_testing(
|
| 43 |
+
int32_t context_size,
|
| 44 |
+
float dry_multiplier,
|
| 45 |
+
float dry_base,
|
| 46 |
+
int32_t dry_allowed_length,
|
| 47 |
+
int32_t dry_penalty_last_n,
|
| 48 |
+
const std::vector<std::vector<llama_token>>& seq_breakers);
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -221,7 +221,7 @@ struct llm_tokenizer_spm_session {
|
|
| 221 |
}
|
| 222 |
|
| 223 |
// seed the work queue with all possible 2-character tokens.
|
| 224 |
-
for (
|
| 225 |
try_add_bigram(i - 1, i);
|
| 226 |
}
|
| 227 |
|
|
@@ -563,7 +563,7 @@ struct llm_tokenizer_bpe_session {
|
|
| 563 |
index++;
|
| 564 |
symbols.emplace_back(sym);
|
| 565 |
}
|
| 566 |
-
for (
|
| 567 |
add_new_bigram(i - 1, i);
|
| 568 |
}
|
| 569 |
|
|
@@ -1663,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
|
|
| 1663 |
return vocab.special_eos_id;
|
| 1664 |
}
|
| 1665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1666 |
llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
|
| 1667 |
return vocab.special_cls_id;
|
| 1668 |
}
|
|
@@ -1688,23 +1696,39 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
|
|
| 1688 |
}
|
| 1689 |
|
| 1690 |
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
|
| 1691 |
-
return vocab.
|
| 1692 |
}
|
| 1693 |
|
| 1694 |
llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
|
| 1695 |
-
return vocab.
|
| 1696 |
}
|
| 1697 |
|
| 1698 |
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
|
| 1699 |
-
return vocab.
|
| 1700 |
}
|
| 1701 |
|
| 1702 |
-
llama_token
|
| 1703 |
-
return vocab.
|
| 1704 |
}
|
| 1705 |
|
| 1706 |
-
llama_token
|
| 1707 |
-
return vocab.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1708 |
}
|
| 1709 |
|
| 1710 |
int32_t llama_tokenize_impl(
|
|
@@ -1942,3 +1966,19 @@ int32_t llama_detokenize_impl(
|
|
| 1942 |
|
| 1943 |
return total <= text_len_max ? total : -total;
|
| 1944 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
}
|
| 222 |
|
| 223 |
// seed the work queue with all possible 2-character tokens.
|
| 224 |
+
for (int i = 1; i < (int) symbols.size(); ++i) {
|
| 225 |
try_add_bigram(i - 1, i);
|
| 226 |
}
|
| 227 |
|
|
|
|
| 563 |
index++;
|
| 564 |
symbols.emplace_back(sym);
|
| 565 |
}
|
| 566 |
+
for (int i = 1; i < (int) symbols.size(); ++i) {
|
| 567 |
add_new_bigram(i - 1, i);
|
| 568 |
}
|
| 569 |
|
|
|
|
| 1663 |
return vocab.special_eos_id;
|
| 1664 |
}
|
| 1665 |
|
| 1666 |
+
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
|
| 1667 |
+
return vocab.special_eot_id;
|
| 1668 |
+
}
|
| 1669 |
+
|
| 1670 |
+
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
|
| 1671 |
+
return vocab.special_eom_id;
|
| 1672 |
+
}
|
| 1673 |
+
|
| 1674 |
llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
|
| 1675 |
return vocab.special_cls_id;
|
| 1676 |
}
|
|
|
|
| 1696 |
}
|
| 1697 |
|
| 1698 |
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
|
| 1699 |
+
return vocab.special_fim_pre_id;
|
| 1700 |
}
|
| 1701 |
|
| 1702 |
llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
|
| 1703 |
+
return vocab.special_fim_mid_id;
|
| 1704 |
}
|
| 1705 |
|
| 1706 |
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
|
| 1707 |
+
return vocab.special_fim_suf_id;
|
| 1708 |
}
|
| 1709 |
|
| 1710 |
+
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
|
| 1711 |
+
return vocab.special_fim_pre_id;
|
| 1712 |
}
|
| 1713 |
|
| 1714 |
+
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
|
| 1715 |
+
return vocab.special_fim_suf_id;
|
| 1716 |
+
}
|
| 1717 |
+
|
| 1718 |
+
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
|
| 1719 |
+
return vocab.special_fim_mid_id;
|
| 1720 |
+
}
|
| 1721 |
+
|
| 1722 |
+
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
|
| 1723 |
+
return vocab.special_fim_pad_id;
|
| 1724 |
+
}
|
| 1725 |
+
|
| 1726 |
+
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
|
| 1727 |
+
return vocab.special_fim_rep_id;
|
| 1728 |
+
}
|
| 1729 |
+
|
| 1730 |
+
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
|
| 1731 |
+
return vocab.special_fim_sep_id;
|
| 1732 |
}
|
| 1733 |
|
| 1734 |
int32_t llama_tokenize_impl(
|
|
|
|
| 1966 |
|
| 1967 |
return total <= text_len_max ? total : -total;
|
| 1968 |
}
|
| 1969 |
+
|
| 1970 |
+
std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
|
| 1971 |
+
std::string text;
|
| 1972 |
+
text.resize(std::max(text.capacity(), tokens.size()));
|
| 1973 |
+
int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
| 1974 |
+
if (n_chars < 0) {
|
| 1975 |
+
text.resize(-n_chars);
|
| 1976 |
+
n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
| 1977 |
+
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
| 1978 |
+
}
|
| 1979 |
+
|
| 1980 |
+
text.resize(n_chars);
|
| 1981 |
+
|
| 1982 |
+
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
|
| 1983 |
+
return text;
|
| 1984 |
+
}
|
examples/talk-llama/llama-vocab.h
CHANGED
|
@@ -37,20 +37,26 @@ struct llama_vocab {
|
|
| 37 |
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
| 38 |
|
| 39 |
// default LLaMA special tokens
|
|
|
|
| 40 |
id special_bos_id = 1;
|
| 41 |
id special_eos_id = 2;
|
|
|
|
|
|
|
| 42 |
id special_unk_id = 0;
|
| 43 |
-
id special_sep_id =
|
| 44 |
-
id special_pad_id =
|
| 45 |
-
id special_cls_id =
|
| 46 |
-
id special_mask_id =
|
| 47 |
-
|
| 48 |
-
id linefeed_id
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
id
|
| 52 |
-
id
|
| 53 |
-
id
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
// set of all tokens that cause "end of generation"
|
| 56 |
std::set<id> special_eog_ids;
|
|
@@ -104,19 +110,26 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t
|
|
| 104 |
|
| 105 |
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
|
| 106 |
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
|
|
|
|
|
|
|
| 107 |
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
|
| 108 |
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
|
| 109 |
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
|
| 110 |
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
|
| 111 |
|
| 112 |
-
bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
|
| 113 |
-
bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
|
| 114 |
-
|
| 115 |
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
|
| 116 |
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
|
| 117 |
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
|
| 118 |
-
|
| 119 |
-
llama_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
int32_t llama_tokenize_impl(
|
| 122 |
const struct llama_vocab & vocab,
|
|
@@ -136,6 +149,12 @@ int32_t llama_token_to_piece_impl(
|
|
| 136 |
int32_t lstrip,
|
| 137 |
bool special);
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
int32_t llama_detokenize_impl(
|
| 140 |
const struct llama_vocab & vocab,
|
| 141 |
const llama_token * tokens,
|
|
@@ -144,3 +163,8 @@ int32_t llama_detokenize_impl(
|
|
| 144 |
int32_t text_len_max,
|
| 145 |
bool remove_special,
|
| 146 |
bool unparse_special);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
| 38 |
|
| 39 |
// default LLaMA special tokens
|
| 40 |
+
// TODO: should we set all of these to LLAMA_TOKEN_NULL?
|
| 41 |
id special_bos_id = 1;
|
| 42 |
id special_eos_id = 2;
|
| 43 |
+
id special_eot_id = LLAMA_TOKEN_NULL;
|
| 44 |
+
id special_eom_id = LLAMA_TOKEN_NULL;
|
| 45 |
id special_unk_id = 0;
|
| 46 |
+
id special_sep_id = LLAMA_TOKEN_NULL;
|
| 47 |
+
id special_pad_id = LLAMA_TOKEN_NULL;
|
| 48 |
+
id special_cls_id = LLAMA_TOKEN_NULL;
|
| 49 |
+
id special_mask_id = LLAMA_TOKEN_NULL;
|
| 50 |
+
|
| 51 |
+
id linefeed_id = 13;
|
| 52 |
+
|
| 53 |
+
// fim tokens
|
| 54 |
+
id special_fim_pre_id = LLAMA_TOKEN_NULL;
|
| 55 |
+
id special_fim_suf_id = LLAMA_TOKEN_NULL;
|
| 56 |
+
id special_fim_mid_id = LLAMA_TOKEN_NULL;
|
| 57 |
+
id special_fim_pad_id = LLAMA_TOKEN_NULL;
|
| 58 |
+
id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
|
| 59 |
+
id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
|
| 60 |
|
| 61 |
// set of all tokens that cause "end of generation"
|
| 62 |
std::set<id> special_eog_ids;
|
|
|
|
| 110 |
|
| 111 |
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
|
| 112 |
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
|
| 113 |
+
llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
|
| 114 |
+
llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
|
| 115 |
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
|
| 116 |
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
|
| 117 |
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
|
| 118 |
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
|
| 119 |
|
|
|
|
|
|
|
|
|
|
| 120 |
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
|
| 121 |
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
|
| 122 |
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
|
| 123 |
+
|
| 124 |
+
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
|
| 125 |
+
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
|
| 126 |
+
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
|
| 127 |
+
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
|
| 128 |
+
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
|
| 129 |
+
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
|
| 130 |
+
|
| 131 |
+
bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
|
| 132 |
+
bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
|
| 133 |
|
| 134 |
int32_t llama_tokenize_impl(
|
| 135 |
const struct llama_vocab & vocab,
|
|
|
|
| 149 |
int32_t lstrip,
|
| 150 |
bool special);
|
| 151 |
|
| 152 |
+
// check if token0 is contained as a prefix in token1
|
| 153 |
+
bool llama_token_is_prefix_impl(
|
| 154 |
+
const struct llama_vocab & vocab,
|
| 155 |
+
llama_token token0,
|
| 156 |
+
llama_token token1);
|
| 157 |
+
|
| 158 |
int32_t llama_detokenize_impl(
|
| 159 |
const struct llama_vocab & vocab,
|
| 160 |
const llama_token * tokens,
|
|
|
|
| 163 |
int32_t text_len_max,
|
| 164 |
bool remove_special,
|
| 165 |
bool unparse_special);
|
| 166 |
+
|
| 167 |
+
std::string llama_detokenize(
|
| 168 |
+
const struct llama_vocab & vocab,
|
| 169 |
+
const std::vector<llama_token> & tokens,
|
| 170 |
+
bool special);
|
examples/talk-llama/llama.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama.h
CHANGED
|
@@ -217,6 +217,7 @@ extern "C" {
|
|
| 217 |
|
| 218 |
typedef struct llama_token_data_array {
|
| 219 |
// TODO: consider SoA
|
|
|
|
| 220 |
llama_token_data * data;
|
| 221 |
size_t size;
|
| 222 |
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
|
@@ -232,8 +233,11 @@ extern "C" {
|
|
| 232 |
// - token : the token ids of the input (used when embd is NULL)
|
| 233 |
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
| 234 |
// - pos : the positions of the respective token in the sequence
|
|
|
|
| 235 |
// - seq_id : the sequence to which the respective token belongs
|
|
|
|
| 236 |
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
|
|
|
| 237 |
//
|
| 238 |
typedef struct llama_batch {
|
| 239 |
int32_t n_tokens;
|
|
@@ -244,15 +248,6 @@ extern "C" {
|
|
| 244 |
int32_t * n_seq_id;
|
| 245 |
llama_seq_id ** seq_id;
|
| 246 |
int8_t * logits; // TODO: rename this to "output"
|
| 247 |
-
|
| 248 |
-
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
| 249 |
-
// for future-proof code, use the above fields instead and ignore everything below
|
| 250 |
-
//
|
| 251 |
-
// pos[i] = all_pos_0 + i*all_pos_1
|
| 252 |
-
//
|
| 253 |
-
llama_pos all_pos_0; // used if pos == NULL
|
| 254 |
-
llama_pos all_pos_1; // used if pos == NULL
|
| 255 |
-
llama_seq_id all_seq_id; // used if seq_id == NULL
|
| 256 |
} llama_batch;
|
| 257 |
|
| 258 |
enum llama_model_kv_override_type {
|
|
@@ -433,6 +428,7 @@ extern "C" {
|
|
| 433 |
LLAMA_API bool llama_supports_mmap (void);
|
| 434 |
LLAMA_API bool llama_supports_mlock (void);
|
| 435 |
LLAMA_API bool llama_supports_gpu_offload(void);
|
|
|
|
| 436 |
|
| 437 |
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
| 438 |
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
|
@@ -775,15 +771,15 @@ extern "C" {
|
|
| 775 |
// Decoding
|
| 776 |
//
|
| 777 |
|
| 778 |
-
// Return batch for single sequence of tokens
|
|
|
|
|
|
|
| 779 |
//
|
| 780 |
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
| 781 |
//
|
| 782 |
LLAMA_API struct llama_batch llama_batch_get_one(
|
| 783 |
llama_token * tokens,
|
| 784 |
-
int32_t n_tokens
|
| 785 |
-
llama_pos pos_0,
|
| 786 |
-
llama_seq_id seq_id);
|
| 787 |
|
| 788 |
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
| 789 |
// Each token can be assigned up to n_seq_max sequence ids
|
|
@@ -896,6 +892,7 @@ extern "C" {
|
|
| 896 |
// Special tokens
|
| 897 |
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
| 898 |
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
|
|
|
| 899 |
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
| 900 |
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
| 901 |
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
|
@@ -904,11 +901,17 @@ extern "C" {
|
|
| 904 |
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
|
| 905 |
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
|
| 906 |
|
| 907 |
-
//
|
| 908 |
-
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model)
|
| 909 |
-
LLAMA_API llama_token llama_token_middle(const struct llama_model * model)
|
| 910 |
-
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model)
|
| 911 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
|
| 913 |
//
|
| 914 |
// Tokenization
|
|
@@ -1067,12 +1070,13 @@ extern "C" {
|
|
| 1067 |
|
| 1068 |
// available samplers:
|
| 1069 |
|
| 1070 |
-
LLAMA_API struct llama_sampler * llama_sampler_init_greedy
|
| 1071 |
-
LLAMA_API struct llama_sampler * llama_sampler_init_dist
|
| 1072 |
|
| 1073 |
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
| 1074 |
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
| 1075 |
-
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void)
|
|
|
|
| 1076 |
|
| 1077 |
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1078 |
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
|
@@ -1088,11 +1092,16 @@ extern "C" {
|
|
| 1088 |
|
| 1089 |
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
| 1090 |
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
|
|
|
|
|
|
| 1091 |
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
| 1092 |
|
| 1093 |
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
| 1094 |
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
| 1095 |
|
|
|
|
|
|
|
|
|
|
| 1096 |
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1097 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
| 1098 |
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
|
@@ -1132,11 +1141,43 @@ extern "C" {
|
|
| 1132 |
bool penalize_nl, // consider newlines as a repeatable token
|
| 1133 |
bool ignore_eos); // ignore the end-of-sequence token
|
| 1134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1135 |
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
| 1136 |
int32_t n_vocab,
|
| 1137 |
int32_t n_logit_bias,
|
| 1138 |
const llama_logit_bias * logit_bias);
|
| 1139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1140 |
|
| 1141 |
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
| 1142 |
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
|
|
|
| 217 |
|
| 218 |
typedef struct llama_token_data_array {
|
| 219 |
// TODO: consider SoA
|
| 220 |
+
// NOTE: this pointer can be modified by the samplers
|
| 221 |
llama_token_data * data;
|
| 222 |
size_t size;
|
| 223 |
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
|
|
|
| 233 |
// - token : the token ids of the input (used when embd is NULL)
|
| 234 |
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
| 235 |
// - pos : the positions of the respective token in the sequence
|
| 236 |
+
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
| 237 |
// - seq_id : the sequence to which the respective token belongs
|
| 238 |
+
// (if set to NULL, the sequence ID will be assumed to be 0)
|
| 239 |
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
| 240 |
+
// (if set to NULL, only the logits for last token will be returned)
|
| 241 |
//
|
| 242 |
typedef struct llama_batch {
|
| 243 |
int32_t n_tokens;
|
|
|
|
| 248 |
int32_t * n_seq_id;
|
| 249 |
llama_seq_id ** seq_id;
|
| 250 |
int8_t * logits; // TODO: rename this to "output"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
} llama_batch;
|
| 252 |
|
| 253 |
enum llama_model_kv_override_type {
|
|
|
|
| 428 |
LLAMA_API bool llama_supports_mmap (void);
|
| 429 |
LLAMA_API bool llama_supports_mlock (void);
|
| 430 |
LLAMA_API bool llama_supports_gpu_offload(void);
|
| 431 |
+
LLAMA_API bool llama_supports_rpc (void);
|
| 432 |
|
| 433 |
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
| 434 |
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
|
|
|
| 771 |
// Decoding
|
| 772 |
//
|
| 773 |
|
| 774 |
+
// Return batch for single sequence of tokens
|
| 775 |
+
// The sequence ID will be fixed to 0
|
| 776 |
+
// The position of the tokens will be tracked automatically by llama_decode
|
| 777 |
//
|
| 778 |
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
| 779 |
//
|
| 780 |
LLAMA_API struct llama_batch llama_batch_get_one(
|
| 781 |
llama_token * tokens,
|
| 782 |
+
int32_t n_tokens);
|
|
|
|
|
|
|
| 783 |
|
| 784 |
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
| 785 |
// Each token can be assigned up to n_seq_max sequence ids
|
|
|
|
| 892 |
// Special tokens
|
| 893 |
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
| 894 |
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
| 895 |
+
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
|
| 896 |
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
| 897 |
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
| 898 |
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
|
|
|
| 901 |
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
|
| 902 |
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
|
| 903 |
|
| 904 |
+
// infill tokens
|
| 905 |
+
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
|
| 906 |
+
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
|
| 907 |
+
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
|
| 908 |
+
|
| 909 |
+
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
|
| 910 |
+
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
|
| 911 |
+
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
|
| 912 |
+
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
|
| 913 |
+
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
|
| 914 |
+
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
|
| 915 |
|
| 916 |
//
|
| 917 |
// Tokenization
|
|
|
|
| 1070 |
|
| 1071 |
// available samplers:
|
| 1072 |
|
| 1073 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
| 1074 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
| 1075 |
|
| 1076 |
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
| 1077 |
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
| 1078 |
+
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
| 1079 |
+
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
|
| 1080 |
|
| 1081 |
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1082 |
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
|
|
|
| 1092 |
|
| 1093 |
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
| 1094 |
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
| 1095 |
+
|
| 1096 |
+
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
|
| 1097 |
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
| 1098 |
|
| 1099 |
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
| 1100 |
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
| 1101 |
|
| 1102 |
+
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
| 1103 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
| 1104 |
+
|
| 1105 |
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1106 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
| 1107 |
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
|
|
|
| 1141 |
bool penalize_nl, // consider newlines as a repeatable token
|
| 1142 |
bool ignore_eos); // ignore the end-of-sequence token
|
| 1143 |
|
| 1144 |
+
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
| 1145 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
| 1146 |
+
const struct llama_model * model,
|
| 1147 |
+
float dry_multiplier,
|
| 1148 |
+
float dry_base,
|
| 1149 |
+
int32_t dry_allowed_length,
|
| 1150 |
+
int32_t dry_penalty_last_n,
|
| 1151 |
+
const char ** seq_breakers,
|
| 1152 |
+
size_t num_breakers);
|
| 1153 |
+
|
| 1154 |
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
| 1155 |
int32_t n_vocab,
|
| 1156 |
int32_t n_logit_bias,
|
| 1157 |
const llama_logit_bias * logit_bias);
|
| 1158 |
|
| 1159 |
+
// this sampler is meant to be used for fill-in-the-middle infilling
|
| 1160 |
+
// it's supposed to be used after top_k + top_p sampling
|
| 1161 |
+
//
|
| 1162 |
+
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
|
| 1163 |
+
// 2. combine probs of tokens that have the same prefix
|
| 1164 |
+
//
|
| 1165 |
+
// example:
|
| 1166 |
+
//
|
| 1167 |
+
// - before:
|
| 1168 |
+
// "hel": 0.5
|
| 1169 |
+
// "hell": 0.2
|
| 1170 |
+
// "hello": 0.1
|
| 1171 |
+
// "dummy": 0.1
|
| 1172 |
+
//
|
| 1173 |
+
// - after:
|
| 1174 |
+
// "hel": 0.8
|
| 1175 |
+
// "dummy": 0.1
|
| 1176 |
+
//
|
| 1177 |
+
// 3. discard non-EOG tokens with low prob
|
| 1178 |
+
// 4. if no tokens are left -> pick EOT
|
| 1179 |
+
//
|
| 1180 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
|
| 1181 |
|
| 1182 |
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
| 1183 |
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
examples/talk-llama/unicode-data.cpp
CHANGED
|
@@ -2311,7 +2311,7 @@ const std::unordered_set<uint32_t> unicode_set_whitespace = {
|
|
| 2311 |
0x003000,
|
| 2312 |
};
|
| 2313 |
|
| 2314 |
-
// list is always in ascending order, to enable binary
|
| 2315 |
const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase = {
|
| 2316 |
{0x000041, 0x000061},
|
| 2317 |
{0x000042, 0x000062},
|
|
@@ -3748,7 +3748,7 @@ const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase
|
|
| 3748 |
{0x01E921, 0x01E943},
|
| 3749 |
};
|
| 3750 |
|
| 3751 |
-
// list is always in ascending order, to enable binary
|
| 3752 |
const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_uppercase = {
|
| 3753 |
{0x000061, 0x000041},
|
| 3754 |
{0x000062, 0x000042},
|
|
|
|
| 2311 |
0x003000,
|
| 2312 |
};
|
| 2313 |
|
| 2314 |
+
// list is always in ascending order, to enable binary search
|
| 2315 |
const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_lowercase = {
|
| 2316 |
{0x000041, 0x000061},
|
| 2317 |
{0x000042, 0x000062},
|
|
|
|
| 3748 |
{0x01E921, 0x01E943},
|
| 3749 |
};
|
| 3750 |
|
| 3751 |
+
// list is always in ascending order, to enable binary search
|
| 3752 |
const std::initializer_list<std::pair<uint32_t, uint32_t>> unicode_map_uppercase = {
|
| 3753 |
{0x000061, 0x000041},
|
| 3754 |
{0x000062, 0x000042},
|