whisper : add whisper_allocr to wrap ggml_allocr

This commit is contained in:
Georgi Gerganov
2023-09-13 11:58:19 +03:00
parent b19888cfb4
commit 254b687239

View File

@ -596,6 +596,21 @@ struct kv_buf {
std::vector<uint8_t> v;
};
// ggml_allocr wrapper for whisper usage
struct whisper_allocr {
ggml_allocr * alloc = nullptr;
std::vector<uint8_t> meta;
std::vector<uint8_t> data;
};
static void whisper_allocr_free(struct whisper_allocr & allocr) {
if (allocr.alloc) {
ggml_allocr_free(allocr.alloc);
allocr.alloc = nullptr;
}
}
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
@ -622,25 +637,12 @@ struct whisper_state {
std::vector<uint8_t> work_buffer;
// ggml-alloc:
// - stores meta info about the intermediate tensors into the `meta_*` buffers
// - stores the actual tensor data into the `data_*` buffers
ggml_allocr * alloc_conv = nullptr;
ggml_allocr * alloc_encode = nullptr;
ggml_allocr * alloc_cross = nullptr;
ggml_allocr * alloc_decode = nullptr;
// meta data
std::vector<uint8_t> meta_conv;
std::vector<uint8_t> meta_encode;
std::vector<uint8_t> meta_cross;
std::vector<uint8_t> meta_decode;
// tensor data
std::vector<uint8_t> data_conv;
std::vector<uint8_t> data_encode;
std::vector<uint8_t> data_cross;
std::vector<uint8_t> data_decode;
// - stores meta info about the intermediate tensors into the `meta` buffers
// - stores the actual tensor data into the `data` buffers
whisper_allocr alloc_conv;
whisper_allocr alloc_encode;
whisper_allocr alloc_cross;
whisper_allocr alloc_decode;
// result of the encoder
struct ggml_tensor * embd_conv = nullptr;
@ -1437,8 +1439,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
const int n_mels = hparams.n_mels;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.meta_conv.size(),
/*.mem_buffer =*/ wstate.meta_conv.data(),
/*.mem_size =*/ wstate.alloc_conv.meta.size(),
/*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
/*.no_alloc =*/ true,
};
@ -1446,7 +1448,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_conv;
ggml_allocr * alloc = wstate.alloc_conv.alloc;
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
ggml_allocr_alloc(alloc, mel);
@ -1533,8 +1535,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
const int n_layer = hparams.n_audio_layer;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.meta_encode.size(),
/*.mem_buffer =*/ wstate.meta_encode.data(),
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
/*.no_alloc =*/ true,
};
@ -1542,7 +1544,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_encode;
ggml_allocr * alloc = wstate.alloc_encode.alloc;
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(alloc, KQscale);
@ -1782,8 +1784,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(
const int n_head = hparams.n_audio_head;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.meta_cross.size(),
/*.mem_buffer =*/ wstate.meta_cross.data(),
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
/*.no_alloc =*/ true,
};
@ -1791,7 +1793,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_cross;
ggml_allocr * alloc = wstate.alloc_cross.alloc;
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
@ -1859,7 +1861,7 @@ static bool whisper_encode_internal(
// conv
{
auto & alloc = wstate.alloc_conv;
auto & alloc = wstate.alloc_conv.alloc;
ggml_allocr_reset(alloc);
@ -1874,7 +1876,7 @@ static bool whisper_encode_internal(
// encoder
if (!whisper_encode_external(wstate)) {
auto & alloc = wstate.alloc_encode;
auto & alloc = wstate.alloc_encode.alloc;
ggml_allocr_reset(alloc);
@ -1896,7 +1898,7 @@ static bool whisper_encode_internal(
// cross
{
auto & alloc = wstate.alloc_cross;
auto & alloc = wstate.alloc_cross.alloc;
ggml_allocr_reset(alloc);
@ -1949,8 +1951,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.meta_decode.size(),
/*.mem_buffer =*/ wstate.meta_decode.data(),
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
/*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
/*.no_alloc =*/ true,
};
@ -1958,7 +1960,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_decode;
ggml_allocr * alloc = wstate.alloc_decode.alloc;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(alloc, embd);
@ -2292,7 +2294,7 @@ static bool whisper_decode_internal(
// decoder
{
auto & alloc = wstate.alloc_decode;
auto & alloc = wstate.alloc_decode.alloc;
ggml_allocr_reset(alloc);
@ -2801,9 +2803,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// conv allocator
{
auto & alloc = state->alloc_conv;
auto & meta = state->meta_conv;
auto & data = state->data_conv;
auto & alloc = state->alloc_conv.alloc;
auto & meta = state->alloc_conv.meta;
auto & data = state->alloc_conv.data;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
@ -2823,9 +2825,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// encoder allocator
if (!whisper_encode_external(*state)) {
auto & alloc = state->alloc_encode;
auto & meta = state->meta_encode;
auto & data = state->data_encode;
auto & alloc = state->alloc_encode.alloc;
auto & meta = state->alloc_encode.meta;
auto & data = state->alloc_encode.data;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
@ -2845,9 +2847,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// cross allocator
{
auto & alloc = state->alloc_cross;
auto & meta = state->meta_cross;
auto & data = state->data_cross;
auto & alloc = state->alloc_cross.alloc;
auto & meta = state->alloc_cross.meta;
auto & data = state->alloc_cross.data;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
@ -2867,9 +2869,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// decoder allocator
{
auto & alloc = state->alloc_decode;
auto & meta = state->meta_decode;
auto & data = state->data_decode;
auto & alloc = state->alloc_decode.alloc;
auto & meta = state->alloc_decode.meta;
auto & data = state->alloc_decode.data;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
@ -2933,19 +2935,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->meta_conv.data(), state->meta_conv.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->data_conv.data(), state->data_conv.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
// TODO: handle multiple decoders
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
#undef WHISPER_METAL_CHECK_BUF
#endif
@ -3171,17 +3172,10 @@ void whisper_free_state(struct whisper_state * state)
}
#endif
if (state->alloc_encode) {
ggml_allocr_free(state->alloc_encode);
}
if (state->alloc_cross) {
ggml_allocr_free(state->alloc_cross);
}
if (state->alloc_decode) {
ggml_allocr_free(state->alloc_decode);
}
whisper_allocr_free(state->alloc_conv);
whisper_allocr_free(state->alloc_decode);
whisper_allocr_free(state->alloc_cross);
whisper_allocr_free(state->alloc_encode);
delete state;
}