mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-02 07:04:48 +02:00
2836 lines
92 KiB
C++
2836 lines
92 KiB
C++
#include "llama-context.h"
|
|
|
|
#include "llama-impl.h"
|
|
#include "llama-io.h"
|
|
#include "llama-mmap.h"
|
|
#include "llama-model.h"
|
|
#include "llama-kv-cache.h"
|
|
|
|
#include <cassert>
|
|
#include <cstring>
|
|
#include <stdexcept>
|
|
#include <cinttypes>
|
|
#include <cmath>
|
|
|
|
//
|
|
// llama_context
|
|
//
|
|
|
|
llama_context::llama_context(
|
|
const llama_model & model,
|
|
llama_context_params params) :
|
|
model(model) {
|
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
|
|
|
t_start_us = model.t_start_us;
|
|
t_load_us = model.t_load_us;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
|
cparams.n_threads = params.n_threads;
|
|
cparams.n_threads_batch = params.n_threads_batch;
|
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
|
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
|
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
|
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
|
cparams.defrag_thold = params.defrag_thold;
|
|
cparams.embeddings = params.embeddings;
|
|
cparams.offload_kqv = params.offload_kqv;
|
|
cparams.flash_attn = params.flash_attn;
|
|
cparams.no_perf = params.no_perf;
|
|
cparams.pooling_type = params.pooling_type;
|
|
cparams.warmup = false;
|
|
|
|
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
|
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
|
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
|
|
|
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
|
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
|
|
hparams.n_ctx_train;
|
|
|
|
cparams.cb_eval = params.cb_eval;
|
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
|
|
|
auto rope_scaling_type = params.rope_scaling_type;
|
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
|
rope_scaling_type = hparams.rope_scaling_type_train;
|
|
}
|
|
|
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
|
|
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
|
|
}
|
|
|
|
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
|
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
|
}
|
|
|
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
|
|
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
} else {
|
|
cparams.pooling_type = hparams.pooling_type;
|
|
}
|
|
}
|
|
|
|
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
|
cparams.causal_attn = hparams.causal_attn;
|
|
} else {
|
|
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
|
}
|
|
|
|
// with causal attention, the batch size is limited by the context size
|
|
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
|
|
|
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
|
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
|
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
|
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
|
|
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
|
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
|
cparams.n_batch = GGML_KQ_MASK_PAD;
|
|
}
|
|
|
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
|
|
|
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
|
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
|
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
|
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
|
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
|
|
|
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
|
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
}
|
|
|
|
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
|
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
}
|
|
|
|
logits_all = params.logits_all;
|
|
|
|
if (!hparams.vocab_only) {
|
|
// GPU backends
|
|
for (auto * dev : model.devices) {
|
|
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
|
if (backend == nullptr) {
|
|
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
|
|
}
|
|
backends.emplace_back(backend);
|
|
}
|
|
|
|
// add ACCEL backends (such as BLAS)
|
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
|
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
|
if (backend == nullptr) {
|
|
throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
|
|
}
|
|
backends.emplace_back(backend);
|
|
}
|
|
}
|
|
|
|
// add CPU backend
|
|
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
|
if (backend_cpu == nullptr) {
|
|
throw std::runtime_error("failed to initialize CPU backend");
|
|
}
|
|
backends.emplace_back(backend_cpu);
|
|
|
|
// create a list of the set_n_threads functions in the backends
|
|
for (auto & backend : backends) {
|
|
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
|
|
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
|
if (reg) {
|
|
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
if (ggml_backend_set_n_threads_fn) {
|
|
set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
|
|
}
|
|
}
|
|
}
|
|
|
|
llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
|
|
|
|
// graph outputs buffer
|
|
{
|
|
// resized during inference when a batch uses more outputs
|
|
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
|
|
throw std::runtime_error("failed to reserve initial output buffer");
|
|
}
|
|
|
|
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
|
|
ggml_backend_buffer_name (buf_output.get()),
|
|
ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
|
|
}
|
|
}
|
|
|
|
// init the memory module
|
|
// TODO: for now, always create a unified KV cache
|
|
if (!hparams.vocab_only) {
|
|
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
|
|
|
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
|
|
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
|
|
|
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
|
|
|
uint32_t kv_size = cparams.n_ctx;
|
|
ggml_type type_k = params.type_k;
|
|
ggml_type type_v = params.type_v;
|
|
|
|
if (llama_model_is_recurrent(&model)) {
|
|
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
|
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
|
// it's probably best to keep as much precision as possible for the states
|
|
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
|
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
|
}
|
|
|
|
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
|
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
|
|
|
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
|
throw std::runtime_error("failed to initialize self-attention cache");
|
|
}
|
|
|
|
{
|
|
const size_t memory_size_k = kv_self->size_k_bytes();
|
|
const size_t memory_size_v = kv_self->size_v_bytes();
|
|
|
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
|
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
|
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
|
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
|
}
|
|
}
|
|
|
|
// init backends
|
|
if (!hparams.vocab_only) {
|
|
LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
|
|
|
|
backend_buft.clear();
|
|
backend_ptrs.clear();
|
|
|
|
for (auto & backend : backends) {
|
|
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
|
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
|
|
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
|
|
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
|
auto * dev = model.devices[0];
|
|
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
|
if (host_buft) {
|
|
buft = host_buft;
|
|
}
|
|
}
|
|
|
|
backend_buft.push_back(buft);
|
|
backend_ptrs.push_back(backend.get());
|
|
}
|
|
|
|
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
|
|
|
const size_t max_nodes = this->graph_max_nodes();
|
|
|
|
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
|
|
// buffer used to store the computation graph and the tensor meta data
|
|
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
|
|
|
// TODO: move these checks to ggml_backend_sched
|
|
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
bool pipeline_parallel =
|
|
model.n_devices() > 1 &&
|
|
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
|
|
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
|
cparams.offload_kqv &&
|
|
!model.has_tensor_overrides();
|
|
|
|
// pipeline parallelism requires support for async compute and events in all devices
|
|
if (pipeline_parallel) {
|
|
for (auto & backend : backends) {
|
|
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
|
// ignore CPU backend
|
|
continue;
|
|
}
|
|
auto * dev = ggml_backend_get_device(backend.get());
|
|
ggml_backend_dev_props props;
|
|
ggml_backend_dev_get_props(dev, &props);
|
|
if (!props.caps.async || !props.caps.events) {
|
|
// device does not support async compute or events
|
|
pipeline_parallel = false;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
|
|
|
if (pipeline_parallel) {
|
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
|
}
|
|
}
|
|
|
|
// reserve worst-case graph
|
|
if (!hparams.vocab_only) {
|
|
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
|
|
// restore later
|
|
// TODO: something cleaner
|
|
const auto n_outputs_save = n_outputs;
|
|
|
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
|
|
int n_splits_pp = -1;
|
|
int n_nodes_pp = -1;
|
|
|
|
int n_splits_tg = -1;
|
|
int n_nodes_tg = -1;
|
|
|
|
// simulate full KV cache
|
|
kv_self->n = kv_self->size;
|
|
|
|
cross.v_embd.clear();
|
|
|
|
// reserve pp graph first so that buffers are only allocated once
|
|
{
|
|
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
|
|
// max number of outputs
|
|
n_outputs = ubatch_pp.n_tokens;
|
|
|
|
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
|
|
auto * gf = graph_init();
|
|
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
|
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
}
|
|
|
|
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
|
n_nodes_pp = ggml_graph_n_nodes(gf);
|
|
}
|
|
|
|
// reserve with tg graph to get the number of splits and nodes
|
|
{
|
|
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
|
|
n_outputs = ubatch_tg.n_tokens;
|
|
|
|
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
|
|
|
auto * gf = graph_init();
|
|
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
|
|
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
}
|
|
|
|
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
|
n_nodes_tg = ggml_graph_n_nodes(gf);
|
|
}
|
|
|
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
{
|
|
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
|
|
n_outputs = ubatch_pp.n_tokens;
|
|
|
|
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
|
|
auto * gf = graph_init();
|
|
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
|
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
}
|
|
}
|
|
|
|
n_outputs = n_outputs_save;
|
|
|
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
ggml_backend_t backend = backend_ptrs[i];
|
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
if (size > 1) {
|
|
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
|
ggml_backend_buft_name(buft),
|
|
size / 1024.0 / 1024.0);
|
|
}
|
|
}
|
|
|
|
if (n_nodes_pp == n_nodes_tg) {
|
|
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
|
} else {
|
|
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
|
}
|
|
|
|
if (n_splits_pp == n_splits_tg) {
|
|
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
|
} else {
|
|
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
}
|
|
}
|
|
}
|
|
|
|
llama_context::~llama_context() = default;
|
|
|
|
void llama_context::synchronize() {
|
|
ggml_backend_sched_synchronize(sched.get());
|
|
|
|
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
|
// the stats will be added to the prompt evaluation stats
|
|
// this should only happen when using batch size 1 to evaluate a batch
|
|
|
|
// add the evaluation to the stats
|
|
if (n_queued_tokens == 1) {
|
|
if (!cparams.no_perf) {
|
|
t_eval_us += ggml_time_us() - t_compute_start_us;
|
|
}
|
|
n_eval++;
|
|
} else if (n_queued_tokens > 1) {
|
|
if (!cparams.no_perf) {
|
|
t_p_eval_us += ggml_time_us() - t_compute_start_us;
|
|
}
|
|
n_p_eval += n_queued_tokens;
|
|
}
|
|
|
|
// get a more accurate load time, upon first eval
|
|
if (n_queued_tokens > 0 && !has_evaluated_once) {
|
|
t_load_us = ggml_time_us() - t_start_us;
|
|
has_evaluated_once = true;
|
|
}
|
|
|
|
n_queued_tokens = 0;
|
|
t_compute_start_us = 0;
|
|
}
|
|
|
|
const llama_model & llama_context::get_model() const {
|
|
return model;
|
|
}
|
|
|
|
uint32_t llama_context::n_ctx() const {
|
|
return cparams.n_ctx;
|
|
}
|
|
|
|
uint32_t llama_context::n_ctx_per_seq() const {
|
|
return cparams.n_ctx / cparams.n_seq_max;
|
|
}
|
|
|
|
uint32_t llama_context::n_batch() const {
|
|
return cparams.n_batch;
|
|
}
|
|
|
|
uint32_t llama_context::n_ubatch() const {
|
|
return cparams.n_ubatch;
|
|
}
|
|
|
|
uint32_t llama_context::n_seq_max() const {
|
|
return cparams.n_seq_max;
|
|
}
|
|
|
|
uint32_t llama_context::n_threads() const {
|
|
return cparams.n_threads;
|
|
}
|
|
|
|
uint32_t llama_context::n_threads_batch() const {
|
|
return cparams.n_threads_batch;
|
|
}
|
|
|
|
llama_kv_cache * llama_context::get_kv_self() {
|
|
return kv_self.get();
|
|
}
|
|
|
|
const llama_kv_cache * llama_context::get_kv_self() const {
|
|
return kv_self.get();
|
|
}
|
|
|
|
ggml_tensor * llama_context::build_rope_shift(
|
|
ggml_context * ctx0,
|
|
ggml_tensor * cur,
|
|
ggml_tensor * shift,
|
|
ggml_tensor * factors,
|
|
float freq_base,
|
|
float freq_scale,
|
|
ggml_backend_buffer * bbuf) const {
|
|
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
|
|
|
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
|
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
|
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
const auto & n_rot = hparams.n_rot;
|
|
const auto & rope_type = hparams.rope_type;
|
|
|
|
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
|
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
|
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
|
|
|
|
ggml_tensor * tmp;
|
|
|
|
if (ggml_is_quantized(cur->type)) {
|
|
// dequantize to f32 -> RoPE -> quantize back
|
|
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
|
|
|
if (bbuf) {
|
|
for (const auto & backend : backends) {
|
|
// Figure out which backend KV cache belongs to
|
|
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
|
|
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
|
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
|
|
|
tmp = ggml_cpy(ctx0, tmp, cur);
|
|
} else {
|
|
// we rotate only the first n_rot dimensions
|
|
tmp = ggml_rope_ext_inplace(ctx0, cur,
|
|
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
|
}
|
|
|
|
return tmp;
|
|
}
|
|
|
|
class llm_graph_input_k_shift : public llm_graph_input_i {
|
|
public:
|
|
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
|
virtual ~llm_graph_input_k_shift() = default;
|
|
|
|
void set_input(const llama_ubatch * ubatch) override;
|
|
|
|
ggml_tensor * k_shift; // I32 [kv_size]
|
|
|
|
const llama_kv_cache_unified * kv_self;
|
|
};
|
|
|
|
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|
GGML_UNUSED(ubatch);
|
|
|
|
if (k_shift) {
|
|
assert(ggml_backend_buffer_is_host(k_shift->buffer));
|
|
|
|
int32_t * data = (int32_t *) k_shift->data;
|
|
|
|
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
|
data[i] = kv_self->cells[i].delta;
|
|
}
|
|
}
|
|
}
|
|
|
|
llm_graph_result_ptr llama_context::build_kv_self_shift(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf) const {
|
|
auto res = std::make_unique<llm_graph_result>();
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
const auto & n_layer = hparams.n_layer;
|
|
|
|
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
|
|
//GGML_ASSERT(kv_self->size == n_ctx);
|
|
|
|
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
|
|
|
|
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
|
|
ggml_set_input(inp->k_shift);
|
|
|
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
|
|
const bool is_swa = hparams.is_swa(il);
|
|
|
|
// note: the swa rope params could become part of the cparams in the future
|
|
// if we decide to make them configurable, like the non-sliding ones
|
|
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
|
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
|
|
|
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
|
|
|
ggml_tensor * k =
|
|
ggml_view_3d(ctx0, kv_self->k_l[il],
|
|
n_embd_head_k, n_head_kv, kv_self->size,
|
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
0);
|
|
|
|
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
}
|
|
|
|
res->add_input(std::move(inp));
|
|
|
|
return res;
|
|
}
|
|
|
|
llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf) const {
|
|
auto res = std::make_unique<llm_graph_result>();
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
const auto & ids = kv_self->defrag_info.ids;
|
|
|
|
#if 0
|
|
// CPU defrag
|
|
//
|
|
// TODO: optimizations are possible:
|
|
// - multiple threads
|
|
// - avoid copying to the host memory when already there
|
|
//
|
|
// likely not worth the effort, as we have ggml_graph based defrag
|
|
//
|
|
|
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
|
|
const uint32_t kv_size = size;
|
|
|
|
std::vector<uint8_t> buf_k;
|
|
std::vector<uint8_t> buf_v;
|
|
|
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
|
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
|
|
|
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
|
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
|
|
|
buf_k.resize(k_size);
|
|
buf_v.resize(v_size);
|
|
|
|
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
|
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
|
|
|
// batch move [i, i+nm) to [id, id+nm)
|
|
// note: cells can move only to a lower index
|
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
|
const uint32_t id = ids[i];
|
|
|
|
if (i == id || id == n_kv) {
|
|
continue;
|
|
}
|
|
|
|
uint32_t nm = 1;
|
|
|
|
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
|
nm++;
|
|
}
|
|
|
|
// move keys
|
|
{
|
|
const int64_t os = i*k_size_row;
|
|
const int64_t od = id*k_size_row;
|
|
|
|
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
|
}
|
|
|
|
// move values (note: they are transposed)
|
|
{
|
|
const int64_t os = i;
|
|
const int64_t od = id;
|
|
|
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
|
}
|
|
}
|
|
|
|
i += nm - 1;
|
|
}
|
|
|
|
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
|
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
|
}
|
|
#else
|
|
for (uint32_t i = 0; i < ids.size(); ++i) {
|
|
const uint32_t id = ids[i];
|
|
|
|
if (i == id || id == ids.size()) {
|
|
continue;
|
|
}
|
|
|
|
uint32_t nm = 1;
|
|
|
|
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
|
nm++;
|
|
}
|
|
|
|
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
|
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
|
|
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
|
|
n_embd_k_gqa, nm,
|
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
|
|
|
|
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
|
|
n_embd_k_gqa, nm,
|
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
|
|
|
|
ggml_tensor * view_v_src;
|
|
ggml_tensor * view_v_dst;
|
|
|
|
if (cparams.flash_attn) {
|
|
// NOTE: the V cache is not transposed when using flash attention
|
|
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
n_embd_v_gqa, nm,
|
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
|
|
|
|
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
n_embd_v_gqa, nm,
|
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
|
|
} else {
|
|
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
nm, n_embd_v_gqa,
|
|
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
|
ggml_row_size(kv_self->v_l[il]->type, i));
|
|
|
|
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
nm, n_embd_v_gqa,
|
|
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
|
ggml_row_size(kv_self->v_l[il]->type, id));
|
|
}
|
|
|
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
|
}
|
|
|
|
i += nm - 1;
|
|
}
|
|
|
|
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
|
#endif
|
|
|
|
return res;
|
|
}
|
|
|
|
void llama_context::kv_self_update() {
|
|
auto & kv = kv_self;
|
|
|
|
bool need_reserve = false;
|
|
|
|
if (kv->has_shift) {
|
|
if (!kv->get_can_shift()) {
|
|
GGML_ABORT("The current context does not support K-shift");
|
|
}
|
|
|
|
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
|
|
|
// apply K-shift if needed
|
|
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
|
ggml_backend_sched_reset(sched.get());
|
|
|
|
auto * gf = graph_init();
|
|
|
|
auto res = build_kv_self_shift(ctx_compute.get(), gf);
|
|
|
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
|
|
res->set_inputs(nullptr);
|
|
|
|
graph_compute(gf, false);
|
|
|
|
need_reserve = true;
|
|
}
|
|
|
|
{
|
|
kv->has_shift = false;
|
|
|
|
for (uint32_t i = 0; i < kv->size; ++i) {
|
|
kv->cells[i].delta = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
// defragment the KV cache if needed
|
|
if (kv->do_defrag) {
|
|
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
|
|
|
if (kv->defrag_prepare(graph_max_nodes())) {
|
|
ggml_backend_sched_reset(sched.get());
|
|
|
|
auto * gf = graph_init();
|
|
|
|
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
|
|
|
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
|
|
res->set_inputs(nullptr);
|
|
|
|
graph_compute(gf, false);
|
|
|
|
need_reserve = true;
|
|
}
|
|
|
|
kv->do_defrag = false;
|
|
}
|
|
|
|
// reserve a worst case graph if needed
|
|
if (need_reserve) {
|
|
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
|
|
|
// build worst-case graph
|
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
|
// simulate full KV cache
|
|
kv_self->n = kv_self->size;
|
|
|
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
|
|
auto * gf = graph_init();
|
|
graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
|
|
|
// initialize scheduler with the worst-case graph
|
|
ggml_backend_sched_reset(sched.get());
|
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
}
|
|
}
|
|
}
|
|
|
|
enum llama_pooling_type llama_context::pooling_type() const {
|
|
return cparams.pooling_type;
|
|
}
|
|
|
|
float * llama_context::get_logits() {
|
|
// reorder logits for backward compatibility
|
|
output_reorder();
|
|
|
|
return logits;
|
|
}
|
|
|
|
float * llama_context::get_logits_ith(int32_t i) {
|
|
int32_t j = -1;
|
|
|
|
try {
|
|
if (logits == nullptr) {
|
|
throw std::runtime_error("no logits");
|
|
}
|
|
|
|
if (i < 0) {
|
|
j = n_outputs + i;
|
|
if (j < 0) {
|
|
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
}
|
|
} else if ((size_t) i >= output_ids.size()) {
|
|
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
} else {
|
|
j = output_ids[i];
|
|
}
|
|
|
|
if (j < 0) {
|
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
}
|
|
if (j >= n_outputs) {
|
|
// This should not happen
|
|
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
|
|
}
|
|
|
|
return logits + j*model.vocab.n_tokens();
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
|
#ifndef NDEBUG
|
|
GGML_ABORT("fatal error");
|
|
#else
|
|
return nullptr;
|
|
#endif
|
|
}
|
|
}
|
|
|
|
float * llama_context::get_embeddings() {
|
|
// reorder embeddings for backward compatibility
|
|
output_reorder();
|
|
|
|
return embd;
|
|
}
|
|
|
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
int32_t j = -1;
|
|
|
|
try {
|
|
if (embd == nullptr) {
|
|
throw std::runtime_error("no embeddings");
|
|
}
|
|
|
|
if (i < 0) {
|
|
j = n_outputs + i;
|
|
if (j < 0) {
|
|
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
}
|
|
} else if ((size_t) i >= output_ids.size()) {
|
|
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
} else {
|
|
j = output_ids[i];
|
|
}
|
|
|
|
if (j < 0) {
|
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
}
|
|
if (j >= n_outputs) {
|
|
// This should not happen
|
|
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
|
|
}
|
|
|
|
return embd + j*model.hparams.n_embd;
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
#ifndef NDEBUG
|
|
GGML_ABORT("fatal error");
|
|
#else
|
|
return nullptr;
|
|
#endif
|
|
}
|
|
}
|
|
|
|
float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|
auto it = embd_seq.find(seq_id);
|
|
if (it == embd_seq.end()) {
|
|
return nullptr;
|
|
}
|
|
|
|
return it->second.data();
|
|
}
|
|
|
|
void llama_context::attach_threadpool(
|
|
ggml_threadpool_t threadpool,
|
|
ggml_threadpool_t threadpool_batch) {
|
|
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
|
|
this->threadpool = threadpool;
|
|
this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
|
}
|
|
|
|
void llama_context::detach_threadpool() {
|
|
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
|
|
this->threadpool = nullptr;
|
|
this->threadpool_batch = nullptr;
|
|
}
|
|
|
|
void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
|
|
LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
|
|
|
|
cparams.n_threads = n_threads;
|
|
cparams.n_threads_batch = n_threads_batch;
|
|
}
|
|
|
|
void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
|
|
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
|
|
this->abort_callback = abort_callback;
|
|
this->abort_callback_data = abort_callback_data;
|
|
|
|
for (auto & backend : backends) {
|
|
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
|
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
|
if (set_abort_callback_fn) {
|
|
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
|
|
}
|
|
}
|
|
}
|
|
|
|
void llama_context::set_embeddings(bool value) {
|
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
|
|
cparams.embeddings = value;
|
|
}
|
|
|
|
void llama_context::set_causal_attn(bool value) {
|
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
|
|
cparams.causal_attn = value;
|
|
}
|
|
|
|
void llama_context::set_warmup(bool value) {
|
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
|
|
cparams.warmup = value;
|
|
}
|
|
|
|
void llama_context::set_adapter_lora(
|
|
llama_adapter_lora * adapter,
|
|
float scale) {
|
|
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
|
|
|
|
loras[adapter] = scale;
|
|
}
|
|
|
|
bool llama_context::rm_adapter_lora(
|
|
llama_adapter_lora * adapter) {
|
|
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
|
|
|
|
auto pos = loras.find(adapter);
|
|
if (pos != loras.end()) {
|
|
loras.erase(pos);
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
void llama_context::clear_adapter_lora() {
|
|
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
|
|
|
loras.clear();
|
|
}
|
|
|
|
bool llama_context::apply_adapter_cvec(
|
|
const float * data,
|
|
size_t len,
|
|
int32_t n_embd,
|
|
int32_t il_start,
|
|
int32_t il_end) {
|
|
LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
|
|
|
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
|
}
|
|
|
|
int llama_context::encode(llama_batch & inp_batch) {
|
|
if (inp_batch.n_tokens == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
return -1;
|
|
}
|
|
|
|
// temporary allocate memory for the input batch if needed
|
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
|
|
|
const llama_batch & batch = batch_allocr.batch;
|
|
const int32_t n_tokens = batch.n_tokens;
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
|
|
if (batch.token) {
|
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
return -1;
|
|
}
|
|
}
|
|
}
|
|
|
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
|
|
if (t_compute_start_us == 0) {
|
|
t_compute_start_us = ggml_time_us();
|
|
}
|
|
|
|
n_queued_tokens += n_tokens;
|
|
|
|
const int64_t n_embd = hparams.n_embd;
|
|
|
|
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
|
|
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
|
|
|
// reserve output buffer
|
|
if (output_reserve(n_tokens) < n_tokens) {
|
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
return -2;
|
|
};
|
|
|
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
output_ids[i] = i;
|
|
}
|
|
|
|
n_outputs = n_tokens;
|
|
|
|
//batch_manager->prepare(ubatch);
|
|
|
|
ggml_backend_sched_reset(sched.get());
|
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
|
|
const auto causal_attn_org = cparams.causal_attn;
|
|
|
|
// always use non-causal attention for encoder graphs
|
|
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
|
cparams.causal_attn = false;
|
|
|
|
auto * gf = graph_init();
|
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
|
|
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
|
|
res->set_inputs(&ubatch);
|
|
|
|
cparams.causal_attn = causal_attn_org;
|
|
|
|
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
|
switch (compute_status) {
|
|
case GGML_STATUS_SUCCESS:
|
|
break;
|
|
case GGML_STATUS_ABORTED:
|
|
return 2;
|
|
case GGML_STATUS_ALLOC_FAILED:
|
|
return -2;
|
|
case GGML_STATUS_FAILED:
|
|
default:
|
|
return -3;
|
|
}
|
|
|
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
|
|
// extract embeddings
|
|
if (t_embd) {
|
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
GGML_ASSERT(backend_embd != nullptr);
|
|
|
|
GGML_ASSERT(embd != nullptr);
|
|
|
|
switch (cparams.pooling_type) {
|
|
case LLAMA_POOLING_TYPE_NONE:
|
|
{
|
|
// extract token embeddings
|
|
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
case LLAMA_POOLING_TYPE_CLS:
|
|
case LLAMA_POOLING_TYPE_LAST:
|
|
{
|
|
// extract sequence embeddings
|
|
auto & embd_seq_out = embd_seq;
|
|
embd_seq_out.clear();
|
|
|
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
|
|
for (int32_t i = 0; i < n_tokens; i++) {
|
|
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
continue;
|
|
}
|
|
embd_seq_out[seq_id].resize(n_embd);
|
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_RANK:
|
|
{
|
|
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
|
// wait for an encoder model that requires this pooling type in order to test it
|
|
// https://github.com/ggerganov/llama.cpp/pull/9510
|
|
GGML_ABORT("RANK pooling not implemented yet");
|
|
}
|
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
{
|
|
GGML_ABORT("unknown pooling type");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
// overlap with device computation.
|
|
ggml_backend_sched_reset(sched.get());
|
|
|
|
// TODO: hacky solution
|
|
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
|
//cross.t_embd = t_embd;
|
|
|
|
synchronize();
|
|
|
|
cross.n_embd = t_embd->ne[0];
|
|
cross.n_enc = t_embd->ne[1];
|
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
|
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
|
|
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
|
cross.seq_ids_enc.resize(n_tokens);
|
|
for (int32_t i = 0; i < n_tokens; i++) {
|
|
cross.seq_ids_enc[i].clear();
|
|
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
|
cross.seq_ids_enc[i].insert(seq_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int llama_context::decode(llama_batch & inp_batch) {
|
|
if (inp_batch.n_tokens == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
return -1;
|
|
}
|
|
|
|
// temporary allocate memory for the input batch if needed
|
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
|
|
|
const llama_batch & batch = batch_allocr.batch;
|
|
|
|
const auto & vocab = model.vocab;
|
|
const auto & hparams = model.hparams;
|
|
|
|
const int32_t n_vocab = vocab.n_tokens();
|
|
|
|
const int64_t n_tokens_all = batch.n_tokens;
|
|
const int64_t n_embd = hparams.n_embd;
|
|
|
|
llama_kv_cache_guard kv_guard(kv_self.get());
|
|
|
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
|
|
if (batch.token) {
|
|
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
|
throw std::runtime_error("invalid token");
|
|
}
|
|
}
|
|
}
|
|
|
|
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
|
|
|
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
|
|
|
if (t_compute_start_us == 0) {
|
|
t_compute_start_us = ggml_time_us();
|
|
}
|
|
n_queued_tokens += n_tokens_all;
|
|
|
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
|
|
embd_seq.clear();
|
|
|
|
int64_t n_outputs_all = 0;
|
|
|
|
// count outputs
|
|
if (batch.logits && !embd_pooled) {
|
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
n_outputs_all += batch.logits[i] != 0;
|
|
}
|
|
} else if (logits_all || embd_pooled) {
|
|
n_outputs_all = n_tokens_all;
|
|
} else {
|
|
// keep last output only
|
|
n_outputs_all = 1;
|
|
}
|
|
|
|
const bool logits_all = n_outputs_all == n_tokens_all;
|
|
|
|
sbatch.from_batch(batch, n_embd,
|
|
/* simple_split */ !kv_self->recurrent,
|
|
/* logits_all */ logits_all);
|
|
|
|
// reserve output buffer
|
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
|
return -2;
|
|
};
|
|
|
|
// handle any pending defrags/shifts
|
|
kv_self_update();
|
|
|
|
int64_t n_outputs_prev = 0;
|
|
|
|
while (sbatch.n_tokens > 0) {
|
|
llama_ubatch ubatch = llama_ubatch();
|
|
|
|
const auto & n_ubatch = cparams.n_ubatch;
|
|
|
|
if (kv_self->recurrent) {
|
|
if (embd_pooled) {
|
|
// Pooled embeddings cannot be split across ubatches (yet)
|
|
ubatch = sbatch.split_seq(cparams.n_ubatch);
|
|
} else {
|
|
// recurrent model architectures are easier to implement
|
|
// with equal-length sequences
|
|
ubatch = sbatch.split_equal(cparams.n_ubatch);
|
|
}
|
|
} else {
|
|
ubatch = sbatch.split_simple(n_ubatch);
|
|
}
|
|
|
|
// count the outputs in this u_batch
|
|
{
|
|
int32_t n_outputs_new = 0;
|
|
|
|
if (n_outputs_all == n_tokens_all) {
|
|
n_outputs_new = ubatch.n_tokens;
|
|
} else {
|
|
GGML_ASSERT(ubatch.output);
|
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
|
}
|
|
}
|
|
|
|
// needs to happen before the graph is built
|
|
n_outputs = n_outputs_new;
|
|
}
|
|
|
|
// find KV slot
|
|
{
|
|
if (!kv_self->find_slot(ubatch)) {
|
|
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
|
|
|
return 1;
|
|
}
|
|
|
|
if (!kv_self->recurrent) {
|
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
// after enough generations, the benefit from this heuristic disappears
|
|
// if we start defragmenting the cache, the benefit from this will be more important
|
|
const uint32_t pad = kv_self->get_padding(cparams);
|
|
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
|
|
}
|
|
}
|
|
|
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
|
|
|
ggml_backend_sched_reset(sched.get());
|
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
|
|
auto * gf = graph_init();
|
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
|
|
|
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
|
|
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
|
|
res->set_inputs(&ubatch);
|
|
|
|
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
|
if (compute_status != GGML_STATUS_SUCCESS) {
|
|
switch (compute_status) {
|
|
case GGML_STATUS_ABORTED:
|
|
return 2;
|
|
case GGML_STATUS_ALLOC_FAILED:
|
|
return -2;
|
|
case GGML_STATUS_FAILED:
|
|
default:
|
|
return -3;
|
|
}
|
|
}
|
|
|
|
// plot the computation graph in dot format (for debugging purposes)
|
|
//if (n_past%100 == 0) {
|
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
//}
|
|
|
|
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
|
|
|
if (t_embd && res->get_embd_pooled()) {
|
|
t_embd = res->get_embd_pooled();
|
|
}
|
|
|
|
// extract logits
|
|
if (t_logits && n_outputs > 0) {
|
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
GGML_ASSERT(backend_res != nullptr);
|
|
GGML_ASSERT(logits != nullptr);
|
|
|
|
float * logits_out = logits + n_outputs_prev*n_vocab;
|
|
|
|
if (n_outputs) {
|
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
|
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
|
}
|
|
}
|
|
|
|
// extract embeddings
|
|
if (t_embd && n_outputs > 0) {
|
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
GGML_ASSERT(backend_embd != nullptr);
|
|
|
|
switch (cparams.pooling_type) {
|
|
case LLAMA_POOLING_TYPE_NONE:
|
|
{
|
|
// extract token embeddings
|
|
GGML_ASSERT(embd != nullptr);
|
|
float * embd_out = embd + n_outputs_prev*n_embd;
|
|
|
|
if (n_outputs) {
|
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
case LLAMA_POOLING_TYPE_CLS:
|
|
case LLAMA_POOLING_TYPE_LAST:
|
|
{
|
|
// extract sequence embeddings (cleared before processing each batch)
|
|
auto & embd_seq_out = embd_seq;
|
|
|
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
continue;
|
|
}
|
|
embd_seq_out[seq_id].resize(n_embd);
|
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_RANK:
|
|
{
|
|
// extract the rerank score - a single float per sequence
|
|
auto & embd_seq_out = embd_seq;
|
|
|
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
continue;
|
|
}
|
|
embd_seq_out[seq_id].resize(1);
|
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
|
}
|
|
} break;
|
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
{
|
|
GGML_ABORT("unknown pooling type");
|
|
}
|
|
}
|
|
}
|
|
|
|
n_outputs_prev += n_outputs;
|
|
}
|
|
|
|
// finalize the batch processing
|
|
kv_guard.commit();
|
|
|
|
// set output mappings
|
|
{
|
|
bool sorted_output = true;
|
|
|
|
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
|
|
|
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
|
int64_t out_id = sbatch.out_ids[i];
|
|
output_ids[out_id] = i;
|
|
if (out_id != i) {
|
|
sorted_output = false;
|
|
}
|
|
}
|
|
|
|
if (sorted_output) {
|
|
sbatch.out_ids.clear();
|
|
}
|
|
}
|
|
|
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
n_outputs = n_outputs_all;
|
|
|
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
|
//synchronize();
|
|
|
|
// decide if we need to defrag the kv cache
|
|
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
|
// - do not defrag small contexts (i.e. < 2048 tokens)
|
|
// - count the padding towards the number of used tokens
|
|
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
|
|
|
|
// queue defragmentation for next llama_kv_cache_update
|
|
if (fragmentation > cparams.defrag_thold) {
|
|
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
|
|
kv_self->defrag();
|
|
}
|
|
}
|
|
|
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
// overlap with device computation.
|
|
ggml_backend_sched_reset(sched.get());
|
|
|
|
return 0;
|
|
}
|
|
|
|
//
|
|
// output
|
|
//
|
|
|
|
int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
const auto & hparams = model.hparams;
|
|
const auto & vocab = model.vocab;
|
|
|
|
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
|
|
|
|
const auto n_batch = cparams.n_batch;
|
|
const auto n_vocab = vocab.n_tokens();
|
|
const auto n_embd = hparams.n_embd;
|
|
|
|
// TODO: use a per-batch flag for logits presence instead
|
|
bool has_logits = !cparams.embeddings;
|
|
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
|
|
|
// TODO: hacky enc-dec support
|
|
if (model.arch == LLM_ARCH_T5) {
|
|
has_logits = true;
|
|
has_embd = true;
|
|
}
|
|
|
|
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
|
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
|
|
|
if (output_ids.empty()) {
|
|
// init, never resized afterwards
|
|
output_ids.resize(n_batch);
|
|
}
|
|
|
|
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
|
const size_t new_size = (logits_size + embd_size) * sizeof(float);
|
|
|
|
// alloc only when more than the current capacity is required
|
|
// TODO: also consider shrinking the buffer
|
|
if (!buf_output || prev_size < new_size) {
|
|
if (buf_output) {
|
|
#ifndef NDEBUG
|
|
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
|
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
|
#endif
|
|
buf_output = nullptr;
|
|
logits = nullptr;
|
|
embd = nullptr;
|
|
}
|
|
|
|
auto * buft = ggml_backend_cpu_buffer_type();
|
|
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
|
|
auto * output_dev = model.dev_output();
|
|
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
|
|
if (output_dev_host_buft) {
|
|
buft = output_dev_host_buft;
|
|
}
|
|
buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
|
|
if (buf_output == nullptr) {
|
|
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
|
|
|
logits = has_logits ? output_base : nullptr;
|
|
embd = has_embd ? output_base + logits_size : nullptr;
|
|
|
|
// set all ids as invalid (negative)
|
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
|
|
ggml_backend_buffer_clear(buf_output.get(), 0);
|
|
|
|
this->n_outputs = 0;
|
|
this->n_outputs_max = n_outputs_max;
|
|
|
|
return n_outputs_max;
|
|
}
|
|
|
|
void llama_context::output_reorder() {
|
|
auto & out_ids = sbatch.out_ids;
|
|
if (!out_ids.empty()) {
|
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
const uint32_t n_embd = model.hparams.n_embd;
|
|
|
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
|
|
// TODO: is there something more efficient which also minimizes swaps?
|
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
|
int32_t j_min = i;
|
|
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
|
if (out_ids[j] < out_ids[j_min]) {
|
|
j_min = j;
|
|
}
|
|
}
|
|
if (j_min == i) { continue; }
|
|
std::swap(out_ids[i], out_ids[j_min]);
|
|
if (logits_size > 0) {
|
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
|
}
|
|
}
|
|
if (embd_size > 0) {
|
|
for (uint32_t k = 0; k < n_embd; k++) {
|
|
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
|
}
|
|
}
|
|
}
|
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
for (int32_t i = 0; i < n_outputs; ++i) {
|
|
output_ids[out_ids[i]] = i;
|
|
}
|
|
out_ids.clear();
|
|
}
|
|
}
|
|
|
|
//
|
|
// graph
|
|
//
|
|
|
|
int32_t llama_context::graph_max_nodes() const {
|
|
return std::max<int32_t>(65536, 5*model.n_tensors());
|
|
}
|
|
|
|
ggml_cgraph * llama_context::graph_init() {
|
|
ggml_init_params params = {
|
|
/*.mem_size =*/ buf_compute_meta.size(),
|
|
/*.mem_buffer =*/ buf_compute_meta.data(),
|
|
/*.no_alloc =*/ true,
|
|
};
|
|
|
|
ctx_compute.reset(ggml_init(params));
|
|
|
|
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
|
}
|
|
|
|
llm_graph_result_ptr llama_context::graph_build(
|
|
ggml_context * ctx,
|
|
ggml_cgraph * gf,
|
|
const llama_ubatch & ubatch,
|
|
llm_graph_type gtype) {
|
|
return model.build_graph(
|
|
{
|
|
/*.ctx =*/ ctx,
|
|
/*.arch =*/ model.arch,
|
|
/*.hparams =*/ model.hparams,
|
|
/*.cparams =*/ cparams,
|
|
/*.ubatch =*/ ubatch,
|
|
/*.sched =*/ sched.get(),
|
|
/*.backend_cpu =*/ backend_cpu,
|
|
/*.cvec =*/ &cvec,
|
|
/*.loras =*/ &loras,
|
|
/*.memory =*/ kv_self.get(),
|
|
/*.cross =*/ &cross,
|
|
/*.n_outputs =*/ n_outputs,
|
|
/*.cb =*/ graph_get_cb(),
|
|
}, gf, gtype);
|
|
}
|
|
|
|
ggml_status llama_context::graph_compute(
|
|
ggml_cgraph * gf,
|
|
bool batched) {
|
|
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
|
|
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
|
|
|
|
if (backend_cpu != nullptr) {
|
|
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
|
|
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
|
|
set_threadpool_fn(backend_cpu, tp);
|
|
}
|
|
|
|
// set the number of threads for all the backends
|
|
for (const auto & set_n_threads_fn : set_n_threads_fns) {
|
|
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
|
}
|
|
|
|
auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf);
|
|
if (status != GGML_STATUS_SUCCESS) {
|
|
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
|
}
|
|
|
|
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
|
|
|
|
return status;
|
|
}
|
|
|
|
llm_graph_cb llama_context::graph_get_cb() const {
|
|
return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
|
|
if (il >= 0) {
|
|
ggml_format_name(cur, "%s-%d", name, il);
|
|
} else {
|
|
ggml_set_name(cur, name);
|
|
}
|
|
|
|
if (!cparams.offload_kqv) {
|
|
if (strcmp(name, "kqv_merged_cont") == 0) {
|
|
// all nodes between the KV store and the attention output are run on the CPU
|
|
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
|
|
}
|
|
}
|
|
|
|
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
|
|
// FIXME: fix in ggml_backend_sched
|
|
const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
|
|
if (ubatch.n_tokens < 32 || full_offload) {
|
|
if (il != -1 && strcmp(name, "norm") == 0) {
|
|
const auto & dev_layer = model.dev_layer(il);
|
|
for (const auto & backend : backends) {
|
|
if (ggml_backend_get_device(backend.get()) == dev_layer) {
|
|
if (ggml_backend_supports_op(backend.get(), cur)) {
|
|
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
//
|
|
// state save/load
|
|
//
|
|
|
|
class llama_io_write_dummy : public llama_io_write_i {
|
|
public:
|
|
llama_io_write_dummy() = default;
|
|
|
|
void write(const void * /* src */, size_t size) override {
|
|
size_written += size;
|
|
}
|
|
|
|
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
|
size_written += size;
|
|
}
|
|
|
|
size_t n_bytes() override {
|
|
return size_written;
|
|
}
|
|
|
|
private:
|
|
size_t size_written = 0;
|
|
};
|
|
|
|
class llama_io_write_buffer : public llama_io_write_i {
|
|
public:
|
|
llama_io_write_buffer(
|
|
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
|
|
|
void write(const void * src, size_t size) override {
|
|
if (size > buf_size) {
|
|
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
}
|
|
memcpy(ptr, src, size);
|
|
ptr += size;
|
|
size_written += size;
|
|
buf_size -= size;
|
|
}
|
|
|
|
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
|
if (size > buf_size) {
|
|
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
}
|
|
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
|
ptr += size;
|
|
size_written += size;
|
|
buf_size -= size;
|
|
}
|
|
|
|
size_t n_bytes() override {
|
|
return size_written;
|
|
}
|
|
|
|
private:
|
|
uint8_t * ptr;
|
|
size_t buf_size = 0;
|
|
size_t size_written = 0;
|
|
};
|
|
|
|
class llama_io_read_buffer : public llama_io_read_i {
|
|
public:
|
|
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
|
|
|
const uint8_t * read(size_t size) override {
|
|
const uint8_t * base_ptr = ptr;
|
|
if (size > buf_size) {
|
|
throw std::runtime_error("unexpectedly reached end of buffer");
|
|
}
|
|
ptr += size;
|
|
size_read += size;
|
|
buf_size -= size;
|
|
return base_ptr;
|
|
}
|
|
|
|
void read_to(void * dst, size_t size) override {
|
|
memcpy(dst, read(size), size);
|
|
}
|
|
|
|
size_t n_bytes() override {
|
|
return size_read;
|
|
}
|
|
|
|
private:
|
|
const uint8_t * ptr;
|
|
size_t buf_size = 0;
|
|
size_t size_read = 0;
|
|
};
|
|
|
|
class llama_io_write_file : public llama_io_write_i {
|
|
public:
|
|
llama_io_write_file(llama_file * f) : file(f) {}
|
|
|
|
void write(const void * src, size_t size) override {
|
|
file->write_raw(src, size);
|
|
size_written += size;
|
|
}
|
|
|
|
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
|
temp_buffer.resize(size);
|
|
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
|
write(temp_buffer.data(), temp_buffer.size());
|
|
}
|
|
|
|
size_t n_bytes() override {
|
|
return size_written;
|
|
}
|
|
|
|
private:
|
|
llama_file * file;
|
|
size_t size_written = 0;
|
|
std::vector<uint8_t> temp_buffer;
|
|
};
|
|
|
|
class llama_io_read_file : public llama_io_read_i {
|
|
public:
|
|
llama_io_read_file(llama_file * f) : file(f) {}
|
|
|
|
void read_to(void * dst, size_t size) override {
|
|
file->read_raw(dst, size);
|
|
size_read += size;
|
|
}
|
|
|
|
const uint8_t * read(size_t size) override {
|
|
temp_buffer.resize(size);
|
|
read_to(temp_buffer.data(), size);
|
|
return temp_buffer.data();
|
|
}
|
|
|
|
size_t n_bytes() override {
|
|
return size_read;
|
|
}
|
|
|
|
private:
|
|
llama_file * file;
|
|
size_t size_read = 0;
|
|
std::vector<uint8_t> temp_buffer;
|
|
};
|
|
|
|
size_t llama_context::state_get_size() {
|
|
llama_io_write_dummy io;
|
|
try {
|
|
return state_write_data(io);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
|
|
llama_io_write_buffer io(dst, size);
|
|
try {
|
|
return state_write_data(io);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
|
llama_io_read_buffer io(src, size);
|
|
try {
|
|
return state_read_data(io);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
|
llama_io_write_dummy io;
|
|
try {
|
|
return state_seq_write_data(io, seq_id);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
|
llama_io_write_buffer io(dst, size);
|
|
try {
|
|
return state_seq_write_data(io, seq_id);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
|
llama_io_read_buffer io(src, size);
|
|
try {
|
|
return state_seq_read_data(io, seq_id);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
llama_file file(filepath, "rb");
|
|
|
|
// sanity checks
|
|
{
|
|
const uint32_t magic = file.read_u32();
|
|
const uint32_t version = file.read_u32();
|
|
|
|
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
|
|
LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// load the prompt
|
|
{
|
|
const uint32_t n_token_count = file.read_u32();
|
|
|
|
if (n_token_count > n_token_capacity) {
|
|
LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
|
return false;
|
|
}
|
|
|
|
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
|
*n_token_count_out = n_token_count;
|
|
}
|
|
|
|
// restore the context state
|
|
{
|
|
const size_t n_state_size_cur = file.size() - file.tell();
|
|
|
|
llama_io_read_file io( &file);
|
|
const size_t n_read = state_read_data(io);
|
|
|
|
if (n_read != n_state_size_cur) {
|
|
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
llama_file file(filepath, "wb");
|
|
|
|
file.write_u32(LLAMA_SESSION_MAGIC);
|
|
file.write_u32(LLAMA_SESSION_VERSION);
|
|
|
|
// save the prompt
|
|
file.write_u32((uint32_t) n_token_count);
|
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
|
|
|
// save the context state using stream saving
|
|
llama_io_write_file io(&file);
|
|
state_write_data(io);
|
|
|
|
return true;
|
|
}
|
|
|
|
size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
llama_file file(filepath, "rb");
|
|
|
|
// version checks
|
|
{
|
|
const uint32_t magic = file.read_u32();
|
|
const uint32_t version = file.read_u32();
|
|
|
|
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
|
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
// load the prompt
|
|
{
|
|
const uint32_t n_token_count = file.read_u32();
|
|
|
|
if (n_token_count > n_token_capacity) {
|
|
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
|
return 0;
|
|
}
|
|
|
|
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
|
*n_token_count_out = n_token_count;
|
|
}
|
|
|
|
// restore the context state
|
|
{
|
|
const size_t state_size = file.size() - file.tell();
|
|
llama_io_read_file io(&file);
|
|
const size_t nread = state_seq_read_data(io, seq_id);
|
|
if (!nread) {
|
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
|
return 0;
|
|
}
|
|
GGML_ASSERT(nread <= state_size);
|
|
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
|
}
|
|
|
|
return file.tell();
|
|
}
|
|
|
|
size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
llama_file file(filepath, "wb");
|
|
|
|
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
|
file.write_u32(LLAMA_STATE_SEQ_VERSION);
|
|
|
|
// save the prompt
|
|
file.write_u32((uint32_t) n_token_count);
|
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
|
|
|
// save the context state using stream saving
|
|
llama_io_write_file io(&file);
|
|
state_seq_write_data(io, seq_id);
|
|
|
|
const size_t res = file.tell();
|
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
|
|
|
return res;
|
|
}
|
|
|
|
size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
|
|
|
|
// write model info
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
|
|
|
|
const std::string arch_str = llm_arch_name(model.arch);
|
|
io.write_string(arch_str);
|
|
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
|
}
|
|
|
|
// write output ids
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
|
|
|
output_reorder();
|
|
|
|
const auto n_outputs = this->n_outputs;
|
|
const auto & output_ids = this->output_ids;
|
|
|
|
std::vector<int32_t> w_output_pos;
|
|
|
|
GGML_ASSERT(n_outputs <= n_outputs_max);
|
|
|
|
w_output_pos.resize(n_outputs);
|
|
|
|
// build a more compact representation of the output ids
|
|
for (size_t i = 0; i < n_batch(); ++i) {
|
|
// map an output id to a position in the batch
|
|
int32_t pos = output_ids[i];
|
|
if (pos >= 0) {
|
|
GGML_ASSERT(pos < n_outputs);
|
|
w_output_pos[pos] = i;
|
|
}
|
|
}
|
|
|
|
io.write(&n_outputs, sizeof(n_outputs));
|
|
|
|
if (n_outputs) {
|
|
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
|
}
|
|
}
|
|
|
|
// write logits
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
|
|
|
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
|
|
|
io.write(&logits_size, sizeof(logits_size));
|
|
|
|
if (logits_size) {
|
|
io.write(logits, logits_size * sizeof(float));
|
|
}
|
|
}
|
|
|
|
// write embeddings
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
|
|
|
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
|
|
|
io.write(&embd_size, sizeof(embd_size));
|
|
|
|
if (embd_size) {
|
|
io.write(embd, embd_size * sizeof(float));
|
|
}
|
|
}
|
|
|
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
|
kv_self->state_write(io);
|
|
|
|
return io.n_bytes();
|
|
}
|
|
|
|
size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
|
|
|
|
// read model info
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
|
|
|
|
const std::string cur_arch_str = llm_arch_name(model.arch);
|
|
|
|
std::string arch_str;
|
|
io.read_string(arch_str);
|
|
if (cur_arch_str != arch_str) {
|
|
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
|
}
|
|
// TODO: add more info which needs to be identical but which is not verified otherwise
|
|
}
|
|
|
|
// read output ids
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
|
|
|
|
auto n_outputs = this->n_outputs;
|
|
io.read_to(&n_outputs, sizeof(n_outputs));
|
|
|
|
if (n_outputs > output_reserve(n_outputs)) {
|
|
throw std::runtime_error("could not reserve outputs");
|
|
}
|
|
|
|
std::vector<int32_t> output_pos;
|
|
|
|
if (n_outputs) {
|
|
output_pos.resize(n_outputs);
|
|
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
|
|
|
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
|
int32_t id = output_pos[i];
|
|
if ((uint32_t) id >= n_batch()) {
|
|
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
|
}
|
|
this->output_ids[id] = i;
|
|
}
|
|
|
|
this->n_outputs = n_outputs;
|
|
}
|
|
}
|
|
|
|
// read logits
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
|
|
|
|
uint64_t logits_size;
|
|
io.read_to(&logits_size, sizeof(logits_size));
|
|
|
|
if (this->logits_size < logits_size) {
|
|
throw std::runtime_error("logits buffer too small");
|
|
}
|
|
|
|
if (logits_size) {
|
|
io.read_to(this->logits, logits_size * sizeof(float));
|
|
}
|
|
}
|
|
|
|
// read embeddings
|
|
{
|
|
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
|
|
|
|
uint64_t embd_size;
|
|
io.read_to(&embd_size, sizeof(embd_size));
|
|
|
|
if (this->embd_size < embd_size) {
|
|
throw std::runtime_error("embeddings buffer too small");
|
|
}
|
|
|
|
if (embd_size) {
|
|
io.read_to(this->embd, embd_size * sizeof(float));
|
|
}
|
|
}
|
|
|
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
|
kv_self->state_read(io);
|
|
|
|
return io.n_bytes();
|
|
}
|
|
|
|
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
|
GGML_UNUSED(seq_id);
|
|
|
|
kv_self->state_write(io, seq_id);
|
|
|
|
return io.n_bytes();
|
|
}
|
|
|
|
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
GGML_UNUSED(seq_id);
|
|
|
|
kv_self->state_read(io, seq_id);
|
|
|
|
return io.n_bytes();
|
|
}
|
|
|
|
//
|
|
// perf
|
|
//
|
|
|
|
llama_perf_context_data llama_context::perf_get_data() const {
|
|
llama_perf_context_data data = {};
|
|
|
|
data.t_start_ms = 1e-3 * t_start_us;
|
|
data.t_load_ms = 1e-3 * t_load_us;
|
|
data.t_p_eval_ms = 1e-3 * t_p_eval_us;
|
|
data.t_eval_ms = 1e-3 * t_eval_us;
|
|
data.n_p_eval = std::max(1, n_p_eval);
|
|
data.n_eval = std::max(1, n_eval);
|
|
|
|
return data;
|
|
}
|
|
|
|
void llama_context::perf_reset() {
|
|
t_start_us = ggml_time_us();
|
|
t_eval_us = n_eval = 0;
|
|
t_p_eval_us = n_p_eval = 0;
|
|
}
|
|
|
|
//
|
|
// interface implementation
|
|
//
|
|
|
|
llama_context_params llama_context_default_params() {
|
|
llama_context_params result = {
|
|
/*.n_ctx =*/ 512,
|
|
/*.n_batch =*/ 2048,
|
|
/*.n_ubatch =*/ 512,
|
|
/*.n_seq_max =*/ 1,
|
|
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
|
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
|
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
|
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
|
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
|
/*.rope_freq_base =*/ 0.0f,
|
|
/*.rope_freq_scale =*/ 0.0f,
|
|
/*.yarn_ext_factor =*/ -1.0f,
|
|
/*.yarn_attn_factor =*/ 1.0f,
|
|
/*.yarn_beta_fast =*/ 32.0f,
|
|
/*.yarn_beta_slow =*/ 1.0f,
|
|
/*.yarn_orig_ctx =*/ 0,
|
|
/*.defrag_thold =*/ -1.0f,
|
|
/*.cb_eval =*/ nullptr,
|
|
/*.cb_eval_user_data =*/ nullptr,
|
|
/*.type_k =*/ GGML_TYPE_F16,
|
|
/*.type_v =*/ GGML_TYPE_F16,
|
|
/*.logits_all =*/ false,
|
|
/*.embeddings =*/ false,
|
|
/*.offload_kqv =*/ true,
|
|
/*.flash_attn =*/ false,
|
|
/*.no_perf =*/ true,
|
|
/*.abort_callback =*/ nullptr,
|
|
/*.abort_callback_data =*/ nullptr,
|
|
};
|
|
|
|
return result;
|
|
}
|
|
|
|
llama_context * llama_init_from_model(
|
|
llama_model * model,
|
|
llama_context_params params) {
|
|
if (!model) {
|
|
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
if (params.n_batch == 0 && params.n_ubatch == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
|
|
LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
|
|
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
|
params.flash_attn = false;
|
|
}
|
|
|
|
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
|
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
|
return nullptr;
|
|
}
|
|
|
|
try {
|
|
auto * ctx = new llama_context(*model, params);
|
|
return ctx;
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
// deprecated
|
|
llama_context * llama_new_context_with_model(
|
|
llama_model * model,
|
|
llama_context_params params) {
|
|
return llama_init_from_model(model, params);
|
|
}
|
|
|
|
void llama_free(llama_context * ctx) {
|
|
delete ctx;
|
|
}
|
|
|
|
uint32_t llama_n_ctx(const llama_context * ctx) {
|
|
return ctx->n_ctx();
|
|
}
|
|
|
|
uint32_t llama_n_batch(const llama_context * ctx) {
|
|
return ctx->n_batch();
|
|
}
|
|
|
|
uint32_t llama_n_ubatch(const llama_context * ctx) {
|
|
return ctx->n_ubatch();
|
|
}
|
|
|
|
uint32_t llama_n_seq_max(const llama_context * ctx) {
|
|
return ctx->n_seq_max();
|
|
}
|
|
|
|
const llama_model * llama_get_model(const llama_context * ctx) {
|
|
return &ctx->get_model();
|
|
}
|
|
|
|
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|
return ctx->get_kv_self();
|
|
}
|
|
|
|
void llama_kv_self_update(llama_context * ctx) {
|
|
ctx->kv_self_update();
|
|
}
|
|
|
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
|
return ctx->pooling_type();
|
|
}
|
|
|
|
void llama_attach_threadpool(
|
|
llama_context * ctx,
|
|
ggml_threadpool_t threadpool,
|
|
ggml_threadpool_t threadpool_batch) {
|
|
ctx->attach_threadpool(threadpool, threadpool_batch);
|
|
}
|
|
|
|
void llama_detach_threadpool(llama_context * ctx) {
|
|
ctx->detach_threadpool();
|
|
}
|
|
|
|
void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
|
|
ctx->set_n_threads(n_threads, n_threads_batch);
|
|
}
|
|
|
|
int32_t llama_n_threads(llama_context * ctx) {
|
|
return ctx->n_threads();
|
|
}
|
|
|
|
int32_t llama_n_threads_batch(llama_context * ctx) {
|
|
return ctx->n_threads_batch();
|
|
}
|
|
|
|
void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
|
|
ctx->set_abort_callback(abort_callback, abort_callback_data);
|
|
}
|
|
|
|
void llama_set_embeddings(llama_context * ctx, bool embeddings) {
|
|
ctx->set_embeddings(embeddings);
|
|
}
|
|
|
|
void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
|
|
ctx->set_causal_attn(causal_attn);
|
|
}
|
|
|
|
void llama_set_warmup(llama_context * ctx, bool warmup) {
|
|
ctx->set_warmup(warmup);
|
|
}
|
|
|
|
void llama_synchronize(llama_context * ctx) {
|
|
ctx->synchronize();
|
|
}
|
|
|
|
float * llama_get_logits(llama_context * ctx) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->get_logits();
|
|
}
|
|
|
|
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->get_logits_ith(i);
|
|
}
|
|
|
|
float * llama_get_embeddings(llama_context * ctx) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->get_embeddings();
|
|
}
|
|
|
|
float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->get_embeddings_ith(i);
|
|
}
|
|
|
|
float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->get_embeddings_seq(seq_id);
|
|
}
|
|
|
|
// llama adapter API
|
|
|
|
int32_t llama_set_adapter_lora(
|
|
llama_context * ctx,
|
|
llama_adapter_lora * adapter,
|
|
float scale) {
|
|
ctx->set_adapter_lora(adapter, scale);
|
|
|
|
return 0;
|
|
}
|
|
|
|
int32_t llama_rm_adapter_lora(
|
|
llama_context * ctx,
|
|
llama_adapter_lora * adapter) {
|
|
bool res = ctx->rm_adapter_lora(adapter);
|
|
|
|
return res ? 0 : -1;
|
|
}
|
|
|
|
void llama_clear_adapter_lora(llama_context * ctx) {
|
|
ctx->clear_adapter_lora();
|
|
}
|
|
|
|
int32_t llama_apply_adapter_cvec(
|
|
llama_context * ctx,
|
|
const float * data,
|
|
size_t len,
|
|
int32_t n_embd,
|
|
int32_t il_start,
|
|
int32_t il_end) {
|
|
bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
|
|
|
|
return res ? 0 : -1;
|
|
}
|
|
|
|
//
|
|
// kv cache view
|
|
//
|
|
|
|
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
|
const auto * kv = ctx->get_kv_self();
|
|
if (kv == nullptr) {
|
|
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
|
return {};
|
|
}
|
|
|
|
return llama_kv_cache_view_init(*kv, n_seq_max);
|
|
}
|
|
|
|
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
|
const auto * kv = ctx->get_kv_self();
|
|
if (kv == nullptr) {
|
|
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
|
return;
|
|
}
|
|
|
|
llama_kv_cache_view_update(view, kv);
|
|
}
|
|
|
|
//
|
|
// kv cache
|
|
//
|
|
|
|
// deprecated
|
|
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
|
return llama_kv_self_n_tokens(ctx);
|
|
}
|
|
|
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
const auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return 0;
|
|
}
|
|
|
|
return kv->get_n_tokens();
|
|
}
|
|
|
|
// deprecated
|
|
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
|
return llama_kv_self_used_cells(ctx);
|
|
}
|
|
|
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
const auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return 0;
|
|
}
|
|
|
|
return kv->get_used_cells();
|
|
}
|
|
|
|
// deprecated
|
|
void llama_kv_cache_clear(llama_context * ctx) {
|
|
llama_kv_self_clear(ctx);
|
|
}
|
|
|
|
void llama_kv_self_clear(llama_context * ctx) {
|
|
auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return;
|
|
}
|
|
|
|
kv->clear();
|
|
}
|
|
|
|
// deprecated
|
|
bool llama_kv_cache_seq_rm(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos p0,
|
|
llama_pos p1) {
|
|
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
|
}
|
|
|
|
bool llama_kv_self_seq_rm(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos p0,
|
|
llama_pos p1) {
|
|
auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return true;
|
|
}
|
|
|
|
return kv->seq_rm(seq_id, p0, p1);
|
|
}
|
|
|
|
// deprecated
|
|
void llama_kv_cache_seq_cp(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id_src,
|
|
llama_seq_id seq_id_dst,
|
|
llama_pos p0,
|
|
llama_pos p1) {
|
|
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
|
}
|
|
|
|
void llama_kv_self_seq_cp(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id_src,
|
|
llama_seq_id seq_id_dst,
|
|
llama_pos p0,
|
|
llama_pos p1) {
|
|
auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return;
|
|
}
|
|
|
|
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
}
|
|
|
|
// deprecated
|
|
void llama_kv_cache_seq_keep(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id) {
|
|
return llama_kv_self_seq_keep(ctx, seq_id);
|
|
}
|
|
|
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return;
|
|
}
|
|
|
|
return kv->seq_keep(seq_id);
|
|
}
|
|
|
|
// deprecated
|
|
void llama_kv_cache_seq_add(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos p0,
|
|
llama_pos p1,
|
|
llama_pos delta) {
|
|
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
|
}
|
|
|
|
void llama_kv_self_seq_add(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos p0,
|
|
llama_pos p1,
|
|
llama_pos delta) {
|
|
auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return;
|
|
}
|
|
|
|
return kv->seq_add(seq_id, p0, p1, delta);
|
|
}
|
|
|
|
// deprecated
|
|
void llama_kv_cache_seq_div(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos p0,
|
|
llama_pos p1,
|
|
int d) {
|
|
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
|
}
|
|
|
|
void llama_kv_self_seq_div(
|
|
llama_context * ctx,
|
|
llama_seq_id seq_id,
|
|
llama_pos p0,
|
|
llama_pos p1,
|
|
int d) {
|
|
auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return;
|
|
}
|
|
|
|
return kv->seq_div(seq_id, p0, p1, d);
|
|
}
|
|
|
|
// deprecated
|
|
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
|
}
|
|
|
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
const auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return 0;
|
|
}
|
|
|
|
return kv->seq_pos_max(seq_id);
|
|
}
|
|
|
|
// deprecated
|
|
void llama_kv_cache_defrag(llama_context * ctx) {
|
|
return llama_kv_self_defrag(ctx);
|
|
}
|
|
|
|
void llama_kv_self_defrag(llama_context * ctx) {
|
|
auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return;
|
|
}
|
|
|
|
return kv->defrag();
|
|
}
|
|
|
|
// deprecated
|
|
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
|
return llama_kv_self_can_shift(ctx);
|
|
}
|
|
|
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
const auto * kv = ctx->get_kv_self();
|
|
if (!kv) {
|
|
return false;
|
|
}
|
|
|
|
return kv->get_can_shift();
|
|
}
|
|
|
|
// deprecated
|
|
void llama_kv_cache_update(llama_context * ctx) {
|
|
llama_kv_self_update(ctx);
|
|
}
|
|
|
|
// llama state API
|
|
|
|
// deprecated
|
|
size_t llama_get_state_size(llama_context * ctx) {
|
|
return llama_state_get_size(ctx);
|
|
}
|
|
|
|
// deprecated
|
|
size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
|
|
return llama_state_get_data(ctx, dst, -1);
|
|
}
|
|
|
|
// deprecated
|
|
size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
|
|
return llama_state_set_data(ctx, src, -1);
|
|
}
|
|
|
|
// deprecated
|
|
bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
|
}
|
|
|
|
// deprecated
|
|
bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
|
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
|
}
|
|
|
|
// Returns the *actual* size of the state.
|
|
// Intended to be used when saving to state to a buffer.
|
|
size_t llama_state_get_size(llama_context * ctx) {
|
|
return ctx->state_get_size();
|
|
}
|
|
|
|
size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->state_get_data(dst, size);
|
|
}
|
|
|
|
// Sets the state reading from the specified source address
|
|
size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->state_set_data(src, size);
|
|
}
|
|
|
|
bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
ctx->synchronize();
|
|
|
|
try {
|
|
return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
|
ctx->synchronize();
|
|
|
|
try {
|
|
return ctx->state_save_file(path_session, tokens, n_token_count);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
|
|
return false;
|
|
}
|
|
}
|
|
|
|
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
|
return ctx->state_seq_get_size(seq_id);
|
|
}
|
|
|
|
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->state_seq_get_data(seq_id, dst, size);
|
|
}
|
|
|
|
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
|
ctx->synchronize();
|
|
|
|
return ctx->state_seq_set_data(seq_id, src, size);
|
|
}
|
|
|
|
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
|
ctx->synchronize();
|
|
|
|
try {
|
|
return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
ctx->synchronize();
|
|
|
|
try {
|
|
return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
|
|
} catch (const std::exception & err) {
|
|
LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
///
|
|
|
|
int32_t llama_encode(
|
|
llama_context * ctx,
|
|
llama_batch batch) {
|
|
const int ret = ctx->encode(batch);
|
|
if (ret != 0) {
|
|
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
int32_t llama_decode(
|
|
llama_context * ctx,
|
|
llama_batch batch) {
|
|
const int ret = ctx->decode(batch);
|
|
if (ret != 0) {
|
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
//
|
|
// perf
|
|
//
|
|
|
|
llama_perf_context_data llama_perf_context(const llama_context * ctx) {
|
|
llama_perf_context_data data = {};
|
|
|
|
if (ctx == nullptr) {
|
|
return data;
|
|
}
|
|
|
|
data = ctx->perf_get_data();
|
|
|
|
return data;
|
|
}
|
|
|
|
void llama_perf_context_print(const llama_context * ctx) {
|
|
const auto data = llama_perf_context(ctx);
|
|
|
|
const double t_end_ms = 1e-3 * ggml_time_us();
|
|
|
|
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
|
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
|
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
|
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
|
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
|
}
|
|
|
|
void llama_perf_context_reset(llama_context * ctx) {
|
|
ctx->perf_reset();
|
|
}
|