talk-llama : add --session support (#845)

* feat: adding session support

* readme: adding --session info in examples/talk-llama

* llama: adding session fixes

* readme: updating session doc

* talk-llama: update the value of need_to_save_session to true in order to save the session in the subsequent interaction

* talk-llama: adding missing function which updates session_tokens
This commit is contained in:
Luis Herrera 2023-05-01 12:18:10 -05:00 committed by GitHub
parent d375d73b2e
commit be5911a9f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 171 additions and 42 deletions

View File

@ -25,6 +25,20 @@ make talk-llama
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience - The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model - The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
## Session
The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.
Example usage:
```bash
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
## TTS ## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice. For best experience, this example needs a TTS tool to convert the generated text responses to voice.

View File

@ -2695,56 +2695,81 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
return ctx->model.tensors_by_name; return ctx->model.tensors_by_name;
} }
size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
// TODO leverage mmap
llama_file file(path_session, "rb"); llama_file file(path_session, "rb");
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (!(magic == 'ggsn' && version == 0)) { // sanity checks
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); {
return 0; const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (!(magic == LLAMA_SESSION_MAGIC && version == LLAMA_SESSION_VERSION)) {
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}
llama_hparams session_hparams;
file.read_raw(&session_hparams, sizeof(llama_hparams));
if (session_hparams != ctx->model.hparams) {
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
return false;
}
} }
llama_hparams session_hparams; // load the prompt
file.read_raw(&session_hparams, sizeof(llama_hparams)); {
const uint32_t n_token_count = file.read_u32();
// REVIEW if (n_token_count > n_token_capacity) {
if (session_hparams != ctx->model.hparams) { fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__); return false;
return 0; }
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
} }
const uint32_t n_token_count = file.read_u32(); // restore the context state
LLAMA_ASSERT(n_token_capacity >= n_token_count); {
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); const size_t n_state_size_cur = file.size - file.tell();
*n_token_count_out = n_token_count; const size_t n_state_size_exp = llama_get_state_size(ctx);
const size_t n_state_size = file.size - file.tell(); if (n_state_size_cur != n_state_size_exp) {
const size_t n_orig_state_size = llama_get_state_size(ctx); fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
if (n_state_size != n_orig_state_size) { return false;
fprintf(stderr, "%s : failed to validate state size\n", __func__); }
std::vector<uint8_t> state_data(n_state_size_cur);
file.read_raw(state_data.data(), n_state_size_cur);
llama_set_state_data(ctx, state_data.data());
} }
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
file.read_raw(state_data.get(), n_state_size); return true;
return llama_set_state_data(ctx, state_data.get());
} }
size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
// TODO save temp & swap
llama_file file(path_session, "wb"); llama_file file(path_session, "wb");
const size_t n_state_size = llama_get_state_size(ctx); file.write_u32(LLAMA_SESSION_MAGIC);
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]); file.write_u32(LLAMA_SESSION_VERSION);
llama_copy_state_data(ctx, state_data.get());
file.write_u32('ggsn'); // magic
file.write_u32(0); // version
file.write_raw(&ctx->model.hparams, sizeof(llama_hparams)); file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
file.write_u32((uint32_t) n_token_count); // REVIEW // save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count);
file.write_raw(state_data.get(), n_state_size); // save the context state
return n_state_size; // REVIEW {
const size_t n_state_size = llama_get_state_size(ctx);
std::vector<uint8_t> state_data(n_state_size);
llama_copy_state_data(ctx, state_data.data());
file.write_raw(state_data.data(), n_state_size);
}
return true;
} }

View File

@ -19,9 +19,11 @@
# define LLAMA_API # define LLAMA_API
#endif #endif
#define LLAMA_FILE_VERSION 1 #define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex #define LLAMA_FILE_MAGIC 'ggjt'
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
#define LLAMA_SESSION_MAGIC 'ggsn'
#define LLAMA_SESSION_VERSION 0
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -138,9 +140,8 @@ extern "C" {
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src); LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
// Save/load session file // Save/load session file
LLAMA_API size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
// Run the llama inference to obtain the logits and probabilities for the next token. // Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process // tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls // n_past is the number of tokens to use from previous eval calls

View File

@ -52,6 +52,7 @@ struct whisper_params {
std::string speak = "./examples/talk-llama/speak.sh"; std::string speak = "./examples/talk-llama/speak.sh";
std::string prompt = ""; std::string prompt = "";
std::string fname_out; std::string fname_out;
std::string path_session = ""; // path to file for saving/loading model eval state
}; };
void whisper_print_usage(int argc, char ** argv, const whisper_params & params); void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -78,6 +79,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i];}
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
@ -124,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama); fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str()); fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -348,6 +351,57 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
} }
// init session
std::string path_session = params.path_session;
std::vector<llama_token> session_tokens;
if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
// fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);
session_tokens.resize(lparams.n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
}
// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
}
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
printf("%s : done! start speaking in the microphone\n", __func__); printf("%s : done! start speaking in the microphone\n", __func__);
printf("\n"); printf("\n");
printf("%s%s", params.person.c_str(), chat_symb.c_str()); printf("%s%s", params.person.c_str(), chat_symb.c_str());
@ -363,6 +417,7 @@ int main(int argc, char ** argv) {
int n_past = n_keep; int n_past = n_keep;
int n_prev = 64; // TODO arg int n_prev = 64; // TODO arg
int n_session_consumed = 0;
std::vector<llama_token> embd; std::vector<llama_token> embd;
@ -450,7 +505,8 @@ int main(int argc, char ** argv) {
// insert n_left/2 tokens at the start of embd from last_n_tokens // insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
// stop saving session if we run out of context
path_session = "";
//printf("\n---\n"); //printf("\n---\n");
//printf("resetting: '"); //printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) { //for (int i = 0; i < (int) embd.size(); i++) {
@ -460,6 +516,29 @@ int main(int argc, char ** argv) {
//printf("\n---\n"); //printf("\n---\n");
} }
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;
}
n_past++;
n_session_consumed++;
if (n_session_consumed >= (int) session_tokens.size()) {
i++;
break;
}
}
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}
}
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) { if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
@ -470,6 +549,10 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end()); embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size(); n_past += embd.size();
if (embd.size() > 0 && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
embd.clear(); embd.clear();
if (done) break; if (done) break;
@ -483,6 +566,11 @@ int main(int argc, char ** argv) {
const int repeat_last_n = 256; const int repeat_last_n = 256;
if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
llama_token id = 0; llama_token id = 0;
{ {
@ -542,6 +630,7 @@ int main(int argc, char ** argv) {
done = true; done = true;
text_to_speak = ::replace(text_to_speak, antiprompt, ""); text_to_speak = ::replace(text_to_speak, antiprompt, "");
fflush(stdout); fflush(stdout);
need_to_save_session = true;
break; break;
} }
} }