From 2f5a5a66dd22587ad9ff3a4203d2c5e64194e346 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 8 Mar 2024 12:04:43 +0200 Subject: [PATCH] talk-llama : use llama_decode instead of llama_eval --- examples/talk-llama/talk-llama.cpp | 36 ++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index ddc9e765..4e1c1755 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -391,6 +391,8 @@ int main(int argc, char ** argv) { prompt_llama = ::replace(prompt_llama, "{4}", chat_symb); + llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1); + // init session std::string path_session = params.path_session; std::vector session_tokens; @@ -426,8 +428,21 @@ int main(int argc, char ** argv) { printf("\n"); printf("%s : initializing - please wait ...\n", __func__); - if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0)) { - fprintf(stderr, "%s : failed to eval\n", __func__); + // prepare batch + { + batch.n_tokens = embd_inp.size(); + + for (int i = 0; i < batch.n_tokens; i++) { + batch.token[i] = embd_inp[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = i == batch.n_tokens - 1; + } + } + + if (llama_decode(ctx_llama, batch)) { + fprintf(stderr, "%s : failed to decode\n", __func__); return 1; } @@ -647,8 +662,21 @@ int main(int argc, char ** argv) { n_session_consumed = session_tokens.size(); } - if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past)) { - fprintf(stderr, "%s : failed to eval\n", __func__); + // prepare batch + { + batch.n_tokens = embd.size(); + + for (int i = 0; i < batch.n_tokens; i++) { + batch.token[i] = embd[i]; + batch.pos[i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = i == batch.n_tokens - 1; + } + } + + if (llama_decode(ctx_llama, batch)) { + fprintf(stderr, "%s : failed to decode\n", __func__); return 1; } }