From 91096daa1aabccc5cf7e4a5d19871df24d0ae770 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Nov 2023 16:57:28 +0200 Subject: [PATCH] whisper : full batched decoding support --- whisper.cpp | 99 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 37 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 2d164180..e42026e4 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -688,6 +688,7 @@ struct whisper_decoder { // grammar parse state of generated sequence of tokens whisper_grammar grammar; + int i_batch; // the index of the token in the current batch int seek_delta; // the window shift found so far based on the decoded timestamp tokens bool failed; // has the current segment failed to decode? @@ -2228,7 +2229,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_allocr * alloc = wstate.alloc_decode.alloc; - const int n_ctx = hparams.n_text_ctx; + const int n_ctx = kv_self.size; const int n_state = hparams.n_text_state; const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; @@ -2569,7 +2570,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // compute logits only for the last token // comment this line to compute logits for all n_tokens // might be useful in the future - cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); @@ -2602,22 +2603,26 @@ static bool whisper_decode_internal( const auto & model = wctx.model; const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; + const int n_vocab = hparams.n_vocab; + const int n_tokens = batch.n_tokens; auto & logits_out = wstate.logits; struct ggml_tensor * logits; - auto & kv_self = wstate.kv_self; + // find KV slot for the batch + { + auto & kv_self = wstate.kv_self; - if (!whisper_kv_cache_find_slot(kv_self, batch)) { - return 1; + if (!whisper_kv_cache_find_slot(kv_self, batch)) { + return false; + } + + kv_self.n = whisper_kv_cache_cell_max(kv_self); + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); } - kv_self.n = whisper_kv_cache_cell_max(kv_self); - //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); - //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); - // decoder { auto & alloc = wstate.alloc_decode.alloc; @@ -2633,15 +2638,13 @@ static bool whisper_decode_internal( ggml_graph_compute_helper(wstate.backend, gf, n_threads); } - // extract logits for all N tokens - //logits_out.resize(n_tokens*n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); - //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab); - - // extract logits only for the last token - logits_out.resize(n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); - ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab); + logits_out.resize(n_tokens*n_vocab); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab); + } if (batch.n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, @@ -3074,7 +3077,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->backend = whisper_backend_init(ctx->params); - if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { + // TODO: determine how large the cache should be + const int factor = 2; + + if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; @@ -4566,7 +4572,7 @@ static void whisper_process_logits( auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float)); if (temperature > 0.0f) { for (int i = 0; i < n_logits; i++) { @@ -5317,6 +5323,8 @@ int whisper_full_with_state( { const int64_t t_start_sample_us = ggml_time_us(); + state->decoders[0].i_batch = prompt.size() - 1; + whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); for (int j = 1; j < n_decoders_cur; ++j) { @@ -5384,7 +5392,6 @@ int whisper_full_with_state( }); uint32_t cur_c = 0; - std::vector decoder_idx(n_decoders_cur, -1); for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; @@ -5408,8 +5415,6 @@ int whisper_full_with_state( decoder.sequence = cur.sequence; decoder.grammar = cur.grammar; - decoder_idx[j] = cur.decoder_idx; - whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", @@ -5535,32 +5540,52 @@ int whisper_full_with_state( state->t_sample_us += ggml_time_us() - t_start_sample_us; // obtain logits for the next token - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + { + auto & batch = state->batch; - if (decoder.failed || decoder.completed) { - continue; - } + batch.n_tokens = 0; - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); - - // TODO: use batch const int n_past = prompt.size() + i; - whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); + + decoder.i_batch = batch.n_tokens; + + batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id; + batch.pos [batch.n_tokens] = n_past; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id [batch.n_tokens][0] = j; + batch.logits [batch.n_tokens] = 1; + batch.n_tokens++; + } + + assert(batch.n_tokens > 0); if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } - { - const int64_t t_start_sample_us = ggml_time_us(); + const int64_t t_start_sample_us = ggml_time_us(); + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } whisper_process_logits(*ctx, *state, params, decoder, t_cur); - - state->t_sample_us += ggml_time_us() - t_start_sample_us; } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; } }