From af6f67b251dc78acdcd76d5f47df40ca454ec332 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Sep 2023 20:09:17 +0300 Subject: [PATCH] whisper : ggml-alloc is now supported --- whisper.cpp | 109 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 21 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 91818115..6ceea676 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -120,6 +120,21 @@ static void byteswap_tensor(ggml_tensor * tensor) { //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 +// +// ggml helpers +// + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + // available whisper models enum e_model { MODEL_UNKNOWN, @@ -606,6 +621,9 @@ struct whisper_state { // memory buffers used by encode / decode contexts std::vector buf_compute; + // reusable buffer for `struct ggml_graph_plan.work_data` + std::vector work_buffer; + // ggml-alloc std::vector buf_encode; std::vector buf_encode_post; @@ -1407,6 +1425,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_allocr * alloc = wstate.alloc_encode; struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + ggml_allocr_alloc(alloc, mel); + assert(mel->type == GGML_TYPE_F32); if (!ggml_allocr_is_measure(alloc)) { float * dst = (float *) mel->data; @@ -1796,6 +1816,32 @@ static bool whisper_encode_internal( const int n_threads) { const int64_t t_start_us = ggml_time_us(); + // encoder + { + auto & alloc = wstate.alloc_encode; + + ggml_allocr_reset(alloc); + + ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate, mel_offset); + + ggml_allocr_alloc_graph(alloc, gf); + + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + } + + // encoder_post + { + auto & alloc = wstate.alloc_encode_post; + + ggml_allocr_reset(alloc); + + ggml_cgraph * gf = whisper_build_graph_encoder_post(wctx, wstate); + + ggml_allocr_alloc_graph(alloc, gf); + + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + } + // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); wstate.t_encode_us += ggml_time_us() - t_start_us; @@ -1841,11 +1887,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_allocr * alloc = wstate.alloc_decode; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(alloc, embd); + if (!ggml_allocr_is_measure(alloc)) { memcpy(embd->data, tokens, N*ggml_element_size(embd)); } struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(alloc, position); + if (!ggml_allocr_is_measure(alloc)) { for (int i = 0; i < N; ++i) { ((int32_t *) position->data)[i] = n_past + i; @@ -2162,33 +2212,51 @@ static bool whisper_decode_internal( const int n_tokens, const int n_past, const int n_threads) { - //const int64_t t_start_us = ggml_time_us(); + const int64_t t_start_us = ggml_time_us(); - //auto & logits_out = wstate.logits; + const auto & model = wctx.model; + const auto & hparams = model.hparams; - //const int n_vocab = hparams.n_vocab; + const int n_vocab = hparams.n_vocab; - // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + auto & logits_out = wstate.logits; - //// extract logits for all N tokens - ////logits_out.resize(N*n_vocab); - ////memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + struct ggml_tensor * logits; - //// extract logits only for the last token - //logits_out.resize(n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + // decoder + { + auto & alloc = wstate.alloc_encode; - //if (N > 1) { - // //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // // ggml_used_mem(ctx0)/1024.0/1024.0, - // // wstate.get_buf_max_mem(0)/1024.0/1024.0, - // // wstate.get_buf_max_mem(1)/1024.0/1024.0, - // // wstate.get_buf_max_mem(2)/1024.0/1024.0, - // // wstate.get_buf_max_mem(3)/1024.0/1024.0); - //} + ggml_allocr_reset(alloc); - //wstate.t_decode_us += ggml_time_us() - t_start_us; - //wstate.n_decode++; + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past); + + ggml_allocr_alloc_graph(alloc, gf); + + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + + logits = gf->nodes[gf->n_nodes - 1]; + } + + // extract logits for all N tokens + //logits_out.resize(N*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + + // extract logits only for the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + + if (n_tokens > 1) { + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wstate.get_buf_max_mem(0)/1024.0/1024.0, + // wstate.get_buf_max_mem(1)/1024.0/1024.0, + // wstate.get_buf_max_mem(2)/1024.0/1024.0, + // wstate.get_buf_max_mem(3)/1024.0/1024.0); + } + + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; return true; } @@ -2759,7 +2827,6 @@ int whisper_ctx_init_openvino_encoder( } struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { - log("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary);