whisper : alternative way to handle the external encoders

This commit is contained in:
Georgi Gerganov 2024-02-12 16:32:26 +02:00
parent 74c260fe34
commit f25edade2b
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -1659,19 +1659,16 @@ static struct ggml_cgraph * whisper_build_graph_conv(
ggml_set_name(cur, "embd_conv");
wstate.embd_conv = cur;
} else {
// keep the "mel" tensor alive - we will use it to store the input data for the external encoders
// TODO: is there a better way to do this
mel = ggml_scale(ctx0, mel, 1.0f);
ggml_build_forward_expand(gf, mel);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
// transform the "mel" tensor to "embd_enc" via a sequence of ggml ops
// these are not actually executed when using external encoder
// necessary only to prepare tensors with the appropriate memory sizes
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); // (conv)
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); // (conv)
cur = ggml_add(ctx0, model.e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); // (cross)
ggml_set_name(cur, "embd_enc");
ggml_set_output(cur);
wstate.embd_enc = cur;
// TODO: without this op, the "embd_enc" tensor ends up being not allocated
// is there a better fix?
cur = ggml_scale(ctx0, cur, 1.0f);
wstate.embd_enc = cur;
}
ggml_build_forward_expand(gf, cur);
@ -1702,14 +1699,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
//ggml_allocr * alloc = wstate.alloc_encode.alloc;
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
//ggml_allocr_alloc(alloc, cur);
//if (!ggml_allocr_is_measure(alloc)) {
// ggml_backend_tensor_copy(wstate.embd_conv, cur);
//}
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
@ -1951,14 +1940,6 @@ static struct ggml_cgraph * whisper_build_graph_cross(
ggml_cgraph * gf = ggml_new_graph(ctx0);
//ggml_allocr * alloc = wstate.alloc_cross.alloc;
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
//ggml_allocr_alloc(alloc, cur);
//if (!ggml_allocr_is_measure(alloc)) {
// ggml_backend_tensor_copy(wstate.embd_enc, cur);
//}
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
const float Kscale = pow(float(n_state) / n_head, -0.25);