whisper : refactor ggml-alloc init

This commit is contained in:
Georgi Gerganov
2023-09-11 15:04:33 +03:00
parent 4d9acc60c3
commit 2770d46ef5

View File

@ -618,20 +618,26 @@ struct whisper_state {
// buffer for swapping KV caches between decoders during beam-search // buffer for swapping KV caches between decoders during beam-search
std::vector<kv_buf> kv_swap_bufs; std::vector<kv_buf> kv_swap_bufs;
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
// reusable buffer for `struct ggml_graph_plan.work_data` // reusable buffer for `struct ggml_graph_plan.work_data`
std::vector<uint8_t> work_buffer; std::vector<uint8_t> work_buffer;
// ggml-alloc // ggml-alloc:
std::vector<uint8_t> buf_encode; // - stores meta info about the intermediate tensors into the `meta_*` buffers
std::vector<uint8_t> buf_encode_post; // - stores the actual tensor data into the `data_*` buffers
std::vector<uint8_t> buf_decode;
ggml_allocr * alloc_encode = NULL; ggml_allocr * alloc_encode = NULL;
ggml_allocr * alloc_encode_post = NULL; ggml_allocr * alloc_cross = NULL;
ggml_allocr * alloc_decode = NULL; ggml_allocr * alloc_decode = NULL;
// meta data
std::vector<uint8_t> meta_encode;
std::vector<uint8_t> meta_cross;
std::vector<uint8_t> meta_decode;
// tensor data
std::vector<uint8_t> data_encode;
std::vector<uint8_t> data_cross;
std::vector<uint8_t> data_decode;
// result of the encoder // result of the encoder
struct ggml_tensor * embd_enc = NULL; struct ggml_tensor * embd_enc = NULL;
@ -1411,8 +1417,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
const int n_mels = hparams.n_mels; const int n_mels = hparams.n_mels;
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(), /*.mem_size =*/ wstate.meta_encode.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(), /*.mem_buffer =*/ wstate.meta_encode.data(),
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
}; };
@ -1746,7 +1752,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
} }
// pre-compute cross-attention memory // pre-compute cross-attention memory
static struct ggml_cgraph * whisper_build_graph_encoder_post( static struct ggml_cgraph * whisper_build_graph_cross(
whisper_context & wctx, whisper_context & wctx,
whisper_state & wstate) { whisper_state & wstate) {
const auto & model = wctx.model; const auto & model = wctx.model;
@ -1757,8 +1763,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
const int n_head = hparams.n_audio_head; const int n_head = hparams.n_audio_head;
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(), /*.mem_size =*/ wstate.meta_cross.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(), /*.mem_buffer =*/ wstate.meta_cross.data(),
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
}; };
@ -1766,7 +1772,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
ggml_cgraph * gf = ggml_new_graph(ctx0); ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_encode_post; ggml_allocr * alloc = wstate.alloc_cross;
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
@ -1863,13 +1869,13 @@ static bool whisper_encode_internal(
//printf("n: %d\n", ggml_nelements(cur)); //printf("n: %d\n", ggml_nelements(cur));
} }
// encoder_post // cross
{ {
auto & alloc = wstate.alloc_encode_post; auto & alloc = wstate.alloc_cross;
ggml_allocr_reset(alloc); ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_encoder_post(wctx, wstate); ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
ggml_allocr_alloc_graph(alloc, gf); ggml_allocr_alloc_graph(alloc, gf);
@ -1924,8 +1930,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(), /*.mem_size =*/ wstate.meta_decode.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(), /*.mem_buffer =*/ wstate.meta_decode.data(),
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
}; };
@ -2733,8 +2739,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
} }
log("debug CI - checkpoint 0\n");
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
delete state; delete state;
@ -2746,8 +2750,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
} }
log("debug CI - checkpoint 1\n");
#ifdef WHISPER_USE_COREML #ifdef WHISPER_USE_COREML
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
@ -2765,70 +2767,73 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
} }
#endif #endif
log("debug CI - checkpoint 2\n");
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
log("debug CI - checkpoint 3\n");
state->logits_id.reserve(ctx->model.hparams.n_vocab); state->logits_id.reserve(ctx->model.hparams.n_vocab);
log("debug CI - checkpoint 4\n");
// TAGS: WHISPER_DECODER_INIT // TAGS: WHISPER_DECODER_INIT
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
log("debug CI - checkpoint 5\n");
state->decoders[0].probs.reserve(ctx->vocab.n_vocab); state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
state->decoders[0].logits.reserve(ctx->vocab.n_vocab); state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
log("debug CI - checkpoint 6\n");
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
log("debug CI - checkpoint 7\n");
static const size_t tensor_alignment = 32; static const size_t tensor_alignment = 32;
log("debug CI - checkpoint 8\n");
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
log("debug CI - checkpoint 9\n");
state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment);
log("debug CI - checkpoint 10\n");
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
log("debug CI - checkpoint 11\n");
// encoder allocator // encoder allocator
{ {
auto & alloc = state->alloc_encode;
auto & meta = state->meta_encode;
auto & data = state->data_encode;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0); ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0);
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment; const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_encode);
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); ggml_allocr_free(alloc);
state->buf_encode.resize(alloc_size); log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
} }
// encoder_post allocator // cross allocator
{ {
ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state); auto & alloc = state->alloc_cross;
auto & meta = state->meta_cross;
auto & data = state->data_cross;
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment; meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
ggml_allocr_free(state->alloc_encode_post);
log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); alloc = ggml_allocr_new_measure(tensor_alignment);
state->buf_encode_post.resize(alloc_size); ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state);
state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
} }
// decoder allocator // decoder allocator
{ {
auto & alloc = state->alloc_decode;
auto & meta = state->meta_decode;
auto & data = state->data_decode;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
// TODO: make sure this is the worst-case scenario // TODO: make sure this is the worst-case scenario
@ -2837,13 +2842,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past); ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment; const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_decode); ggml_allocr_free(alloc);
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
state->buf_decode.resize(alloc_size); data.resize(alloc_size);
state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment); alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
} }
state->rng = std::mt19937(0); state->rng = std::mt19937(0);
@ -3071,8 +3076,8 @@ void whisper_free_state(struct whisper_state * state)
ggml_allocr_free(state->alloc_encode); ggml_allocr_free(state->alloc_encode);
} }
if (state->alloc_encode_post) { if (state->alloc_cross) {
ggml_allocr_free(state->alloc_encode_post); ggml_allocr_free(state->alloc_cross);
} }
if (state->alloc_decode) { if (state->alloc_decode) {