diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index e060ba7b..da190e33 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -17,6 +17,9 @@ if (WHISPER_SDL2) llama-impl.cpp llama-io.cpp llama-kv-cache.cpp + llama-kv-cache-unified.cpp + llama-kv-cache-unified-iswa.cpp + llama-kv-cache-recurrent.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index abf436ad..c0590e10 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -174,6 +174,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, + { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -448,6 +450,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 41a023da..930cb4ec 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -213,6 +213,8 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_CLASSIFIER_OUTPUT_LABELS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index b98e3256..6a19a243 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { break; } } - ubatch_token.resize(!has_embd ? n_ubatch : 0); - ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); - ubatch_pos.resize(n_ubatch); - ubatch_n_seq_id.resize(n_ubatch); - ubatch_seq_id.resize(n_ubatch); - ubatch_output.resize(n_ubatch); + + udatas.push_back({}); + + auto & udata = udatas.back(); + + udata.token.resize(!has_embd ? n_ubatch : 0); + udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); + udata.pos.resize(n_ubatch); + udata.n_seq_id.resize(n_ubatch); + udata.seq_id.resize(n_ubatch); + udata.output.resize(n_ubatch); + llama_ubatch ubatch = { /*equal_seqs =*/ true, /*n_tokens =*/ 0, /*n_seq_tokens =*/ 0, /*n_seqs =*/ 0, - /*token =*/ !has_embd ? ubatch_token.data() : nullptr, - /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, - /*pos =*/ ubatch_pos.data(), - /*n_seq_id =*/ ubatch_n_seq_id.data(), - /*seq_id =*/ ubatch_seq_id.data(), - /*output =*/ ubatch_output.data(), + /*token =*/ !has_embd ? udata.token.data() : nullptr, + /*embd =*/ has_embd ? udata.embd.data() : nullptr, + /*pos =*/ udata.pos.data(), + /*n_seq_id =*/ udata.n_seq_id.data(), + /*seq_id =*/ udata.seq_id.data(), + /*output =*/ udata.output.data(), }; + return ubatch; } diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index 6305051b..b8260b94 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -11,15 +11,15 @@ struct llama_ubatch { bool equal_seqs; // TODO: whole_seqs for embeddings? - uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] - int32_t * n_seq_id; // [n_seqs] - llama_seq_id ** seq_id; // [n_seqs] + int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id; int8_t * output; // [n_tokens] }; @@ -49,13 +49,18 @@ struct llama_sbatch { const llama_batch * batch = nullptr; - // buffers for the ubatch - std::vector ubatch_token; - std::vector ubatch_embd; - std::vector ubatch_pos; - std::vector ubatch_n_seq_id; - std::vector ubatch_seq_id; - std::vector ubatch_output; + // buffers for the ubatches + // TODO: very hacky, this needs a complete rework + struct ubatch_data { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector output; + }; + + std::vector udatas; llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index e153351a..4ab57438 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -6,9 +6,10 @@ #include "llama-model.h" #include "llama-kv-cache.h" -#include -#include #include +#include +#include +#include // // llama_context @@ -122,6 +123,11 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } + if (!params.swa_full && cparams.n_seq_max > 1) { + LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", + __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); + } + if (!hparams.vocab_only) { // GPU backends for (auto * dev : model.devices) { @@ -259,15 +265,9 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = 1; // TODO: worst-case number of sequences + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - - // restore later - // TODO: something cleaner - const auto n_outputs_save = n_outputs; - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); int n_splits_pp = -1; @@ -279,23 +279,17 @@ llama_context::llama_context( // simulate full KV cache llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->set_full(); + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } cross.v_embd.clear(); // reserve pp graph first so that buffers are only allocated once { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - // max number of outputs - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -305,16 +299,8 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_tg.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(1, 1, 1, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -324,22 +310,12 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } } - n_outputs = n_outputs_save; - for (size_t i = 0; i < backend_ptrs.size(); ++i) { ggml_backend_t backend = backend_ptrs[i]; ggml_backend_buffer_type_t buft = backend_buft[i]; @@ -453,36 +429,33 @@ const llama_kv_cache * llama_context::get_kv_self() const { return kv_self; } -void llama_context::kv_self_update() { - bool need_reserve = false; +bool llama_context::kv_self_update() { + if (!memory) { + return false; + } llama_kv_cache * kv_self = static_cast(memory.get()); - need_reserve = kv_self->update(*this); - - // reserve a worst case graph if needed - if (need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - // simulate full KV cache - kv_self->set_full(); - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - } + if (!kv_self->update(*this)) { + // no updates have been performed + return false; } + + // if the KV cache did any computation, we have to reserve a new worst-case graph + const auto kv_state = kv_self->init_full(); + if (!kv_state) { + throw std::runtime_error("failed to initialize KV cache"); + } + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); + } + + return true; } enum llama_pooling_type llama_context::pooling_type() const { @@ -676,6 +649,49 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { + if (mstate && !mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto * gf = graph_init(); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + + res->set_inputs(&ubatch); + + const auto status = graph_compute(gf, ubatch.n_tokens > 1); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); + ret = status; + return nullptr; + } + + ret = GGML_STATUS_SUCCESS; + + return res; +} + int llama_context::encode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -737,8 +753,6 @@ int llama_context::encode(llama_batch & inp_batch) { n_outputs = n_tokens; - //batch_manager->prepare(ubatch); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -749,26 +763,18 @@ int llama_context::encode(llama_batch & inp_batch) { // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 cparams.causal_attn = false; - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(&ubatch); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); cparams.causal_attn = causal_attn_org; - const auto compute_status = graph_compute(gf, n_tokens > 1); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } } auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); @@ -889,8 +895,6 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kv_guard(kv_self); - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT // TODO: move the validation to the llama_batch_allocr @@ -936,7 +940,48 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } - llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all); + // handle any pending defrags/shifts + kv_self_update(); + + llama_memory_state_ptr kv_state; + + bool did_defrag = false; + + while (true) { + kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + if (!kv_state) { + return -2; + } + + switch (kv_state->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + } break; + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + { + if (!did_defrag) { + did_defrag = true; + + kv_self->defrag_sched(-1.0f); + if (kv_self_update()) { + LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens); + + continue; + } + } + + LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); + + return 1; + } + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return -2; + } + } + + break; + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -944,13 +989,10 @@ int llama_context::decode(llama_batch & inp_batch) { return -2; }; - // handle any pending defrags/shifts - kv_self_update(); - int64_t n_outputs_prev = 0; - while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + do { + const auto & ubatch = kv_state->get_ubatch(); // count the outputs in this u_batch { @@ -969,33 +1011,37 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_new; } - // find KV slot - if (!kv_self->find_slot(ubatch)) { - return 1; - } - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + if (!res) { + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits::max() }; - ggml_backend_sched_alloc_graph(sched.get(), gf); + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto & seq_id = ubatch.seq_id[i][0]; - res->set_inputs(&ubatch); + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); + } - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); - if (compute_status != GGML_STATUS_SUCCESS) { - switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (pos_min[s] == std::numeric_limits::max()) { + continue; + } + + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + + llama_kv_self_seq_rm(this, s, pos_min[s], -1); + } + + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); } } @@ -1082,10 +1128,7 @@ int llama_context::decode(llama_batch & inp_batch) { } n_outputs_prev += n_outputs; - } - - // finalize the batch processing - kv_guard.commit(); + } while (kv_state->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1094,7 +1137,7 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - auto & out_ids = sbatch.out_ids; + auto & out_ids = kv_state->out_ids(); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); @@ -1254,11 +1297,52 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { + LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + + if (n_tokens % n_seqs != 0) { + n_tokens = (n_tokens / n_seqs) * n_seqs; + n_outputs = std::min(n_outputs, n_tokens); + + LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); + } + + // store the n_outputs as it is, and restore it afterwards + // TODO: not sure if needed, might simplify in the future by removing this + const auto save_n_outputs = this->n_outputs; + + this->n_outputs = n_outputs; + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); + + this->n_outputs = save_n_outputs; + + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); + return nullptr; + } + + ggml_backend_sched_reset(sched.get()); + + // initialize scheduler with the specified graph + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } + + return gf; +} + llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype) { + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate) { return model.build_graph( { /*.ctx =*/ ctx, @@ -1270,7 +1354,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ memory.get(), + /*.mstate =*/ mstate, /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -1951,7 +2035,6 @@ void llama_context::opt_epoch_iter( llama_kv_cache * kv_self = static_cast(memory.get()); kv_self->clear(); - llama_kv_cache_guard kv_guard(kv_self); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -1974,7 +2057,11 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true); + auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); + break; + } // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -1982,20 +2069,19 @@ void llama_context::opt_epoch_iter( GGML_ABORT("TODO: handle this error"); }; - for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) { - llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + uint32_t pos_batch = 0; + do { + const auto & ubatch = kv_state->get_ubatch(); n_outputs = ubatch.n_tokens; - // TODO: not sure if this is needed - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - GGML_ABORT("TODO: handle this error"); + if (!kv_state->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + break; } auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); struct ggml_context * ctx_compute_opt; { @@ -2010,6 +2096,7 @@ void llama_context::opt_epoch_iter( } ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); + res->set_inputs(&ubatch); { struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); @@ -2027,10 +2114,10 @@ void llama_context::opt_epoch_iter( callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); } ggml_free(ctx_compute_opt); - } - } - kv_guard.commit(); + pos_batch += ubatch.n_tokens; + } while (kv_state->next()); + } } void llama_context::opt_epoch( @@ -2194,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) { return ctx->get_kv_self(); } +// deprecated void llama_kv_self_update(llama_context * ctx) { ctx->kv_self_update(); } @@ -2448,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { return kv->seq_pos_max(seq_id); } +// deprecated void llama_kv_self_defrag(llama_context * ctx) { auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2589,22 +2678,8 @@ int32_t llama_encode( int32_t llama_decode( llama_context * ctx, llama_batch batch) { - int ret = ctx->decode(batch); - - // defrag and try again - // TODO: distinguish return code when we are sure that even after defrag there is no space available - if (ret == 1) { - llama_kv_self_defrag(ctx); - ret = ctx->decode(batch); - - if (ret == 1) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); - - return ret; - } - } - - if (ret != 0) { + const int ret = ctx->decode(batch); + if (ret != 0 && ret != 1) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index c0ceacb1..3b880286 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -18,6 +18,9 @@ struct llama_kv_cache; class llama_io_read_i; class llama_io_write_i; +class llama_memory_i; +class llama_memory_state_i; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -47,7 +50,9 @@ struct llama_context { llama_kv_cache * get_kv_self(); const llama_kv_cache * get_kv_self() const; - void kv_self_update(); + // return true of the KV cache was updated + // TODO: remove + bool kv_self_update(); enum llama_pooling_type pooling_type() const; @@ -88,6 +93,16 @@ struct llama_context { int32_t il_start, int32_t il_end); + // process a single ubatch with a specific graph type + // if memory_state is provided, it will be applied first to the context's memory + // ret contains the status of the graph computation + // returns nullptr only if ret != GGML_STATUS_SUCCESS + llm_graph_result_ptr process_ubatch( + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_state_i * mstate, + ggml_status & ret); + int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); @@ -180,16 +195,18 @@ public: ggml_cgraph * graph_init(); // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); + ggml_status graph_compute(ggml_cgraph * gf, bool batched); + + // reserve a graph with a dummy ubatch of the specified size + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); private: llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype); + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate); llm_graph_cb graph_get_cb() const; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index cdd5887d..727e119e 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -3,7 +3,10 @@ #include "llama-impl.h" #include "llama-batch.h" #include "llama-cparams.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include #include @@ -83,7 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - kv_self->set_input_pos_bucket(pos_bucket, ubatch); + kv_state->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -234,7 +237,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -242,7 +245,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_copy(i); + data[i] = kv_state->s_copy(i); } } } @@ -250,7 +253,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); @@ -258,7 +261,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_mask(i); + data[i] = kv_state->s_mask(i); } } } @@ -362,17 +365,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } if (self_kq_mask_swa) { - kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -448,14 +451,14 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : backend_cpu (params.backend_cpu), cvec (params.cvec), loras (params.loras), - memory (params.memory), + mstate (params.mstate), cross (params.cross), cb_func (params.cb), res (std::make_unique()) { } int64_t llm_graph_context::n_pos_per_embd() const { - return arch == LLM_ARCH_QWEN2VL ? 4 : 1; + return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; } void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { @@ -954,11 +957,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_copy; @@ -971,11 +974,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { } ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_mask; @@ -1025,11 +1028,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, kv_self); + auto inp = std::make_unique(hparams, kv_state); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->pos_bucket; @@ -1231,14 +1234,14 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = kv_self->get_n(); + const auto n_kv = kv_state->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1268,19 +1271,19 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); // store to KV cache { - ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_self->get_k(ctx0, il); - ggml_tensor * v = kv_self->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1301,12 +1304,12 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); { - const auto n_kv = kv_self->get_kv_base()->get_n(); + const auto n_kv = kv_state->get_base()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1318,7 +1321,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); - const auto n_kv = kv_self->get_kv_swa()->get_n(); + const auto n_kv = kv_state->get_swa()->get_n_kv(); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); @@ -1348,23 +1351,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); + const auto * kv_state_iswa = static_cast(mstate); + const bool is_swa = hparams.is_swa(il); - const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); - - const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base(); + const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); // store to KV cache { - ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv->get_k(ctx0, il); - ggml_tensor * v = kv->get_v(ctx0, il); + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1446,12 +1449,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto n_kv = kv_self->n; - const auto kv_head = kv_self->head; + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size()); // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv @@ -1478,13 +1481,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * token_shift_all = kv_self->k_l[il]; + ggml_tensor * token_shift_all = kv_state->get_k_l(il); ggml_tensor * token_shift = build_copy_mask_state( gf, token_shift_all, state_copy, state_mask, @@ -1499,19 +1502,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il])) + ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il))) ); } @@ -1562,20 +1565,25 @@ void llm_graph_context::build_pooling( ggml_tensor * inp_cls = build_inp_cls(); inp = ggml_get_rows(ctx0, inp, inp_cls); - // classification head - // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 - GGML_ASSERT(cls != nullptr); - GGML_ASSERT(cls_b != nullptr); + if (cls != nullptr && cls_b != nullptr) { + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b); + cur = ggml_tanh(ctx0, cur); - cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b); - cur = ggml_tanh(ctx0, cur); - - // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en - // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 - if (cls_out) { + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (cls_out) { + GGML_ASSERT(cls_out_b != nullptr); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b); + } + } else if (cls_out) { + // Single layer classification head (direct projection) + // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 GGML_ASSERT(cls_out_b != nullptr); - - cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b); + } else { + GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b"); } } break; default: diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 2b85bb25..d1c5dd1b 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -17,10 +17,11 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -class llama_memory_i; -class llama_kv_cache_unified; -class llama_kv_cache_unified_iswa; -class llama_kv_cache_recurrent; +class llama_memory_state_i; + +class llama_kv_cache_unified_state; +class llama_kv_cache_unified_iswa_state; +class llama_kv_cache_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {} + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -141,7 +142,7 @@ public: ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] const llama_hparams & hparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -188,26 +189,26 @@ public: class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_s_mask : public llm_graph_input_i { public: - llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_mask() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_mask; // F32 [1, n_kv] - const llama_kv_cache_recurrent * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -247,10 +248,10 @@ public: llm_graph_input_attn_kv_unified( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified * kv_self) : + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified() = default; @@ -264,7 +265,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { @@ -272,10 +273,10 @@ public: llm_graph_input_attn_kv_unified_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa * kv_self) : + const llama_kv_cache_unified_iswa_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified_iswa() = default; @@ -292,7 +293,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_iswa * kv_self; + const llama_kv_cache_unified_iswa_state * kv_state; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -383,10 +384,10 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; int32_t n_outputs; @@ -435,10 +436,10 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; const llm_graph_cb & cb_func; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 2d72eab1..b2bcb8b0 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -131,6 +131,9 @@ struct llama_hparams { bool attn_soft_cap = false; bool use_kq_norm = true; + // for Classifiers + uint32_t n_cls_out = 1; + // llama4 uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; diff --git a/examples/talk-llama/llama-kv-cache-recurrent.cpp b/examples/talk-llama/llama-kv-cache-recurrent.cpp new file mode 100644 index 00000000..641eab2f --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-recurrent.cpp @@ -0,0 +1,1132 @@ +#include "llama-kv-cache-recurrent.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include +#include +#include +#include + +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_recurrent::clear() { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + head = 0; + used = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += shift; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { + llama_pos result = std::numeric_limits::max(); + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::min(result, cells[i].pos); + } + } + + if (result == std::numeric_limits::max()) { + result = -1; + } + + return result; +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = -1; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { + // simply remember the full state because it is very small for this type of cache + // TODO: optimize + auto org_cells = cells; + auto org_used = used; + auto org_head = head; + + bool success = true; + + // TODO: here we have to verify that all ubatches can fit in the cells + // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells + // during the compute of each ubatch. to reproduce, uncomment the following loop and run: + // + // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 + // + // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed + // + GGML_UNUSED(ubatches); + //for (const auto & ubatch : ubatches) { + // if (!find_slot(ubatch)) { + // success = false; + // break; + // } + //} + + // restore the original state + cells = std::move(org_cells); + used = org_used; + head = org_head; + + return success; +} + +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); + // noop + return false; +} + +void llama_kv_cache_recurrent::defrag_sched(float thold) { + GGML_UNUSED(thold); + // noop +} + +bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_tokens) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); + return false; + } + if (j > 0) { + kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const kv_cell & cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + return false; +} + +int32_t llama_kv_cache_recurrent::s_copy(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + // prevent out-of-bound sources + if (cell.src < 0 || (uint32_t) cell.src >= size) { + cell.src = cell_id; + } + + int32_t res = cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (cell.src != (int32_t) cell_id) { + cell.src = cell_id; + } + + return res; +} + +float llama_kv_cache_recurrent::s_mask(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + float res = (float) (cell.src >= 0); + + // only clear once + if (cell.src < 0) { + cell.src = cell_id; + } + + return res; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (false != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_recurrent_state +// + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { +} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + +llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; + +bool llama_kv_cache_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { + return is_full ? kv->size : kv->n; +} + +uint32_t llama_kv_cache_recurrent_state::get_head() const { + return is_full ? 0 : kv->head; +} + +uint32_t llama_kv_cache_recurrent_state::get_size() const { + return kv->size; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { + return kv->k_l[il]; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { + return kv->v_l[il]; +} + +int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { + return kv->s_copy(i); +} + +float llama_kv_cache_recurrent_state::s_mask(int i) const { + return kv->s_mask(i); +} diff --git a/examples/talk-llama/llama-kv-cache-recurrent.h b/examples/talk-llama/llama-kv-cache-recurrent.h new file mode 100644 index 00000000..a178ae85 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-recurrent.h @@ -0,0 +1,191 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" + +#include +#include + +// +// llama_kv_cache_recurrent +// + +// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// see the implementation of llama_kv_cache_unified_state_i for an example how to do it +class llama_kv_cache_recurrent : public llama_kv_cache { +public: + llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); + + ~llama_kv_cache_recurrent() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool prepare(const std::vector & ubatches); + + // find a contiguous slot of kv cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); + + bool get_can_shift() const override; + + // TODO: temporary methods - they are not really const as they do const_cast<>, fix this + int32_t s_copy(int i) const; + float s_mask(int i) const; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + // TODO: optimize for recurrent state needs + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + //const llama_model & model; + const llama_hparams & hparams; + + const uint32_t n_seq_max = 1; + + std::vector ctxs; + std::vector bufs; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_recurrent_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_recurrent_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv); + + // used to create a state from a batch + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches); + + virtual ~llama_kv_cache_recurrent_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_recurrent_state specific API + // + + uint32_t get_n_kv() const; + uint32_t get_head() const; + uint32_t get_size() const; + + ggml_tensor * get_k_l(int32_t il) const; + ggml_tensor * get_v_l(int32_t il) const; + + int32_t s_copy(int i) const; + float s_mask(int i) const; + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp new file mode 100644 index 00000000..0eb04563 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp @@ -0,0 +1,249 @@ +#include "llama-kv-cache-unified-iswa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include + +// +// llama_kv_cache_unified_iswa +// + +llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad) : hparams(model.hparams) { + llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; + llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + + const uint32_t size_base = kv_size; + + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size + if (swa_full) { + LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + + size_swa = size_base; + } + + LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + + kv_base = std::make_unique( + model, std::move(filter_base), type_k, type_v, + v_trans, offload, size_base, n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE); + + LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); + + kv_swa = std::make_unique( + model, std::move(filter_swa), type_k, type_v, + v_trans, offload, size_swa, n_seq_max, n_pad, + hparams.n_swa, hparams.swa_type); +} + +void llama_kv_cache_unified_iswa::clear() { + kv_base->clear(); + kv_swa ->clear(); +} + +bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_base->seq_rm(seq_id, p0, p1); + res = res & kv_swa ->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { + kv_base->seq_keep(seq_id); + kv_swa ->seq_keep(seq_id); +} + +void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_base->seq_add(seq_id, p0, p1, shift); + kv_swa ->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_base->seq_div(seq_id, p0, p1, d); + kv_swa ->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the base cache is a superset of the SWA cache, so we can just check the SWA cache + return kv_swa->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { + return kv_swa->seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + // TODO: if we fail with split_simple, we should attempt different splitting strategies + // but to do that properly, we first have to refactor the batches to be more flexible + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_simple(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { + bool res = false; + + res = res | kv_base->update(lctx); + res = res | kv_swa ->update(lctx); + + return res; +} + +void llama_kv_cache_unified_iswa::defrag_sched(float thold) { + kv_base->defrag_sched(thold); + kv_swa ->defrag_sched(thold); +} + +bool llama_kv_cache_unified_iswa::get_can_shift() const { + return kv_base->get_size() == kv_swa->get_size(); +} + +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_base->state_write(io, seq_id); + kv_swa ->state_write(io, seq_id); +} + +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_base->state_read(io, seq_id); + kv_swa ->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { + return kv_base.get(); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { + return kv_swa.get(); +} + +// +// llama_kv_cache_unified_iswa_state +// + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv) : status(status) { + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(status), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + } + +llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; + +bool llama_kv_cache_unified_iswa_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + state_base->next(); + state_swa ->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_iswa_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + bool res = true; + + res = res & state_base->apply(); + res = res & state_swa ->apply(); + + return res; +} + +std::vector & llama_kv_cache_unified_iswa_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_base.get(); +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_swa.get(); +} diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.h b/examples/talk-llama/llama-kv-cache-unified-iswa.h new file mode 100644 index 00000000..8b067da0 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.h @@ -0,0 +1,136 @@ +#pragma once + +#include "llama-kv-cache-unified.h" + +#include + +// +// llama_kv_cache_unified_iswa +// + +// utilizes two instances of llama_kv_cache_unified +// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers + +class llama_kv_cache_unified_iswa : public llama_kv_cache { +public: + llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad); + + ~llama_kv_cache_unified_iswa() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified_iswa specific API + // + + llama_kv_cache_unified * get_base() const; + llama_kv_cache_unified * get_swa () const; + +private: + const llama_hparams & hparams; + + std::unique_ptr kv_base; + std::unique_ptr kv_swa; +}; + +class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_iswa_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv); + + // used to create a state from a batch + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_iswa_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_iswa_state specific API + // + + const llama_kv_cache_unified_state * get_base() const; + const llama_kv_cache_unified_state * get_swa() const; + +private: + const llama_memory_status status; + + //llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + std::unique_ptr state_base; + std::unique_ptr state_swa; +}; diff --git a/examples/talk-llama/llama-kv-cache-unified.cpp b/examples/talk-llama/llama-kv-cache-unified.cpp new file mode 100644 index 00000000..a8171547 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-unified.cpp @@ -0,0 +1,1717 @@ +#include "llama-kv-cache-unified.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +#include +#include +#include +#include +#include +#include + +// +// llama_kv_cache_unified +// + +llama_kv_cache_unified::llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type) : + model(model), hparams(model.hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + + GGML_ASSERT(kv_size % n_pad == 0); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + head = 0; + + cells.resize(kv_size); + + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + continue; + } + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k; + ggml_tensor * v; + + k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); + v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + + ggml_format_name(k, "cache_k_l%d", il); + ggml_format_name(v, "cache_v_l%d", il); + + map_layer_ids[il] = layers.size(); + layers.push_back({ il, k, v }); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + + ggml_backend_buffer_clear(buf, 0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_unified::clear() { + cells.reset(); + + head = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } +} + +void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_keep(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over all cells. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + if (cells.pos_add(i, shift)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + head = new_head != cells.size() ? new_head : 0; +} + +void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + cells.pos_div(i, d); + } + } +} + +llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + return cells.seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + return cells.seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + while (sbatch.n_tokens > 0) { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } + + auto heads = prepare(ubatches); + if (heads.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { + std::vector res; + + struct state { + uint32_t head_old; // old position of the head, before placing the ubatch + uint32_t head_new; // new position of the head, after placing the ubatch + + llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + }; + + // remember the old state of the cells so we can restore it in the end + std::vector states; + + bool success = true; + + for (const auto & ubatch : ubatches) { + // only find a suitable slot for the ubatch. don't modify the cells yet + const int32_t head_new = find_slot(ubatch); + if (head_new < 0) { + success = false; + break; + } + + // remeber the position that we found + res.push_back(head_new); + + // store the old state of the cells in the recovery stack + states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + + // now emplace the ubatch + apply_ubatch(head_new, ubatch); + } + + // iterate backwards and restore the cells to their original state + for (auto it = states.rbegin(); it != states.rend(); ++it) { + cells.set(it->head_new, it->cells); + head = it->head_old; + } + + if (!success) { + return {}; + } + + return res; +} + +bool llama_kv_cache_unified::update(llama_context & lctx) { + bool updated = false; + + auto * sched = lctx.get_sched(); + + if (cells.get_has_shift()) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx.graph_init(); + + auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); + return updated; + } + + updated = true; + } + + cells.reset_shift(); + } + + if (do_defrag) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + + if (defrag_prepare(lctx.graph_max_nodes())) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx.graph_init(); + + auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; + } + + do_defrag = false; + } + + return updated; +} + +void llama_kv_cache_unified::defrag_sched(float thold) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } +} + +int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + uint32_t head_cur = this->head; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { + head_cur = 0; + } + + // otherwise, one cell per token. + + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return -1; + } + +//#define FIND_SLOT_DEBUG 1 +#if FIND_SLOT_DEBUG + LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa); + + // for debugging + { + std::string ss; + if (n_swa > 0) { + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.is_empty(i)) { + ss += '.'; + } else { + ss += std::to_string(cells.seq_get(i)); + } + if (i%256 == 255) { + ss += '\n'; + } + } + } + LLAMA_LOG_WARN("\n%s\n", ss.c_str()); + } + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } +#endif + + uint32_t n_tested = 0; + + while (true) { + if (head_cur + n_tokens > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + // keep track of what the minimum sequence positions would be if we accept the ubatch + llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos_min[s] = cells.seq_pos_min(s); + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_pos pos = ubatch.pos[i]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(head_cur + i); + + if (!can_use && cells.seq_count(head_cur + i) == 1) { + const llama_pos pos_cell = cells.pos_get(head_cur + i); + + // causal mask + if (cells.seq_has(head_cur + i, seq_id)) { + can_use = pos_cell >= pos; + } + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + + // SWA mask + // note: we insert only in the cell with minimum pos in order to preserve the invariant that + // all positions between [pos_min, pos_max] for each sequence will be present in the cache + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + if (pos_cell == seq_pos_min[seq_id_cell] && + is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + seq_pos_min[seq_id_cell]++; + can_use = true; + } + } + } + + if (!can_use) { + found = false; + head_cur += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return -1; + } + } + + return head_cur; +} + +void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (!cells.is_empty(head_cur + i)) { + cells.rm(head_cur + i); + } + + cells.pos_set(head_cur + i, ubatch.pos[i]); + + for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { + cells.seq_add(head_cur + i, ubatch.seq_id[i][j]); + } + } + + // move the head at the end of the slot + head = head_cur + ubatch.n_tokens; +} + +bool llama_kv_cache_unified::get_can_shift() const { + return true; +} + +uint32_t llama_kv_cache_unified::get_size() const { + return cells.size(); +} + +uint32_t llama_kv_cache_unified::get_n_kv() const { + return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + return ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + 0); +} + +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + if (!v_trans) { + // note: v->nb[1] <= v->nb[2] + return ggml_view_3d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] + 0); + } + + // note: v->nb[1] > v->nb[2] + return ggml_view_3d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, + ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, v->ne[1]), // v->nb[2] + 0); +} + +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + const int64_t n_tokens = k_cur->ne[2]; + + ggml_tensor * k_view = ggml_view_1d(ctx, k, + n_tokens*hparams.n_embd_k_gqa(il), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + + return ggml_cpy(ctx, k_cur, k_view); +} + +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + const int64_t n_tokens = v_cur->ne[2]; + + v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + + ggml_tensor * v_view = nullptr; + + if (!v_trans) { + v_view = ggml_view_1d(ctx, v, + n_tokens*hparams.n_embd_v_gqa(il), + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + } else { + // note: the V cache is transposed when not using flash attention + v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), + (v->ne[1])*ggml_element_size(v), + (head_cur)*ggml_element_size(v)); + + v_cur = ggml_transpose(ctx, v_cur); + } + + return ggml_cpy(ctx, v_cur, v_view); +} + +void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const auto n_kv = dst->ne[0]; + + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; + + for (uint32_t i = 0; i < n_kv; ++i) { + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(i)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(i); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(i, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + + // mask padded tokens + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (uint32_t j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } +} + +void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) dst->data; + + const int32_t n_kv = dst->ne[0]; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + // the position when the cells is empty is irrelevant - it will be masked out later in the attention + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); + + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + } + } + } +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_unified::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & layer : layers) { + size_k_bytes += ggml_nbytes(layer.k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_unified::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & layer : layers) { + size_v_bytes += ggml_nbytes(layer.v); + } + + return size_v_bytes; +} + +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; + + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE + // @ngxson : this is a workaround + // for M-RoPE, we want to rotate the whole vector when doing KV shift + // a normal RoPE should work, we just need to use the correct ordering + // ref: https://github.com/ggml-org/llama.cpp/pull/13870 + ? LLAMA_ROPE_TYPE_NEOX + : hparams.rope_type; + + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 + ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) + : cparams.yarn_attn_factor; + + ggml_tensor * tmp; + + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + + tmp = ggml_rope_ext(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } + + return tmp; +} + +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * k_shift; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + kv_self->set_input_k_shift(k_shift); + } +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + //GGML_ASSERT(kv_self->size == n_ctx); + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + ggml_set_input(inp->k_shift); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + ggml_tensor * k = + ggml_view_3d(ctx, layer.k, + n_embd_head_k, n_head_kv, cells.size(), + ggml_row_size(layer.k->type, n_embd_head_k), + ggml_row_size(layer.k->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & ids = defrag_info.ids; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { + const uint32_t n_layer = layers.size(); + + const uint32_t n_kv = cells.used_max_p1(); + const uint32_t n_used = cells.get_used(); + + assert(n_used <= n_kv); + + //const int64_t t_start = ggml_time_us(); + + // number of cells moved + uint32_t n_moves = 0; + + // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + //const uint32_t max_moves = max_nodes()/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + + // determine which KV cells to move where + // + // cell i moves to ids[i] + // + // if ids[i] == i || ids[i] == n_kv, then cell i is not moved + // + auto & ids = defrag_info.ids; + + ids.clear(); + ids.resize(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + if (!cells.is_empty(i0)) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + if (cells.is_empty(is) || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; + + if (nf == nh) { + break; + } + } + + // this can only happen if `n_used` is not accurate, which would be a bug + GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); + + nf = 0; + + uint32_t i1 = is; + + // are we moving a continuous block of memory? + bool cont = false; + + // should we stop searching for the next move? + bool stop = false; + + // go back and move the nf cells to the hole + for (; i1 < n_kv; ++i1) { + if (cells.is_empty(i1) || ids[i1] != n_kv) { + if (n_moves == max_moves) { + stop = true; + break; + } + + cont = false; + continue; + } + + // this cell goes to (i0 + nf) + ids[i1] = i0 + nf; + + // move the cell meta data + cells.mv(i1, i0 + nf); + + head = n_used; + + if (!cont) { + n_moves++; + cont = true; + } + + nf++; + + if (nf == nh) { + break; + } + } + + if (stop || n_moves == max_moves) { + break; + } + + //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); + + i0 += nh - 1; + } + + if (n_moves == 0) { + return false; + } + + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); + + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); + + return true; +} + +bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + } + + return false; +} + +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } + } + } + + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, cells.size()); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + std::vector seq_ids; + + for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { + if (cur == seq_id || seq_id == -1) { + if (cells.seq_has(i, cur)) { + seq_ids.push_back(cur); + } + } + } + + const llama_pos pos = cells.pos_get(i); + const uint32_t n_seq_id = seq_ids.size(); + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + for (const auto & seq_id : seq_ids) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } +} + +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = this->v_trans ? 1 : 0; + const uint32_t n_layer = layers.size(); + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)layer.k->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(layer.k, range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(layer.v, range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = cells.size(); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(layer.v->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(layer.v, src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 1) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + // read the sequence id, but directly discard it - we will use dest_seq_id instead + { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + } + + batch.pos[i] = pos; + batch.n_seq_id[i] = n_seq_id; + batch.seq_id[i] = &dest_seq_id; + } + + const auto head_cur = find_slot(batch); + if (head_cur < 0) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + apply_ubatch(head_cur, batch); + + // keep the head at the old position because we will read the KV data into it in state_read_data() + head = head_cur; + + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head_cur + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]); + GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cells.pos_set(i, pos); + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); + return false; + } + + cells.seq_add(i, seq_id); + } + } + + head = 0; + } + + return true; +} + +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != layers.size()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); + return false; + } + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); + return false; + } + + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) layer.k->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(layer.v->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_unified_state +// + +llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv) : status(status), kv(kv) { + n_kv = kv->get_size(); + head = 0; + } + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches) + : status(status), + kv(kv), + sbatch(std::move(sbatch)), + heads(std::move(heads)), + ubatches(std::move(ubatches)) { + } + +llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; + +bool llama_kv_cache_unified_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->apply_ubatch(heads[i_next], ubatches[i_next]); + + n_kv = kv->get_n_kv(); + head = heads[i_next]; + + return true; +} + +std::vector & llama_kv_cache_unified_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_unified_state::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { + return kv->get_v(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + return kv->cpy_k(ctx, k_cur, il, head); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + return kv->cpy_v(ctx, v_cur, il, head); +} + +void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { + kv->set_input_k_shift(dst); +} + +void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_pos_bucket(dst, ubatch); +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} diff --git a/examples/talk-llama/llama-kv-cache-unified.h b/examples/talk-llama/llama-kv-cache-unified.h new file mode 100644 index 00000000..1f1d44b9 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-unified.h @@ -0,0 +1,278 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" +#include "llama-kv-cells.h" + +#include +#include + +struct llama_cparams; +struct llama_hparams; +struct llama_model; +struct llama_context; + +// +// llama_kv_cache_unified +// + +class llama_kv_cache_unified : public llama_kv_cache { +public: + static uint32_t get_padding(const llama_cparams & cparams); + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type); + + ~llama_kv_cache_unified() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified specific API + // + + uint32_t get_size() const; + + // + // graph_build API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + + // + // preparation API + // + + // find places for the provided ubatches in the cache, returns the head locations + // return empty vector on failure + std::vector prepare(const std::vector & ubatches); + + // return the cell position where we can insert the ubatch + // return -1 on failure to find a contiguous slot of kv cells + int32_t find_slot(const llama_ubatch & ubatch) const; + + // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) + void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + + // + // set_input API + // + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_k_shift (ggml_tensor * dst) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_model & model; + const llama_hparams & hparams; + + struct kv_layer { + // layer index in the model + // note: can be different from the layer index in the KV cache + uint32_t il; + + ggml_tensor * k; + ggml_tensor * v; + }; + + bool do_defrag = false; + bool v_trans = true; // the value tensor is transposed + + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t head = 0; + + const uint32_t n_seq_max = 1; + + // required padding + const uint32_t n_pad = 1; + + // SWA + const uint32_t n_swa = 0; + + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + + std::vector ctxs; + std::vector bufs; + + llama_kv_cells_unified cells; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + // defrag + struct { + std::vector ids; + } defrag_info; + + // return true if cells have been moved + bool defrag_prepare(int32_t n_max_nodes); + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + bool is_masked_swa(llama_pos p0, llama_pos p1) const; + + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + llm_graph_result_ptr build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_unified_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv); + + // used to create a state from a batch + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_state specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_memory_status status; + + llama_kv_cache_unified * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads; + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; + + // the beginning of the current slot in which the ubatch will be inserted + int32_t head; +}; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 4a42d6ec..aefd23e3 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -1,2739 +1 @@ #include "llama-kv-cache.h" - -#include "llama-impl.h" -#include "llama-batch.h" -#include "llama-cparams.h" -#include "llama-model.h" -#include "llama-context.h" - -#include -#include -#include -#include -#include -#include - -// -// llama_kv_cache_unified -// - -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} - -llama_kv_cache_unified::llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type) : - model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { - - GGML_ASSERT(kv_size % n_pad == 0); - - // create a context for each buffer type - std::map ctx_map; - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - return nullptr; - } - - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); - - return ctx; - } - - return it->second; - }; - - head = 0; - - cells.resize(kv_size); - - for (uint32_t il = 0; il < hparams.n_layer; il++) { - if (filter && !filter(il)) { - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); - continue; - } - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - - if (offload) { - auto * dev = model.dev_layer(il); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } - - LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); - - ggml_context * ctx = ctx_for_buft(buft); - if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); - } - - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); - - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); - - map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v }); - } - - // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); - } - - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - - ggml_backend_buffer_clear(buf, 0); - bufs.emplace_back(buf); - } - - { - const size_t memory_size_k = size_k_bytes(); - const size_t memory_size_v = size_v_bytes(); - - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } -} - -void llama_kv_cache_unified::clear() { - cells.reset(); - - head = 0; - - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); - } -} - -bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = cells.size(); - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { - if (new_head == cells.size()) { - new_head = i; - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cells.size() && new_head < head) { - head = new_head; - } - - return true; -} - -void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id_src)) { - cells.seq_add(i, seq_id_dst); - } - } -} - -void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = cells.size(); - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.seq_keep(i, seq_id)) { - if (new_head == cells.size()) { - new_head = i; - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cells.size() && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - if (shift == 0) { - return; - } - - uint32_t new_head = cells.size(); - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over all cells. - if (p0 == p1) { - return; - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id)) { - if (cells.pos_add(i, shift)) { - if (new_head == cells.size()) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - head = new_head != cells.size() ? new_head : 0; -} - -void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) { - return; - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id)) { - cells.pos_div(i, d); - } - } -} - -llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { - return cells.seq_pos_min(seq_id); -} - -llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { - return cells.seq_pos_max(seq_id); -} - -void llama_kv_cache_unified::restore() { - for (auto & state : recovery.states) { - cells.set(state.i, state.cells); - } - - recovery.clear(); -} - -void llama_kv_cache_unified::commit() { - if (recovery.states.empty()) { - LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); - return; - } - - recovery.clear(); -} - -bool llama_kv_cache_unified::update(llama_context & lctx) { - bool need_reserve = false; - - auto * sched = lctx.get_sched(); - - if (cells.get_has_shift()) { - if (!get_can_shift()) { - GGML_ABORT("The current KV cache / model configuration does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched); - - auto * gf = lctx.graph_init(); - - auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - - ggml_backend_sched_alloc_graph(sched, gf); - - res->set_inputs(nullptr); - - lctx.graph_compute(gf, false); - - need_reserve = true; - } - - cells.reset_shift(); - } - - if (do_defrag) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - if (defrag_prepare(lctx.graph_max_nodes())) { - ggml_backend_sched_reset(sched); - - auto * gf = lctx.graph_init(); - - auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - - ggml_backend_sched_alloc_graph(sched, gf); - - res->set_inputs(nullptr); - - lctx.graph_compute(gf, false); - - need_reserve = true; - } - - do_defrag = false; - } - - return need_reserve; -} - -void llama_kv_cache_unified::defrag_sched(float thold) { - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } -} - -void llama_kv_cache_unified::set_full() { - n = cells.size(); - - // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not - // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. - // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so - // setting it to 0 is the simplest way to achieve that - // ref: https://github.com/ggml-org/llama.cpp/issues/13359 - head = 0; -} - -llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > cells.get_used() + 2*ubatch.n_tokens) { - head = 0; - } - - // otherwise, one cell per token. - - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return false; - } - -//#define FIND_SLOT_DEBUG 1 -#if FIND_SLOT_DEBUG - LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); - - // for debugging - { - std::string ss; - if (n_swa > 0) { - for (uint32_t i = 0; i < size; ++i) { - if (cells.is_empty(i)) { - ss += '.'; - } else { - ss += 'x'; - } - if (i%256 == 255) { - ss += '\n'; - } - } - } - LLAMA_LOG_WARN("\n%s\n", ss.c_str()); - } -#endif - - uint32_t n_tested = 0; - - while (true) { - if (head + n_tokens > cells.size()) { - n_tested += cells.size() - head; - head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - // TODO: improve to accept cells that are masked by the SWA - if (!cells.is_empty(head + i)) { - found = false; - head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } - } - - // store the old state of the cells in the recovery stack - recovery.states.push_back({head, cells.cp(head, n_tokens)}); - - for (uint32_t i = 0; i < n_tokens; ++i) { - cells.pos_set(head + i, ubatch.pos[i]); - - for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { - cells.seq_add(head + i, ubatch.seq_id[i][j]); - } - } - - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); - -#ifdef FIND_SLOT_DEBUG - LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); -#endif - - return true; -} - -bool llama_kv_cache_unified::get_can_shift() const { - return true; -} - -uint32_t llama_kv_cache_unified::get_n() const { - return n; -} - -uint32_t llama_kv_cache_unified::get_size() const { - return cells.size(); -} - -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * k = layers[ikv].k; - - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n, - ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); -} - -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * v = layers[ikv].v; - - if (!v_trans) { - // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); - } - - // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); -} - -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * k = layers[ikv].k; - - const int64_t n_tokens = k_cur->ne[2]; - - ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head); - - return ggml_cpy(ctx, k_cur, k_view); -} - -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * v = layers[ikv].v; - - const int64_t n_tokens = v_cur->ne[2]; - - v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); - - ggml_tensor * v_view = nullptr; - - if (!v_trans) { - v_view = ggml_view_1d(ctx, v, - n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head); - } else { - // note: the V cache is transposed when not using flash attention - v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), - (v->ne[1])*ggml_element_size(v), - ( head)*ggml_element_size(v)); - - v_cur = ggml_transpose(ctx, v_cur); - } - - return ggml_cpy(ctx, v_cur, v_view); -} - -void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) { - // no pruning is needed when the cache does not use SWA - GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache"); - - int n_attended = 0; - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.seq_has(i, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(i); - - if (p0 <= pmin && !is_masked_swa(p0, pmin)) { - n_attended++; - } - - if (is_masked_swa(p0, pmax)) { - cells.seq_rm(i, seq_id); - } - } - - if (n_attended < std::min(n_swa, pmin)) { - LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa); - } -} - -void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; - - const int64_t n_kv = n; - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[s][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; - - for (int i = 0; i < n_kv; ++i) { - float f = 0.0f; - - bool masked = false; - - if (cells.is_empty(i)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(i); - - // mask the token if not the same sequence - masked = masked || (!cells.seq_has(i, seq_id)); - - // mask future tokens - masked = masked || (causal_attn && p0 > p1); - - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); - - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); - } - } - - if (masked) { - f = -INFINITY; - } - - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - } - } - - // mask padded tokens - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - } -} - -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - - int32_t * data = (int32_t *) dst->data; - - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); - } -} - -void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { - const int64_t n_tokens = ubatch->n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing - - int32_t * data = (int32_t *) dst->data; - - const int64_t n_kv = n; - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_kv; ++i) { - // the position when the cells is empty is irrelevant - it will be masked out later in the attention - const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); - - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); - } - } - } -} - -size_t llama_kv_cache_unified::total_size() const { - size_t size = 0; - - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -size_t llama_kv_cache_unified::size_k_bytes() const { - size_t size_k_bytes = 0; - - for (const auto & layer : layers) { - size_k_bytes += ggml_nbytes(layer.k); - } - - return size_k_bytes; -} - -size_t llama_kv_cache_unified::size_v_bytes() const { - size_t size_v_bytes = 0; - - for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); - } - - return size_v_bytes; -} - -ggml_tensor * llama_kv_cache_unified::build_rope_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const { - const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; - - const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type; - - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; - - ggml_tensor * tmp; - - if (ggml_is_quantized(cur->type)) { - // dequantize to f32 -> RoPE -> quantize back - tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - - tmp = ggml_rope_ext(ctx, tmp, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - - tmp = ggml_cpy(ctx, tmp, cur); - } else { - // we rotate only the first n_rot dimensions - tmp = ggml_rope_ext_inplace(ctx, cur, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - } - - return tmp; -} - -class llm_graph_input_k_shift : public llm_graph_input_i { -public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_k_shift() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * k_shift; // I32 [kv_size] - - const llama_kv_cache_unified * kv_self; -}; - -void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - if (k_shift) { - kv_self->set_input_k_shift(k_shift); - } -} - -llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - //GGML_ASSERT(kv_self->size == n_ctx); - - auto inp = std::make_unique(this); - - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); - ggml_set_input(inp->k_shift); - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - ggml_tensor * k = - ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), - ggml_row_size(layer.k->type, n_embd_head_k), - ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); - - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); - - ggml_build_forward_expand(gf, cur); - } - - res->add_input(std::move(inp)); - - return res; -} - -llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & ids = defrag_info.ids; - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; -} - -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { - const uint32_t n_layer = layers.size(); - - const uint32_t n_kv = cells.used_max_p1(); - const uint32_t n_used = cells.get_used(); - - assert(n_used <= n_kv); - - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); - - // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; - - ids.clear(); - ids.resize(n_kv, n_kv); - - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - if (!cells.is_empty(i0)) { - ids[i0] = i0; - - continue; - } - - // found a hole - fill it with data from the end of the cache - - uint32_t nh = 1; - - // determine the size of the hole - while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { - nh++; - } - - uint32_t nf = 0; - uint32_t is = n_kv - 1; - - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - if (cells.is_empty(is) || ids[is] != n_kv) { - continue; - } - - // non-empty cell which is not yet moved - nf++; - - if (nf == nh) { - break; - } - } - - // this can only happen if `n_used` is not accurate, which would be a bug - GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); - - nf = 0; - - uint32_t i1 = is; - - // are we moving a continuous block of memory? - bool cont = false; - - // should we stop searching for the next move? - bool stop = false; - - // go back and move the nf cells to the hole - for (; i1 < n_kv; ++i1) { - if (cells.is_empty(i1) || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - - cont = false; - continue; - } - - // this cell goes to (i0 + nf) - ids[i1] = i0 + nf; - - // move the cell meta data - cells.mv(i1, i0 + nf); - - head = n_used; - - if (!cont) { - n_moves++; - cont = true; - } - - nf++; - - if (nf == nh) { - break; - } - } - - if (stop || n_moves == max_moves) { - break; - } - - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); - - i0 += nh - 1; - } - - if (n_moves == 0) { - return false; - } - - LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - - LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - - return true; -} - -bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - } - - return false; -} - -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = cells.size(); - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { - ++cell_count; - if (cell_range_begin == cells.size()) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = cells.size(); - } - } - } - - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, cells.size()); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - io.write(&cell_count, sizeof(cell_count)); - - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} - -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); - - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); - - if (!res) { - if (seq_id == -1) { - clear(); - } else { - seq_rm(seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } -} - -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - std::vector seq_ids; - - for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { - if (cur == seq_id || seq_id == -1) { - if (cells.seq_has(i, cur)) { - seq_ids.push_back(cur); - } - } - } - - const llama_pos pos = cells.pos_get(i); - const uint32_t n_seq_id = seq_ids.size(); - - io.write(&pos, sizeof(pos)); - io.write(&n_seq_id, sizeof(n_seq_id)); - - for (const auto & seq_id : seq_ids) { - io.write(&seq_id, sizeof(seq_id)); - } - } - } -} - -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { - const uint32_t v_trans = this->v_trans ? 1 : 0; - const uint32_t n_layer = layers.size(); - - io.write(&v_trans, sizeof(v_trans)); - io.write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)layer.k->type; - io.write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); - io.write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - io.write_tensor(layer.k, range.first * k_size_row, buf_size); - } - } - - if (!v_trans) { - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); - io.write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - io.write_tensor(layer.v, range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = cells.size(); - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(layer.v->type); - io.write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - io.write_tensor(layer.v, src_offset, buf_size); - } - } - } - } -} - -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { - if (dest_seq_id != -1) { - // single sequence - - seq_rm(dest_seq_id, -1, -1); - - llama_sbatch sbatch; - llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - - batch.n_tokens = cell_count; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 1) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - // read the sequence id, but directly discard it - we will use dest_seq_id instead - { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - } - - batch.pos[i] = pos; - batch.n_seq_id[i] = n_seq_id; - batch.seq_id[i] = &dest_seq_id; - } - - if (!find_slot(batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - - commit(); - - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head) == batch.pos[0]); - GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > cells.size()) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - clear(); - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - cells.pos_set(i, pos); - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - - if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); - return false; - } - - cells.seq_add(i, seq_id); - } - } - - head = 0; - } - - return true; -} - -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { - uint32_t v_trans; - uint32_t n_layer; - - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != layers.size()) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); - return false; - } - if (cell_count > cells.size()) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); - return false; - } - if (this->v_trans != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) layer.k->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); - } - } - - if (!this->v_trans) { - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(layer.v->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } - } - } - } - - return true; -} - -// -// llama_kv_cache_unified_iswa -// - -llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool swa_full, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_batch, - uint32_t n_pad) : hparams(model.hparams) { - llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; - llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; - - const uint32_t size_base = kv_size; - - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad)); - - // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning - if (swa_full) { - LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - - size_swa = size_base; - do_prune = false; - } - - LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); - - kv_base = std::make_unique( - model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE); - - LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); - - kv_swa = std::make_unique( - model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type); -} - -void llama_kv_cache_unified_iswa::clear() { - kv_base->clear(); - kv_swa ->clear(); -} - -bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - bool res = true; - - res = res & kv_base->seq_rm(seq_id, p0, p1); - res = res & kv_swa ->seq_rm(seq_id, p0, p1); - - return res; -} - -void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); - kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { - kv_base->seq_keep(seq_id); - kv_swa ->seq_keep(seq_id); -} - -void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - kv_base->seq_add(seq_id, p0, p1, shift); - kv_swa ->seq_add(seq_id, p0, p1, shift); -} - -void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - kv_base->seq_div(seq_id, p0, p1, d); - kv_swa ->seq_div(seq_id, p0, p1, d); -} - -llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { - // the base cache is a superset of the SWA cache, so we can just check the SWA cache - return kv_swa->seq_pos_min(seq_id); -} - -llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { - return kv_swa->seq_pos_max(seq_id); -} - -void llama_kv_cache_unified_iswa::restore() { - kv_base->restore(); - kv_swa ->restore(); -} - -void llama_kv_cache_unified_iswa::commit() { - kv_base->commit(); - kv_swa ->commit(); - - // slide the attention window, forgetting/pruning old tokens that are outside the window - if (do_prune) { - for (const auto & [seq_id, entry] : pending.pos) { - kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax); - } - - } - - pending.clear(); -} - -bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = true; - - res = res & kv_base->update(lctx); - res = res & kv_swa ->update(lctx); - - return res; -} - -void llama_kv_cache_unified_iswa::defrag_sched(float thold) { - kv_base->defrag_sched(thold); - kv_swa ->defrag_sched(thold); -} - -void llama_kv_cache_unified_iswa::set_full() { - kv_base->set_full(); - kv_swa ->set_full(); -} - -llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) { - pending.clear(); - - if (do_prune) { - for (int i = 0; i < batch.n_tokens; ++i) { - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - const llama_seq_id seq_id = batch.seq_id[i][s]; - const llama_pos pos = batch.pos[i]; - - if (pending.pos.find(seq_id) == pending.pos.end()) { - pending.pos[seq_id].pmin = pos; - pending.pos[seq_id].pmax = pos; - } else { - pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos); - pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos); - } - } - } - } - - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} - -bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) { - bool res = true; - - res = res & kv_base->find_slot(batch); - res = res & kv_swa ->find_slot(batch); - - return res; -} - -bool llama_kv_cache_unified_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); -} - -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - kv_base->state_write(io, seq_id); - kv_swa ->state_write(io, seq_id); -} - -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - kv_base->state_read(io, seq_id); - kv_swa ->state_read(io, seq_id); -} - -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const { - return kv_base.get(); -} - -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const { - return kv_swa.get(); -} - -// -// llama_kv_cache_recurrent -// - -llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { - const int32_t n_layer = hparams.n_layer; - - LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", - __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); - - head = 0; - size = kv_size; - used = 0; - - cells.clear(); - cells.resize(kv_size); - - // create a context for each buffer type - std::map ctx_map; - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - return nullptr; - } - - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); - - return ctx; - } - - return it->second; - }; - - k_l.reserve(n_layer); - v_l.reserve(n_layer); - - for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - - if (offload) { - auto * dev = model.dev_layer(i); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } - - LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); - - ggml_context * ctx = ctx_for_buft(buft); - if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); - } - - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); - } - - // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); - } - ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); - } - - { - const size_t memory_size_k = size_k_bytes(); - const size_t memory_size_v = size_v_bytes(); - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } -} - -void llama_kv_cache_recurrent::clear() { - for (int32_t i = 0; i < (int32_t) size; ++i) { - cells[i].pos = -1; - cells[i].seq_id.clear(); - cells[i].src = -1; - cells[i].tail = -1; - } - head = 0; - used = 0; - - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); - } -} - -bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // models like Mamba or RWKV can't have a state partially erased - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - const kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - // invalidate tails which will be cleared - if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; - } - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos >= p0 && cells[i].pos < p1) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - cells[i].pos = -1; - cells[i].src = -1; - if (new_head == size) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } - - return true; -} - -void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - kv_cell & tail_src = cells[seq_id_src]; - kv_cell & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - kv_cell & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.src = -1; - used -= 1; - } - } - if (tail_src.tail >= 0) { - kv_cell & cell_src = cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; - } - } -} - -void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = size; - - for (uint32_t i = 0; i < size; ++i) { - if ((llama_seq_id) i != seq_id) { - cells[i].tail = -1; - } - - if (!cells[i].has_seq_id(seq_id)) { - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].src = -1; - cells[i].seq_id.clear(); - - if (new_head == size){ - new_head = i; - } - } else { - cells[i].seq_id.clear(); - cells[i].seq_id.insert(seq_id); - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - if (shift == 0) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the - if (p0 == p1) { - return; - } - - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += shift; - } - } - } -} - -void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) { - return; - } - - // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } - } - } -} - -llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { - llama_pos result = std::numeric_limits::max(); - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::min(result, cells[i].pos); - } - } - - if (result == std::numeric_limits::max()) { - result = -1; - } - - return result; -} - -llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = -1; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::max(result, cells[i].pos); - } - } - - return result; -} - -void llama_kv_cache_recurrent::restore() { - if (pending.ranges.empty()) { - return; - } - - seq_rm(-1, -1, -1); -} - -void llama_kv_cache_recurrent::commit() { - pending.ranges.clear(); -} - -bool llama_kv_cache_recurrent::update(llama_context & ctx) { - GGML_UNUSED(ctx); - return false; -} - -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} - -void llama_kv_cache_recurrent::set_full() { - n = size; - head = 0; -} - -llama_sbatch llama_kv_cache_recurrent::sbatch_init( - const llama_batch & batch, - bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, false, logits_all); -} - -llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - return sbatch.split_seq(n_ubatch); - } - - return sbatch.split_equal(n_ubatch); -} - -bool llama_kv_cache_recurrent::find_slot( - const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*n_tokens) { - head = 0; - } - - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. - - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); - - int32_t min = size - 1; - int32_t max = 0; - - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - - if (seq_id < 0 || (uint32_t) seq_id >= size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); - return false; - } - if (j > 0) { - kv_cell & seq = cells[seq_id]; - if (seq.tail >= 0) { - kv_cell & cell = cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - used -= 1; - } - } - } - } - } - -#ifndef NDEBUG - { - std::vector tails_verif; - tails_verif.assign(size, -1); - for (uint32_t i = 0; i < size; ++i) { - kv_cell & cell = cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } - for (uint32_t i = 0; i < size; ++i) { - if (tails_verif[i] != cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); - } - } - } -#endif - - // find next empty cell - uint32_t next_empty_cell = head; - - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - kv_cell & seq_meta = cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - kv_cell & cell = cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - kv_cell & empty_cell = cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - kv_cell & orig_cell = cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } - - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - kv_cell & dst_cell = cells[dst_id]; - kv_cell & src_cell = cells[src_id]; - - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); - - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; - } - } - } - - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - kv_cell & cell = cells[cell_id]; - - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cells[seq_id].tail = cell_id; - } - } - - // allow getting the range of used cells, from head to head + n - head = min; - n = max - min + 1; - used = std::count_if(cells.begin(), cells.end(), - [](const kv_cell & cell){ return !cell.is_empty(); }); - - // sanity check - return n >= n_seqs; -} - -bool llama_kv_cache_recurrent::get_can_shift() const { - return false; -} - -int32_t llama_kv_cache_recurrent::s_copy(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - // prevent out-of-bound sources - if (cell.src < 0 || (uint32_t) cell.src >= size) { - cell.src = cell_id; - } - - int32_t res = cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (cell.src != (int32_t) cell_id) { - cell.src = cell_id; - } - - return res; -} - -float llama_kv_cache_recurrent::s_mask(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - float res = (float) (cell.src >= 0); - - // only clear once - if (cell.src < 0) { - cell.src = cell_id; - } - - return res; -} - -uint32_t llama_kv_cache_recurrent::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - -size_t llama_kv_cache_recurrent::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -size_t llama_kv_cache_recurrent::size_k_bytes() const { - size_t size_k_bytes = 0; - - for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); - } - - return size_k_bytes; -} - -size_t llama_kv_cache_recurrent::size_v_bytes() const { - size_t size_v_bytes = 0; - - for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); - } - - return size_v_bytes; -} - -void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = size; - for (uint32_t i = 0; i < size; ++i) { - const auto & cell = cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = size; - } - } - } - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, size); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - io.write(&cell_count, sizeof(cell_count)); - - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} - -void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); - - bool res = true; - - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); - - if (!res) { - if (seq_id == -1) { - clear(); - } else { - seq_rm(seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } -} - -void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; - - io.write(&pos, sizeof(pos)); - io.write(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - io.write(&seq_id, sizeof(seq_id)); - } - } - } - } -} - -void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { - const uint32_t v_trans = 0; - const uint32_t n_layer = hparams.n_layer; - - io.write(&v_trans, sizeof(v_trans)); - io.write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)k_l[il]->type; - io.write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - io.write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - io.write_tensor(k_l[il], range.first * k_size_row, buf_size); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - io.write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - io.write_tensor(v_l[il], range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = size; - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(v_l[il]->type); - io.write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - io.write_tensor(v_l[il], src_offset, buf_size); - } - } - } - } -} - -bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { - if (dest_seq_id != -1) { - // single sequence - - seq_rm(dest_seq_id, -1, -1); - - llama_sbatch sbatch; - llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!find_slot(batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - commit(); - - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= size); - GGML_ASSERT(cells[head].pos == batch.pos[0]); - GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); - GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - clear(); - - for (uint32_t i = 0; i < cell_count; ++i) { - kv_cell & cell = cells[i]; - - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - cell.pos = pos; - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - - // TODO: llama_kv_cache_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); - return false; - } - - cell.seq_id.insert(seq_id); - - int32_t & tail = cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } - } - - head = 0; - used = cell_count; - } - - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = head + i; - // make sure the recurrent states will keep their restored state - cells[cell_id].src = cell_id; - } - - return true; -} - -bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { - uint32_t v_trans; - uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); - return false; - } - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); - return false; - } - if (false != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) k_l[il]->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(v_l[il]->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * size) * v_size_el; - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } - } - } - } - - return true; -} diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index ce6261e4..2d04705f 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -2,60 +2,34 @@ #include "llama.h" #include "llama-io.h" -#include "llama-graph.h" #include "llama-memory.h" -#include "llama-kv-cells.h" - -#include "ggml-cpp.h" - -#include -#include -#include - -struct llama_cparams; -struct llama_hparams; -struct llama_ubatch; -struct llama_sbatch; -struct llama_model; -struct llama_context; struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; - // call if batch processing fails - restores the cache state - virtual void restore() = 0; + // split the input batch into a set of ubatches and verify that they can fit into the cache + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) = 0; - // call after successful batch processing - clears any pending state - virtual void commit() = 0; + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch + // return true if any operations were performed virtual bool update(llama_context & lctx) = 0; // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing + // TODO: change to + // llama_memory_state_ptr init_defrag(float thold) = 0; + // virtual void defrag_sched(float thold) = 0; - // simulate full cache, used for allocating worst-case compute buffers - // TODO: remove - virtual void set_full() = 0; - - // - // batch processing - // - - // ============================================================================================================= - // TODO: refactor and simplify this [TAG: KV_API] - - virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; - - // different KV caches require different batch splitting strategies - virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; - - // find an empty slot of size "n_tokens" in the cache - virtual bool find_slot(const llama_ubatch & batch) = 0; - - // ============================================================================================================= - // getters virtual bool get_can_shift() const = 0; @@ -68,435 +42,3 @@ struct llama_kv_cache : public llama_memory_i { virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; - -// -// llama_kv_cache_guard -// - -struct llama_kv_cache_guard { - llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} - - ~llama_kv_cache_guard() { - kv->restore(); - } - - void commit() { - kv->commit(); - } - -private: - llama_kv_cache * kv; -}; - -// -// llama_kv_cache_unified -// - -class llama_kv_cache_unified : public llama_kv_cache { -public: - static uint32_t get_padding(const llama_cparams & cparams); - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - - llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type); - - ~llama_kv_cache_unified() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; - - void defrag_sched(float thold) override; - - void set_full() override; - - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch) override; - - bool get_can_shift() const override; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - // - // llama_kv_cache_unified specific API - // - - uint32_t get_n() const; - uint32_t get_size() const; - - // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; - - // store k_cur and v_cur in the cache based on the current head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; - - void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax); - - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_k_shift (ggml_tensor * dst) const; - void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; - -private: - const llama_model & model; - const llama_hparams & hparams; - - struct kv_layer { - // layer index in the model - // note: can be different from the layer index in the KV cache - uint32_t il; - - ggml_tensor * k; - ggml_tensor * v; - }; - - bool do_defrag = false; - bool v_trans = true; // the value tensor is transposed - - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - - // computed before each graph build - // TODO: cells should start to maintain this value dynamically based on the edits - uint32_t n = 0; - - const uint32_t n_seq_max = 1; - - // required padding - const uint32_t n_pad = 1; - - // SWA - const uint32_t n_swa = 0; - - const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - - std::vector ctxs; - std::vector bufs; - - llama_kv_cells_unified cells; - - std::vector layers; - - // model layer id -> KV cache layer id - std::unordered_map map_layer_ids; - - // recovery information used to restore the KV cells to their original state in case of a failure - // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation - // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API] - struct { - void clear() { - states.clear(); - } - - struct state { - uint32_t i; - - llama_kv_cells_unified cells; - }; - - // stack with the partial states before each ubatch - std::vector states; - } recovery; - - // defrag - struct { - std::vector ids; - } defrag_info; - - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); - - size_t total_size() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - - ggml_tensor * build_rope_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; - - llm_graph_result_ptr build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; - - llm_graph_result_ptr build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; - - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); -}; - -// -// llama_kv_cache_unified_iswa -// - -// utilizes two instances of llama_kv_cache_unified -// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -// upon successful commit, the SWA cache removes old tokens outside the n_swa window - -class llama_kv_cache_unified_iswa : public llama_kv_cache { -public: - llama_kv_cache_unified_iswa( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool swa_full, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_batch, - uint32_t n_pad); - - ~llama_kv_cache_unified_iswa() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; - - void defrag_sched(float thold) override; - - void set_full() override; - - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - bool find_slot(const llama_ubatch & batch) override; - - bool get_can_shift() const override; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - // - // llama_kv_cache_unified_iswa specific API - // - - llama_kv_cache_unified * get_kv_base() const; - llama_kv_cache_unified * get_kv_swa () const; - -private: - const llama_hparams & hparams; - - bool do_prune = true; - - struct { - struct entry { - llama_pos pmin; - llama_pos pmax; - }; - - void clear() { - pos.clear(); - } - - // used to perform SWA pruning of old tokens - std::unordered_map pos; - } pending; - - std::unique_ptr kv_base; - std::unique_ptr kv_swa; -}; - -// -// llama_kv_cache_recurrent -// - -class llama_kv_cache_recurrent : public llama_kv_cache { -public: - struct kv_cell { - llama_pos pos = -1; - int32_t src = -1; // used to copy states - int32_t tail = -1; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - - llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); - - ~llama_kv_cache_recurrent() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - void restore() override; - void commit() override; - - bool update(llama_context & ctx) override; - - void defrag_sched(float thold) override; - - void set_full() override; - - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - - bool find_slot(const llama_ubatch & batch) override; - - bool get_can_shift() const override; - - // TODO: temporary methods - they are not really const as they do const_cast<>, fix this - int32_t s_copy(int i) const; - float s_mask(int i) const; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - uint32_t size = 0; // total number of cells, shared across all sequences - uint32_t used = 0; // used cells (i.e. at least one seq_id) - - // computed before each graph build - uint32_t n = 0; - - std::vector cells; - - std::vector k_l; // per layer - std::vector v_l; - -private: - //const llama_model & model; - const llama_hparams & hparams; - - // commit/restore cache - // TODO: rework for recurrent cache - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; - - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; - - const uint32_t n_seq_max = 1; - - std::vector ctxs; - std::vector bufs; - - // find how many cells are currently in use - uint32_t cell_max() const; - - size_t total_size() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); -}; diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h index dbbd03fc..9e2c4d92 100644 --- a/examples/talk-llama/llama-kv-cells.h +++ b/examples/talk-llama/llama-kv-cells.h @@ -68,12 +68,6 @@ public: // the index of the last cell that is used + 1 // return 0 if no cells are used uint32_t used_max_p1() const { -#if 0 - if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin()); - if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin()); - if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin()); -#endif - return used.empty() ? 0 : *used.rbegin() + 1; } @@ -144,6 +138,19 @@ public: } } + // clear a non-empty cell + void rm(uint32_t i) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] = -1; + seq[i].reset(); + + used.erase(i); + } + // note: call only if the cell has seq_id // return true if the cell becomes empty bool seq_rm(uint32_t i, llama_seq_id seq_id) { @@ -196,6 +203,15 @@ public: return false; } + // number of different sequences in the cell + int seq_count(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return seq[i].count(); + } + + // check if the cell contains seq_id bool seq_has(uint32_t i, llama_seq_id seq_id) const { assert(i < pos.size()); assert(seq_id >= 0); @@ -213,6 +229,20 @@ public: seq_pos[seq_id].insert(pos[i]); } + // return the sequence id of this cell + // note: call only for cells with exactly one sequence + llama_seq_id seq_get(uint32_t i) const { + assert(seq[i].count() == 1); + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + return s; + } + } + + return -1; + } + // the minimum position of sequence seq_id currently present in any of the cells // return -1 if the sequence is not present llama_pos seq_pos_min(llama_seq_id seq_id) const { @@ -268,6 +298,7 @@ public: void pos_set(uint32_t i, llama_pos p) { assert(i < pos.size()); assert(pos[i] == -1); + assert(seq[i].none()); pos[i] = p; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index a2d25043..b3799d66 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -2,6 +2,11 @@ #include "llama.h" +#include +#include + +struct llama_ubatch; + struct llama_memory_params { // kv cache ggml_type type_k; @@ -30,3 +35,42 @@ public: virtual bool get_can_edit() const = 0; }; + +enum llama_memory_status { + LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_FAILED_PREPARE, + LLAMA_MEMORY_STATUS_FAILED_COMPUTE, +}; + +// the interface for managing the memory state during batch processing +// this interface is implemented per memory type. see: +// - llama_kv_cache_unified_state +// - llama_kv_cache_unified_iswa_state +// ... +// +// the only method that can mutate the memory and the memory state is llama_memory_i::apply() +// +// TODO: rename to llama_memory_context_i ? +class llama_memory_state_i { +public: + virtual ~llama_memory_state_i() = default; + + // consume the current ubatch from the state and proceed to the next one + // return false if we are done + virtual bool next() = 0; + + // apply the memory state for the current ubatch to the memory object + // return false on failure + virtual bool apply() = 0; + + // TODO: this might get reworked in the future when refactoring llama_batch + virtual std::vector & out_ids() = 0; + + // get the current ubatch + virtual const llama_ubatch & get_ubatch() const = 0; + + // get the status of the memory state + virtual llama_memory_status get_status() const = 0; +}; + +using llama_memory_state_ptr = std::unique_ptr; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index e99f5309..50264a69 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -5,7 +5,10 @@ #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include "ggml-cpp.h" @@ -683,6 +686,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false); switch (hparams.n_layer) { case 3: @@ -2113,7 +2117,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); if (arch == LLM_ARCH_BERT) { pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); @@ -2121,8 +2125,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -2131,7 +2135,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - if (arch == LLM_ARCH_BERT) { + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (!layer.wqkv) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); @@ -2140,12 +2147,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } else { - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - } - - if (arch == LLM_ARCH_NOMIC_BERT_MOE) { - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -5887,8 +5888,10 @@ struct llm_build_bert : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); // token types are hardcoded to zero ("Sentence A") - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + if (model.type_embd) { + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); + } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); } @@ -5909,36 +5912,11 @@ struct llm_build_bert : public llm_graph_context { ggml_tensor * Vcur; // self-attention - if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, - model.layers[il].attn_q_norm_b, - LLM_NORM, il); - } - - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, - model.layers[il].attn_k_norm, - model.layers[il].attn_k_norm_b, - LLM_NORM, il); - } - - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - // compute Q and K and RoPE them + if (model.layers[il].wqkv) { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + if (model.layers[il].bqkv) { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } @@ -5946,11 +5924,32 @@ struct llm_build_bert : public llm_graph_context { Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -8896,9 +8895,9 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -8916,8 +8915,8 @@ struct llm_build_mamba : public llm_graph_context { GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - ggml_tensor * conv_states_all = kv_self->k_l[il]; - ggml_tensor * ssm_states_all = kv_self->v_l[il]; + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states ggml_tensor * conv = build_copy_mask_state( @@ -11644,7 +11643,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11654,7 +11653,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -11766,7 +11765,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11785,9 +11784,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -12040,7 +12039,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12049,7 +12048,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { const auto head_count = n_embd / head_size; const auto n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -12120,7 +12119,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + gf, kv_state->get_v_l(il), state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12134,9 +12133,9 @@ struct llm_build_rwkv7_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -13234,7 +13233,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.swa_full, cparams.n_ctx, cparams.n_seq_max, - cparams.n_batch, + cparams.n_ubatch, padding); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -13266,7 +13265,6 @@ llm_graph_result_ptr llama_model::build_graph( switch (arch) { case LLM_ARCH_LLAMA: - case LLM_ARCH_MINICPM: { llm = std::make_unique(*this, params, gf); } break; @@ -13507,6 +13505,7 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MINICPM: { llm = std::make_unique(*this, params, gf); } break; @@ -13597,6 +13596,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) { return model->hparams.n_head_kv(); } +int32_t llama_model_n_swa(const llama_model * model) { + return model->hparams.n_swa; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 01762bea..da0f652c 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -259,9 +259,9 @@ extern "C" { llama_token * token; float * embd; llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence + llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id; + int8_t * logits; // TODO: rename this to "output" } llama_batch; enum llama_model_kv_override_type { @@ -366,6 +366,8 @@ extern "C" { bool no_perf; // measure performance timings bool op_offload; // offload host tensor operations to device bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases + // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 }; // model quantization parameters @@ -502,6 +504,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); @@ -652,7 +655,6 @@ extern "C" { // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_self_seq_add( @@ -665,7 +667,6 @@ extern "C" { // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_self_seq_div( @@ -677,12 +678,14 @@ extern "C" { // Returns the smallest position present in the KV cache for the specified sequence // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, llama_seq_id seq_id); // Returns the largest position present in the KV cache for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, @@ -691,14 +694,15 @@ extern "C" { // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() - LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx), + "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); // Check if the context supports KV cache shifting LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - LLAMA_API void llama_kv_self_update(struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx), + "simply remove this call, updates are applied lazily on the next llama_decode()"); // // State / sessions