whisper : fix beam-search with CUDA

This commit is contained in:
Georgi Gerganov 2023-11-10 12:41:11 +02:00
parent 3dfbe64911
commit dcf9511dbb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -725,6 +725,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe
static void whisper_allocr_free(struct whisper_allocr & allocr) {
if (allocr.alloc) {
ggml_allocr_free(allocr.alloc);
ggml_backend_buffer_free(allocr.buffer);
allocr.alloc = nullptr;
}
}
@ -765,6 +766,7 @@ struct whisper_state {
struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr;
// TODO: helper until conv is implemented in CUDA
std::vector<float> inp_mel;
// decode output (2-dimensional array: [n_tokens][n_vocab])
@ -940,6 +942,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t back
static void kv_cache_free(struct whisper_kv_cache & cache) {
if (cache.ctx) {
ggml_free(cache.ctx);
ggml_backend_buffer_free(cache.buffer);
cache.ctx = nullptr;
}
}
@ -3265,6 +3268,9 @@ void whisper_free(struct whisper_context * ctx) {
ggml_free(ctx->model.ctx);
}
if (ctx->model.data) {
ggml_backend_buffer_free(ctx->model.data->buffer_conv);
ggml_backend_buffer_free(ctx->model.data->buffer_main);
delete ctx->model.data;
}
@ -4406,8 +4412,10 @@ static bool whisper_kv_swap_fast(
for (auto & i : two_copy) {
// make a copy of KV caches
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
//memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
//memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
}
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
@ -4420,13 +4428,17 @@ static bool whisper_kv_swap_fast(
if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV caches of decoder using data from kv_swap_bufs
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
} else {
// modify KV caches of decoder using data from correspond decoder KV caches directly
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
}
}
@ -4440,13 +4452,17 @@ static bool whisper_kv_swap_fast(
if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV caches of decoder using data from kv_swap_bufs
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
} else {
// modify KV caches of decoder using data from correspond decoder KV caches directly
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
}
}
@ -4765,8 +4781,11 @@ int whisper_full_with_state(
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
// TODO: fix CUDA
//memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
//memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
decoder.kv_self.n += prompt.size();