Luis Herrera commited on
Commit
6eca3b7
·
unverified ·
1 Parent(s): 1251039

talk-llama : fix session prompt load (#854)

Browse files
Files changed (1) hide show
  1. examples/talk-llama/talk-llama.cpp +32 -24
examples/talk-llama/talk-llama.cpp CHANGED
@@ -333,27 +333,10 @@ int main(int argc, char ** argv) {
333
 
334
  prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
335
 
336
- // evaluate the initial prompt
337
-
338
- auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
339
-
340
- printf("\n");
341
- printf("%s : initializing - please wait ...\n", __func__);
342
-
343
- if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
344
- fprintf(stderr, "%s : failed to eval\n", __func__);
345
- return 1;
346
- }
347
-
348
- if (params.verbose_prompt) {
349
- fprintf(stdout, "\n");
350
- fprintf(stdout, "%s", prompt_llama.c_str());
351
- fflush(stdout);
352
- }
353
-
354
  // init session
355
  std::string path_session = params.path_session;
356
  std::vector<llama_token> session_tokens;
 
357
 
358
  if (!path_session.empty()) {
359
  fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
@@ -370,6 +353,9 @@ int main(int argc, char ** argv) {
370
  return 1;
371
  }
372
  session_tokens.resize(n_token_count_out);
 
 
 
373
 
374
  fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
375
  } else {
@@ -377,6 +363,22 @@ int main(int argc, char ** argv) {
377
  }
378
  }
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  // debug message about similarity of saved session, if applicable
381
  size_t n_matching_session_tokens = 0;
382
  if (session_tokens.size()) {
@@ -417,7 +419,7 @@ int main(int argc, char ** argv) {
417
 
418
  int n_past = n_keep;
419
  int n_prev = 64; // TODO arg
420
- int n_session_consumed = 0;
421
 
422
  std::vector<llama_token> embd;
423
 
@@ -494,6 +496,11 @@ int main(int argc, char ** argv) {
494
 
495
  embd = ::llama_tokenize(ctx_llama, text_heard, false);
496
 
 
 
 
 
 
497
  // text inference
498
  bool done = false;
499
  std::string text_to_speak;
@@ -539,20 +546,21 @@ int main(int argc, char ** argv) {
539
  }
540
  }
541
 
 
 
 
 
 
542
  if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
543
  fprintf(stderr, "%s : failed to eval\n", __func__);
544
  return 1;
545
  }
546
  }
547
 
548
- //printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
549
 
550
  embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
551
  n_past += embd.size();
552
- if (embd.size() > 0 && !path_session.empty()) {
553
- session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
554
- n_session_consumed = session_tokens.size();
555
- }
556
  embd.clear();
557
 
558
  if (done) break;
 
333
 
334
  prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  // init session
337
  std::string path_session = params.path_session;
338
  std::vector<llama_token> session_tokens;
339
+ auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
340
 
341
  if (!path_session.empty()) {
342
  fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
 
353
  return 1;
354
  }
355
  session_tokens.resize(n_token_count_out);
356
+ for (size_t i = 0; i < session_tokens.size(); i++) {
357
+ embd_inp[i] = session_tokens[i];
358
+ }
359
 
360
  fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
361
  } else {
 
363
  }
364
  }
365
 
366
+ // evaluate the initial prompt
367
+
368
+ printf("\n");
369
+ printf("%s : initializing - please wait ...\n", __func__);
370
+
371
+ if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
372
+ fprintf(stderr, "%s : failed to eval\n", __func__);
373
+ return 1;
374
+ }
375
+
376
+ if (params.verbose_prompt) {
377
+ fprintf(stdout, "\n");
378
+ fprintf(stdout, "%s", prompt_llama.c_str());
379
+ fflush(stdout);
380
+ }
381
+
382
  // debug message about similarity of saved session, if applicable
383
  size_t n_matching_session_tokens = 0;
384
  if (session_tokens.size()) {
 
419
 
420
  int n_past = n_keep;
421
  int n_prev = 64; // TODO arg
422
+ int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
423
 
424
  std::vector<llama_token> embd;
425
 
 
496
 
497
  embd = ::llama_tokenize(ctx_llama, text_heard, false);
498
 
499
+ // Append the new input tokens to the session_tokens vector
500
+ if (!path_session.empty()) {
501
+ session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
502
+ }
503
+
504
  // text inference
505
  bool done = false;
506
  std::string text_to_speak;
 
546
  }
547
  }
548
 
549
+ if (embd.size() > 0 && !path_session.empty()) {
550
+ session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
551
+ n_session_consumed = session_tokens.size();
552
+ }
553
+
554
  if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
555
  fprintf(stderr, "%s : failed to eval\n", __func__);
556
  return 1;
557
  }
558
  }
559
 
 
560
 
561
  embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
562
  n_past += embd.size();
563
+
 
 
 
564
  embd.clear();
565
 
566
  if (done) break;