mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-09 16:16:22 +02:00
talk-llama : sync llama.cpp
This commit is contained in:
@ -1,5 +1,8 @@
|
||||
#include "llama-context.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-mmap.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
@ -467,11 +470,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
|
||||
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & vocab = lctx.model.vocab;
|
||||
|
||||
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = hparams.n_vocab;
|
||||
const auto n_vocab = vocab.n_tokens();
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
@ -504,7 +508,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
|
||||
|
||||
auto * buft = ggml_backend_cpu_buffer_type();
|
||||
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
|
||||
auto * output_dev = lctx.model.dev_output.dev;
|
||||
auto * output_dev = lctx.model.dev_output();
|
||||
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
|
||||
if (output_dev_host_buft) {
|
||||
buft = output_dev_host_buft;
|
||||
@ -538,7 +542,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
|
||||
void llama_output_reorder(struct llama_context & ctx) {
|
||||
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
|
||||
if (!out_ids.empty()) {
|
||||
const uint32_t n_vocab = ctx.model.hparams.n_vocab;
|
||||
const uint32_t n_vocab = ctx.model.vocab.n_tokens();
|
||||
const uint32_t n_embd = ctx.model.hparams.n_embd;
|
||||
|
||||
const int32_t n_outputs = ctx.n_outputs;
|
||||
@ -722,7 +726,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||
}
|
||||
|
||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||
return ctx->logits + j*ctx->model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
@ -882,7 +886,7 @@ struct llama_data_write {
|
||||
}
|
||||
|
||||
void write_logits(const struct llama_context * ctx) {
|
||||
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
|
||||
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
|
||||
|
||||
write(&logits_size, sizeof(logits_size));
|
||||
|
||||
|
Reference in New Issue
Block a user