whisper : remove ggml_repeat in the encoder

This commit is contained in:
Georgi Gerganov
2023-09-12 20:34:32 +03:00
parent cd476375b4
commit ec9a7db74c

View File

@ -1530,10 +1530,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// cur = ln_0_w*cur + ln_0_b // cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
ggml_mul(ctx0, ggml_mul(ctx0, cur, layer.attn_ln_0_w),
ggml_repeat(ctx0, layer.attn_ln_0_w, cur), layer.attn_ln_0_b);
cur),
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
} }
// self-attention // self-attention
@ -1542,11 +1540,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.attn_q_w, layer.attn_q_w,
cur); cur);
Qcur = ggml_add(ctx0, Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
ggml_repeat(ctx0,
layer.attn_q_b,
Qcur),
Qcur);
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
@ -1561,11 +1555,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.attn_v_w, layer.attn_v_w,
cur); cur);
Vcur = ggml_add(ctx0, Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b);
ggml_repeat(ctx0,
layer.attn_v_b,
Vcur),
Vcur);
// ------ // ------
@ -1641,9 +1631,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.attn_ln_1_w, layer.attn_ln_1_w,
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0, cur, layer.attn_ln_1_b);
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
} }
// add the input // add the input
@ -1659,10 +1647,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// cur = mlp_ln_w*cur + mlp_ln_b // cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
ggml_mul(ctx0, ggml_mul(ctx0, cur, layer.mlp_ln_w),
ggml_repeat(ctx0, layer.mlp_ln_w, cur), layer.mlp_ln_b);
cur),
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
} }
#ifdef WHISPER_USE_FLASH_FF #ifdef WHISPER_USE_FLASH_FF
@ -1675,9 +1661,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.mlp_0_w, layer.mlp_0_w,
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0, cur, layer.mlp_0_b);
ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
// GELU activation // GELU activation
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
@ -1687,9 +1671,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.mlp_1_w, layer.mlp_1_w,
cur); cur);
cur = ggml_add(ctx0, cur = ggml_add(ctx0, cur, layer.mlp_1_b);
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
#endif #endif
} }
@ -1704,10 +1686,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// cur = ln_f_g*cur + ln_f_b // cur = ln_f_g*cur + ln_f_b
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
ggml_mul(ctx0, ggml_mul(ctx0, cur, model.e_ln_w),
ggml_repeat(ctx0, model.e_ln_w, cur), model.e_ln_b);
cur),
ggml_repeat(ctx0, model.e_ln_b, cur));
} }
} }
#ifdef WHISPER_USE_COREML #ifdef WHISPER_USE_COREML