mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-14 15:18:50 +02:00
whisper : add whisper_allocr to wrap ggml_allocr
This commit is contained in:
128
whisper.cpp
128
whisper.cpp
@ -596,6 +596,21 @@ struct kv_buf {
|
||||
std::vector<uint8_t> v;
|
||||
};
|
||||
|
||||
// ggml_allocr wrapper for whisper usage
|
||||
struct whisper_allocr {
|
||||
ggml_allocr * alloc = nullptr;
|
||||
|
||||
std::vector<uint8_t> meta;
|
||||
std::vector<uint8_t> data;
|
||||
};
|
||||
|
||||
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
||||
if (allocr.alloc) {
|
||||
ggml_allocr_free(allocr.alloc);
|
||||
allocr.alloc = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
struct whisper_state {
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_encode_us = 0;
|
||||
@ -622,25 +637,12 @@ struct whisper_state {
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
// ggml-alloc:
|
||||
// - stores meta info about the intermediate tensors into the `meta_*` buffers
|
||||
// - stores the actual tensor data into the `data_*` buffers
|
||||
|
||||
ggml_allocr * alloc_conv = nullptr;
|
||||
ggml_allocr * alloc_encode = nullptr;
|
||||
ggml_allocr * alloc_cross = nullptr;
|
||||
ggml_allocr * alloc_decode = nullptr;
|
||||
|
||||
// meta data
|
||||
std::vector<uint8_t> meta_conv;
|
||||
std::vector<uint8_t> meta_encode;
|
||||
std::vector<uint8_t> meta_cross;
|
||||
std::vector<uint8_t> meta_decode;
|
||||
|
||||
// tensor data
|
||||
std::vector<uint8_t> data_conv;
|
||||
std::vector<uint8_t> data_encode;
|
||||
std::vector<uint8_t> data_cross;
|
||||
std::vector<uint8_t> data_decode;
|
||||
// - stores meta info about the intermediate tensors into the `meta` buffers
|
||||
// - stores the actual tensor data into the `data` buffers
|
||||
whisper_allocr alloc_conv;
|
||||
whisper_allocr alloc_encode;
|
||||
whisper_allocr alloc_cross;
|
||||
whisper_allocr alloc_decode;
|
||||
|
||||
// result of the encoder
|
||||
struct ggml_tensor * embd_conv = nullptr;
|
||||
@ -1437,8 +1439,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
const int n_mels = hparams.n_mels;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.meta_conv.size(),
|
||||
/*.mem_buffer =*/ wstate.meta_conv.data(),
|
||||
/*.mem_size =*/ wstate.alloc_conv.meta.size(),
|
||||
/*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
@ -1446,7 +1448,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_conv;
|
||||
ggml_allocr * alloc = wstate.alloc_conv.alloc;
|
||||
|
||||
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
||||
ggml_allocr_alloc(alloc, mel);
|
||||
@ -1533,8 +1535,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
const int n_layer = hparams.n_audio_layer;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.meta_encode.size(),
|
||||
/*.mem_buffer =*/ wstate.meta_encode.data(),
|
||||
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
||||
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
@ -1542,7 +1544,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_encode;
|
||||
ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
||||
|
||||
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(alloc, KQscale);
|
||||
@ -1782,8 +1784,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
||||
const int n_head = hparams.n_audio_head;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.meta_cross.size(),
|
||||
/*.mem_buffer =*/ wstate.meta_cross.data(),
|
||||
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
||||
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
@ -1791,7 +1793,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_cross;
|
||||
ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
||||
|
||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
||||
|
||||
@ -1859,7 +1861,7 @@ static bool whisper_encode_internal(
|
||||
|
||||
// conv
|
||||
{
|
||||
auto & alloc = wstate.alloc_conv;
|
||||
auto & alloc = wstate.alloc_conv.alloc;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
@ -1874,7 +1876,7 @@ static bool whisper_encode_internal(
|
||||
|
||||
// encoder
|
||||
if (!whisper_encode_external(wstate)) {
|
||||
auto & alloc = wstate.alloc_encode;
|
||||
auto & alloc = wstate.alloc_encode.alloc;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
@ -1896,7 +1898,7 @@ static bool whisper_encode_internal(
|
||||
|
||||
// cross
|
||||
{
|
||||
auto & alloc = wstate.alloc_cross;
|
||||
auto & alloc = wstate.alloc_cross.alloc;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
@ -1949,8 +1951,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.meta_decode.size(),
|
||||
/*.mem_buffer =*/ wstate.meta_decode.data(),
|
||||
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
||||
/*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
@ -1958,7 +1960,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_decode;
|
||||
ggml_allocr * alloc = wstate.alloc_decode.alloc;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
ggml_allocr_alloc(alloc, embd);
|
||||
@ -2292,7 +2294,7 @@ static bool whisper_decode_internal(
|
||||
|
||||
// decoder
|
||||
{
|
||||
auto & alloc = wstate.alloc_decode;
|
||||
auto & alloc = wstate.alloc_decode.alloc;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
@ -2801,9 +2803,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// conv allocator
|
||||
{
|
||||
auto & alloc = state->alloc_conv;
|
||||
auto & meta = state->meta_conv;
|
||||
auto & data = state->data_conv;
|
||||
auto & alloc = state->alloc_conv.alloc;
|
||||
auto & meta = state->alloc_conv.meta;
|
||||
auto & data = state->alloc_conv.data;
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
@ -2823,9 +2825,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// encoder allocator
|
||||
if (!whisper_encode_external(*state)) {
|
||||
auto & alloc = state->alloc_encode;
|
||||
auto & meta = state->meta_encode;
|
||||
auto & data = state->data_encode;
|
||||
auto & alloc = state->alloc_encode.alloc;
|
||||
auto & meta = state->alloc_encode.meta;
|
||||
auto & data = state->alloc_encode.data;
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
@ -2845,9 +2847,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// cross allocator
|
||||
{
|
||||
auto & alloc = state->alloc_cross;
|
||||
auto & meta = state->meta_cross;
|
||||
auto & data = state->data_cross;
|
||||
auto & alloc = state->alloc_cross.alloc;
|
||||
auto & meta = state->alloc_cross.meta;
|
||||
auto & data = state->alloc_cross.data;
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
@ -2867,9 +2869,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// decoder allocator
|
||||
{
|
||||
auto & alloc = state->alloc_decode;
|
||||
auto & meta = state->meta_decode;
|
||||
auto & data = state->data_decode;
|
||||
auto & alloc = state->alloc_decode.alloc;
|
||||
auto & meta = state->alloc_decode.meta;
|
||||
auto & data = state->alloc_decode.data;
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
@ -2933,19 +2935,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->meta_conv.data(), state->meta_conv.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->data_conv.data(), state->data_conv.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
|
||||
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
||||
|
||||
// TODO: handle multiple decoders
|
||||
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
||||
#undef WHISPER_METAL_CHECK_BUF
|
||||
#endif
|
||||
@ -3171,17 +3172,10 @@ void whisper_free_state(struct whisper_state * state)
|
||||
}
|
||||
#endif
|
||||
|
||||
if (state->alloc_encode) {
|
||||
ggml_allocr_free(state->alloc_encode);
|
||||
}
|
||||
|
||||
if (state->alloc_cross) {
|
||||
ggml_allocr_free(state->alloc_cross);
|
||||
}
|
||||
|
||||
if (state->alloc_decode) {
|
||||
ggml_allocr_free(state->alloc_decode);
|
||||
}
|
||||
whisper_allocr_free(state->alloc_conv);
|
||||
whisper_allocr_free(state->alloc_decode);
|
||||
whisper_allocr_free(state->alloc_cross);
|
||||
whisper_allocr_free(state->alloc_encode);
|
||||
|
||||
delete state;
|
||||
}
|
||||
|
Reference in New Issue
Block a user