whisper : wip sched (not working yet)

This commit is contained in:
Georgi Gerganov 2023-11-09 19:07:54 +02:00
parent 005b8ccbf0
commit bf4110dbcf
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -715,6 +715,11 @@ struct whisper_state {
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
// TODO: free this
struct ggml_context * ctx;
ggml_backend_buffer_t buffer;
// cross-attention KV cache for the decoders
// shared between all decoders
whisper_kv_cache kv_cross;
@ -725,9 +730,6 @@ struct whisper_state {
// buffer for swapping KV caches between decoders during beam-search
std::vector<kv_buf> kv_swap_bufs;
// reusable buffer for `struct ggml_graph_plan.work_data`
std::vector<uint8_t> work_buffer;
// backend schedulers
ggml_backend_sched_t sched_conv;
ggml_backend_sched_t sched_encode;
@ -735,8 +737,10 @@ struct whisper_state {
ggml_backend_sched_t sched_decode;
// result of the encoder
struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr;
struct ggml_tensor * inp_mel = nullptr;
struct ggml_tensor * inp_kq_scale = nullptr;
struct ggml_tensor * out_conv = nullptr;
struct ggml_tensor * out_enc = nullptr;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@ -1546,7 +1550,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
//struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
struct ggml_tensor * mel = wstate.inp_mel;
//ggml_allocr_alloc(alloc, mel);
//assert(mel->type == GGML_TYPE_F32);
@ -1588,9 +1593,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
cur);
cur = ggml_gelu(ctx0, cur);
ggml_set_name(cur, "out_conv");
}
wstate.embd_conv = cur;
} else {
#ifdef WHISPER_USE_COREML
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
@ -1609,7 +1613,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
}
#endif
wstate.embd_enc = cur;
ggml_set_name(cur, "out_enc");
}
ggml_build_forward_expand(gf, cur);
@ -1650,7 +1654,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
//}
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
// TODO: use ggml_backend_tensor_copy to copy the data from the previous graph
struct ggml_tensor * cur = wstate.out_conv;
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
@ -1670,7 +1675,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
// ===================================================================
@ -1851,7 +1855,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_build_forward_expand(gf, cur);
wstate.embd_enc = cur;
//wstate.out_enc = cur;
//ggml_graph_print(gf);
@ -1893,7 +1897,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
// TODO: use ggml_backend_tensor_copy to copy the data from the previous graph
struct ggml_tensor * cur = wstate.out_enc;
struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
//ggml_allocr_alloc(alloc, Kscale);
@ -2862,15 +2867,73 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
// per-state context for input and output tensors
{
const size_t n_tensors = 100; // TODO: fix
struct ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead() * n_tensors,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
state->ctx = ggml_init(params);
if (!state->ctx) {
WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
delete state;
return nullptr;
}
}
auto & backends = ctx->backends;
// allocate input tensors
{
const auto & hparams = ctx->model.hparams;
const int n_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_mels = hparams.n_mels;
const int n_state = hparams.n_audio_state;
auto & ctx = state->ctx;
state->inp_mel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_ctx, n_mels);
state->inp_kq_scale = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
state->out_conv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ctx, n_state);
state->out_enc = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_state, n_ctx);
const size_t size_inp =
(ggml_tensor_overhead() + ggml_nbytes(state->inp_mel)) +
(ggml_tensor_overhead() + ggml_nbytes(state->inp_kq_scale)) +
(ggml_tensor_overhead() + ggml_nbytes(state->out_conv)) +
(ggml_tensor_overhead() + ggml_nbytes(state->out_enc));
ggml_backend_t backend_input = backends[0];
state->buffer = ggml_backend_alloc_buffer(backend_input, size_inp);
printf("%s: backend_in = %s (%zu bytes)\n", __func__, ggml_backend_name(backend_input), size_inp);
ggml_allocr * alloc = ggml_allocr_new_from_buffer(state->buffer);
ggml_allocr_alloc(alloc, state->inp_mel);
ggml_allocr_alloc(alloc, state->inp_kq_scale);
ggml_allocr_alloc(alloc, state->out_conv);
ggml_allocr_alloc(alloc, state->out_enc);
ggml_allocr_free(alloc);
// initialize KQ_scale
//float s = 1.0f/sqrtf(float(hparams.n_embd)/hparams.n_head);
//ggml_backend_tensor_set(model.KQ_scale, &s, 0, sizeof(s));
}
// conv allocator
{
auto & sched = state->sched_conv;
sched = ggml_backend_sched_new(backends.data(), backends.size());
ggml_backend_sched_init_measure(sched, whisper_build_graph_conv(*ctx, *state, 0));
struct ggml_cgraph * gf = whisper_build_graph_conv(*ctx, *state, 0);
ggml_backend_sched_init_measure(sched, gf);
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(sched, backends) / 1024.0 / 1024.0);
}