mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
whisper : allocate encoder results in dedicated buffer
This commit is contained in:
parent
de4d067f1e
commit
c6c94de43a
43
whisper.cpp
43
whisper.cpp
@ -793,6 +793,9 @@ struct whisper_state {
|
||||
struct ggml_tensor * embd_conv = nullptr;
|
||||
struct ggml_tensor * embd_enc = nullptr;
|
||||
|
||||
ggml_context * ctx_embd = nullptr;
|
||||
ggml_backend_buffer_t buffer_embd = nullptr;
|
||||
|
||||
// helpers for GPU offloading
|
||||
std::vector<float> inp_mel;
|
||||
std::vector<float> inp_mask;
|
||||
@ -1669,15 +1672,9 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
}
|
||||
|
||||
ggml_set_name(cur, "embd_conv");
|
||||
wstate.embd_conv = cur;
|
||||
cur = ggml_cpy(ctx0, cur, wstate.embd_conv);
|
||||
} else {
|
||||
ggml_build_forward_expand(gf, mel);
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
||||
|
||||
ggml_set_name(cur, "embd_enc");
|
||||
wstate.embd_enc = cur;
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
@ -1708,7 +1705,10 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
||||
|
||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
||||
// TODO: this still triggers the assert:
|
||||
//struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
||||
|
||||
struct ggml_tensor * cur = wstate.embd_conv;
|
||||
|
||||
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
|
||||
|
||||
@ -1908,9 +1908,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
model.e_ln_b);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
cur = ggml_cpy(ctx0, cur, wstate.embd_enc);
|
||||
|
||||
wstate.embd_enc = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
//ggml_graph_print(gf);
|
||||
|
||||
@ -1949,7 +1949,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
||||
struct ggml_tensor * cur = wstate.embd_enc;
|
||||
|
||||
const float Kscale = pow(float(n_state) / n_head, -0.25);
|
||||
|
||||
@ -3001,6 +3001,27 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
||||
}
|
||||
|
||||
// encoder results
|
||||
{
|
||||
ggml_init_params init_params = {
|
||||
/* .mem_size */ 2*ggml_tensor_overhead(),
|
||||
/* .mem_buffer */ nullptr,
|
||||
/* .no_alloc */ true,
|
||||
};
|
||||
state->ctx_embd = ggml_init(init_params);
|
||||
|
||||
state->embd_enc = ggml_new_tensor_2d(state->ctx_embd, GGML_TYPE_F32, ctx->model.hparams.n_audio_state, ctx->model.hparams.n_audio_ctx);
|
||||
state->embd_conv = ggml_new_tensor_2d(state->ctx_embd, GGML_TYPE_F32, ctx->model.hparams.n_audio_ctx, ctx->model.hparams.n_audio_state);
|
||||
|
||||
ggml_set_name(state->embd_enc, "embd_enc");
|
||||
ggml_set_name(state->embd_conv, "embd_conv");
|
||||
|
||||
state->buffer_embd = ggml_backend_alloc_ctx_tensors_from_buft(state->ctx_embd, ggml_backend_get_default_buffer_type(ctx->backend));
|
||||
|
||||
WHISPER_LOG_INFO("%s: %s enc results size = %.2f MiB\n", __func__,
|
||||
ggml_backend_buffer_name(state->buffer_embd), ggml_backend_buffer_get_size(state->buffer_embd) / 1e6);
|
||||
}
|
||||
|
||||
#ifdef WHISPER_USE_COREML
|
||||
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user