talk-llama : sync llama.cpp

This commit is contained in:
Georgi Gerganov
2025-01-14 09:53:50 +02:00
parent 19d95f9f9a
commit 99b011a9f5
26 changed files with 5788 additions and 5093 deletions

View File

@ -1,5 +1,8 @@
#include "llama-context.h"
#include "llama-impl.h"
#include "llama-mmap.h"
#include <cassert>
#include <cmath>
#include <cstring>
@ -467,11 +470,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & vocab = lctx.model.vocab;
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
@ -504,7 +508,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
auto * buft = ggml_backend_cpu_buffer_type();
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
auto * output_dev = lctx.model.dev_output.dev;
auto * output_dev = lctx.model.dev_output();
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
if (output_dev_host_buft) {
buft = output_dev_host_buft;
@ -538,7 +542,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
void llama_output_reorder(struct llama_context & ctx) {
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = ctx.model.hparams.n_vocab;
const uint32_t n_vocab = ctx.model.vocab.n_tokens();
const uint32_t n_embd = ctx.model.hparams.n_embd;
const int32_t n_outputs = ctx.n_outputs;
@ -722,7 +726,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
return ctx->logits + j*ctx->model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
@ -882,7 +886,7 @@ struct llama_data_write {
}
void write_logits(const struct llama_context * ctx) {
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
write(&logits_size, sizeof(logits_size));