whisper : offload the Encoder to Metal

This commit is contained in:
Georgi Gerganov
2023-09-13 00:09:44 +03:00
parent ec9a7db74c
commit 3074a7ff14

View File

@ -625,21 +625,25 @@ struct whisper_state {
// - stores meta info about the intermediate tensors into the `meta_*` buffers // - stores meta info about the intermediate tensors into the `meta_*` buffers
// - stores the actual tensor data into the `data_*` buffers // - stores the actual tensor data into the `data_*` buffers
ggml_allocr * alloc_conv = nullptr;
ggml_allocr * alloc_encode = nullptr; ggml_allocr * alloc_encode = nullptr;
ggml_allocr * alloc_cross = nullptr; ggml_allocr * alloc_cross = nullptr;
ggml_allocr * alloc_decode = nullptr; ggml_allocr * alloc_decode = nullptr;
// meta data // meta data
std::vector<uint8_t> meta_conv;
std::vector<uint8_t> meta_encode; std::vector<uint8_t> meta_encode;
std::vector<uint8_t> meta_cross; std::vector<uint8_t> meta_cross;
std::vector<uint8_t> meta_decode; std::vector<uint8_t> meta_decode;
// tensor data // tensor data
std::vector<uint8_t> data_conv;
std::vector<uint8_t> data_encode; std::vector<uint8_t> data_encode;
std::vector<uint8_t> data_cross; std::vector<uint8_t> data_cross;
std::vector<uint8_t> data_decode; std::vector<uint8_t> data_decode;
// result of the encoder // result of the encoder
struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr; struct ggml_tensor * embd_enc = nullptr;
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
@ -1401,7 +1405,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return true; return true;
} }
static struct ggml_cgraph * whisper_build_graph_encoder( static bool whisper_encode_external(const whisper_state & wstate) {
#ifndef WHISPER_USE_COREML
const bool use_coreml = false;
#else
const bool use_coreml = wstate.ctx_coreml != nullptr;
#endif
#ifndef WHISPER_USE_OPENVINO
const bool use_openvino = false;
#else
const bool use_openvino = wstate.ctx_openvino != nullptr;
#endif
return use_coreml || use_openvino;
}
static struct ggml_cgraph * whisper_build_graph_conv(
whisper_context & wctx, whisper_context & wctx,
whisper_state & wstate, whisper_state & wstate,
const int mel_offset) { const int mel_offset) {
@ -1410,15 +1430,13 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state; const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state);
const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer;
const int n_mels = hparams.n_mels; const int n_mels = hparams.n_mels;
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ wstate.meta_encode.size(), /*.mem_size =*/ wstate.meta_conv.size(),
/*.mem_buffer =*/ wstate.meta_encode.data(), /*.mem_buffer =*/ wstate.meta_conv.data(),
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
}; };
@ -1426,7 +1444,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_cgraph * gf = ggml_new_graph(ctx0); ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_encode; ggml_allocr * alloc = wstate.alloc_conv;
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
ggml_allocr_alloc(alloc, mel); ggml_allocr_alloc(alloc, mel);
@ -1448,30 +1466,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
} }
} }
ggml_build_forward_expand(gf, mel); struct ggml_tensor * cur = nullptr;
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); if (!whisper_encode_external(wstate)) {
ggml_allocr_alloc(alloc, KQscale);
if (!ggml_allocr_is_measure(alloc)) {
ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
}
struct ggml_tensor * cur;
#ifndef WHISPER_USE_COREML
const bool use_coreml = false;
#else
const bool use_coreml = wstate.ctx_coreml != nullptr;
#endif
#ifndef WHISPER_USE_OPENVINO
const bool use_openvino = false;
#else
const bool use_openvino = wstate.ctx_openvino != nullptr;
#endif
if (!use_coreml && !use_openvino) {
// convolution + gelu // convolution + gelu
{ {
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
@ -1493,6 +1490,67 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
} }
wstate.embd_conv = cur;
} else {
#ifdef WHISPER_USE_COREML
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
ggml_allocr_alloc(alloc, cur);
if (!ggml_allocr_is_measure(alloc)) {
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
}
#endif
#ifdef WHISPER_USE_OPENVINO
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
ggml_allocr_alloc(alloc, cur);
if (!ggml_allocr_is_measure(alloc)) {
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
}
#endif
wstate.embd_enc = cur;
}
ggml_build_forward_expand(gf, cur);
ggml_free(ctx0);
return gf;
}
static struct ggml_cgraph * whisper_build_graph_encoder(
whisper_context & wctx,
whisper_state & wstate) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.meta_encode.size(),
/*.mem_buffer =*/ wstate.meta_encode.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_encode;
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(alloc, KQscale);
if (!ggml_allocr_is_measure(alloc)) {
ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
}
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
// =================================================================== // ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore) // NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1; //static int iter = -1;
@ -1512,7 +1570,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
// =================================================================== // ===================================================================
@ -1689,27 +1747,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_mul(ctx0, cur, model.e_ln_w), ggml_mul(ctx0, cur, model.e_ln_w),
model.e_ln_b); model.e_ln_b);
} }
}
#ifdef WHISPER_USE_COREML
else if (use_coreml) {
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
ggml_allocr_alloc(alloc, cur);
if (!ggml_allocr_is_measure(alloc)) {
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
}
}
#endif
#ifdef WHISPER_USE_OPENVINO
else if (use_openvino) {
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
ggml_allocr_alloc(alloc, cur);
if (!ggml_allocr_is_measure(alloc)) {
whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
}
}
#endif
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -1818,17 +1855,38 @@ static bool whisper_encode_internal(
const int n_threads) { const int n_threads) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
// encoder // conv
{ {
auto & alloc = wstate.alloc_conv;
ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
ggml_allocr_alloc_graph(alloc, gf);
if (!whisper_encode_external(wstate)) {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
}
// encoder
if (!whisper_encode_external(wstate)) {
auto & alloc = wstate.alloc_encode; auto & alloc = wstate.alloc_encode;
ggml_allocr_reset(alloc); ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate, mel_offset); ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
ggml_allocr_alloc_graph(alloc, gf); ggml_allocr_alloc_graph(alloc, gf);
#ifdef WHISPER_USE_COREML #ifdef GGML_USE_METAL
if (wstate.ctx_metal) {
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
#else #else
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
#endif #endif
@ -1845,7 +1903,7 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf); ggml_allocr_alloc_graph(alloc, gf);
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (wstate.ctx_metal && false) { if (wstate.ctx_metal) {
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf); ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else { } else {
@ -2739,8 +2797,30 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
static const size_t tensor_alignment = 32; static const size_t tensor_alignment = 32;
// encoder allocator // conv allocator
{ {
auto & alloc = state->alloc_conv;
auto & meta = state->meta_conv;
auto & data = state->data_conv;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_conv(*ctx, *state, 0);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (conv) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
}
// encoder allocator
if (!whisper_encode_external(*state)) {
auto & alloc = state->alloc_encode; auto & alloc = state->alloc_encode;
auto & meta = state->meta_encode; auto & meta = state->meta_encode;
auto & data = state->data_encode; auto & data = state->data_encode;
@ -2749,7 +2829,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
alloc = ggml_allocr_new_measure(tensor_alignment); alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0); ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
@ -2851,10 +2931,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->meta_conv.data(), state->meta_conv.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->meta_encode.data(), state->meta_encode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->meta_cross.data(), state->meta_cross.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->meta_decode.data(), state->meta_decode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->data_conv.data(), state->data_conv.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->data_encode.data(), state->data_encode.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->data_cross.data(), state->data_cross.size(), 0));
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0)); WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->data_decode.data(), state->data_decode.size(), 0));