whisper : fix UB with measure buffers

This commit is contained in:
Georgi Gerganov 2023-11-11 18:35:23 +02:00
parent fc8565d0e2
commit 40c66036b6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -642,7 +642,7 @@ struct whisper_allocr {
}; };
static size_t whisper_allocr_size(struct whisper_allocr & allocr) { static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
return allocr.meta.size() + ggml_backend_buffer_get_size(allocr.buffer); return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
} }
// measure the memory usage of a graph and prepare the allocr's internal data buffer // measure the memory usage of a graph and prepare the allocr's internal data buffer
@ -655,12 +655,19 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()); ggml_allocr_alloc_graph(alloc, get_graph());
}
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
auto & alloc = allocr.alloc;
auto & buffer = allocr.buffer;
size_t size = ggml_allocr_max_size(alloc);
ggml_allocr_free(alloc); ggml_allocr_free(alloc);
buffer = ggml_backend_alloc_buffer(backend, alloc_size); buffer = ggml_backend_alloc_buffer(backend, size);
alloc = ggml_allocr_new_from_buffer(buffer); alloc = ggml_allocr_new_from_buffer(buffer);
} }
static void whisper_allocr_free(struct whisper_allocr & allocr) { static void whisper_allocr_free(struct whisper_allocr & allocr) {
@ -2915,6 +2922,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
} }
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
state->rng = std::mt19937(0); state->rng = std::mt19937(0);
return state; return state;