mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-09 23:28:59 +02:00
talk-llama : sync llama.cpp
This commit is contained in:
@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_475M: return "475M";
|
||||
case LLM_TYPE_770M: return "770M";
|
||||
case LLM_TYPE_780M: return "780M";
|
||||
case LLM_TYPE_0_3B: return "0.3B";
|
||||
case LLM_TYPE_0_5B: return "0.5B";
|
||||
case LLM_TYPE_0_6B: return "0.6B";
|
||||
case LLM_TYPE_1B: return "1B";
|
||||
@ -103,6 +104,8 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
||||
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
||||
case LLM_TYPE_E2B: return "E2B";
|
||||
case LLM_TYPE_E4B: return "E4B";
|
||||
default: return "?B";
|
||||
}
|
||||
}
|
||||
@ -1017,6 +1020,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(5);
|
||||
|
||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||
hparams.f_attention_scale = 1.0f;
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 30: type = LLM_TYPE_E2B; break;
|
||||
case 35: type = LLM_TYPE_E4B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@ -1484,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 18: type = LLM_TYPE_0_3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
@ -2950,6 +2979,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
{
|
||||
const int64_t n_altup = hparams.n_altup;
|
||||
const int64_t laurel_rank = hparams.laurel_rank;
|
||||
const int64_t n_embd_altup = hparams.n_embd_altup;
|
||||
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
||||
|
||||
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// altup & laurel
|
||||
layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
|
||||
layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
|
||||
layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
|
||||
layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
|
||||
layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
|
||||
layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
|
||||
layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
|
||||
layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
|
||||
layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@ -4268,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// optional bias tensors
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
@ -8980,6 +9099,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
ggml_cgraph * gf;
|
||||
|
||||
const int64_t n_embd_head;
|
||||
const int64_t n_embd_altup;
|
||||
const int64_t n_altup;
|
||||
const int i_altup_act;
|
||||
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
|
||||
const int n_layer_sparsity = 10; // number of layers using activation sparsity
|
||||
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
|
||||
|
||||
ggml_tensor * one; // containing single element 1.0f
|
||||
|
||||
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
|
||||
: llm_graph_context(params),
|
||||
model(model),
|
||||
gf(gf),
|
||||
n_embd_head(model.hparams.n_embd_head_k),
|
||||
n_embd_altup(model.hparams.n_embd_altup),
|
||||
n_altup(model.hparams.n_altup),
|
||||
i_altup_act(model.hparams.i_altup_act) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
// TODO: remove this when ggml_scale_add is implemented
|
||||
one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
{
|
||||
auto inp = std::make_unique<llm_graph_input_one>();
|
||||
inp->one = one;
|
||||
res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
|
||||
if (ubatch.token) {
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
|
||||
// inpL now has only 1 altup, project it to the rest of the altups
|
||||
// these "added" altups will be concat to the last dim of inpL
|
||||
{
|
||||
ggml_tensor * target_magnitude = calc_magnitude(inpL);
|
||||
ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
|
||||
ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
|
||||
ggml_tensor * new_magnitude = calc_magnitude(altup_added);
|
||||
altup_added = ggml_div(ctx0,
|
||||
ggml_mul(ctx0, altup_added, target_magnitude),
|
||||
new_magnitude);
|
||||
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
|
||||
cb(inpL, "inp_stacked", -1);
|
||||
}
|
||||
|
||||
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
||||
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
|
||||
const bool has_kv = (il < n_layer_kv);
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
|
||||
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
|
||||
|
||||
// predicted value will go through self-attention and laurel
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
|
||||
cur = active_prediction;
|
||||
cb(cur, "active_prediction", il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// laurel
|
||||
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
|
||||
|
||||
// self-attention
|
||||
if (has_kv) {
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
|
||||
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
cb(Vcur, "Vcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cb(Qcur, "Qcur_pos", il);
|
||||
cb(Kcur, "Kcur_pos", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
} else {
|
||||
// no KV layers
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_pos", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
}
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
|
||||
cb(cur, "attn_gated", il);
|
||||
|
||||
ggml_tensor * attn_laurel = ggml_scale(ctx0,
|
||||
ggml_add(ctx0, cur, laurel_out),
|
||||
1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
|
||||
cb(attn_laurel, "attn_laurel", il);
|
||||
|
||||
cur = build_norm(attn_laurel,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
|
||||
ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
|
||||
|
||||
if (il < n_layer_sparsity) {
|
||||
// apply activation sparsity
|
||||
gate_proj = gaussian_topk(gate_proj);
|
||||
}
|
||||
gate_proj = ggml_gelu(ctx0, gate_proj);
|
||||
|
||||
cur = ggml_mul(ctx0, up_proj, gate_proj);
|
||||
cur = build_lora_mm(model.layers[il].ffn_down, cur);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
|
||||
cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
|
||||
|
||||
ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
|
||||
|
||||
ggml_tensor * first_prediction; // [n_embd, n_tokens]
|
||||
{
|
||||
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
|
||||
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
|
||||
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_gated", il);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_scaled", il);
|
||||
|
||||
first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
|
||||
first_prediction = build_norm(first_prediction,
|
||||
model.layers[il].per_layer_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(first_prediction, "first_prediction_out", il);
|
||||
}
|
||||
|
||||
// equivalent to python code: corrected_predictions[1:] += first_prediction
|
||||
{
|
||||
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
|
||||
ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
|
||||
ggml_row_size(corrected->type, n_embd),
|
||||
ggml_row_size(corrected->type, n_embd*n_tokens),
|
||||
n_embd*n_tokens*ggml_element_size(corrected));
|
||||
ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
|
||||
corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
|
||||
}
|
||||
|
||||
cur = corrected; // [n_embd, n_tokens, n_altup]
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL; // [n_embd, n_tokens, n_altup]
|
||||
|
||||
// cur now has multiple altup(s), we want to merge them back to 1 altup
|
||||
{
|
||||
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
|
||||
// do a view to skip the first slice (active altup)
|
||||
ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
|
||||
ggml_row_size(cur->type, n_embd),
|
||||
ggml_row_size(cur->type, n_embd*n_tokens),
|
||||
n_embd*n_tokens*ggml_element_size(cur));
|
||||
ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
|
||||
ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
|
||||
altup_unembd = ggml_div(ctx0,
|
||||
ggml_mul(ctx0, altup_unembd, target_magnitude),
|
||||
new_magnitude);
|
||||
cb(altup_unembd, "altup_unembd", -1);
|
||||
|
||||
// equivalent to torch.mean(hidden_states, dim=0)
|
||||
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
|
||||
for (int i = 0; i < n_altup - 1; ++i) {
|
||||
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
|
||||
}
|
||||
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
|
||||
cb(cur, "unembd_merged", -1);
|
||||
}
|
||||
|
||||
// cur now has shape: [n_embd, n_tokens]
|
||||
|
||||
// TODO: move this to right after the last KV layer
|
||||
{
|
||||
// skip computing output for unused tokens
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
{
|
||||
// final logit soft-capping
|
||||
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
||||
}
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
ggml_tensor * calc_magnitude(ggml_tensor * x) {
|
||||
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int)x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
|
||||
ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * get_per_layer_inputs() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
} else {
|
||||
GGML_ABORT("TODO: support embd input");
|
||||
}
|
||||
res->add_input(std::move(inp));
|
||||
return inp_per_layer;
|
||||
}
|
||||
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// output shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj,
|
||||
model.per_layer_proj_norm, NULL,
|
||||
LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
// permute to shape: [n_embd_altup, n_tokens, n_layer]
|
||||
inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
|
||||
return inp_per_layer;
|
||||
}
|
||||
|
||||
// input cur shape: [n_altup, n_tokens]
|
||||
// output shape: [n_altup, n_tokens]
|
||||
ggml_tensor * laurel(ggml_tensor * cur, int il) {
|
||||
ggml_tensor * tmp = cur;
|
||||
tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
|
||||
tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
|
||||
tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
tmp = ggml_add(ctx0, tmp, cur);
|
||||
cb(tmp, "laurel_out", il);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// input x shape: [n_embd, n_tokens]
|
||||
// output shape: [n_embd, n_tokens]
|
||||
ggml_tensor * gaussian_topk(ggml_tensor * x) {
|
||||
ggml_tensor * mean = ggml_mean(ctx0, x);
|
||||
ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
|
||||
ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
|
||||
1.0f / (float)(x->ne[0] - 1)
|
||||
));
|
||||
ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
|
||||
return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
|
||||
}
|
||||
|
||||
//
|
||||
// altup functions
|
||||
//
|
||||
|
||||
// equivalent to compute_router_modalities() in python code
|
||||
// input x shape: [n_embd, n_tokens]
|
||||
// output shape: [n_altup, n_tokens]
|
||||
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
|
||||
ggml_tensor * router_inputs = build_norm(x,
|
||||
model.layers[il].altup_router_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
|
||||
// router_input_scale
|
||||
router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
|
||||
|
||||
ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
|
||||
return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
|
||||
}
|
||||
|
||||
// input cur shape: [n_embd, n_tokens, n_altup]
|
||||
// output shape: [n_embd, n_tokens, n_altup]
|
||||
ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
|
||||
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
|
||||
cb(all_coefs, "all_coefs", il);
|
||||
// first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
|
||||
all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
|
||||
|
||||
// permute to [n_altup, n_embd, n_tokens]
|
||||
ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||
ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
|
||||
|
||||
// final shape must be the same as cur: [n_embd, n_tokens, n_altup]
|
||||
predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
|
||||
predictions = ggml_add(ctx0, predictions, cur);
|
||||
cb(predictions, "predictions", il);
|
||||
|
||||
return predictions;
|
||||
}
|
||||
|
||||
// input predictions shape: [n_embd, n_tokens, n_altup]
|
||||
// input activated shape: [n_embd, n_tokens]
|
||||
// output shape: [n_embd, n_tokens, n_altup]
|
||||
ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
|
||||
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
|
||||
cb(innovation, "innovation", il);
|
||||
|
||||
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
|
||||
all_coefs = ggml_add(ctx0, all_coefs, one);
|
||||
cb(all_coefs, "all_coefs", il);
|
||||
all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
|
||||
all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
|
||||
|
||||
innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
|
||||
ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
|
||||
corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
|
||||
cb(corrected, "corrected", il);
|
||||
|
||||
return corrected;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: move up next to build_starcoder
|
||||
struct llm_build_starcoder2 : public llm_graph_context {
|
||||
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
@ -9171,9 +9726,9 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
ggml_tensor * cur,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
@ -9191,8 +9746,8 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
ggml_tensor * conv = build_rs(
|
||||
@ -11916,7 +12471,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
ggml_tensor * x_prev,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -11926,7 +12481,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
const auto n_head = n_embd / head_size;
|
||||
const auto n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
@ -12038,7 +12593,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
}
|
||||
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_s_l(il),
|
||||
inp, gf, mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output;
|
||||
@ -12057,9 +12612,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_state->get_s_l(il),
|
||||
mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s() * n_seqs,
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
@ -12313,7 +12868,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -12322,7 +12877,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
const auto head_count = n_embd / head_size;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
@ -12393,7 +12948,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_s_l(il),
|
||||
inp, gf, mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
@ -12407,9 +12962,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_state->get_s_l(il),
|
||||
mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s() * n_seqs,
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
@ -13613,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_ernie4_5 : public llm_graph_context {
|
||||
llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_arcee : public llm_graph_context {
|
||||
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
@ -13974,6 +14659,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
|
||||
@ -14119,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
{
|
||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@ -14270,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_BAILINGMOE:
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_ARCEE:
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
@ -14295,6 +14989,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
@ -14377,7 +15072,7 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
|
||||
// do not extend this list unless absolutely necessary
|
||||
// Mistral-Small-2503 does not have built-in chat template
|
||||
llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
|
||||
if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
|
||||
if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
|
||||
return "mistral-v7-tekken";
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user