From 949ab6328d0e492abc49072ff8c866580213c56e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Sep 2023 19:23:06 +0300 Subject: [PATCH] whisper : factor out graph builds --- Makefile | 5 + whisper.cpp | 546 +++++++++++++++++----------------------------------- 2 files changed, 182 insertions(+), 369 deletions(-) diff --git a/Makefile b/Makefile index fb0d8d5e..605d8d48 100644 --- a/Makefile +++ b/Makefile @@ -295,6 +295,11 @@ $(info ) ggml.o: ggml.c ggml.h ggml-cuda.h $(CC) $(CFLAGS) -c $< -o $@ +ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h + $(CC) $(CFLAGS) -c $< -o $@ + +WHISPER_OBJ += ggml-alloc.o + whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h $(CXX) $(CXXFLAGS) -c $< -o $@ diff --git a/whisper.cpp b/whisper.cpp index e3a6900f..17785739 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -12,6 +12,7 @@ #endif #include "ggml.h" +#include "ggml-alloc.h" #include #include @@ -119,9 +120,6 @@ static void byteswap_tensor(ggml_tensor * tensor) { //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 -#define WHISPER_USE_SCRATCH -#define WHISPER_MAX_SCRATCH_BUFFERS 16 - // available whisper models enum e_model { MODEL_UNKNOWN, @@ -236,38 +234,7 @@ static const std::map> g_lang = { static const size_t MB = 1ull*1024*1024; -static const std::map MEM_REQ_SCRATCH0 = { - { MODEL_TINY, 62ull*MB }, - { MODEL_BASE, 80ull*MB }, - { MODEL_SMALL, 120ull*MB }, - { MODEL_MEDIUM, 158ull*MB }, - { MODEL_LARGE, 198ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH1 = { - { MODEL_TINY, 18ull*MB }, - { MODEL_BASE, 24ull*MB }, - { MODEL_SMALL, 36ull*MB }, - { MODEL_MEDIUM, 48ull*MB }, - { MODEL_LARGE, 60ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH2 = { - { MODEL_TINY, 4ull*MB }, - { MODEL_BASE, 4ull*MB }, - { MODEL_SMALL, 6ull*MB }, - { MODEL_MEDIUM, 7ull*MB }, - { MODEL_LARGE, 9ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH3 = { - { MODEL_TINY, 4ull*MB }, - { MODEL_BASE, 4ull*MB }, - { MODEL_SMALL, 6ull*MB }, - { MODEL_MEDIUM, 7ull*MB }, - { MODEL_LARGE, 9ull*MB }, -}; - +// TODO: avoid using GGUF static const std::map> MEM_REQ_MODEL = { { GGML_TYPE_F32, { @@ -334,38 +301,6 @@ static const std::map> MEM_REQ_MODEL = { }, }; -static const std::map MEM_REQ_KV_SELF = { - { MODEL_TINY, 3ull*MB }, - { MODEL_BASE, 6ull*MB }, - { MODEL_SMALL, 16ull*MB }, - { MODEL_MEDIUM, 43ull*MB }, - { MODEL_LARGE, 71ull*MB }, -}; - -static const std::map MEM_REQ_KV_CROSS = { - { MODEL_TINY, 9ull*MB }, - { MODEL_BASE, 18ull*MB }, - { MODEL_SMALL, 53ull*MB }, - { MODEL_MEDIUM, 141ull*MB }, - { MODEL_LARGE, 235ull*MB }, -}; - -static const std::map MEM_REQ_ENCODE = { - { MODEL_TINY, 30ull*MB }, - { MODEL_BASE, 38ull*MB }, - { MODEL_SMALL, 56ull*MB }, - { MODEL_MEDIUM, 74ull*MB }, - { MODEL_LARGE, 94ull*MB }, -}; - -static const std::map MEM_REQ_DECODE = { - { MODEL_TINY, 3ull*MB }, - { MODEL_BASE, 5ull*MB }, - { MODEL_SMALL, 10ull*MB }, - { MODEL_MEDIUM, 18ull*MB }, - { MODEL_LARGE, 27ull*MB }, -}; - struct whisper_mel { int n_len; int n_len_org; @@ -670,10 +605,12 @@ struct whisper_state { // memory buffers used by encode / decode contexts std::vector buf_compute; - std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; - int buf_last = 0; - size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; + std::vector buf_encode; + std::vector buf_decode; + + ggml_allocr * alloc_encode = NULL; + ggml_allocr * alloc_decode = NULL; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -709,37 +646,6 @@ struct whisper_state { // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx = 0; // 0 - use default - - void use_buf(struct ggml_context * ctx, int i) { -#if defined(WHISPER_USE_SCRATCH) - size_t last_size = 0; - - if (i == -1) { - last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); - } else { - auto & buf = buf_scratch[i]; - last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); - } - - if (buf_last >= 0) { - buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); - } - - buf_last = i; -#else - (void) i; - (void) ctx; -#endif - } - - size_t get_buf_max_mem(int i) const { -#if defined(WHISPER_USE_SCRATCH) - return buf_max_size[i]; -#else - (void) i; - return 0; -#endif - } }; struct whisper_context { @@ -786,10 +692,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) { static bool kv_cache_init( const struct whisper_hparams & hparams, - const size_t mem_bytes, struct whisper_kv_cache & cache, ggml_type wtype, int n_ctx) { + const int64_t n_text_state = hparams.n_text_state; + const int64_t n_text_layer = hparams.n_text_layer; + + const int64_t n_mem = n_text_layer*n_ctx; + const int64_t n_elements = n_text_state*n_mem; + + const size_t mem_bytes = ggml_type_size(wtype)*n_elements; + cache.buf.resize(mem_bytes); struct ggml_init_params params = { @@ -805,12 +718,6 @@ static bool kv_cache_init( return false; } - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - - const int n_mem = n_text_layer*n_ctx; - const int n_elements = n_text_state*n_mem; - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); @@ -953,22 +860,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // print memory requirements { - // this is the total memory required to run the inference - const size_t mem_required = - MEM_REQ_SCRATCH0.at(model.type) + - MEM_REQ_SCRATCH1.at(model.type) + - MEM_REQ_SCRATCH2.at(model.type) + - MEM_REQ_SCRATCH3.at(model.type) + - scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) + - scale*MEM_REQ_KV_CROSS.at(model.type) + - scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); - - // this is the memory required by one decoder - const size_t mem_required_decoder = - scale*MEM_REQ_KV_SELF.at(model.type); - - log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + // TODO + //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); } // initialize all memory buffers @@ -1477,24 +1371,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con return true; } -// evaluate the encoder with the given state -// -// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder -// part of the transformer model and returns the encoded features -// -// - wctx: the model -// - wstate: the state of the encoder -// - n_threads: number of threads to use -// - mel_offset: offset in the mel spectrogram (i.e. audio offset) -// -static bool whisper_encode_internal( +static struct ggml_cgraph * whisper_build_graph_encoder( whisper_context & wctx, whisper_state & wstate, - const int mel_offset, - const int n_threads){ - - const int64_t t_start_us = ggml_time_us(); - + const int mel_offset) { const auto & model = wctx.model; const auto & mel_inp = wstate.mel; const auto & hparams = model.hparams; @@ -1510,12 +1390,12 @@ static bool whisper_encode_internal( struct ggml_init_params params = { /*.mem_size =*/ wstate.buf_compute.size(), /*.mem_buffer =*/ wstate.buf_compute.data(), - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; struct ggml_context * ctx0 = ggml_init(params); - wstate.use_buf(ctx0, 0); + ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); @@ -1550,8 +1430,6 @@ static bool whisper_encode_internal( if (!use_coreml && !use_openvino) { // convolution + gelu { - wstate.use_buf(ctx0, 1); - cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, @@ -1561,8 +1439,6 @@ static bool whisper_encode_internal( cur = ggml_gelu(ctx0, cur); - wstate.use_buf(ctx0, 0); - cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, @@ -1573,8 +1449,6 @@ static bool whisper_encode_internal( cur = ggml_gelu(ctx0, cur); } - wstate.use_buf(ctx0, 3); - // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) //static int iter = -1; @@ -1608,8 +1482,6 @@ static bool whisper_encode_internal( // norm { - wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_0_w*cur + ln_0_b @@ -1622,8 +1494,6 @@ static bool whisper_encode_internal( // self-attention { - wstate.use_buf(ctx0, 1); - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); @@ -1655,8 +1525,6 @@ static bool whisper_encode_internal( // ------ - wstate.use_buf(ctx0, 0); - #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = ggml_permute(ctx0, @@ -1722,8 +1590,6 @@ static bool whisper_encode_internal( #endif struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - wstate.use_buf(ctx0, 1); - cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); @@ -1731,21 +1597,15 @@ static bool whisper_encode_internal( // projection { - wstate.use_buf(ctx0, 0); - cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - wstate.use_buf(ctx0, 1); - cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - wstate.use_buf(ctx0, 2); - // add the input cur = ggml_add(ctx0, cur, inpL); @@ -1755,12 +1615,8 @@ static bool whisper_encode_internal( { // norm { - wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpFF, hparams.eps); - wstate.use_buf(ctx0, 1); - // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -1770,47 +1626,33 @@ static bool whisper_encode_internal( } #ifdef WHISPER_USE_FLASH_FF - wstate.use_buf(ctx0, 0); - cur = ggml_flash_ff(ctx0, ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else - wstate.use_buf(ctx0, 0); - // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - wstate.use_buf(ctx0, 1); - cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - wstate.use_buf(ctx0, 0); - // GELU activation cur = ggml_gelu(ctx0, cur); - wstate.use_buf(ctx0, 1); - // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - wstate.use_buf(ctx0, 0); - cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); #endif } - wstate.use_buf(ctx0, 3); - inpL = ggml_add(ctx0, cur, inpFF); } @@ -1818,12 +1660,8 @@ static bool whisper_encode_internal( // norm { - wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, cur, hparams.eps); - wstate.use_buf(ctx0, 1); - // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -1832,22 +1670,10 @@ static bool whisper_encode_internal( ggml_repeat(ctx0, model.e_ln_b, cur)); } - wstate.use_buf(ctx0, -1); - - // run the computation - { - struct ggml_cgraph gf = {}; - - ggml_build_forward_expand (&gf, cur); - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - - //ggml_graph_print(&gf); - } + ggml_build_forward_expand (gf, cur); } #ifdef WHISPER_USE_COREML else if (use_coreml) { - wstate.use_buf(ctx0, -1); - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data); @@ -1855,8 +1681,6 @@ static bool whisper_encode_internal( #endif #ifdef WHISPER_USE_OPENVINO else if (use_openvino) { - wstate.use_buf(ctx0, -1); - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) { @@ -1865,69 +1689,6 @@ static bool whisper_encode_internal( } #endif - // cur - //{ - // printf("ne0 = %d\n", cur->ne[0]); - // printf("ne1 = %d\n", cur->ne[1]); - // for (int i = 0; i < 10; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("... "); - // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("\n"); - //} - - // pre-compute cross-attention memory - { - struct ggml_cgraph gf = {}; - - // TODO: hack to disconnect the encoded features from the previous graph - cur->op = GGML_OP_NONE; - cur->src[0] = nullptr; - cur->src[1] = nullptr; - - for (int il = 0; il < model.hparams.n_text_layer; ++il) { - auto& layer = model.layers_decoder[il]; - - wstate.use_buf(ctx0, 0); - - struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, - layer.cross_attn_k_w, - cur); - - Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25))); - - wstate.use_buf(ctx0, 1); - - struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, - layer.cross_attn_v_w, - cur); - - Vcross = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_v_b, - Vcross), - Vcross); - - wstate.use_buf(ctx0, -1); - - Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); - - struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, - ( n_ctx)*ggml_element_size(wstate.kv_cross.v), - (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); - - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); - } - - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - //ggml_graph_print(&gf); - } - //////////////////////////////////////////////////////////////////////////// //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, @@ -1939,32 +1700,105 @@ static bool whisper_encode_internal( ggml_free(ctx0); + return gf; +} + +// pre-compute cross-attention memory +static struct ggml_cgraph * whisper_build_graph_encoder_post( + whisper_context & wctx, + whisper_state & wstate, + struct ggml_tensor * embd_enc) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.buf_compute.size(), + /*.mem_buffer =*/ wstate.buf_compute.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur = embd_enc; + + // TODO: hack to disconnect the encoded features from the previous graph + cur->op = GGML_OP_NONE; + cur->src[0] = nullptr; + cur->src[1] = nullptr; + + for (int il = 0; il < model.hparams.n_text_layer; ++il) { + auto & layer = model.layers_decoder[il]; + + struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25))); + + struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.cross_attn_v_b, + Vcross), + Vcross); + + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); + } + + ggml_free(ctx0); + + return gf; +} + +// evaluate the encoder with the given state +// +// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder +// part of the transformer model and returns the encoded features +// +// - wctx: the model +// - wstate: the state of the encoder +// - n_threads: number of threads to use +// - mel_offset: offset in the mel spectrogram (i.e. audio offset) +// +static bool whisper_encode_internal( + whisper_context & wctx, + whisper_state & wstate, + const int mel_offset, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + wstate.t_encode_us += ggml_time_us() - t_start_us; wstate.n_encode++; return true; } -// evaluate the decoder -// -// given text prompt + audio features -> computes the logits for the next token -// -// - model: the model -// - n_threads: number of threads to use -// - tokens: text prompt -// - n_tokens: number of tokens in the prompt -// - n_past: number of past tokens to prefix the prompt with -// -static bool whisper_decode_internal( - whisper_context & wctx, - whisper_state & wstate, - whisper_decoder & decoder, - const whisper_token * tokens, - const int n_tokens, - const int n_past, - const int n_threads) { - const int64_t t_start_us = ggml_time_us(); - +static struct ggml_cgraph * whisper_build_graph_decoder( + whisper_context & wctx, + whisper_state & wstate, + whisper_decoder & decoder, + const whisper_token * tokens, + int n_tokens, + int n_past) { const auto & model = wctx.model; const auto & hparams = model.hparams; @@ -1972,10 +1806,6 @@ static bool whisper_decode_internal( WHISPER_ASSERT(!!kv_self.ctx); - auto & logits_out = wstate.logits; - - const int n_vocab = hparams.n_vocab; - const int n_ctx = hparams.n_text_ctx; const int n_state = hparams.n_text_state; const int n_head = hparams.n_text_head; @@ -1989,12 +1819,12 @@ static bool whisper_decode_internal( struct ggml_init_params params = { /*.mem_size =*/ wstate.buf_compute.size(), /*.mem_buffer =*/ wstate.buf_compute.data(), - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; + ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -2004,8 +1834,6 @@ static bool whisper_decode_internal( ((int32_t *) position->data)[i] = n_past + i; } - wstate.use_buf(ctx0, 3); - // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2019,8 +1847,6 @@ static bool whisper_decode_internal( // norm { - wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_0_w*cur + ln_0_b @@ -2071,14 +1897,12 @@ static bool whisper_decode_internal( ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } // ------ - wstate.use_buf(ctx0, 0); - struct ggml_tensor * Q = ggml_permute(ctx0, ggml_cpy(ctx0, @@ -2093,8 +1917,6 @@ static bool whisper_decode_internal( n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); - wstate.use_buf(ctx0, 1); - // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -2126,28 +1948,20 @@ static bool whisper_decode_internal( // projection { - wstate.use_buf(ctx0, 0); - cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - wstate.use_buf(ctx0, 1); - cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - wstate.use_buf(ctx0, 2); - // add the input struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); // norm { - wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here // cur = ln_0_w*cur + ln_0_b @@ -2232,21 +2046,15 @@ static bool whisper_decode_internal( // projection { - wstate.use_buf(ctx0, 0); - cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); - wstate.use_buf(ctx0, 1); - cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), cur); } - wstate.use_buf(ctx0, 2); - // add the input cur = ggml_add(ctx0, cur, inpCA); @@ -2256,12 +2064,8 @@ static bool whisper_decode_internal( { // norm { - wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpFF, hparams.eps); - wstate.use_buf(ctx0, 1); - // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -2270,40 +2074,28 @@ static bool whisper_decode_internal( ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } - wstate.use_buf(ctx0, 0); - // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - wstate.use_buf(ctx0, 1); - cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - wstate.use_buf(ctx0, 0); - // GELU activation cur = ggml_gelu(ctx0, cur); - wstate.use_buf(ctx0, 1); - // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - wstate.use_buf(ctx0, 0); - cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); } - wstate.use_buf(ctx0, 3); - inpL = ggml_add(ctx0, cur, inpFF); } @@ -2311,12 +2103,8 @@ static bool whisper_decode_internal( // norm { - wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, cur, hparams.eps); - wstate.use_buf(ctx0, 1); - cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.d_ln_w, cur), @@ -2324,8 +2112,6 @@ static bool whisper_decode_internal( ggml_repeat(ctx0, model.d_ln_b, cur)); } - wstate.use_buf(ctx0, 0); - // compute logits only for the last token // comment this line to compute logits for all N tokens // might be useful in the future @@ -2333,39 +2119,63 @@ static bool whisper_decode_internal( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - wstate.use_buf(ctx0, -1); - - // run the computation - { - ggml_build_forward_expand (&gf, logits); - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - } - - // extract logits for all N tokens - //logits_out.resize(N*n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); - - // extract logits only for the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); - - if (N > 1) { - //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1024.0/1024.0, - // wstate.get_buf_max_mem(0)/1024.0/1024.0, - // wstate.get_buf_max_mem(1)/1024.0/1024.0, - // wstate.get_buf_max_mem(2)/1024.0/1024.0, - // wstate.get_buf_max_mem(3)/1024.0/1024.0); - } + ggml_build_forward_expand(gf, logits); ggml_free(ctx0); - wstate.t_decode_us += ggml_time_us() - t_start_us; - wstate.n_decode++; + return gf; +} + +// evaluate the decoder +// +// given text prompt + audio features -> computes the logits for the next token +// +// - model: the model +// - n_threads: number of threads to use +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with +// +static bool whisper_decode_internal( + whisper_context & wctx, + whisper_state & wstate, + whisper_decoder & decoder, + const whisper_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads) { + //const int64_t t_start_us = ggml_time_us(); + + //auto & logits_out = wstate.logits; + + //const int n_vocab = hparams.n_vocab; + + // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + + //// extract logits for all N tokens + ////logits_out.resize(N*n_vocab); + ////memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + + //// extract logits only for the last token + //logits_out.resize(n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + + //if (N > 1) { + // //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // // ggml_used_mem(ctx0)/1024.0/1024.0, + // // wstate.get_buf_max_mem(0)/1024.0/1024.0, + // // wstate.get_buf_max_mem(1)/1024.0/1024.0, + // // wstate.get_buf_max_mem(2)/1024.0/1024.0, + // // wstate.get_buf_max_mem(3)/1024.0/1024.0); + //} + + //wstate.t_decode_us += ggml_time_us() - t_start_us; + //wstate.n_decode++; return true; } + // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2774,9 +2584,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { fill_sin_cos_table(); whisper_state * state = new whisper_state; - const size_t scale = ctx->model.hparams.ftype ? 1 : 2; - - if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { + if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { log("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; @@ -2787,7 +2595,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); delete state; return nullptr; @@ -2825,12 +2633,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->decoders[0].probs.reserve(ctx->vocab.n_vocab); state->decoders[0].logits.reserve(ctx->vocab.n_vocab); state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); - state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type))); - state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); - state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); - state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type)); - state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type)); + state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + + static const size_t tensor_alignment = 32; + state->alloc_encode = ggml_allocr_new_measure(tensor_alignment); + state->alloc_decode = ggml_allocr_new_measure(tensor_alignment); state->rng = std::mt19937(0);