mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-18 16:28:33 +02:00
talk-llama : use llama_decode instead of llama_eval
This commit is contained in:
parent
8e409d1113
commit
2f5a5a66dd
@ -391,6 +391,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
|
||||||
|
|
||||||
// init session
|
// init session
|
||||||
std::string path_session = params.path_session;
|
std::string path_session = params.path_session;
|
||||||
std::vector<llama_token> session_tokens;
|
std::vector<llama_token> session_tokens;
|
||||||
@ -426,8 +428,21 @@ int main(int argc, char ** argv) {
|
|||||||
printf("\n");
|
printf("\n");
|
||||||
printf("%s : initializing - please wait ...\n", __func__);
|
printf("%s : initializing - please wait ...\n", __func__);
|
||||||
|
|
||||||
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0)) {
|
// prepare batch
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
{
|
||||||
|
batch.n_tokens = embd_inp.size();
|
||||||
|
|
||||||
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
batch.token[i] = embd_inp[i];
|
||||||
|
batch.pos[i] = i;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id[i][0] = 0;
|
||||||
|
batch.logits[i] = i == batch.n_tokens - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(ctx_llama, batch)) {
|
||||||
|
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -647,8 +662,21 @@ int main(int argc, char ** argv) {
|
|||||||
n_session_consumed = session_tokens.size();
|
n_session_consumed = session_tokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past)) {
|
// prepare batch
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
{
|
||||||
|
batch.n_tokens = embd.size();
|
||||||
|
|
||||||
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
batch.token[i] = embd[i];
|
||||||
|
batch.pos[i] = n_past + i;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id[i][0] = 0;
|
||||||
|
batch.logits[i] = i == batch.n_tokens - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(ctx_llama, batch)) {
|
||||||
|
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user