From b6f09669a2a09097a3c4e42f6caf0f53e83559c3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Sep 2023 12:51:52 +0300 Subject: [PATCH] whisper : factor out alloc init in a function --- whisper.cpp | 123 +++++++++++++++++++++------------------------------- 1 file changed, 49 insertions(+), 74 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 6fec19ee..8cada57d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -604,6 +604,31 @@ struct whisper_allocr { std::vector data; }; +static size_t whisper_allocr_size(struct whisper_allocr & allocr) { + return allocr.meta.size() + allocr.data.size(); +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function && get_graph) { + const int tensor_alignment = 32; + + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; + auto & data = allocr.data; + + meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + + alloc = ggml_allocr_new_measure(tensor_alignment); + + const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment; + + ggml_allocr_free(alloc); + + data.resize(alloc_size); + + alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); +} + static void whisper_allocr_free(struct whisper_allocr & allocr) { if (allocr.alloc) { ggml_allocr_free(allocr.alloc); @@ -2799,100 +2824,50 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->decoders[0].logits.reserve (ctx->vocab.n_vocab); state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); - static const size_t tensor_alignment = 32; - // conv allocator { - auto & alloc = state->alloc_conv.alloc; - auto & meta = state->alloc_conv.meta; - auto & data = state->alloc_conv.data; + whisper_allocr_graph_init(state->alloc_conv, + [&]() { + return whisper_build_graph_conv(*ctx, *state, 0); + }); - meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - - alloc = ggml_allocr_new_measure(tensor_alignment); - - ggml_cgraph * gf = whisper_build_graph_conv(*ctx, *state, 0); - - const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; - - ggml_allocr_free(alloc); - - log("%s: compute buffer (conv) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0); - - data.resize(alloc_size); - alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); + log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); } // encoder allocator if (!whisper_encode_external(*state)) { - auto & alloc = state->alloc_encode.alloc; - auto & meta = state->alloc_encode.meta; - auto & data = state->alloc_encode.data; + whisper_allocr_graph_init(state->alloc_encode, + [&]() { + return whisper_build_graph_encoder(*ctx, *state); + }); - meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - - alloc = ggml_allocr_new_measure(tensor_alignment); - - ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state); - - const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; - - ggml_allocr_free(alloc); - - log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0); - - data.resize(alloc_size); - alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); + log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); } // cross allocator { - auto & alloc = state->alloc_cross.alloc; - auto & meta = state->alloc_cross.meta; - auto & data = state->alloc_cross.data; + whisper_allocr_graph_init(state->alloc_cross, + [&]() { + return whisper_build_graph_cross(*ctx, *state); + }); - meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - - alloc = ggml_allocr_new_measure(tensor_alignment); - - ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state); - - const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; - - ggml_allocr_free(alloc); - - log("%s: compute buffer (cross) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0); - - data.resize(alloc_size); - alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); + log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); } // decoder allocator { - auto & alloc = state->alloc_decode.alloc; - auto & meta = state->alloc_decode.meta; - auto & data = state->alloc_decode.data; + whisper_allocr_graph_init(state->alloc_decode, + [&]() { + const auto & hparams = ctx->model.hparams; - meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + // TODO: make sure this is the worst-case scenario + const int n_tokens = hparams.n_text_ctx; + const int n_past = 0; - alloc = ggml_allocr_new_measure(tensor_alignment); + return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); + }); - const auto & hparams = ctx->model.hparams; - - // TODO: make sure this is the worst-case scenario - const int n_tokens = hparams.n_text_ctx; - const int n_past = 0; - - ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); - - const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; - - ggml_allocr_free(alloc); - - log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0); - - data.resize(alloc_size); - alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); + log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); } #ifdef GGML_USE_METAL