whisper : factor out alloc init in a function

This commit is contained in:
Georgi Gerganov
2023-09-13 12:51:52 +03:00
parent 254b687239
commit b6f09669a2

View File

@ -604,6 +604,31 @@ struct whisper_allocr {
std::vector<uint8_t> data;
};
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
return allocr.meta.size() + allocr.data.size();
}
// measure the memory usage of a graph and prepare the allocr's internal data buffer
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
const int tensor_alignment = 32;
auto & alloc = allocr.alloc;
auto & meta = allocr.meta;
auto & data = allocr.data;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
ggml_allocr_free(alloc);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
}
static void whisper_allocr_free(struct whisper_allocr & allocr) {
if (allocr.alloc) {
ggml_allocr_free(allocr.alloc);
@ -2799,100 +2824,50 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
static const size_t tensor_alignment = 32;
// conv allocator
{
auto & alloc = state->alloc_conv.alloc;
auto & meta = state->alloc_conv.meta;
auto & data = state->alloc_conv.data;
whisper_allocr_graph_init(state->alloc_conv,
[&]() {
return whisper_build_graph_conv(*ctx, *state, 0);
});
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_conv(*ctx, *state, 0);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (conv) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
}
// encoder allocator
if (!whisper_encode_external(*state)) {
auto & alloc = state->alloc_encode.alloc;
auto & meta = state->alloc_encode.meta;
auto & data = state->alloc_encode.data;
whisper_allocr_graph_init(state->alloc_encode,
[&]() {
return whisper_build_graph_encoder(*ctx, *state);
});
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
}
// cross allocator
{
auto & alloc = state->alloc_cross.alloc;
auto & meta = state->alloc_cross.meta;
auto & data = state->alloc_cross.data;
whisper_allocr_graph_init(state->alloc_cross,
[&]() {
return whisper_build_graph_cross(*ctx, *state);
});
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
}
// decoder allocator
{
auto & alloc = state->alloc_decode.alloc;
auto & meta = state->alloc_decode.meta;
auto & data = state->alloc_decode.data;
whisper_allocr_graph_init(state->alloc_decode,
[&]() {
const auto & hparams = ctx->model.hparams;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
// TODO: make sure this is the worst-case scenario
const int n_tokens = hparams.n_text_ctx;
const int n_past = 0;
alloc = ggml_allocr_new_measure(tensor_alignment);
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
});
const auto & hparams = ctx->model.hparams;
// TODO: make sure this is the worst-case scenario
const int n_tokens = hparams.n_text_ctx;
const int n_past = 0;
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
}
#ifdef GGML_USE_METAL