Spaces:
Running
Running
Luis Herrera
commited on
talk-llama : fix session prompt load (#854)
Browse files
examples/talk-llama/talk-llama.cpp
CHANGED
|
@@ -333,27 +333,10 @@ int main(int argc, char ** argv) {
|
|
| 333 |
|
| 334 |
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
| 335 |
|
| 336 |
-
// evaluate the initial prompt
|
| 337 |
-
|
| 338 |
-
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
| 339 |
-
|
| 340 |
-
printf("\n");
|
| 341 |
-
printf("%s : initializing - please wait ...\n", __func__);
|
| 342 |
-
|
| 343 |
-
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
| 344 |
-
fprintf(stderr, "%s : failed to eval\n", __func__);
|
| 345 |
-
return 1;
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
if (params.verbose_prompt) {
|
| 349 |
-
fprintf(stdout, "\n");
|
| 350 |
-
fprintf(stdout, "%s", prompt_llama.c_str());
|
| 351 |
-
fflush(stdout);
|
| 352 |
-
}
|
| 353 |
-
|
| 354 |
// init session
|
| 355 |
std::string path_session = params.path_session;
|
| 356 |
std::vector<llama_token> session_tokens;
|
|
|
|
| 357 |
|
| 358 |
if (!path_session.empty()) {
|
| 359 |
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
|
|
@@ -370,6 +353,9 @@ int main(int argc, char ** argv) {
|
|
| 370 |
return 1;
|
| 371 |
}
|
| 372 |
session_tokens.resize(n_token_count_out);
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
| 375 |
} else {
|
|
@@ -377,6 +363,22 @@ int main(int argc, char ** argv) {
|
|
| 377 |
}
|
| 378 |
}
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
// debug message about similarity of saved session, if applicable
|
| 381 |
size_t n_matching_session_tokens = 0;
|
| 382 |
if (session_tokens.size()) {
|
|
@@ -417,7 +419,7 @@ int main(int argc, char ** argv) {
|
|
| 417 |
|
| 418 |
int n_past = n_keep;
|
| 419 |
int n_prev = 64; // TODO arg
|
| 420 |
-
int n_session_consumed = 0;
|
| 421 |
|
| 422 |
std::vector<llama_token> embd;
|
| 423 |
|
|
@@ -494,6 +496,11 @@ int main(int argc, char ** argv) {
|
|
| 494 |
|
| 495 |
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
| 496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
// text inference
|
| 498 |
bool done = false;
|
| 499 |
std::string text_to_speak;
|
|
@@ -539,20 +546,21 @@ int main(int argc, char ** argv) {
|
|
| 539 |
}
|
| 540 |
}
|
| 541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
| 543 |
fprintf(stderr, "%s : failed to eval\n", __func__);
|
| 544 |
return 1;
|
| 545 |
}
|
| 546 |
}
|
| 547 |
|
| 548 |
-
//printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
|
| 549 |
|
| 550 |
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
| 551 |
n_past += embd.size();
|
| 552 |
-
|
| 553 |
-
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
| 554 |
-
n_session_consumed = session_tokens.size();
|
| 555 |
-
}
|
| 556 |
embd.clear();
|
| 557 |
|
| 558 |
if (done) break;
|
|
|
|
| 333 |
|
| 334 |
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
// init session
|
| 337 |
std::string path_session = params.path_session;
|
| 338 |
std::vector<llama_token> session_tokens;
|
| 339 |
+
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
| 340 |
|
| 341 |
if (!path_session.empty()) {
|
| 342 |
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
|
|
|
|
| 353 |
return 1;
|
| 354 |
}
|
| 355 |
session_tokens.resize(n_token_count_out);
|
| 356 |
+
for (size_t i = 0; i < session_tokens.size(); i++) {
|
| 357 |
+
embd_inp[i] = session_tokens[i];
|
| 358 |
+
}
|
| 359 |
|
| 360 |
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
| 361 |
} else {
|
|
|
|
| 363 |
}
|
| 364 |
}
|
| 365 |
|
| 366 |
+
// evaluate the initial prompt
|
| 367 |
+
|
| 368 |
+
printf("\n");
|
| 369 |
+
printf("%s : initializing - please wait ...\n", __func__);
|
| 370 |
+
|
| 371 |
+
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
| 372 |
+
fprintf(stderr, "%s : failed to eval\n", __func__);
|
| 373 |
+
return 1;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
if (params.verbose_prompt) {
|
| 377 |
+
fprintf(stdout, "\n");
|
| 378 |
+
fprintf(stdout, "%s", prompt_llama.c_str());
|
| 379 |
+
fflush(stdout);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
// debug message about similarity of saved session, if applicable
|
| 383 |
size_t n_matching_session_tokens = 0;
|
| 384 |
if (session_tokens.size()) {
|
|
|
|
| 419 |
|
| 420 |
int n_past = n_keep;
|
| 421 |
int n_prev = 64; // TODO arg
|
| 422 |
+
int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
|
| 423 |
|
| 424 |
std::vector<llama_token> embd;
|
| 425 |
|
|
|
|
| 496 |
|
| 497 |
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
| 498 |
|
| 499 |
+
// Append the new input tokens to the session_tokens vector
|
| 500 |
+
if (!path_session.empty()) {
|
| 501 |
+
session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
// text inference
|
| 505 |
bool done = false;
|
| 506 |
std::string text_to_speak;
|
|
|
|
| 546 |
}
|
| 547 |
}
|
| 548 |
|
| 549 |
+
if (embd.size() > 0 && !path_session.empty()) {
|
| 550 |
+
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
| 551 |
+
n_session_consumed = session_tokens.size();
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
| 555 |
fprintf(stderr, "%s : failed to eval\n", __func__);
|
| 556 |
return 1;
|
| 557 |
}
|
| 558 |
}
|
| 559 |
|
|
|
|
| 560 |
|
| 561 |
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
| 562 |
n_past += embd.size();
|
| 563 |
+
|
|
|
|
|
|
|
|
|
|
| 564 |
embd.clear();
|
| 565 |
|
| 566 |
if (done) break;
|