mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
talk-llama : fix session prompt load (#854)
This commit is contained in:
parent
b806420873
commit
0bf680fea2
@ -333,27 +333,10 @@ int main(int argc, char ** argv) {
|
||||
|
||||
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
||||
|
||||
// evaluate the initial prompt
|
||||
|
||||
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
||||
|
||||
printf("\n");
|
||||
printf("%s : initializing - please wait ...\n", __func__);
|
||||
|
||||
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.verbose_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s", prompt_llama.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
// init session
|
||||
std::string path_session = params.path_session;
|
||||
std::vector<llama_token> session_tokens;
|
||||
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
||||
|
||||
if (!path_session.empty()) {
|
||||
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
|
||||
@ -370,6 +353,9 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
session_tokens.resize(n_token_count_out);
|
||||
for (size_t i = 0; i < session_tokens.size(); i++) {
|
||||
embd_inp[i] = session_tokens[i];
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
||||
} else {
|
||||
@ -377,6 +363,22 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
// evaluate the initial prompt
|
||||
|
||||
printf("\n");
|
||||
printf("%s : initializing - please wait ...\n", __func__);
|
||||
|
||||
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.verbose_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s", prompt_llama.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
// debug message about similarity of saved session, if applicable
|
||||
size_t n_matching_session_tokens = 0;
|
||||
if (session_tokens.size()) {
|
||||
@ -417,7 +419,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int n_past = n_keep;
|
||||
int n_prev = 64; // TODO arg
|
||||
int n_session_consumed = 0;
|
||||
int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
@ -494,6 +496,11 @@ int main(int argc, char ** argv) {
|
||||
|
||||
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
||||
|
||||
// Append the new input tokens to the session_tokens vector
|
||||
if (!path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
|
||||
}
|
||||
|
||||
// text inference
|
||||
bool done = false;
|
||||
std::string text_to_speak;
|
||||
@ -539,20 +546,21 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (embd.size() > 0 && !path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
}
|
||||
|
||||
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
||||
n_past += embd.size();
|
||||
if (embd.size() > 0 && !path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
|
||||
if (done) break;
|
||||
|
Loading…
Reference in New Issue
Block a user