From bed5ad69dd6a45d9be46a1d8d09e760f438aa783 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Sep 2023 19:50:34 +0300 Subject: [PATCH] whisper : allocate encoder and decoder using ggml-alloc --- whisper.cpp | 100 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 12 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 17785739..91818115 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -606,11 +606,17 @@ struct whisper_state { // memory buffers used by encode / decode contexts std::vector buf_compute; + // ggml-alloc std::vector buf_encode; + std::vector buf_encode_post; std::vector buf_decode; - ggml_allocr * alloc_encode = NULL; - ggml_allocr * alloc_decode = NULL; + ggml_allocr * alloc_encode = NULL; + ggml_allocr * alloc_encode_post = NULL; + ggml_allocr * alloc_decode = NULL; + + // result of the encoder + struct ggml_tensor * embd_enc = NULL; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -701,7 +707,7 @@ static bool kv_cache_init( const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; - const size_t mem_bytes = ggml_type_size(wtype)*n_elements; + const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead()); cache.buf.resize(mem_bytes); @@ -1385,6 +1391,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const int n_layer = hparams.n_audio_layer; const int n_mels = hparams.n_mels; + assert(mel_inp.n_mel == n_mels); struct ggml_init_params params = { @@ -1397,9 +1404,11 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_allocr * alloc = wstate.alloc_encode; + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); - { + if (!ggml_allocr_is_measure(alloc)) { float * dst = (float *) mel->data; memset(dst, 0, ggml_nbytes(mel)); @@ -1689,6 +1698,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder( } #endif + wstate.embd_enc = cur; + //////////////////////////////////////////////////////////////////////////// //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, @@ -1706,8 +1717,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // pre-compute cross-attention memory static struct ggml_cgraph * whisper_build_graph_encoder_post( whisper_context & wctx, - whisper_state & wstate, - struct ggml_tensor * embd_enc) { + whisper_state & wstate) { const auto & model = wctx.model; const auto & hparams = model.hparams; @@ -1725,7 +1735,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post( ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor * cur = embd_enc; + //ggml_allocr * alloc = wstate.alloc_encode_post; + + struct ggml_tensor * cur = wstate.embd_enc; // TODO: hack to disconnect the encoded features from the previous graph cur->op = GGML_OP_NONE; @@ -1826,12 +1838,18 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_allocr * alloc = wstate.alloc_decode; + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); + if (!ggml_allocr_is_measure(alloc)) { + memcpy(embd->data, tokens, N*ggml_element_size(embd)); + } struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; + if (!ggml_allocr_is_measure(alloc)) { + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } } // token encoding + position encoding @@ -2637,8 +2655,54 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); static const size_t tensor_alignment = 32; - state->alloc_encode = ggml_allocr_new_measure(tensor_alignment); - state->alloc_decode = ggml_allocr_new_measure(tensor_alignment); + + state->alloc_encode = ggml_allocr_new_measure(tensor_alignment); + state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment); + state->alloc_decode = ggml_allocr_new_measure(tensor_alignment); + + // encoder allocator + { + ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0); + + const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment; + ggml_allocr_free(state->alloc_encode); + + log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); + + state->buf_encode.resize(alloc_size); + state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment); + } + + // encoder_post allocator + { + ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state); + + const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment; + ggml_allocr_free(state->alloc_encode_post); + + log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); + + state->buf_encode_post.resize(alloc_size); + state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment); + } + + // decoder allocator + { + const auto & hparams = ctx->model.hparams; + + const int n_tokens = hparams.n_text_ctx/2; + const int n_past = hparams.n_text_ctx/2; // TODO: double-check + + ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past); + + const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment; + ggml_allocr_free(state->alloc_decode); + + log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); + + state->buf_decode.resize(alloc_size); + state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment); + } state->rng = std::mt19937(0); @@ -2862,6 +2926,18 @@ void whisper_free_state(struct whisper_state * state) } #endif + if (state->alloc_encode) { + ggml_allocr_free(state->alloc_encode); + } + + if (state->alloc_encode_post) { + ggml_allocr_free(state->alloc_encode_post); + } + + if (state->alloc_decode) { + ggml_allocr_free(state->alloc_decode); + } + delete state; } }