whisper : factor out alloc init in a function

This commit is contained in:
Georgi Gerganov
2023-09-13 12:51:52 +03:00
parent 254b687239
commit b6f09669a2

View File

@ -604,6 +604,31 @@ struct whisper_allocr {
std::vector<uint8_t> data; std::vector<uint8_t> data;
}; };
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
return allocr.meta.size() + allocr.data.size();
}
// measure the memory usage of a graph and prepare the allocr's internal data buffer
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
const int tensor_alignment = 32;
auto & alloc = allocr.alloc;
auto & meta = allocr.meta;
auto & data = allocr.data;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
ggml_allocr_free(alloc);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
}
static void whisper_allocr_free(struct whisper_allocr & allocr) { static void whisper_allocr_free(struct whisper_allocr & allocr) {
if (allocr.alloc) { if (allocr.alloc) {
ggml_allocr_free(allocr.alloc); ggml_allocr_free(allocr.alloc);
@ -2799,100 +2824,50 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->decoders[0].logits.reserve (ctx->vocab.n_vocab); state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
static const size_t tensor_alignment = 32;
// conv allocator // conv allocator
{ {
auto & alloc = state->alloc_conv.alloc; whisper_allocr_graph_init(state->alloc_conv,
auto & meta = state->alloc_conv.meta; [&]() {
auto & data = state->alloc_conv.data; return whisper_build_graph_conv(*ctx, *state, 0);
});
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_conv(*ctx, *state, 0);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (conv) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
} }
// encoder allocator // encoder allocator
if (!whisper_encode_external(*state)) { if (!whisper_encode_external(*state)) {
auto & alloc = state->alloc_encode.alloc; whisper_allocr_graph_init(state->alloc_encode,
auto & meta = state->alloc_encode.meta; [&]() {
auto & data = state->alloc_encode.data; return whisper_build_graph_encoder(*ctx, *state);
});
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
} }
// cross allocator // cross allocator
{ {
auto & alloc = state->alloc_cross.alloc; whisper_allocr_graph_init(state->alloc_cross,
auto & meta = state->alloc_cross.meta; [&]() {
auto & data = state->alloc_cross.data; return whisper_build_graph_cross(*ctx, *state);
});
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
} }
// decoder allocator // decoder allocator
{ {
auto & alloc = state->alloc_decode.alloc; whisper_allocr_graph_init(state->alloc_decode,
auto & meta = state->alloc_decode.meta; [&]() {
auto & data = state->alloc_decode.data;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
// TODO: make sure this is the worst-case scenario // TODO: make sure this is the worst-case scenario
const int n_tokens = hparams.n_text_ctx; const int n_tokens = hparams.n_text_ctx;
const int n_past = 0; const int n_past = 0;
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
});
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
ggml_allocr_free(alloc);
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
} }
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL