diff --git a/whisper.cpp b/whisper.cpp index d16492cd..471d9a85 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -642,7 +642,7 @@ struct whisper_allocr { }; static size_t whisper_allocr_size(struct whisper_allocr & allocr) { - return allocr.meta.size() + ggml_backend_buffer_get_size(allocr.buffer); + return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); } // measure the memory usage of a graph and prepare the allocr's internal data buffer @@ -655,12 +655,19 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); - const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()); + ggml_allocr_alloc_graph(alloc, get_graph()); +} + +static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { + auto & alloc = allocr.alloc; + auto & buffer = allocr.buffer; + + size_t size = ggml_allocr_max_size(alloc); ggml_allocr_free(alloc); - buffer = ggml_backend_alloc_buffer(backend, alloc_size); - alloc = ggml_allocr_new_from_buffer(buffer); + buffer = ggml_backend_alloc_buffer(backend, size); + alloc = ggml_allocr_new_from_buffer(buffer); } static void whisper_allocr_free(struct whisper_allocr & allocr) { @@ -2915,6 +2922,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); } + whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend); + state->rng = std::mt19937(0); return state;