talk-llama : sync llama.cpp

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-12 16:26:16 +03:00
parent 6d64e4abf3
commit 6ddff4d96a
24 changed files with 2831 additions and 690 deletions

View File

@ -45,6 +45,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_JAMBA, "jamba" },
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" }, { LLM_ARCH_COHERE2, "cohere2" },
@ -70,6 +73,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_ARWKV7, "arwkv7" }, { LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" }, { LLM_ARCH_PLM, "plm" },
@ -77,6 +81,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -149,7 +156,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@ -170,6 +176,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
@ -182,6 +189,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@ -1004,6 +1013,77 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
}, },
}, },
{
LLM_ARCH_MAMBA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_JAMBA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_FALCON_H1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
{ {
@ -1564,6 +1644,43 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{
LLM_ARCH_GRANITE_HYBRID,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
// mamba(2) ssm layers
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
// attention layers
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
// dense FFN
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
// moe FFN
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
// shared expert
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{ {
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
{ {
@ -1676,6 +1793,67 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_HUNYUAN_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_SMOLLM3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LFM2,
{
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
}
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -1760,7 +1938,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@ -1839,6 +2021,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
}; };
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
@ -1894,6 +2079,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
bool llm_arch_is_recurrent(const llm_arch & arch) { bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) { switch (arch) {
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7: case LLM_ARCH_RWKV7:
@ -1905,9 +2091,12 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
} }
bool llm_arch_is_hybrid(const llm_arch & arch) { bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) { switch (arch) {
case LLM_ARCH_JAMBA:
case LLM_ARCH_FALCON_H1:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
return true;
default: default:
return false; return false;
} }

View File

@ -49,6 +49,9 @@ enum llm_arch {
LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_JAMBA,
LLM_ARCH_FALCON_H1,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R, LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2, LLM_ARCH_COHERE2,
@ -74,6 +77,7 @@ enum llm_arch {
LLM_ARCH_ARWKV7, LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_GRANITE_HYBRID,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM, LLM_ARCH_PLM,
@ -81,6 +85,9 @@ enum llm_arch {
LLM_ARCH_DOTS1, LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE, LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_LFM2,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -153,7 +160,6 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_LAYER_INDICES,
LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_DIMENSION_SECTIONS,
@ -174,6 +180,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS, LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE, LLM_KV_WKV_HEAD_SIZE,
@ -221,6 +228,8 @@ enum llm_kv {
LLM_KV_CLASSIFIER_OUTPUT_LABELS, LLM_KV_CLASSIFIER_OUTPUT_LABELS,
LLM_KV_SHORTCONV_L_CACHE,
// deprecated: // deprecated:
LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID,
@ -291,8 +300,12 @@ enum llm_tensor {
LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_X,
LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_DT_NORM,
LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_B_NORM,
LLM_TENSOR_SSM_C_NORM,
LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W1,
@ -386,6 +399,9 @@ enum llm_tensor {
LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_K,
LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_V,
LLM_TENSOR_POS_NET_ATTN_OUT, LLM_TENSOR_POS_NET_ATTN_OUT,
LLM_TENSOR_SHORTCONV_CONV,
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
}; };
enum llm_tensor_layer { enum llm_tensor_layer {

View File

@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
// note: tracking the other way around is not necessary for now // note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true; //seq_cpl[s0][s1] = true;
has_cpl = true;
} }
} }
} }
@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs; return n_outputs;
} }
uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}
std::vector<int32_t> & llama_batch_allocr::get_out_ids() { std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids; return out_ids;
} }
@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
void llama_batch_allocr::split_reset() { void llama_batch_allocr::split_reset() {
out_ids.clear(); out_ids.clear();
n_used = 0;
used.clear(); used.clear();
used.resize(get_n_tokens(), false); used.resize(get_n_tokens(), false);
@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
idxs.push_back(cur_idx); idxs.push_back(cur_idx);
used[cur_idx] = true; used[cur_idx] = true;
++n_used;
++cur_idx; ++cur_idx;
@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
return ubatch_add(idxs, idxs.size(), false); return ubatch_add(idxs, idxs.size(), false);
} }
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
return {};
}
std::vector<seq_set_t> cur_seq_set; std::vector<seq_set_t> cur_seq_set;
llama_seq_id last_seq_id = -1;
// determine the non-overlapping sequence sets participating in this ubatch // determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) { if (used[i]) {
@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
} }
} }
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
if (add) { if (add) {
cur_seq_set.push_back(seq_set[i]); cur_seq_set.push_back(seq_set[i]);
last_seq_id = batch.seq_id[i][0];
if (cur_seq_set.size() > n_ubatch) { if (cur_seq_set.size() > n_ubatch) {
break; break;
} }
@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
idxs_per_seq[s].push_back(idx); idxs_per_seq[s].push_back(idx);
used[idx] = true; used[idx] = true;
++n_used;
++cur_idx[s]; ++cur_idx[s];
} }
@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
idxs.push_back(cur_idx); idxs.push_back(cur_idx);
used[cur_idx] = true; used[cur_idx] = true;
++n_used;
if (idxs.size() >= n_ubatch) { if (idxs.size() >= n_ubatch) {
break; break;

View File

@ -54,6 +54,7 @@ public:
uint32_t get_n_tokens() const; uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const; uint32_t get_n_outputs() const;
uint32_t get_n_used() const;
// the array of output indices in the order they were encountered during the ubatch splitting // the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids(); std::vector<int32_t> & get_out_ids();
@ -69,7 +70,8 @@ public:
llama_ubatch split_simple(uint32_t n_ubatch); llama_ubatch split_simple(uint32_t n_ubatch);
// make ubatches of equal-length sequences sets // make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch); // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
// sequence-set-wise split - each ubatch contains a single sequence-set // sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch); llama_ubatch split_seq(uint32_t n_ubatch);
@ -112,6 +114,9 @@ private:
using pos_set_t = std::set<llama_pos>; using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>; using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@ -125,6 +130,8 @@ private:
// batch indices of the output // batch indices of the output
std::vector<int32_t> out_ids; std::vector<int32_t> out_ids;
uint32_t n_used;
// used[i] indicates if token i has already been used in a previous ubatch // used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used; std::vector<bool> used;

View File

@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
}; };
llm_chat_template llm_chat_template_from_str(const std::string & name) { llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA4; return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) { } else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1; return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} }
return LLM_CHAT_TEMPLATE_UNKNOWN; return LLM_CHAT_TEMPLATE_UNKNOWN;
} }
@ -665,6 +668,18 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|response|>"; ss << "<|response|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
// tencent/Hunyuan-A13B-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
} else if (role == "assistant") {
ss << "<|startoftext|>" << message->content << "<|eos|>";
} else {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

View File

@ -44,6 +44,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_UNKNOWN, LLM_CHAT_TEMPLATE_UNKNOWN,
}; };

View File

@ -281,20 +281,23 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
if (self_kq_mask_swa) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
} }
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(cross_kq_mask); GGML_ASSERT(cross_kq_mask);
@ -333,27 +336,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { inp_attn->set_input(ubatch);
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); inp_rs->set_input(ubatch);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
int32_t * data = (int32_t *) s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
void llm_graph_input_one::set_input(const llama_ubatch *) {
GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
} }
// //
@ -987,33 +971,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
return pos_bias; return pos_bias;
} }
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy);
}
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * llm_graph_context::build_attn_mha(
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * q, ggml_tensor * q,
@ -1135,8 +1092,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
//cb(inp_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->kq_mask); ggml_set_input(inp->kq_mask);
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@ -1188,8 +1144,12 @@ ggml_tensor * llm_graph_context::build_attn(
return cur; return cur;
} }
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx); ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur); auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
@ -1197,14 +1157,25 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
//cb(inp->self_kq_mask, "KQ_mask", -1); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
} }
return inp;
}
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
} }
@ -1226,12 +1197,15 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx); const auto * mctx_cur = inp->mctx;
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); const auto & k_idxs = inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
@ -1282,7 +1256,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
} }
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx); const auto * mctx_iswa = inp->mctx;
const bool is_swa = hparams.is_swa(il); const bool is_swa = hparams.is_swa(il);
@ -1290,11 +1264,15 @@ ggml_tensor * llm_graph_context::build_attn(
// optionally store to KV cache // optionally store to KV cache
if (k_cur) { if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
} }
if (v_cur) { if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
} }
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@ -1326,7 +1304,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->cross_kq_mask); ggml_set_input(inp->cross_kq_mask);
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@ -1376,56 +1354,9 @@ ggml_tensor * llm_graph_context::build_attn(
return cur; return cur;
} }
ggml_tensor * llm_graph_context::build_attn( // TODO: maybe separate the inner implementation into a separate function
llm_graph_input_mem_hybrid * inp, // like with the non-sliding window equivalent
ggml_cgraph * gf, // once sliding-window hybrid caches are a thing.
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
@ -1434,8 +1365,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{ {
const auto n_kv = mctx_cur->get_base()->get_n_kv(); const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
//cb(inp->self_kq_mask, "KQ_mask", -1); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1446,8 +1379,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
const auto n_kv = mctx_cur->get_swa()->get_n_kv(); const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask_swa); ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@ -1466,7 +1401,7 @@ ggml_tensor * llm_graph_context::build_rs(
uint32_t kv_head, uint32_t kv_head,
uint32_t kv_size, uint32_t kv_size,
int32_t rs_zero, int32_t rs_zero,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
@ -1475,19 +1410,11 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
ggml_tensor * output_states;
if (!avoid_copies) {
// copy states // copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs} // {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states); ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv) // copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
@ -1499,8 +1426,9 @@ ggml_tensor * llm_graph_context::build_rs(
return output_states; return output_states;
} }
llm_graph_input_rs * llm_graph_context::build_rs_inp() const { static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); ggml_context * ctx0,
const llama_memory_recurrent_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur); auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
@ -1509,6 +1437,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy); ggml_set_input(inp->s_copy);
return inp;
}
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
return (llm_graph_input_rs *) res->add_input(std::move(inp)); return (llm_graph_input_rs *) res->add_input(std::move(inp));
} }
@ -1518,22 +1454,10 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); const auto * kv_state = inp->mctx;
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
} }
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@ -1578,6 +1502,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
); );
} }
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
void llm_graph_context::build_pooling( void llm_graph_context::build_pooling(
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * cls, ggml_tensor * cls,

View File

@ -228,8 +228,8 @@ public:
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -249,10 +249,16 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -274,13 +280,23 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -297,8 +313,8 @@ public:
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
const llama_cross * cross = nullptr; const llama_cross * cross = nullptr;
}; };
@ -306,41 +322,25 @@ public:
class llm_graph_input_mem_hybrid : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public: public:
llm_graph_input_mem_hybrid( llm_graph_input_mem_hybrid(
const llama_hparams & hparams, std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
const llama_cparams & cparams, std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) : const llama_memory_hybrid_context * mctx) :
hparams(hparams), inp_attn(std::move(inp_attn)),
cparams(cparams), inp_rs(std::move(inp_rs)),
mctx(mctx) { mctx(mctx) { }
}
virtual ~llm_graph_input_mem_hybrid() = default; virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size] std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_memory_hybrid_context * mctx; const llama_memory_hybrid_context * mctx;
}; };
// TODO: remove this when ggml_scale_add is implemented
class llm_graph_input_one : public llm_graph_input_i {
public:
llm_graph_input_one() {}
virtual ~llm_graph_input_one() = default;
void set_input(const llama_ubatch *) override;
ggml_tensor * one = nullptr; // F32
};
// //
// llm_graph_result // llm_graph_result
// //
@ -424,6 +424,9 @@ struct llm_graph_params {
const llm_graph_cb & cb; const llm_graph_cb & cb;
}; };
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
struct llm_graph_context { struct llm_graph_context {
const llm_arch arch; const llm_arch arch;
@ -554,8 +557,6 @@ struct llm_graph_context {
ggml_tensor * build_inp_pos_bucket_dec() const; ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
// //
// attention // attention
// //
@ -631,18 +632,6 @@ struct llm_graph_context {
float kq_scale, float kq_scale,
int il) const; int il) const;
ggml_tensor * build_attn(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
// //
// recurrent // recurrent
// //
@ -663,7 +652,7 @@ struct llm_graph_context {
uint32_t kv_head, uint32_t kv_head,
uint32_t kv_size, uint32_t kv_size,
int32_t rs_zero, int32_t rs_zero,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
llm_graph_input_rs * build_rs_inp() const; llm_graph_input_rs * build_rs_inp() const;
@ -673,15 +662,7 @@ struct llm_graph_context {
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rs(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
ggml_tensor * build_rwkv_token_shift_load( ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp, llm_graph_input_rs * inp,
@ -693,6 +674,11 @@ struct llm_graph_context {
ggml_tensor * token_shift, ggml_tensor * token_shift,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const; int il) const;
//
// hybrid
//
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
// //
// pooling // pooling

View File

@ -71,9 +71,15 @@ uint32_t llama_hparams::n_embd_r() const {
return token_shift_count * n_embd; return token_shift_count * n_embd;
} }
if (n_shortconv_l_cache != 0) {
// for LFM2 models
return n_embd * (n_shortconv_l_cache - 1);
}
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; // Corresponds to Mamba's conv_states size
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
} }
uint32_t llama_hparams::n_embd_s() const { uint32_t llama_hparams::n_embd_s() const {

View File

@ -55,6 +55,8 @@ struct llama_hparams {
struct llama_hparams_posnet posnet; struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext; struct llama_hparams_convnext convnext;
uint32_t n_shortconv_l_cache = 0;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr; std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr; std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr; std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
@ -114,6 +116,7 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0; uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0; uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0; uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for hybrid state space models // for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr; std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;

View File

@ -113,20 +113,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); if (balloc.get_n_used() < balloc.get_n_tokens()) {
if (heads_base.empty()) { // failed to find a suitable split
break; break;
} }
auto heads_swa = kv_swa->prepare(ubatches); auto sinfos_base = kv_base->prepare(ubatches);
if (heads_swa.empty()) { if (sinfos_base.empty()) {
break; break;
} }
assert(heads_base.size() == heads_swa.size()); auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
// if it fails, try equal split // if it fails, try equal split
@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (true) { while (true) {
auto ubatch = balloc.split_equal(n_ubatch); auto ubatch = balloc.split_equal(n_ubatch, false);
if (ubatch.n_tokens == 0) { if (ubatch.n_tokens == 0) {
break; break;
@ -144,20 +149,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); if (balloc.get_n_used() < balloc.get_n_tokens()) {
if (heads_base.empty()) { // failed to find a suitable split
break; break;
} }
auto heads_swa = kv_swa->prepare(ubatches); auto sinfos_base = kv_base->prepare(ubatches);
if (heads_swa.empty()) { if (sinfos_base.empty()) {
break; break;
} }
assert(heads_base.size() == heads_swa.size()); auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
// TODO: if we fail again, we should attempt different splitting strategies // TODO: if we fail again, we should attempt different splitting strategies
@ -220,13 +230,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, slot_info_vec_t sinfos_base,
std::vector<uint32_t> heads_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)), ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)), ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }

View File

@ -74,6 +74,8 @@ private:
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// used for errors // used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status); llama_kv_cache_unified_iswa_context(llama_memory_status status);
@ -90,8 +92,8 @@ public:
// used to create a batch processing context from a batch // used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, slot_info_vec_t sinfos_base,
std::vector<uint32_t> heads_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context(); virtual ~llama_kv_cache_unified_iswa_context();

View File

@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
if (!supports_set_rows) {
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
}
} }
void llama_kv_cache_unified::clear(bool data) { void llama_kv_cache_unified::clear(bool data) {
@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads = prepare(ubatches); if (balloc.get_n_used() < balloc.get_n_tokens()) {
if (heads.empty()) { // failed to find a suitable split
break;
}
auto sinfos = prepare(ubatches);
if (sinfos.empty()) {
break; break;
} }
return std::make_unique<llama_kv_cache_unified_context>( return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(heads), std::move(ubatches)); this, std::move(sinfos), std::move(ubatches));
} while (false); } while (false);
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo)); return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
} }
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::ubatch_heads res; llama_kv_cache_unified::slot_info_vec_t res;
struct state { struct state {
uint32_t head_old; // old position of the head, before placing the ubatch uint32_t head_old; // old position of the head, before placing the ubatch
uint32_t head_new; // new position of the head, after placing the ubatch
slot_info sinfo; // slot info for the ubatch
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
}; };
@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
bool success = true; bool success = true;
for (const auto & ubatch : ubatches) { for (const auto & ubatch : ubatches) {
// non-continuous slots require support for ggml_set_rows()
const bool cont = supports_set_rows ? false : true;
// only find a suitable slot for the ubatch. don't modify the cells yet // only find a suitable slot for the ubatch. don't modify the cells yet
const int32_t head_new = find_slot(ubatch); const auto sinfo_new = find_slot(ubatch, cont);
if (head_new < 0) { if (sinfo_new.empty()) {
success = false; success = false;
break; break;
} }
// remeber the position that we found // remeber the position that we found
res.push_back(head_new); res.push_back(sinfo_new);
// store the old state of the cells in the recovery stack // store the old state of the cells in the recovery stack
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
// now emplace the ubatch // now emplace the ubatch
apply_ubatch(head_new, ubatch); apply_ubatch(sinfo_new, ubatch);
} }
// iterate backwards and restore the cells to their original state // iterate backwards and restore the cells to their original state
for (auto it = states.rbegin(); it != states.rend(); ++it) { for (auto it = states.rbegin(); it != states.rend(); ++it) {
cells.set(it->head_new, it->cells); cells.set(it->sinfo.idxs, it->cells);
head = it->head_old; head = it->head_old;
} }
@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
return updated; return updated;
} }
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head; uint32_t head_cur = this->head;
@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
if (n_tokens > cells.size()) { if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return -1; return { };
} }
if (debug > 0) { if (debug > 0) {
@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
uint32_t n_tested = 0; uint32_t n_tested = 0;
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
// for non-continuous slots, we test the tokens one by one
const uint32_t n_test = cont ? n_tokens : 1;
slot_info res;
auto & idxs = res.idxs;
idxs.reserve(n_tokens);
while (true) { while (true) {
if (head_cur + n_tokens > cells.size()) { if (head_cur + n_test > cells.size()) {
n_tested += cells.size() - head_cur; n_tested += cells.size() - head_cur;
head_cur = 0; head_cur = 0;
continue; continue;
} }
bool found = true; for (uint32_t i = 0; i < n_test; i++) {
for (uint32_t i = 0; i < n_tokens; i++) { const auto idx = head_cur;
//const llama_pos pos = ubatch.pos[i]; //const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0]; //const llama_seq_id seq_id = ubatch.seq_id[i][0];
@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
// - (disabled) mask causally, if the sequence is the same as the one we are inserting // - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache // - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos // always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i); bool can_use = cells.is_empty(idx);
if (!can_use && cells.seq_count(head_cur + i) == 1) { if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(head_cur + i); const llama_pos pos_cell = cells.pos_get(idx);
// (disabled) causal mask // (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand // note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(head_cur + i, seq_id)) { //if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos; // can_use = pos_cell >= pos;
//} //}
if (!can_use) { if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask // SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
} }
} }
if (!can_use) { head_cur++;
found = false; n_tested++;
head_cur += i + 1;
n_tested += i + 1; if (can_use) {
idxs.push_back(idx);
} else {
break; break;
} }
} }
if (found) { if (idxs.size() == n_tokens) {
break; break;
} }
if (cont) {
idxs.clear();
}
if (n_tested >= cells.size()) { if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return -1; return { };
} }
} }
return head_cur; // we didn't find a suitable slot - return empty result
if (idxs.size() < n_tokens) {
res.clear();
} }
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { return res;
}
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// keep track of the max sequence position that we would overwrite with this ubatch // keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty // for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_pos_max_rm[s] = -1; seq_pos_max_rm[s] = -1;
} }
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { assert(ubatch.n_tokens == sinfo.idxs.size());
if (!cells.is_empty(head_cur + i)) {
assert(cells.seq_count(head_cur + i) == 1);
const llama_seq_id seq_id = cells.seq_get(head_cur + i); for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const llama_pos pos = cells.pos_get(head_cur + i); const auto idx = sinfo.idxs.at(i);
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(head_cur + i); cells.rm(idx);
} }
cells.pos_set(head_cur + i, ubatch.pos[i]); cells.pos_set(idx, ubatch.pos[i]);
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]); cells.seq_add(idx, ubatch.seq_id[i][s]);
} }
} }
@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
} }
// move the head at the end of the slot // move the head at the end of the slot
head = head_cur + ubatch.n_tokens; head = sinfo.idxs.back() + 1;
} }
bool llama_kv_cache_unified::get_can_shift() const { bool llama_kv_cache_unified::get_can_shift() const {
@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
0); 0);
} }
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il); const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k; auto * k = layers[ikv].k;
const int64_t n_embd_k_gqa = k->ne[0];
const int64_t n_tokens = k_cur->ne[2]; const int64_t n_tokens = k_cur->ne[2];
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
if (k_idxs && supports_set_rows) {
return ggml_set_rows(ctx, k, k_cur, k_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
ggml_tensor * k_view = ggml_view_1d(ctx, k, ggml_tensor * k_view = ggml_view_1d(ctx, k,
n_tokens*hparams.n_embd_k_gqa(il), n_tokens*n_embd_k_gqa,
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
return ggml_cpy(ctx, k_cur, k_view); return ggml_cpy(ctx, k_cur, k_view);
} }
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il); const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v; auto * v = layers[ikv].v;
const int64_t n_embd_v_gqa = v->ne[0];
const int64_t n_tokens = v_cur->ne[2]; const int64_t n_tokens = v_cur->ne[2];
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
if (v_idxs && supports_set_rows) {
if (!v_trans) {
return ggml_set_rows(ctx, v, v_cur, v_idxs);
}
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
// note: the V cache is transposed when not using flash attention
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
// note: we can be more explicit here at the cost of extra cont
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
//v_cur = ggml_transpose(ctx, v_cur);
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
// we broadcast the KV indices n_embd_v_gqa times
// v [1, n_kv, n_embd_v_gqa]
// v_cur [1, n_tokens, n_embd_v_gqa]
// v_idxs [n_tokens, 1, 1]
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
ggml_tensor * v_view = nullptr; ggml_tensor * v_view = nullptr;
if (!v_trans) { if (!v_trans) {
v_view = ggml_view_1d(ctx, v, v_view = ggml_view_1d(ctx, v,
n_tokens*hparams.n_embd_v_gqa(il), n_tokens*n_embd_v_gqa,
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
} else { } else {
// note: the V cache is transposed when not using flash attention
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
(v->ne[1])*ggml_element_size(v),
(head_cur)*ggml_element_size(v));
v_cur = ggml_transpose(ctx, v_cur); v_cur = ggml_transpose(ctx, v_cur);
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
(v->ne[1] )*ggml_element_size(v),
(sinfo.head())*ggml_element_size(v));
} }
return ggml_cpy(ctx, v_cur, v_view); return ggml_cpy(ctx, v_cur, v_view);
} }
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
ggml_set_input(k_idxs);
return k_idxs;
}
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
ggml_set_input(v_idxs);
return v_idxs;
}
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
}
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
}
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens; const uint32_t n_tokens = ubatch->n_tokens;
@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
ubatch.seq_id[i] = &dest_seq_id; ubatch.seq_id[i] = &dest_seq_id;
} }
const auto head_cur = find_slot(ubatch); const auto sinfo = find_slot(ubatch, true);
if (head_cur < 0) { if (sinfo.empty()) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;
} }
apply_ubatch(head_cur, ubatch); apply_ubatch(sinfo, ubatch);
const auto head_cur = sinfo.head();
// keep the head at the old position because we will read the KV data into it in state_read_data() // keep the head at the old position because we will read the KV data into it in state_read_data()
head = head_cur; head = head_cur;
@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size(); n_kv = kv->get_size();
head = 0;
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
sinfos.resize(1);
sinfos[0].idxs.resize(1);
sinfos[0].idxs[0] = 0;
} }
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_kv_cache_unified::ubatch_heads heads, llama_kv_cache_unified::slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
} }
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_context::next() { bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) { if (++i_cur >= ubatches.size()) {
return false; return false;
} }
@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() {
return true; return true;
} }
kv->apply_ubatch(heads[i_next], ubatches[i_next]); kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
n_kv = kv->get_n_kv(); n_kv = kv->get_n_kv();
head = heads[i_next];
return true; return true;
} }
@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_cur];
} }
uint32_t llama_kv_cache_unified_context::get_n_kv() const { uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
return kv->get_v(ctx, il, n_kv); return kv->get_v(ctx, il, n_kv);
} }
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_k(ctx, k_cur, il, head); return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
} }
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
return kv->cpy_v(ctx, v_cur, il, head); return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_k_idxs(ctx, ubatch);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_v_idxs(ctx, ubatch);
} }
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst); kv->set_input_k_shift(dst);
} }
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn); kv->set_input_kq_mask(dst, ubatch, causal_attn);
} }

View File

@ -24,8 +24,6 @@ public:
// this callback is used to filter out layers that should not be included in the cache // this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>; using layer_filter_cb = std::function<bool(int32_t il)>;
using ubatch_heads = std::vector<uint32_t>;
struct defrag_info { struct defrag_info {
bool empty() const { bool empty() const {
return ids.empty(); return ids.empty();
@ -37,6 +35,32 @@ public:
std::vector<uint32_t> ids; std::vector<uint32_t> ids;
}; };
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info {
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
idx_vec_t idxs;
uint32_t head() const {
return idxs.at(0);
}
bool empty() const {
return idxs.empty();
}
void clear() {
idxs.clear();
}
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
};
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified( llama_kv_cache_unified(
const llama_model & model, const llama_model & model,
layer_filter_cb && filter, layer_filter_cb && filter,
@ -102,30 +126,37 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
// //
// preparation API // preparation API
// //
// find places for the provided ubatches in the cache, returns the head locations // find places for the provided ubatches in the cache, returns the slot infos
// return empty vector on failure // return empty vector on failure
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches); slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// return the cell position where we can insert the ubatch // find a slot of kv cells that can hold the ubatch
// return -1 on failure to find a contiguous slot of kv cells // if cont == true, then the slot must be continuous
int32_t find_slot(const llama_ubatch & ubatch) const; // return empty slot_info on failure
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
// //
// set_input API // input API
// //
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const; void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -157,8 +188,13 @@ private:
// SWA // SWA
const uint32_t n_swa = 0; const uint32_t n_swa = 0;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0; int debug = 0;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
int supports_set_rows = false;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
@ -211,7 +247,7 @@ private:
class llama_kv_cache_unified_context : public llama_memory_context_i { class llama_kv_cache_unified_context : public llama_memory_context_i {
public: public:
// some shorthands // some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads; using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info; using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors // used for errors
@ -231,7 +267,7 @@ public:
// used to create a batch procesing context from a batch // used to create a batch procesing context from a batch
llama_kv_cache_unified_context( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
ubatch_heads heads, slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context(); virtual ~llama_kv_cache_unified_context();
@ -257,11 +293,16 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_shift (ggml_tensor * dst) const; void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -283,10 +324,10 @@ private:
// batch processing context // batch processing context
// //
// the index of the next ubatch to process // the index of the cur ubatch to process
size_t i_next = 0; size_t i_cur = 0;
ubatch_heads heads; slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
@ -297,7 +338,4 @@ private:
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// as the cache gets filled, the benefit from this heuristic disappears // as the cache gets filled, the benefit from this heuristic disappears
int32_t n_kv; int32_t n_kv;
// the beginning of the current slot in which the ubatch will be inserted
int32_t head;
}; };

View File

@ -105,10 +105,30 @@ public:
res.resize(n); res.resize(n);
for (uint32_t j = 0; j < n; ++j) { for (uint32_t j = 0; j < n; ++j) {
res.pos[j] = pos[i + j]; const auto idx = i + j;
res.seq[j] = seq[i + j];
assert(shift[i + j] == 0); res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
}
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
res.resize(idxs.size());
for (uint32_t j = 0; j < idxs.size(); ++j) {
const auto idx = idxs[j];
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
} }
return res; return res;
@ -119,26 +139,58 @@ public:
assert(i + other.pos.size() <= pos.size()); assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) { for (uint32_t j = 0; j < other.pos.size(); ++j) {
if (pos[i + j] == -1 && other.pos[j] != -1) { const auto idx = i + j;
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(i + j); used.insert(i + j);
} }
if (pos[i + j] != -1 && other.pos[j] == -1) { if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(i + j); used.erase(i + j);
} }
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_rm(i + j); seq_pos_rm(i + j);
} }
pos[i + j] = other.pos[j]; pos[idx] = other.pos[j];
seq[i + j] = other.seq[j]; seq[idx] = other.seq[j];
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_add(i + j); seq_pos_add(i + j);
} }
assert(shift[i + j] == 0); assert(shift[idx] == 0);
}
}
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
const auto idx = idxs[j];
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(idx);
}
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(idx);
}
if (pos[idx] != -1) {
seq_pos_rm(idx);
}
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
seq_pos_add(idx);
}
assert(shift[idx] == 0);
} }
} }

View File

@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence // if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch); ubatch = balloc.split_seq(n_ubatch);
} else { } else {
ubatch = balloc.split_equal(n_ubatch); ubatch = balloc.split_equal(n_ubatch, false);
} }
if (ubatch.n_tokens == 0) { if (ubatch.n_tokens == 0) {
@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
// prepare the recurrent batches first // prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) { if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point? // TODO: will the recurrent cache be in an undefined context at this point?
@ -195,11 +200,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }

View File

@ -92,6 +92,8 @@ private:
class llama_memory_hybrid_context : public llama_memory_context_i { class llama_memory_hybrid_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// init failure // init failure
explicit llama_memory_hybrid_context(llama_memory_status status); explicit llama_memory_hybrid_context(llama_memory_status status);
@ -107,7 +109,7 @@ public:
// init success // init success
llama_memory_hybrid_context( llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_context() = default; ~llama_memory_hybrid_context() = default;

View File

@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent(
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer; const int32_t n_layer = hparams.n_layer;
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
head = 0; head = 0;
size = mem_size; size = mem_size;
used = 0; used = 0;
@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent(
ggml_context * ctx = ctx_for_buft(buft); ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) { if (!ctx) {
throw std::runtime_error("failed to create ggml context for kv cache"); throw std::runtime_error("failed to create ggml context for rs cache");
} }
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent(
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (!buf) { if (!buf) {
throw std::runtime_error("failed to allocate buffer for kv cache"); throw std::runtime_error("failed to allocate buffer for rs cache");
} }
ggml_backend_buffer_clear(buf, 0); ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
bufs.emplace_back(buf); bufs.emplace_back(buf);
} }
@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent(
const size_t memory_size_r = size_r_bytes(); const size_t memory_size_r = size_r_bytes();
const size_t memory_size_s = size_s_bytes(); const size_t memory_size_s = size_s_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
} }
@ -374,7 +371,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence // if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch); ubatch = balloc.split_seq(n_ubatch);
} else { } else {
ubatch = balloc.split_equal(n_ubatch); ubatch = balloc.split_equal(n_ubatch, false);
} }
if (ubatch.n_tokens == 0) { if (ubatch.n_tokens == 0) {
@ -384,6 +381,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
if (!prepare(ubatches)) { if (!prepare(ubatches)) {
break; break;
} }

File diff suppressed because it is too large Load Diff

View File

@ -32,17 +32,21 @@ enum llm_type {
LLM_TYPE_190M, LLM_TYPE_190M,
LLM_TYPE_220M, LLM_TYPE_220M,
LLM_TYPE_250M, LLM_TYPE_250M,
LLM_TYPE_256M,
LLM_TYPE_270M, LLM_TYPE_270M,
LLM_TYPE_335M, LLM_TYPE_335M,
LLM_TYPE_350M,
LLM_TYPE_410M, LLM_TYPE_410M,
LLM_TYPE_450M, LLM_TYPE_450M,
LLM_TYPE_475M, LLM_TYPE_475M,
LLM_TYPE_700M,
LLM_TYPE_770M, LLM_TYPE_770M,
LLM_TYPE_780M, LLM_TYPE_780M,
LLM_TYPE_0_3B, LLM_TYPE_0_3B,
LLM_TYPE_0_5B, LLM_TYPE_0_5B,
LLM_TYPE_0_6B, LLM_TYPE_0_6B,
LLM_TYPE_1B, LLM_TYPE_1B,
LLM_TYPE_1_2B,
LLM_TYPE_1_3B, LLM_TYPE_1_3B,
LLM_TYPE_1_4B, LLM_TYPE_1_4B,
LLM_TYPE_1_5B, LLM_TYPE_1_5B,
@ -94,6 +98,7 @@ enum llm_type {
LLM_TYPE_57B_A14B, LLM_TYPE_57B_A14B,
LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B,
LLM_TYPE_30B_A3B, LLM_TYPE_30B_A3B,
LLM_TYPE_235B_A22B, LLM_TYPE_235B_A22B,
LLM_TYPE_E2B, LLM_TYPE_E2B,
@ -153,6 +158,12 @@ struct llama_layer_convnext {
struct ggml_tensor * gamma = nullptr; struct ggml_tensor * gamma = nullptr;
}; };
struct llama_layer_shortconv {
struct ggml_tensor * in_proj = nullptr;
struct ggml_tensor * conv = nullptr;
struct ggml_tensor * out_proj = nullptr;
};
struct llama_layer { struct llama_layer {
// normalization // normalization
struct ggml_tensor * attn_norm = nullptr; struct ggml_tensor * attn_norm = nullptr;
@ -172,6 +183,10 @@ struct llama_layer {
struct ggml_tensor * ffn_sub_norm = nullptr; struct ggml_tensor * ffn_sub_norm = nullptr;
struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_cross = nullptr;
struct ggml_tensor * attn_norm_enc = nullptr; struct ggml_tensor * attn_norm_enc = nullptr;
struct ggml_tensor * ssm_norm = nullptr;
struct ggml_tensor * ssm_dt_norm = nullptr;
struct ggml_tensor * ssm_b_norm = nullptr;
struct ggml_tensor * ssm_c_norm = nullptr;
// attention // attention
struct ggml_tensor * wq = nullptr; struct ggml_tensor * wq = nullptr;
@ -335,6 +350,8 @@ struct llama_layer {
struct llama_layer_posnet posnet; struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext; struct llama_layer_convnext convnext;
struct llama_layer_shortconv shortconv;
}; };
struct llama_model { struct llama_model {

View File

@ -844,6 +844,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// do not quantize Mamba's small yet 2D weights // do not quantize Mamba's small yet 2D weights
// NOTE: can't use LLM_TN here because the layer number is not known // NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ssm_conv1d.weight") == std::string::npos; quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
quantize &= name.find("shortconv.conv.weight") == std::string::npos;
// do not quantize RWKV's small yet 2D weights // do not quantize RWKV's small yet 2D weights
quantize &= name.find("time_mix_first.weight") == std::string::npos; quantize &= name.find("time_mix_first.weight") == std::string::npos;

View File

@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
break; break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_QWEN2:
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
regex_exprs = { regex_exprs = {
// original regex from tokenizer.json // original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
@ -1522,7 +1523,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe"|| tokenizer_pre == "llama-bpe"||
tokenizer_pre == "falcon3" || tokenizer_pre == "falcon3" ||
tokenizer_pre == "pixtral") { tokenizer_pre == "falcon-h1" ||
tokenizer_pre == "pixtral" ||
tokenizer_pre == "midm-2.0" ||
tokenizer_pre == "lfm2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
ignore_merges = true; ignore_merges = true;
add_bos = true; add_bos = true;
@ -1554,7 +1558,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "jina-de" || tokenizer_pre == "jina-de" ||
tokenizer_pre == "gigachat" || tokenizer_pre == "gigachat" ||
tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-es" ||
tokenizer_pre == "jina-v2-de") { tokenizer_pre == "jina-v2-de" ||
tokenizer_pre == "a.x-4.0") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
} else if ( } else if (
tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v1-en" ||
@ -1656,6 +1661,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "seed-coder") { tokenizer_pre == "seed-coder") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
clean_spaces = false; clean_spaces = false;
} else if (
tokenizer_pre == "hunyuan") {
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
clean_spaces = false;
} else { } else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
} }
@ -1839,6 +1848,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<EOT>" || t.first == "<EOT>"
|| t.first == "_<EOT>" || t.first == "_<EOT>"
|| t.first == "<end▁of▁sentence>" // DeepSeek || t.first == "<end▁of▁sentence>" // DeepSeek
|| t.first == "<end_of_utterance>" // smoldocling
) { ) {
special_eot_id = t.second; special_eot_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@ -1998,6 +2008,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<EOT>" || t.first == "<EOT>"
|| t.first == "_<EOT>" || t.first == "_<EOT>"
|| t.first == "<|end_of_text|>" || t.first == "<|end_of_text|>"
|| t.first == "<end_of_utterance>" // smoldocling
) { ) {
special_eog_ids.insert(t.second); special_eog_ids.insert(t.second);
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {

View File

@ -6,6 +6,47 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
// pre-tokenization types
enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
};
struct LLM_KV; struct LLM_KV;
struct llama_model_loader; struct llama_model_loader;

View File

@ -79,46 +79,6 @@ extern "C" {
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
}; };
// pre-tokenization types
enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
};
enum llama_rope_type { enum llama_rope_type {
LLAMA_ROPE_TYPE_NONE = -1, LLAMA_ROPE_TYPE_NONE = -1,
LLAMA_ROPE_TYPE_NORM = 0, LLAMA_ROPE_TYPE_NORM = 0,