From 254b687239e6fde1be03c40a3999fb4968414977 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Sep 2023 11:58:19 +0300 Subject: [PATCH] whisper : add whisper_allocr to wrap ggml_allocr --- whisper.cpp | 128 +++++++++++++++++++++++++--------------------------- 1 file changed, 61 insertions(+), 67 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index fc70bf95..6fec19ee 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -596,6 +596,21 @@ struct kv_buf { std::vector v; }; +// ggml_allocr wrapper for whisper usage +struct whisper_allocr { + ggml_allocr * alloc = nullptr; + + std::vector meta; + std::vector data; +}; + +static void whisper_allocr_free(struct whisper_allocr & allocr) { + if (allocr.alloc) { + ggml_allocr_free(allocr.alloc); + allocr.alloc = nullptr; + } +} + struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; @@ -622,25 +637,12 @@ struct whisper_state { std::vector work_buffer; // ggml-alloc: - // - stores meta info about the intermediate tensors into the `meta_*` buffers - // - stores the actual tensor data into the `data_*` buffers - - ggml_allocr * alloc_conv = nullptr; - ggml_allocr * alloc_encode = nullptr; - ggml_allocr * alloc_cross = nullptr; - ggml_allocr * alloc_decode = nullptr; - - // meta data - std::vector meta_conv; - std::vector meta_encode; - std::vector meta_cross; - std::vector meta_decode; - - // tensor data - std::vector data_conv; - std::vector data_encode; - std::vector data_cross; - std::vector data_decode; + // - stores meta info about the intermediate tensors into the `meta` buffers + // - stores the actual tensor data into the `data` buffers + whisper_allocr alloc_conv; + whisper_allocr alloc_encode; + whisper_allocr alloc_cross; + whisper_allocr alloc_decode; // result of the encoder struct ggml_tensor * embd_conv = nullptr; @@ -1437,8 +1439,8 @@ static struct ggml_cgraph * whisper_build_graph_conv( const int n_mels = hparams.n_mels; struct ggml_init_params params = { - /*.mem_size =*/ wstate.meta_conv.size(), - /*.mem_buffer =*/ wstate.meta_conv.data(), + /*.mem_size =*/ wstate.alloc_conv.meta.size(), + /*.mem_buffer =*/ wstate.alloc_conv.meta.data(), /*.no_alloc =*/ true, }; @@ -1446,7 +1448,7 @@ static struct ggml_cgraph * whisper_build_graph_conv( ggml_cgraph * gf = ggml_new_graph(ctx0); - ggml_allocr * alloc = wstate.alloc_conv; + ggml_allocr * alloc = wstate.alloc_conv.alloc; struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); ggml_allocr_alloc(alloc, mel); @@ -1533,8 +1535,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const int n_layer = hparams.n_audio_layer; struct ggml_init_params params = { - /*.mem_size =*/ wstate.meta_encode.size(), - /*.mem_buffer =*/ wstate.meta_encode.data(), + /*.mem_size =*/ wstate.alloc_encode.meta.size(), + /*.mem_buffer =*/ wstate.alloc_encode.meta.data(), /*.no_alloc =*/ true, }; @@ -1542,7 +1544,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_cgraph * gf = ggml_new_graph(ctx0); - ggml_allocr * alloc = wstate.alloc_encode; + ggml_allocr * alloc = wstate.alloc_encode.alloc; struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(alloc, KQscale); @@ -1782,8 +1784,8 @@ static struct ggml_cgraph * whisper_build_graph_cross( const int n_head = hparams.n_audio_head; struct ggml_init_params params = { - /*.mem_size =*/ wstate.meta_cross.size(), - /*.mem_buffer =*/ wstate.meta_cross.data(), + /*.mem_size =*/ wstate.alloc_cross.meta.size(), + /*.mem_buffer =*/ wstate.alloc_cross.meta.data(), /*.no_alloc =*/ true, }; @@ -1791,7 +1793,7 @@ static struct ggml_cgraph * whisper_build_graph_cross( ggml_cgraph * gf = ggml_new_graph(ctx0); - ggml_allocr * alloc = wstate.alloc_cross; + ggml_allocr * alloc = wstate.alloc_cross.alloc; struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); @@ -1859,7 +1861,7 @@ static bool whisper_encode_internal( // conv { - auto & alloc = wstate.alloc_conv; + auto & alloc = wstate.alloc_conv.alloc; ggml_allocr_reset(alloc); @@ -1874,7 +1876,7 @@ static bool whisper_encode_internal( // encoder if (!whisper_encode_external(wstate)) { - auto & alloc = wstate.alloc_encode; + auto & alloc = wstate.alloc_encode.alloc; ggml_allocr_reset(alloc); @@ -1896,7 +1898,7 @@ static bool whisper_encode_internal( // cross { - auto & alloc = wstate.alloc_cross; + auto & alloc = wstate.alloc_cross.alloc; ggml_allocr_reset(alloc); @@ -1949,8 +1951,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params = { - /*.mem_size =*/ wstate.meta_decode.size(), - /*.mem_buffer =*/ wstate.meta_decode.data(), + /*.mem_size =*/ wstate.alloc_decode.meta.size(), + /*.mem_buffer =*/ wstate.alloc_decode.meta.data(), /*.no_alloc =*/ true, }; @@ -1958,7 +1960,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_cgraph * gf = ggml_new_graph(ctx0); - ggml_allocr * alloc = wstate.alloc_decode; + ggml_allocr * alloc = wstate.alloc_decode.alloc; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); ggml_allocr_alloc(alloc, embd); @@ -2292,7 +2294,7 @@ static bool whisper_decode_internal( // decoder { - auto & alloc = wstate.alloc_decode; + auto & alloc = wstate.alloc_decode.alloc; ggml_allocr_reset(alloc); @@ -2801,9 +2803,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // conv allocator { - auto & alloc = state->alloc_conv; - auto & meta = state->meta_conv; - auto & data = state->data_conv; + auto & alloc = state->alloc_conv.alloc; + auto & meta = state->alloc_conv.meta; + auto & data = state->alloc_conv.data; meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); @@ -2823,9 +2825,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // encoder allocator if (!whisper_encode_external(*state)) { - auto & alloc = state->alloc_encode; - auto & meta = state->meta_encode; - auto & data = state->data_encode; + auto & alloc = state->alloc_encode.alloc; + auto & meta = state->alloc_encode.meta; + auto & data = state->alloc_encode.data; meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); @@ -2845,9 +2847,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // cross allocator { - auto & alloc = state->alloc_cross; - auto & meta = state->meta_cross; - auto & data = state->data_cross; + auto & alloc = state->alloc_cross.alloc; + auto & meta = state->alloc_cross.meta; + auto & data = state->alloc_cross.data; meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); @@ -2867,9 +2869,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // decoder allocator { - auto & alloc = state->alloc_decode; - auto & meta = state->meta_decode; - auto & data = state->data_decode; + auto & alloc = state->alloc_decode.alloc; + auto & meta = state->alloc_decode.meta; + auto & data = state->alloc_decode.data; meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); @@ -2933,19 +2935,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->meta_conv.data(), state->meta_conv.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->data_conv.data(), state->data_conv.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); - // TODO: handle multiple decoders WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); #undef WHISPER_METAL_CHECK_BUF #endif @@ -3171,17 +3172,10 @@ void whisper_free_state(struct whisper_state * state) } #endif - if (state->alloc_encode) { - ggml_allocr_free(state->alloc_encode); - } - - if (state->alloc_cross) { - ggml_allocr_free(state->alloc_cross); - } - - if (state->alloc_decode) { - ggml_allocr_free(state->alloc_decode); - } + whisper_allocr_free(state->alloc_conv); + whisper_allocr_free(state->alloc_decode); + whisper_allocr_free(state->alloc_cross); + whisper_allocr_free(state->alloc_encode); delete state; }