mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
whisper : fix UB with measure buffers
This commit is contained in:
parent
fc8565d0e2
commit
40c66036b6
20
whisper.cpp
20
whisper.cpp
@ -642,7 +642,7 @@ struct whisper_allocr {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
||||||
return allocr.meta.size() + ggml_backend_buffer_get_size(allocr.buffer);
|
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
||||||
@ -655,12 +655,19 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe
|
|||||||
|
|
||||||
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
||||||
|
|
||||||
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph());
|
ggml_allocr_alloc_graph(alloc, get_graph());
|
||||||
|
}
|
||||||
|
|
||||||
|
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
|
||||||
|
auto & alloc = allocr.alloc;
|
||||||
|
auto & buffer = allocr.buffer;
|
||||||
|
|
||||||
|
size_t size = ggml_allocr_max_size(alloc);
|
||||||
|
|
||||||
ggml_allocr_free(alloc);
|
ggml_allocr_free(alloc);
|
||||||
|
|
||||||
buffer = ggml_backend_alloc_buffer(backend, alloc_size);
|
buffer = ggml_backend_alloc_buffer(backend, size);
|
||||||
alloc = ggml_allocr_new_from_buffer(buffer);
|
alloc = ggml_allocr_new_from_buffer(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
||||||
@ -2915,6 +2922,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
|
||||||
|
whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
|
||||||
|
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
||||||
|
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
||||||
|
|
||||||
state->rng = std::mt19937(0);
|
state->rng = std::mt19937(0);
|
||||||
|
|
||||||
return state;
|
return state;
|
||||||
|
Loading…
Reference in New Issue
Block a user