diff --git a/whisper.cpp b/whisper.cpp index 3192fbc6..5c14b43e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -537,6 +538,7 @@ struct whisper_kv_cache { struct ggml_context * ctx; + // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init) std::vector buf; int n; // number of tokens currently in the cache @@ -602,7 +604,7 @@ struct whisper_sequence { // TAGS: WHISPER_DECODER_INIT struct whisper_decoder { - // each decoders keeps its own KV-cache + // each decoder keeps its own KV-cache whisper_kv_cache kv_self; // the currently generated sequence of tokens @@ -622,6 +624,24 @@ struct whisper_decoder { std::vector tokens_tmp; // used for whisper_decode calls }; +// replace std::pair by using customized pair struct (reason: std::pair is very slow) +template +struct whisper_pair { + A first; + B second; + + // Define a constructor that takes two arguments. + whisper_pair(const A& a, const B& b) : first(a), second(b) {} + // Define a constructor that takes no argument. + whisper_pair() : first(A()), second(B()) {} +}; + +// beam-search helpers +struct kv_buf { + std::vector k; + std::vector v; +}; + struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; @@ -641,6 +661,9 @@ struct whisper_state { whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + // buffer for swapping KV caches between decoders during beam-search + std::vector kv_swap_bufs; + // memory buffers used by encode / decode contexts std::vector buf_compute; std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; @@ -655,7 +678,7 @@ struct whisper_state { std::vector prompt_past; // work container used to avoid memory allocations - std::vector> logits_id; + std::vector> logits_id; mutable std::mt19937 rng; // used for sampling at t > 0.0 @@ -3975,17 +3998,21 @@ static std::vector whisper_sample_token_topk( auto & logits_id = state.logits_id; - logits_id.clear(); + logits_id.resize(n_logits); for (int i = 0; i < n_logits; ++i) { - logits_id.push_back({ logits[i], i }); + logits_id[i].first = logits[i]; + logits_id[i].second = i; } - std::partial_sort( - logits_id.begin(), - logits_id.begin() + k, logits_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); + { + using pair_type = std::remove_reference::type::value_type; + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } std::vector result; result.reserve(k); @@ -4080,6 +4107,115 @@ static void whisper_sequence_score( } } +static bool whisper_kv_swap_fast( + std::vector & view, + whisper_decoder src[], + std::vector & kv_swap_bufs, + const int & n_decoders) { + WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders); + + // (decoder->buffer->decoder or decoder->buffer + decoder->decoder) + std::set two_copy; // decoder indices require two copies to safely modify KV caches + + // (buffer->decoder or decoder->decoder) + std::set one_copy; // decoder indices require one copy to safely modify KV caches + + // (decoder<->decoder) + std::set p_swap_set; // decoder indices able to swap KV-cache pointers + std::vector> p_swap_vec; + p_swap_vec.reserve(n_decoders); + + // see https://github.com/ggerganov/whisper.cpp/wiki + for (int i = 0; i < n_decoders; i++) { + // zero-copy (no modification) + if (i == view[i] || view[i] < 0) { + continue; + } + + bool is_one_copy = true; + // since we modify data sequentially, we only consider decoder indices after current index + for (int j = i + 1; j < n_decoders; j++) { + if (i == view[j]) { + // detect symmetric diagram + if (j == view[i]) { + p_swap_set.insert(i); + p_swap_set.insert(j); + p_swap_vec.emplace_back(i, j); + } else { + two_copy.insert(i); + is_one_copy = false; + } + break; + } + } + if (is_one_copy) { + one_copy.insert(i); + } + } + + kv_swap_bufs.resize(n_decoders); + + for (int i = 0; i < n_decoders; i++) { + kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k)); + kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v)); + } + + for (auto & i : two_copy) { + // make a copy of KV caches + WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i); + memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); + memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); + } + + // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first + for (auto & i : two_copy) { + // skip the decoder indices that require pointer swapping + if (p_swap_set.find(i) != p_swap_set.end()) { + continue; + } + + if (two_copy.find(view[i]) != two_copy.end()) { + // modify KV caches of decoder using data from kv_swap_bufs + WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); + memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); + memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + } else { + // modify KV caches of decoder using data from correspond decoder KV caches directly + WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); + memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); + memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); + } + } + + // then modify one-copy decoder KV caches + for (auto & i : one_copy) { + // skip the decoder indices that require pointer swapping + if (p_swap_set.find(i) != p_swap_set.end()) { + continue; + } + + if (two_copy.find(view[i]) != two_copy.end()) { + // modify KV caches of decoder using data from kv_swap_bufs + WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); + memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); + memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + } else { + // modify KV caches of decoder using data from correspond decoder KV caches directly + WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); + memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); + memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); + } + } + + // swap the pointers + for (auto & i : p_swap_vec) { + WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second); + std::swap(src[i.first].kv_self, src[i.second].kv_self); + } + + return true; +} + int whisper_full_with_state( struct whisper_context * ctx, struct whisper_state * state, @@ -4243,14 +4379,6 @@ int whisper_full_with_state( std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); - // beam-search helpers - struct kv_buf { - std::vector k; - std::vector v; - }; - - std::vector kv_bufs; - struct beam_candidate { int decoder_idx; int seek_delta; @@ -4399,23 +4527,7 @@ int whisper_full_with_state( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { const int64_t t_start_sample_us = ggml_time_us(); - // store the KV caches of all decoders when doing beam-search if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { - kv_bufs.resize(n_decoders_cur); - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; - - if (decoder.completed || decoder.failed) { - continue; - } - - kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k)); - kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v)); - - memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); - memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); - } - beam_candidates.clear(); } @@ -4463,6 +4575,7 @@ int whisper_full_with_state( }); uint32_t cur_c = 0; + std::vector decoder_idx(n_decoders_cur, -1); for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; @@ -4481,12 +4594,13 @@ int whisper_full_with_state( decoder.seek_delta = cur.seek_delta; decoder.has_ts = cur.has_ts; - memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); - memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); - + decoder_idx[j] = cur.decoder_idx; WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); } + + // update KV caches + whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur); } // update the decoder state