mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-03 20:48:59 +01:00
whisper : faster beam_search sampling via reduced KV cache copies (#1243)
* Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6ddc727fac
commit
9b14418863
178
whisper.cpp
178
whisper.cpp
@ -18,6 +18,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -537,6 +538,7 @@ struct whisper_kv_cache {
|
|||||||
|
|
||||||
struct ggml_context * ctx;
|
struct ggml_context * ctx;
|
||||||
|
|
||||||
|
// buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
|
||||||
std::vector<uint8_t> buf;
|
std::vector<uint8_t> buf;
|
||||||
|
|
||||||
int n; // number of tokens currently in the cache
|
int n; // number of tokens currently in the cache
|
||||||
@ -602,7 +604,7 @@ struct whisper_sequence {
|
|||||||
|
|
||||||
// TAGS: WHISPER_DECODER_INIT
|
// TAGS: WHISPER_DECODER_INIT
|
||||||
struct whisper_decoder {
|
struct whisper_decoder {
|
||||||
// each decoders keeps its own KV-cache
|
// each decoder keeps its own KV-cache
|
||||||
whisper_kv_cache kv_self;
|
whisper_kv_cache kv_self;
|
||||||
|
|
||||||
// the currently generated sequence of tokens
|
// the currently generated sequence of tokens
|
||||||
@ -622,6 +624,24 @@ struct whisper_decoder {
|
|||||||
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
||||||
|
template<typename A, typename B>
|
||||||
|
struct whisper_pair {
|
||||||
|
A first;
|
||||||
|
B second;
|
||||||
|
|
||||||
|
// Define a constructor that takes two arguments.
|
||||||
|
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
||||||
|
// Define a constructor that takes no argument.
|
||||||
|
whisper_pair() : first(A()), second(B()) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// beam-search helpers
|
||||||
|
struct kv_buf {
|
||||||
|
std::vector<uint8_t> k;
|
||||||
|
std::vector<uint8_t> v;
|
||||||
|
};
|
||||||
|
|
||||||
struct whisper_state {
|
struct whisper_state {
|
||||||
int64_t t_sample_us = 0;
|
int64_t t_sample_us = 0;
|
||||||
int64_t t_encode_us = 0;
|
int64_t t_encode_us = 0;
|
||||||
@ -641,6 +661,9 @@ struct whisper_state {
|
|||||||
|
|
||||||
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
||||||
|
|
||||||
|
// buffer for swapping KV caches between decoders during beam-search
|
||||||
|
std::vector<kv_buf> kv_swap_bufs;
|
||||||
|
|
||||||
// memory buffers used by encode / decode contexts
|
// memory buffers used by encode / decode contexts
|
||||||
std::vector<uint8_t> buf_compute;
|
std::vector<uint8_t> buf_compute;
|
||||||
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
||||||
@ -655,7 +678,7 @@ struct whisper_state {
|
|||||||
std::vector<whisper_token> prompt_past;
|
std::vector<whisper_token> prompt_past;
|
||||||
|
|
||||||
// work container used to avoid memory allocations
|
// work container used to avoid memory allocations
|
||||||
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
|
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
||||||
|
|
||||||
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
||||||
|
|
||||||
@ -3975,17 +3998,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|||||||
|
|
||||||
auto & logits_id = state.logits_id;
|
auto & logits_id = state.logits_id;
|
||||||
|
|
||||||
logits_id.clear();
|
logits_id.resize(n_logits);
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
logits_id.push_back({ logits[i], i });
|
logits_id[i].first = logits[i];
|
||||||
|
logits_id[i].second = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
|
||||||
std::partial_sort(
|
std::partial_sort(
|
||||||
logits_id.begin(),
|
logits_id.begin(),
|
||||||
logits_id.begin() + k, logits_id.end(),
|
logits_id.begin() + k, logits_id.end(),
|
||||||
[](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
|
[](const pair_type & a, const pair_type & b) {
|
||||||
return a.first > b.first;
|
return a.first > b.first;
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<whisper_token_data> result;
|
std::vector<whisper_token_data> result;
|
||||||
result.reserve(k);
|
result.reserve(k);
|
||||||
@ -4080,6 +4107,115 @@ static void whisper_sequence_score(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool whisper_kv_swap_fast(
|
||||||
|
std::vector<int> & view,
|
||||||
|
whisper_decoder src[],
|
||||||
|
std::vector<kv_buf> & kv_swap_bufs,
|
||||||
|
const int & n_decoders) {
|
||||||
|
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
|
||||||
|
|
||||||
|
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
|
||||||
|
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
|
||||||
|
|
||||||
|
// (buffer->decoder or decoder->decoder)
|
||||||
|
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
|
||||||
|
|
||||||
|
// (decoder<->decoder)
|
||||||
|
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
|
||||||
|
std::vector<whisper_pair<int, int>> p_swap_vec;
|
||||||
|
p_swap_vec.reserve(n_decoders);
|
||||||
|
|
||||||
|
// see https://github.com/ggerganov/whisper.cpp/wiki
|
||||||
|
for (int i = 0; i < n_decoders; i++) {
|
||||||
|
// zero-copy (no modification)
|
||||||
|
if (i == view[i] || view[i] < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_one_copy = true;
|
||||||
|
// since we modify data sequentially, we only consider decoder indices after current index
|
||||||
|
for (int j = i + 1; j < n_decoders; j++) {
|
||||||
|
if (i == view[j]) {
|
||||||
|
// detect symmetric diagram
|
||||||
|
if (j == view[i]) {
|
||||||
|
p_swap_set.insert(i);
|
||||||
|
p_swap_set.insert(j);
|
||||||
|
p_swap_vec.emplace_back(i, j);
|
||||||
|
} else {
|
||||||
|
two_copy.insert(i);
|
||||||
|
is_one_copy = false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (is_one_copy) {
|
||||||
|
one_copy.insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_swap_bufs.resize(n_decoders);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_decoders; i++) {
|
||||||
|
kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
|
||||||
|
kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & i : two_copy) {
|
||||||
|
// make a copy of KV caches
|
||||||
|
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
||||||
|
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
||||||
|
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
||||||
|
for (auto & i : two_copy) {
|
||||||
|
// skip the decoder indices that require pointer swapping
|
||||||
|
if (p_swap_set.find(i) != p_swap_set.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (two_copy.find(view[i]) != two_copy.end()) {
|
||||||
|
// modify KV caches of decoder using data from kv_swap_bufs
|
||||||
|
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
||||||
|
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
||||||
|
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
||||||
|
} else {
|
||||||
|
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
||||||
|
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
||||||
|
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
||||||
|
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// then modify one-copy decoder KV caches
|
||||||
|
for (auto & i : one_copy) {
|
||||||
|
// skip the decoder indices that require pointer swapping
|
||||||
|
if (p_swap_set.find(i) != p_swap_set.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (two_copy.find(view[i]) != two_copy.end()) {
|
||||||
|
// modify KV caches of decoder using data from kv_swap_bufs
|
||||||
|
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
||||||
|
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
||||||
|
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
||||||
|
} else {
|
||||||
|
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
||||||
|
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
||||||
|
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
||||||
|
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// swap the pointers
|
||||||
|
for (auto & i : p_swap_vec) {
|
||||||
|
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
|
||||||
|
std::swap(src[i.first].kv_self, src[i.second].kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
int whisper_full_with_state(
|
int whisper_full_with_state(
|
||||||
struct whisper_context * ctx,
|
struct whisper_context * ctx,
|
||||||
struct whisper_state * state,
|
struct whisper_state * state,
|
||||||
@ -4243,14 +4379,6 @@ int whisper_full_with_state(
|
|||||||
std::vector<whisper_token> prompt;
|
std::vector<whisper_token> prompt;
|
||||||
prompt.reserve(whisper_n_text_ctx(ctx));
|
prompt.reserve(whisper_n_text_ctx(ctx));
|
||||||
|
|
||||||
// beam-search helpers
|
|
||||||
struct kv_buf {
|
|
||||||
std::vector<uint8_t> k;
|
|
||||||
std::vector<uint8_t> v;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<kv_buf> kv_bufs;
|
|
||||||
|
|
||||||
struct beam_candidate {
|
struct beam_candidate {
|
||||||
int decoder_idx;
|
int decoder_idx;
|
||||||
int seek_delta;
|
int seek_delta;
|
||||||
@ -4399,23 +4527,7 @@ int whisper_full_with_state(
|
|||||||
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
// store the KV caches of all decoders when doing beam-search
|
|
||||||
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
||||||
kv_bufs.resize(n_decoders_cur);
|
|
||||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
||||||
auto & decoder = state->decoders[j];
|
|
||||||
|
|
||||||
if (decoder.completed || decoder.failed) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
|
|
||||||
kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
|
|
||||||
|
|
||||||
memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
|
|
||||||
memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
beam_candidates.clear();
|
beam_candidates.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4463,6 +4575,7 @@ int whisper_full_with_state(
|
|||||||
});
|
});
|
||||||
|
|
||||||
uint32_t cur_c = 0;
|
uint32_t cur_c = 0;
|
||||||
|
std::vector<int> decoder_idx(n_decoders_cur, -1);
|
||||||
|
|
||||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||||
auto & decoder = state->decoders[j];
|
auto & decoder = state->decoders[j];
|
||||||
@ -4481,12 +4594,13 @@ int whisper_full_with_state(
|
|||||||
decoder.seek_delta = cur.seek_delta;
|
decoder.seek_delta = cur.seek_delta;
|
||||||
decoder.has_ts = cur.has_ts;
|
decoder.has_ts = cur.has_ts;
|
||||||
|
|
||||||
memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
|
decoder_idx[j] = cur.decoder_idx;
|
||||||
memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
|
|
||||||
|
|
||||||
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
||||||
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update KV caches
|
||||||
|
whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update the decoder state
|
// update the decoder state
|
||||||
|
Loading…
Reference in New Issue
Block a user