whisper : fix beam-search with CUDA

This commit is contained in:
Georgi Gerganov 2023-11-10 12:41:11 +02:00
parent 3dfbe64911
commit dcf9511dbb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -725,6 +725,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe
static void whisper_allocr_free(struct whisper_allocr & allocr) { static void whisper_allocr_free(struct whisper_allocr & allocr) {
if (allocr.alloc) { if (allocr.alloc) {
ggml_allocr_free(allocr.alloc); ggml_allocr_free(allocr.alloc);
ggml_backend_buffer_free(allocr.buffer);
allocr.alloc = nullptr; allocr.alloc = nullptr;
} }
} }
@ -765,6 +766,7 @@ struct whisper_state {
struct ggml_tensor * embd_conv = nullptr; struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr; struct ggml_tensor * embd_enc = nullptr;
// TODO: helper until conv is implemented in CUDA
std::vector<float> inp_mel; std::vector<float> inp_mel;
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
@ -940,6 +942,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t back
static void kv_cache_free(struct whisper_kv_cache & cache) { static void kv_cache_free(struct whisper_kv_cache & cache) {
if (cache.ctx) { if (cache.ctx) {
ggml_free(cache.ctx); ggml_free(cache.ctx);
ggml_backend_buffer_free(cache.buffer);
cache.ctx = nullptr; cache.ctx = nullptr;
} }
} }
@ -3265,6 +3268,9 @@ void whisper_free(struct whisper_context * ctx) {
ggml_free(ctx->model.ctx); ggml_free(ctx->model.ctx);
} }
if (ctx->model.data) { if (ctx->model.data) {
ggml_backend_buffer_free(ctx->model.data->buffer_conv);
ggml_backend_buffer_free(ctx->model.data->buffer_main);
delete ctx->model.data; delete ctx->model.data;
} }
@ -4406,8 +4412,10 @@ static bool whisper_kv_swap_fast(
for (auto & i : two_copy) { for (auto & i : two_copy) {
// make a copy of KV caches // make a copy of KV caches
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i); WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
} }
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
@ -4420,13 +4428,17 @@ static bool whisper_kv_swap_fast(
if (two_copy.find(view[i]) != two_copy.end()) { if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV caches of decoder using data from kv_swap_bufs // modify KV caches of decoder using data from kv_swap_bufs
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
} else { } else {
// modify KV caches of decoder using data from correspond decoder KV caches directly // modify KV caches of decoder using data from correspond decoder KV caches directly
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
} }
} }
@ -4440,13 +4452,17 @@ static bool whisper_kv_swap_fast(
if (two_copy.find(view[i]) != two_copy.end()) { if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV caches of decoder using data from kv_swap_bufs // modify KV caches of decoder using data from kv_swap_bufs
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
} else { } else {
// modify KV caches of decoder using data from correspond decoder KV caches directly // modify KV caches of decoder using data from correspond decoder KV caches directly
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
} }
} }
@ -4765,8 +4781,11 @@ int whisper_full_with_state(
for (int j = 1; j < n_decoders_cur; ++j) { for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j]; auto & decoder = state->decoders[j];
memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); // TODO: fix CUDA
memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
//memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
decoder.kv_self.n += prompt.size(); decoder.kv_self.n += prompt.size();