talk-llama : sync llama.cpp

This commit is contained in:
Georgi Gerganov
2025-07-01 12:21:09 +03:00
parent c4ea72be9a
commit 1f816de7da
24 changed files with 1456 additions and 430 deletions

View File

@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@ -75,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -932,6 +934,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3N,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
@ -1621,6 +1659,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_ERNIE4_5,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -1749,6 +1804,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},

View File

@ -46,6 +46,7 @@ enum llm_arch {
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@ -79,6 +80,7 @@ enum llm_arch {
LLM_ARCH_BAILINGMOE,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_UNKNOWN,
};
@ -269,6 +271,22 @@ enum llm_tensor {
LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
LLM_TENSOR_ALTUP_PROJ, // gemma3n
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
LLM_TENSOR_LAUREL_L, // gemma3n
LLM_TENSOR_LAUREL_R, // gemma3n
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X,

View File

@ -244,22 +244,35 @@ bool llama_batch_allocr::init(
continue;
}
if (memory) {
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (p0 >= 0) {
bool ok = true;
if (batch.token) {
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
return false;
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);
// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
return false;
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
}
}
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {

View File

@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
}
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << "User: " << message->content << "\n\nAssistant:";
} else {
ss << message->content << "\n\n";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "system") {
ss << "System: " << trim(chat[i]->content) << "\n\n";
} else if (role == "user") {
ss << "User: " << trim(chat[i]->content) << "\n\n";
if (i == chat.size() - 1) {
ss << "Assistant:";
}
} else if (role == "assistant") {
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {

View File

@ -280,8 +280,8 @@ llama_context::llama_context(
// simulate full KV cache
const auto mstate = memory->init_full();
if (!mstate) {
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize KV cache");
}
@ -289,7 +289,7 @@ llama_context::llama_context(
// reserve pp graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
@ -300,7 +300,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes
{
auto * gf = graph_reserve(1, 1, 1, mstate.get());
auto * gf = graph_reserve(1, 1, 1, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
@ -311,7 +311,7 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
optimize |= memory_force_optimize;
memory_force_optimize = false;
const auto mstate = memory->init_update(this, optimize);
switch (mstate->get_status()) {
const auto mctx = memory->init_update(this, optimize);
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
// noop
@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
}
}
if (!mstate->apply()) {
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
}
}
// if the memory module did any computation, we have to reserve a new worst-case graph
{
const auto mstate = memory->init_full();
if (!mstate) {
throw std::runtime_error("failed to initialize memory state");
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory context");
}
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
if (mstate && !mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
return nullptr;
}
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED;
@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
// handle any pending defrags/shifts
kv_self_update(false);
llama_memory_state_ptr mstate;
llama_memory_context_ptr mctx;
while (true) {
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mstate) {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mctx) {
return -2;
}
switch (mstate->get_status()) {
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
return -2;
}
@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
int64_t n_outputs_prev = 0;
do {
const auto & ubatch = mstate->get_ubatch();
const auto & ubatch = mctx->get_ubatch();
// count the outputs in this ubatch
{
@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
// TODO: fix sequence indexing
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0];
@ -1126,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
n_outputs_prev += n_outputs;
} while (mstate->next());
} while (mctx->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
@ -1292,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
}
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
if (n_tokens % n_seqs != 0) {
@ -1312,7 +1311,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
this->n_outputs = save_n_outputs;
@ -1333,11 +1332,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
}
llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_state_i * mstate) {
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx) {
return model.build_graph(
{
/*.ctx =*/ ctx,
@ -1349,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mstate =*/ mstate,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
@ -2042,8 +2041,8 @@ void llama_context::opt_epoch_iter(
uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
}
@ -2056,17 +2055,17 @@ void llama_context::opt_epoch_iter(
uint32_t pos_batch = 0;
do {
const auto & ubatch = mstate->get_ubatch();
const auto & ubatch = mctx->get_ubatch();
n_outputs = ubatch.n_tokens;
if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
break;
}
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
struct ggml_context * ctx_compute_opt;
{
@ -2101,7 +2100,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens;
} while (mstate->next());
} while (mctx->next());
}
}

View File

@ -18,7 +18,7 @@ class llama_io_read_i;
class llama_io_write_i;
struct llama_memory_i;
struct llama_memory_state_i;
struct llama_memory_context_i;
struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
@ -93,14 +93,14 @@ struct llama_context {
int32_t il_end);
// process a single ubatch with a specific graph type
// if memory_state is provided, it will be applied first to the context's memory
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process_ubatch(
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_state_i * mstate,
ggml_status & ret);
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
ggml_status & ret);
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);
@ -197,15 +197,15 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_state_i * mstate);
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx);
llm_graph_cb graph_get_cb() const;

View File

@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) {
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
mctx->set_input_pos_bucket(pos_bucket, ubatch);
}
}
@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
const int64_t n_rs = mem_state->get_n_rs();
const int64_t n_rs = mctx->get_n_rs();
if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mem_state->s_copy(i);
data[i] = mctx->s_copy(i);
}
}
}
@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_kq_mask_swa) {
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
}
@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@ -345,11 +345,17 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mem_state->get_state_recr()->s_copy(i);
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
void llm_graph_input_one::set_input(const llama_ubatch *) {
GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
}
//
// llm_graph_context
//
@ -389,7 +395,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
backend_cpu (params.backend_cpu),
cvec (params.cvec),
loras (params.loras),
mstate (params.mstate),
mctx (params.mctx),
cross (params.cross),
cb_func (params.cb),
res (std::make_unique<llm_graph_result>()) {
@ -554,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn(
switch (type_op) {
case LLM_FFN_SILU:
{
if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_swiglu_split(ctx0, cur, tmp);
cb(cur, "ffn_swiglu", il);
type_gate = LLM_FFN_SEQ;
} else {
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_silu", il);
} break;
case LLM_FFN_GELU:
{
if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_geglu_split(ctx0, cur, tmp);
cb(cur, "ffn_geglu", il);
type_gate = LLM_FFN_SEQ;
} else {
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il);
if (act_scales != NULL) {
@ -568,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn(
}
} break;
case LLM_FFN_RELU:
{
if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_reglu_split(ctx0, cur, tmp);
cb(cur, "ffn_reglu", il);
type_gate = LLM_FFN_SEQ;
} else {
cur = ggml_relu(ctx0, cur);
cb(cur, "ffn_relu", il);
} break;
@ -582,32 +600,19 @@ ggml_tensor * llm_graph_context::build_ffn(
} break;
case LLM_FFN_SWIGLU:
{
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
int64_t split_point = cur->ne[0] / 2;
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
x0 = ggml_silu(ctx0, x0);
cb(cur, "ffn_silu", il);
cur = ggml_mul(ctx0, x0, x1);
cb(cur, "ffn_mul", il);
cur = ggml_swiglu(ctx0, cur);
cb(cur, "ffn_swiglu", il);
} break;
case LLM_FFN_GEGLU:
{
// Split into two equal parts
int64_t split_point = cur->ne[0] / 2;
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
x0 = ggml_gelu(ctx0, x0);
cb(x0, "ffn_gelu", il);
cur = ggml_mul(ctx0, x0, x1);
cur = ggml_geglu(ctx0, cur);
cb(cur, "ffn_geglu", il);
} break;
case LLM_FFN_REGLU:
{
cur = ggml_reglu(ctx0, cur);
cb(cur, "ffn_reglu", il);
} break;
}
if (gate && type_gate == LLM_FFN_PAR) {
@ -737,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
switch (type_op) {
case LLM_FFN_SILU:
{
if (gate_exps) {
cur = ggml_swiglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_swiglu", il);
} else {
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
{
if (gate_exps) {
cur = ggml_geglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_geglu", il);
} else {
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_moe_gelu", il);
} break;
@ -750,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
GGML_ABORT("fatal error");
}
if (gate_exps) {
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate_par", il);
}
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
@ -950,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
}
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
const auto n_kv = kv_state->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
auto & cur = inp->pos_bucket;
@ -982,14 +988,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
}
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
@ -999,7 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
}
{
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy);
@ -1183,14 +1189,14 @@ ggml_tensor * llm_graph_context::build_attn(
}
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = kv_state->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1220,19 +1226,19 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
@ -1267,26 +1273,35 @@ ggml_tensor * llm_graph_context::build_attn(
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
if (k_cur) {
ggml_build_forward_expand(gf, k_cur);
}
if (v_cur) {
ggml_build_forward_expand(gf, v_cur);
}
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
const bool is_swa = hparams.is_swa(il);
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
// optionally store to KV cache
if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
}
if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
@ -1379,19 +1394,19 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
@ -1412,12 +1427,12 @@ ggml_tensor * llm_graph_context::build_attn(
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
{
const auto n_kv = kv_state->get_base()->get_n_kv();
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1429,7 +1444,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv();
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@ -1485,11 +1500,11 @@ ggml_tensor * llm_graph_context::build_rs(
}
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
const auto n_rs = kv_state->get_n_rs();
const auto n_rs = mctx_cur->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy);
@ -1504,9 +1519,9 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rs(
@ -1516,9 +1531,9 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@ -1526,13 +1541,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count;
const int64_t n_seqs = ubatch.n_seqs;
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all,
@ -1547,19 +1562,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd;
const int64_t n_seqs = ubatch.n_seqs;
const auto kv_head = kv_state->get_head();
const auto kv_head = mctx_cur->get_head();
return ggml_cpy(
ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
);
}

View File

@ -17,12 +17,12 @@ struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
struct llama_memory_state_i;
struct llama_memory_context_i;
class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state;
class llama_memory_recurrent_state;
class llama_memory_hybrid_state;
class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@ -38,6 +38,7 @@ enum llm_ffn_op_type {
LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU,
LLM_FFN_GEGLU,
LLM_FFN_REGLU,
};
enum llm_ffn_gate_type {
@ -136,7 +137,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
@ -144,7 +145,8 @@ public:
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams;
const llama_kv_cache_unified_state * kv_state;
const llama_kv_cache_unified_context * mctx;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
@ -191,14 +193,14 @@ public:
class llm_graph_input_rs : public llm_graph_input_i {
public:
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
virtual ~llm_graph_input_rs() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_memory_recurrent_state * mem_state;
const llama_memory_recurrent_context * mctx;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
@ -238,10 +240,10 @@ public:
llm_graph_input_attn_kv_unified(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_state * kv_state) :
const llama_kv_cache_unified_context * mctx) :
hparams(hparams),
cparams(cparams),
kv_state(kv_state) {
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified() = default;
@ -255,7 +257,7 @@ public:
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified_state * kv_state;
const llama_kv_cache_unified_context * mctx;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@ -263,10 +265,10 @@ public:
llm_graph_input_attn_kv_unified_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_state * kv_state) :
const llama_kv_cache_unified_iswa_context * mctx) :
hparams(hparams),
cparams(cparams),
kv_state(kv_state) {
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified_iswa() = default;
@ -283,7 +285,7 @@ public:
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified_iswa_state * kv_state;
const llama_kv_cache_unified_iswa_context * mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
@ -306,10 +308,10 @@ public:
llm_graph_input_mem_hybrid(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_memory_hybrid_state * mem_state) :
const llama_memory_hybrid_context * mctx) :
hparams(hparams),
cparams(cparams),
mem_state(mem_state) {
mctx(mctx) {
}
virtual ~llm_graph_input_mem_hybrid() = default;
@ -325,7 +327,18 @@ public:
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_memory_hybrid_state * mem_state;
const llama_memory_hybrid_context * mctx;
};
// TODO: remove this when ggml_scale_add is implemented
class llm_graph_input_one : public llm_graph_input_i {
public:
llm_graph_input_one() {}
virtual ~llm_graph_input_one() = default;
void set_input(const llama_ubatch *) override;
ggml_tensor * one = nullptr; // F32
};
//
@ -401,10 +414,10 @@ struct llm_graph_params {
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_state_i * mstate;
const llama_cross * cross;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
uint32_t n_outputs;
@ -453,16 +466,17 @@ struct llm_graph_context {
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_state_i * mstate;
const llama_cross * cross;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
const llm_graph_cb & cb_func;
std::unique_ptr<llm_graph_result> res;
llm_graph_context(const llm_graph_params & params);
virtual ~llm_graph_context() = default;
void cb(ggml_tensor * cur, const char * name, int il) const;
@ -588,14 +602,15 @@ struct llm_graph_context {
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,

View File

@ -143,6 +143,12 @@ struct llama_hparams {
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// gemma3n altup
uint32_t n_altup = 4; // altup_num_inputs
uint32_t i_altup_act = 0; // altup_active_idx
uint32_t laurel_rank = 64;
uint32_t n_embd_altup = 256;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

View File

@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);
// first try simple split
@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false);
@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false);
// TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
}
bool llama_kv_cache_unified_iswa::get_can_shift() const {
@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
}
//
// llama_kv_cache_unified_iswa_state
// llama_kv_cache_unified_iswa_context
//
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv) :
state_base(kv->get_base()->init_full()),
state_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
ctx_base(kv->get_base()->init_full()),
ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize) :
state_base(kv->get_base()->init_update(lctx, optimize)),
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
ctx_base(kv->get_base()->init_update(lctx, optimize)),
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
bool llama_kv_cache_unified_iswa_state::next() {
bool llama_kv_cache_unified_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_base->next();
state_swa ->next();
ctx_base->next();
ctx_swa ->next();
if (++i_next >= ubatches.size()) {
return false;
@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
return true;
}
bool llama_kv_cache_unified_iswa_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool llama_kv_cache_unified_iswa_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
res = res & state_base->apply();
res = res & state_swa ->apply();
res = res & ctx_base->apply();
res = res & ctx_swa ->apply();
return res;
}
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
}
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
}

View File

@ -31,14 +31,14 @@ public:
// llama_memory_i
//
llama_memory_state_ptr init_batch(
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
@ -72,32 +72,32 @@ private:
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public:
// used for errors
llama_kv_cache_unified_iswa_state(llama_memory_status status);
llama_kv_cache_unified_iswa_context(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_iswa_state(
// used to create a full-cache context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv);
// used to create an update state
llama_kv_cache_unified_iswa_state(
// used to create an update context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize);
// used to create a state from a batch
llama_kv_cache_unified_iswa_state(
// used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_state();
virtual ~llama_kv_cache_unified_iswa_context();
//
// llama_memory_state_i
// llama_memory_context_i
//
bool next() override;
@ -107,11 +107,11 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_iswa_state specific API
// llama_kv_cache_unified_iswa_context specific API
//
const llama_kv_cache_unified_state * get_base() const;
const llama_kv_cache_unified_state * get_swa() const;
const llama_kv_cache_unified_context * get_base() const;
const llama_kv_cache_unified_context * get_swa() const;
private:
//llama_kv_cache_unified_iswa * kv;
@ -121,8 +121,8 @@ private:
std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_base;
const llama_memory_state_ptr state_swa;
const llama_memory_context_ptr ctx_base;
const llama_memory_context_ptr ctx_swa;
const llama_memory_status status;
};

View File

@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
GGML_ASSERT(kv_size % n_pad == 0);
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
auto n_layer_cache = hparams.n_layer;
if (model.arch == LLM_ARCH_GEMMA3N) {
n_layer_cache = 20;
}
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
cells.resize(kv_size);
for (uint32_t il = 0; il < hparams.n_layer; il++) {
for (uint32_t il = 0; il < n_layer_cache; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
continue;
@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
layers.push_back({ il, k, v });
}
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
if (model.arch == LLM_ARCH_GEMMA3N) {
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
continue;
}
const bool is_swa = hparams.is_swa(il);
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
map_layer_ids[il] = map_layer_ids[il_reuse];
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
}
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
@ -307,7 +333,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
return cells.seq_pos_max(seq_id);
}
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) {
@ -332,18 +358,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
break;
}
return std::make_unique<llama_kv_cache_unified_state>(
return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(heads), std::move(ubatches));
} while (false);
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_state>(this);
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_context>(this);
}
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
bool do_shift = get_has_shift();
defrag_info dinfo;
@ -373,7 +399,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
}
}
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
}
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@ -1710,18 +1736,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
}
//
// llama_kv_cache_unified_state
// llama_kv_cache_unified_context
//
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size();
head = 0;
}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
@ -1731,15 +1757,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
}
}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_unified::ubatch_heads heads,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
}
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_state::next() {
bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) {
@ -1749,8 +1775,8 @@ bool llama_kv_cache_unified_state::next() {
return true;
}
bool llama_kv_cache_unified_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool llama_kv_cache_unified_context::apply() {
assert(!llama_memory_status_is_fail(status));
// no ubatches -> this is a KV cache update
if (ubatches.empty()) {
@ -1767,45 +1793,45 @@ bool llama_kv_cache_unified_state::apply() {
return true;
}
llama_memory_status llama_kv_cache_unified_state::get_status() const {
llama_memory_status llama_kv_cache_unified_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
uint32_t llama_kv_cache_unified_state::get_n_kv() const {
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
return n_kv;
}
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv);
}
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
return kv->get_v(ctx, il, n_kv);
}
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
return kv->cpy_k(ctx, k_cur, il, head);
}
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
return kv->cpy_v(ctx, v_cur, il, head);
}
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst);
}
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn);
}
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}

View File

@ -56,14 +56,14 @@ public:
// llama_memory_i
//
llama_memory_state_ptr init_batch(
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
@ -208,36 +208,36 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
class llama_kv_cache_unified_state : public llama_memory_state_i {
class llama_kv_cache_unified_context : public llama_memory_context_i {
public:
// some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors
llama_kv_cache_unified_state(llama_memory_status status);
llama_kv_cache_unified_context(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_state(
// used to create a full-cache context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv);
// used to create an update state
llama_kv_cache_unified_state(
// used to create an update context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo);
// used to create a decode state from a batch
llama_kv_cache_unified_state(
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
ubatch_heads heads,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state();
virtual ~llama_kv_cache_unified_context();
//
// llama_memory_state_i
// llama_memory_context_i
//
bool next() override;
@ -247,7 +247,7 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_state specific API
// llama_kv_cache_unified_context specific API
//
uint32_t get_n_kv() const;
@ -272,7 +272,7 @@ private:
llama_context * lctx;
//
// update state
// update context
//
bool do_shift = false;
@ -280,7 +280,7 @@ private:
defrag_info dinfo;
//
// batch processing state
// batch processing context
//
// the index of the next ubatch to process

View File

@ -7,6 +7,7 @@
#include <cassert>
#include <vector>
#include <set>
#include <map>
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
@ -164,7 +165,7 @@ public:
assert(seq_id >= 0);
seq[i].reset(seq_id);
seq_pos[seq_id].erase(pos[i]);
seq_pos_dec(seq_id, pos[i]);
if (seq[i].none()) {
pos[i] = -1;
@ -187,7 +188,7 @@ public:
seq[i].reset();
seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]);
seq_pos_inc(seq_id, pos[i]);
return false;
}
@ -232,7 +233,7 @@ public:
assert(!seq[i].test(seq_id));
seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]);
seq_pos_inc(seq_id, pos[i]);
}
// return the sequence id of this cell
@ -259,7 +260,9 @@ public:
return -1;
}
return *seq_pos[seq_id].begin();
assert(seq_pos[seq_id].begin()->second > 0);
return seq_pos[seq_id].begin()->first;
}
// the maximum position of sequence seq_id currently present in any of the cells
@ -272,7 +275,9 @@ public:
return -1;
}
return *seq_pos[seq_id].rbegin();
assert(seq_pos[seq_id].rbegin()->second > 0);
return seq_pos[seq_id].rbegin()->first;
}
// note: call only if the cell is not empty
@ -389,17 +394,36 @@ private:
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<seq_set_t> seq;
// the set seq_pos[s] tells us which positions are currently present for sequence s
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
// if the position p is not present, seq_pos[s][p] is not set
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
//
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
// - during performing a cache reuse via (rm + add)
// - some vision models have input embeddings with repeating positions
//
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
// helper functions for updating `seq_pos`, once cell at a time:
void seq_pos_dec(llama_seq_id s, llama_pos p) {
auto it = seq_pos[s].find(p);
assert(it != seq_pos[s].end());
if (--it->second == 0) {
seq_pos[s].erase(it);
}
}
void seq_pos_inc(llama_seq_id s, llama_pos p) {
seq_pos[s][p]++;
}
// remove cell i
void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].erase(pos[i]);
seq_pos_dec(s, pos[i]);
}
}
}
@ -408,7 +432,7 @@ private:
void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].insert(pos[i]);
seq_pos_inc(s, pos[i]);
}
}
}

View File

@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max
)) {}
llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do {
balloc.split_reset();
@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined state at this point?
// TODO: will the recurrent cache be in an undefined context at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
// prepare the attention cache
auto heads_attn = mem_attn->prepare(ubatches);
if (heads_attn.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_memory_hybrid_state>(
return std::make_unique<llama_memory_hybrid_context>(
this, std::move(heads_attn), std::move(ubatches));
} while(false);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_memory_hybrid::init_full() {
return std::make_unique<llama_memory_hybrid_state>(this);
llama_memory_context_ptr llama_memory_hybrid::init_full() {
return std::make_unique<llama_memory_hybrid_context>(this);
}
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
}
bool llama_memory_hybrid::get_can_shift() const {
@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
return mem_recr.get();
}
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
state_attn(mem->get_mem_attn()->init_full()),
state_recr(mem->get_mem_recr()->init_full()),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
ctx_attn(mem->get_mem_attn()->init_full()),
ctx_recr(mem->get_mem_recr()->init_full()),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_state::llama_memory_hybrid_state(
llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem,
llama_context * lctx,
bool optimize) :
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_state::llama_memory_hybrid_state(
llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
bool llama_memory_hybrid_state::next() {
bool llama_memory_hybrid_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_attn->next();
state_recr->next();
ctx_attn->next();
ctx_recr->next();
if (++i_next >= ubatches.size()) {
return false;
@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
return true;
}
bool llama_memory_hybrid_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool llama_memory_hybrid_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
res = res & state_attn->apply();
res = res & state_recr->apply();
res = res & ctx_attn->apply();
res = res & ctx_recr->apply();
return res;
}
llama_memory_status llama_memory_hybrid_state::get_status() const {
llama_memory_status llama_memory_hybrid_context::get_status() const {
return status;
}
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
}
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
}

View File

@ -49,14 +49,14 @@ public:
// llama_memory_i
//
llama_memory_state_ptr init_batch(
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
@ -90,27 +90,27 @@ private:
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};
class llama_memory_hybrid_state : public llama_memory_state_i {
class llama_memory_hybrid_context : public llama_memory_context_i {
public:
// init failure
explicit llama_memory_hybrid_state(llama_memory_status status);
explicit llama_memory_hybrid_context(llama_memory_status status);
// init full
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
// init update
explicit llama_memory_hybrid_state(
explicit llama_memory_hybrid_context(
llama_memory_hybrid * mem,
llama_context * lctx,
bool optimize);
// init success
llama_memory_hybrid_state(
llama_memory_hybrid_context(
llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_state() = default;
~llama_memory_hybrid_context() = default;
bool next() override;
bool apply() override;
@ -119,11 +119,11 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_memory_hybrid_state
// llama_memory_hybrid_context
//
const llama_kv_cache_unified_state * get_state_attn() const;
const llama_memory_recurrent_state * get_state_recr() const;
const llama_kv_cache_unified_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const;
private:
// the index of the next ubatch to process
@ -131,8 +131,8 @@ private:
std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_attn;
const llama_memory_state_ptr state_recr;
const llama_memory_context_ptr ctx_attn;
const llama_memory_context_ptr ctx_recr;
const llama_memory_status status;
};

View File

@ -362,42 +362,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
std::vector<llama_ubatch> ubatches;
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do {
balloc.split_reset();
while (true) {
llama_ubatch ubatch;
std::vector<llama_ubatch> ubatches;
while (true) {
llama_ubatch ubatch;
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
}
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (ubatch.n_tokens == 0) {
if (!prepare(ubatches)) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
} while (false);
if (!prepare(ubatches)) {
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_memory_recurrent::init_full() {
return std::make_unique<llama_memory_recurrent_state>(this);
llama_memory_context_ptr llama_memory_recurrent::init_full() {
return std::make_unique<llama_memory_recurrent_context>(this);
}
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
GGML_UNUSED(lctx);
GGML_UNUSED(optimize);
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
}
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@ -1040,22 +1045,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
}
//
// llama_memory_recurrent_state
// llama_memory_recurrent_context
//
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
llama_memory_recurrent_state::llama_memory_recurrent_state(
llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
}
llama_memory_recurrent_state::llama_memory_recurrent_state(
llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
bool llama_memory_recurrent_state::next() {
bool llama_memory_recurrent_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) {
@ -1065,48 +1070,56 @@ bool llama_memory_recurrent_state::next() {
return true;
}
bool llama_memory_recurrent_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool llama_memory_recurrent_context::apply() {
assert(!llama_memory_status_is_fail(status));
// no ubatches -> this is an update
if (ubatches.empty()) {
// recurrent cache never performs updates
assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
return true;
}
mem->find_slot(ubatches[i_next]);
return true;
}
llama_memory_status llama_memory_recurrent_state::get_status() const {
llama_memory_status llama_memory_recurrent_context::get_status() const {
return status;
}
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
uint32_t llama_memory_recurrent_state::get_n_rs() const {
uint32_t llama_memory_recurrent_context::get_n_rs() const {
return is_full ? mem->size : mem->n;
}
uint32_t llama_memory_recurrent_state::get_head() const {
uint32_t llama_memory_recurrent_context::get_head() const {
return is_full ? 0 : mem->head;
}
int32_t llama_memory_recurrent_state::get_rs_z() const {
int32_t llama_memory_recurrent_context::get_rs_z() const {
return is_full ? 0 : mem->rs_z;
}
uint32_t llama_memory_recurrent_state::get_size() const {
uint32_t llama_memory_recurrent_context::get_size() const {
return mem->size;
}
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
return mem->r_l[il];
}
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
return mem->s_l[il];
}
int32_t llama_memory_recurrent_state::s_copy(int i) const {
int32_t llama_memory_recurrent_context::s_copy(int i) const {
return mem->cells[i + mem->head].src0;
}

View File

@ -11,8 +11,8 @@
// llama_memory_recurrent
//
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i {
public:
@ -34,14 +34,14 @@ public:
// llama_memory_i
//
llama_memory_state_ptr init_batch(
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
void clear(bool data) override;
@ -125,24 +125,24 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
class llama_memory_recurrent_state : public llama_memory_state_i {
class llama_memory_recurrent_context : public llama_memory_context_i {
public:
// used for errors
llama_memory_recurrent_state(llama_memory_status status);
llama_memory_recurrent_context(llama_memory_status status);
// used to create a full-cache state
llama_memory_recurrent_state(
// used to create a full-cache or update context
llama_memory_recurrent_context(
llama_memory_recurrent * mem);
// used to create a state from a batch
llama_memory_recurrent_state(
// used to create a batch processing context from a batch
llama_memory_recurrent_context(
llama_memory_recurrent * mem,
std::vector<llama_ubatch> ubatches);
virtual ~llama_memory_recurrent_state();
virtual ~llama_memory_recurrent_context();
//
// llama_memory_state_i
// llama_memory_context_i
//
bool next() override;
@ -152,7 +152,7 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_memory_recurrent_state specific API
// llama_memory_recurrent_context specific API
//
uint32_t get_n_rs() const;

View File

@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
// if either status has an update, then the combined status has an update
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
}
bool llama_memory_status_is_fail(llama_memory_status status) {
switch (status) {
case LLAMA_MEMORY_STATUS_SUCCESS:
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
return false;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
return true;
}
}
return false;
}

View File

@ -3,7 +3,6 @@
#include "llama.h"
#include <memory>
#include <vector>
struct llama_ubatch;
@ -28,23 +27,24 @@ enum llama_memory_status {
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
};
// helper function for combining the status of two memory states
// helper function for combining the status of two memory contexts
// useful for implementing hybrid memory types (e.g. iSWA)
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
// the interface for managing the memory state during batch processing
// helper function for checking if a memory status indicates a failure
bool llama_memory_status_is_fail(llama_memory_status status);
// the interface for managing the memory context during batch processing
// this interface is implemented per memory type. see:
// - llama_kv_cache_unified_state
// - llama_kv_cache_unified_iswa_state
// - llama_kv_cache_unified_context
// - llama_kv_cache_unified_iswa_context
// ...
//
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
//
// TODO: rename to llama_memory_context_i ?
struct llama_memory_state_i {
virtual ~llama_memory_state_i() = default;
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
struct llama_memory_context_i {
virtual ~llama_memory_context_i() = default;
// consume the current ubatch from the state and proceed to the next one
// consume the current ubatch from the context and proceed to the next one
// return false if we are done
virtual bool next() = 0;
@ -55,11 +55,11 @@ struct llama_memory_state_i {
// get the current ubatch
virtual const llama_ubatch & get_ubatch() const = 0;
// get the status of the memory state - used for error handling and checking if any updates would be applied
// get the status of the memory context - used for error handling and checking if any updates would be applied
virtual llama_memory_status get_status() const = 0;
};
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
@ -67,19 +67,19 @@ struct llama_memory_i {
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
// return a context object containing the ubatches and memory state required to process them
// check the llama_memory_context_i::get_status() for the result
virtual llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
virtual llama_memory_context_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;

View File

@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_475M: return "475M";
case LLM_TYPE_770M: return "770M";
case LLM_TYPE_780M: return "780M";
case LLM_TYPE_0_3B: return "0.3B";
case LLM_TYPE_0_5B: return "0.5B";
case LLM_TYPE_0_6B: return "0.6B";
case LLM_TYPE_1B: return "1B";
@ -103,6 +104,8 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_E2B: return "E2B";
case LLM_TYPE_E4B: return "E4B";
default: return "?B";
}
}
@ -1017,6 +1020,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break;
case LLM_ARCH_GEMMA3N:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(5);
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
hparams.f_attention_scale = 1.0f;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 30: type = LLM_TYPE_E2B; break;
case 35: type = LLM_TYPE_E4B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -1484,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_ERNIE4_5:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 18: type = LLM_TYPE_0_3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}
@ -2950,6 +2979,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_GEMMA3N:
{
const int64_t n_altup = hparams.n_altup;
const int64_t laurel_rank = hparams.laurel_rank;
const int64_t n_embd_altup = hparams.n_embd_altup;
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
// altup & laurel
layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -4268,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_ERNIE4_5:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
// optional bias tensors
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
@ -8980,6 +9099,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
}
};
struct llm_build_gemma3n_iswa : public llm_graph_context {
const llama_model & model;
ggml_cgraph * gf;
const int64_t n_embd_head;
const int64_t n_embd_altup;
const int64_t n_altup;
const int i_altup_act;
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
const int n_layer_sparsity = 10; // number of layers using activation sparsity
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
ggml_tensor * one; // containing single element 1.0f
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
: llm_graph_context(params),
model(model),
gf(gf),
n_embd_head(model.hparams.n_embd_head_k),
n_embd_altup(model.hparams.n_embd_altup),
n_altup(model.hparams.n_altup),
i_altup_act(model.hparams.i_altup_act) {
ggml_tensor * cur;
ggml_tensor * inpL;
// TODO: remove this when ggml_scale_add is implemented
one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
{
auto inp = std::make_unique<llm_graph_input_one>();
inp->one = one;
res->add_input(std::move(inp));
}
inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) {
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_unified_iswa();
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
// inpL now has only 1 altup, project it to the rest of the altups
// these "added" altups will be concat to the last dim of inpL
{
ggml_tensor * target_magnitude = calc_magnitude(inpL);
ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
ggml_tensor * new_magnitude = calc_magnitude(altup_added);
altup_added = ggml_div(ctx0,
ggml_mul(ctx0, altup_added, target_magnitude),
new_magnitude);
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
cb(inpL, "inp_stacked", -1);
}
// inpL now has shape: [n_embd, n_tokens, n_altup]
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
for (int il = 0; il < n_layer; ++il) {
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
const bool has_kv = (il < n_layer_kv);
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
// predicted value will go through self-attention and laurel
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
cur = active_prediction;
cb(cur, "active_prediction", il);
// norm
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// laurel
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
// self-attention
if (has_kv) {
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
cb(Vcur, "Vcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cb(Kcur, "Kcur_pos", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
} else {
// no KV layers
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
}
cur = build_norm(cur,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
cb(cur, "attn_gated", il);
ggml_tensor * attn_laurel = ggml_scale(ctx0,
ggml_add(ctx0, cur, laurel_out),
1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
cb(attn_laurel, "attn_laurel", il);
cur = build_norm(attn_laurel,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// feed-forward network
{
ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
if (il < n_layer_sparsity) {
// apply activation sparsity
gate_proj = gaussian_topk(gate_proj);
}
gate_proj = ggml_gelu(ctx0, gate_proj);
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = build_lora_mm(model.layers[il].ffn_down, cur);
cb(cur, "ffn_out", il);
}
cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);
ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
ggml_tensor * first_prediction; // [n_embd, n_tokens]
{
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
cb(first_prediction, "first_prediction_gated", il);
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
cb(first_prediction, "first_prediction_scaled", il);
first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
first_prediction = build_norm(first_prediction,
model.layers[il].per_layer_post_norm, NULL,
LLM_NORM_RMS, il);
cb(first_prediction, "first_prediction_out", il);
}
// equivalent to python code: corrected_predictions[1:] += first_prediction
{
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
ggml_row_size(corrected->type, n_embd),
ggml_row_size(corrected->type, n_embd*n_tokens),
n_embd*n_tokens*ggml_element_size(corrected));
ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
}
cur = corrected; // [n_embd, n_tokens, n_altup]
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL; // [n_embd, n_tokens, n_altup]
// cur now has multiple altup(s), we want to merge them back to 1 altup
{
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
// do a view to skip the first slice (active altup)
ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
ggml_row_size(cur->type, n_embd),
ggml_row_size(cur->type, n_embd*n_tokens),
n_embd*n_tokens*ggml_element_size(cur));
ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
altup_unembd = ggml_div(ctx0,
ggml_mul(ctx0, altup_unembd, target_magnitude),
new_magnitude);
cb(altup_unembd, "altup_unembd", -1);
// equivalent to torch.mean(hidden_states, dim=0)
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
for (int i = 0; i < n_altup - 1; ++i) {
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
}
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
cb(cur, "unembd_merged", -1);
}
// cur now has shape: [n_embd, n_tokens]
// TODO: move this to right after the last KV layer
{
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
{
// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * calc_magnitude(ggml_tensor * x) {
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
}
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int)x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}
// equivalent to get_per_layer_inputs() in python code
// output shape: [n_embd_altup, n_layer, n_tokens]
ggml_tensor * get_per_layer_inputs() {
auto inp = std::make_unique<llm_graph_input_embd>();
ggml_tensor * inp_per_layer;
if (ubatch.token) {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
ggml_set_input(inp->tokens);
res->t_tokens = inp->tokens;
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
cb(inp_per_layer, "inp_per_layer_selected", -1);
} else {
GGML_ABORT("TODO: support embd input");
}
res->add_input(std::move(inp));
return inp_per_layer;
}
// equivalent to project_per_layer_inputs() in python code
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
// output shape: [n_embd_altup, n_tokens, n_layer]
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
per_layer_proj = build_norm(per_layer_proj,
model.per_layer_proj_norm, NULL,
LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
cb(per_layer_proj, "per_layer_proj", -1);
inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
cb(inp_per_layer, "inp_per_layer", -1);
// permute to shape: [n_embd_altup, n_tokens, n_layer]
inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
return inp_per_layer;
}
// input cur shape: [n_altup, n_tokens]
// output shape: [n_altup, n_tokens]
ggml_tensor * laurel(ggml_tensor * cur, int il) {
ggml_tensor * tmp = cur;
tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
tmp = ggml_add(ctx0, tmp, cur);
cb(tmp, "laurel_out", il);
return tmp;
}
// input x shape: [n_embd, n_tokens]
// output shape: [n_embd, n_tokens]
ggml_tensor * gaussian_topk(ggml_tensor * x) {
ggml_tensor * mean = ggml_mean(ctx0, x);
ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
1.0f / (float)(x->ne[0] - 1)
));
ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
}
//
// altup functions
//
// equivalent to compute_router_modalities() in python code
// input x shape: [n_embd, n_tokens]
// output shape: [n_altup, n_tokens]
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
ggml_tensor * router_inputs = build_norm(x,
model.layers[il].altup_router_norm, NULL,
LLM_NORM_RMS, il);
// router_input_scale
router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
}
// input cur shape: [n_embd, n_tokens, n_altup]
// output shape: [n_embd, n_tokens, n_altup]
ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
cb(modalities, "modalities", il);
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
cb(all_coefs, "all_coefs", il);
// first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
// permute to [n_altup, n_embd, n_tokens]
ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
// final shape must be the same as cur: [n_embd, n_tokens, n_altup]
predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
predictions = ggml_add(ctx0, predictions, cur);
cb(predictions, "predictions", il);
return predictions;
}
// input predictions shape: [n_embd, n_tokens, n_altup]
// input activated shape: [n_embd, n_tokens]
// output shape: [n_embd, n_tokens, n_altup]
ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
cb(modalities, "modalities", il);
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
cb(innovation, "innovation", il);
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
all_coefs = ggml_add(ctx0, all_coefs, one);
cb(all_coefs, "all_coefs", il);
all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
cb(corrected, "corrected", il);
return corrected;
}
};
// TODO: move up next to build_starcoder
struct llm_build_starcoder2 : public llm_graph_context {
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
@ -9171,9 +9726,9 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * cur,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto kv_head = kv_state->get_head();
const auto kv_head = mctx_cur->get_head();
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
@ -9191,8 +9746,8 @@ struct llm_build_mamba : public llm_graph_context {
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// (ab)using the KV cache to store the states
ggml_tensor * conv = build_rs(
@ -11916,7 +12471,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
ggml_tensor * x_prev,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -11926,7 +12481,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
const auto n_head = n_embd / head_size;
const auto n_head_kv = hparams.n_head_kv(il);
const auto kv_head = kv_state->get_head();
const auto kv_head = mctx_cur->get_head();
const auto & layer = model.layers[il];
@ -12038,7 +12593,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
}
ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_s_l(il),
inp, gf, mctx_cur->get_s_l(il),
hparams.n_embd_s(), n_seqs);
ggml_tensor * wkv_output;
@ -12057,9 +12612,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
wkv_state,
ggml_view_1d(
ctx0,
kv_state->get_s_l(il),
mctx_cur->get_s_l(il),
hparams.n_embd_s() * n_seqs,
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
)
)
);
@ -12313,7 +12868,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -12322,7 +12877,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
const auto head_count = n_embd / head_size;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = kv_state->get_head();
const auto kv_head = mctx_cur->get_head();
const auto & layer = model.layers[il];
@ -12393,7 +12948,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_s_l(il),
inp, gf, mctx_cur->get_s_l(il),
hparams.n_embd_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@ -12407,9 +12962,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
wkv_state,
ggml_view_1d(
ctx0,
kv_state->get_s_l(il),
mctx_cur->get_s_l(il),
hparams.n_embd_s() * n_seqs,
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
)
)
);
@ -13613,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
}
};
struct llm_build_ernie4_5 : public llm_graph_context {
llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
{
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
}
// self-attention
{
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
{
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_arcee : public llm_graph_context {
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@ -13974,6 +14659,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
} break;
case LLM_ARCH_GEMMA3N:
{
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
} break;
case LLM_ARCH_STARCODER2:
{
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
@ -14119,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
} break;
case LLM_ARCH_ERNIE4_5:
{
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
} break;
default:
GGML_ABORT("fatal error");
}
@ -14270,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_BAILINGMOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_ARCEE:
case LLM_ARCH_ERNIE4_5:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
@ -14295,6 +14989,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
@ -14377,7 +15072,7 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
// do not extend this list unless absolutely necessary
// Mistral-Small-2503 does not have built-in chat template
llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
return "mistral-v7-tekken";
}

View File

@ -39,6 +39,7 @@ enum llm_type {
LLM_TYPE_475M,
LLM_TYPE_770M,
LLM_TYPE_780M,
LLM_TYPE_0_3B,
LLM_TYPE_0_5B,
LLM_TYPE_0_6B,
LLM_TYPE_1B,
@ -95,6 +96,8 @@ enum llm_type {
LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_30B_A3B,
LLM_TYPE_235B_A22B,
LLM_TYPE_E2B,
LLM_TYPE_E4B,
};
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
@ -316,6 +319,19 @@ struct llama_layer {
struct ggml_tensor * ffn_up_scale = nullptr;
struct ggml_tensor * ffn_down_scale = nullptr;
// altup & laurel
struct ggml_tensor * per_layer_inp_gate = nullptr;
struct ggml_tensor * per_layer_proj = nullptr;
struct ggml_tensor * per_layer_post_norm = nullptr;
struct ggml_tensor * altup_correct_coef = nullptr;
struct ggml_tensor * altup_correct_scale = nullptr;
struct ggml_tensor * altup_predict_coef = nullptr;
struct ggml_tensor * altup_router = nullptr;
struct ggml_tensor * altup_router_norm = nullptr;
struct ggml_tensor * laurel_l = nullptr;
struct ggml_tensor * laurel_r = nullptr;
struct ggml_tensor * laurel_post_norm = nullptr;
struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext;
@ -354,6 +370,13 @@ struct llama_model {
struct ggml_tensor * conv1d = nullptr;
struct ggml_tensor * conv1d_b = nullptr;
// gemma3n altup
struct ggml_tensor * tok_embd_per_layer = nullptr;
struct ggml_tensor * altup_proj = nullptr;
struct ggml_tensor * altup_unembd_proj = nullptr;
struct ggml_tensor * per_layer_model_proj = nullptr;
struct ggml_tensor * per_layer_proj_norm = nullptr;
std::vector<llama_layer> layers;
llama_model_params params;

View File

@ -1,5 +1,4 @@
#include "llama-quant.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-model-loader.h"
@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
}
}
static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
if (prune.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const int blk = std::stoi(match[1]);
std::string new_name = orig_name;
if (mapped.count(blk)) {
// Already mapped, do nothing
} else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
mapped[blk] = "";
} else if (blk < prune.front()) {
mapped[blk] = std::to_string(blk);
next_id = blk + 1;
} else {
mapped[blk] = std::to_string(next_id);
++next_id;
}
return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
}
return orig_name;
}
static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
if (mapped.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const std::string blk(match[1]);
std::string new_name = orig_name;
for (const auto & p : mapped) {
if (p.second == blk) {
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
}
}
GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
}
return orig_name;
}
struct quantize_state_impl {
const llama_model & model;
const llama_model_quantize_params * params;
@ -174,7 +223,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
new_type = GGML_TYPE_Q6_K;
}
}
} else if (name == "token_embd.weight") {
} else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
new_type = qs.params->token_embedding_type;
} else {
@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const size_t align = GGUF_DEFAULT_ALIGNMENT;
gguf_context_ptr ctx_out { gguf_init_empty() };
std::vector<int> prune_list = {};
if (params->prune_layers) {
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
}
// copy the KV pairs from the input file
gguf_set_kv (ctx_out.get(), ml.meta.get());
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
}
std::map<int, std::string> mapped;
int blk_id = 0;
int pruned_attention_w = 0;
// make a list of weights
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
tensors.reserve(ml.weights_map.size());
for (const auto & it : ml.weights_map) {
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
if (remapped_name.empty()) {
if (it.first.find("attn_v.weight") != std::string::npos ||
it.first.find("attn_qkv.weight") != std::string::npos ||
it.first.find("attn_kv_b.weight") != std::string::npos) {
pruned_attention_w++;
}
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
continue;
} else if (remapped_name != it.first) {
ggml_set_name(it.second.tensor, remapped_name.c_str());
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
}
tensors.push_back(&it.second);
}
if (!prune_list.empty()) {
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
}
// keep_split requires that the weights are sorted by split index
if (params->keep_split) {
@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (llama_model_has_encoder(&model)) {
n_attn_layer *= 3;
}
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
}
size_t total_size_org = 0;
@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
for (size_t i = 0; i < ctx_outs.size(); ++i) {
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
}
}
@ -756,6 +830,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
// these are very small (e.g. 4x4)
quantize &= name.find("altup") == std::string::npos;
quantize &= name.find("laurel") == std::string::npos;
// these are not too big so keep them as it is
quantize &= name.find("per_layer_model_proj") == std::string::npos;
// do not quantize positional embeddings and token types (BERT)
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
@ -832,7 +913,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const float * imatrix = nullptr;
if (imatrix_data) {
auto it = imatrix_data->find(tensor->name);
auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
if (it == imatrix_data->end()) {
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
} else {
@ -947,6 +1028,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.tensor_type =*/ nullptr,
/*.prune_layers =*/ nullptr
};
return result;

View File

@ -390,6 +390,7 @@ extern "C" {
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
} llama_model_quantize_params;
typedef struct llama_logit_bias {
@ -943,12 +944,14 @@ extern "C" {
// Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// Upon non-zero return values, the memory state is restored to the state before this call
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
// Upon other return values, the memory state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted
// 2 - aborted (processed ubatches will remain in the context's memory)
// -1 - invalid input batch
// < -1 - error
// < -1 - fatal error (processed ubatches will remain in the context's memory)
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);