whisper : remove ggml_repeat in the encoder

This commit is contained in:
Georgi Gerganov
2023-09-12 20:34:32 +03:00
parent cd476375b4
commit ec9a7db74c

View File

@ -1530,10 +1530,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
cur),
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
ggml_mul(ctx0, cur, layer.attn_ln_0_w),
layer.attn_ln_0_b);
}
// self-attention
@ -1542,11 +1540,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.attn_q_w,
cur);
Qcur = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.attn_q_b,
Qcur),
Qcur);
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
@ -1561,11 +1555,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.attn_v_w,
cur);
Vcur = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.attn_v_b,
Vcur),
Vcur);
Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b);
// ------
@ -1641,9 +1631,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.attn_ln_1_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
cur = ggml_add(ctx0, cur, layer.attn_ln_1_b);
}
// add the input
@ -1659,10 +1647,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
cur),
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
ggml_mul(ctx0, cur, layer.mlp_ln_w),
layer.mlp_ln_b);
}
#ifdef WHISPER_USE_FLASH_FF
@ -1675,9 +1661,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.mlp_0_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
cur = ggml_add(ctx0, cur, layer.mlp_0_b);
// GELU activation
cur = ggml_gelu(ctx0, cur);
@ -1687,9 +1671,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.mlp_1_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
#endif
}
@ -1704,10 +1686,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// cur = ln_f_g*cur + ln_f_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.e_ln_w, cur),
cur),
ggml_repeat(ctx0, model.e_ln_b, cur));
ggml_mul(ctx0, cur, model.e_ln_w),
model.e_ln_b);
}
}
#ifdef WHISPER_USE_COREML