whisper : add kv_pad

This commit is contained in:
Georgi Gerganov 2024-05-14 17:16:37 +03:00
parent 7c94a11162
commit 2877b026cf
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -148,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
} \
} while (0)
//#define WHISPER_USE_FLASH_ATTN
#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
@ -810,6 +810,9 @@ struct whisper_state {
// shared between all decoders
whisper_kv_cache kv_cross;
// padded buffer for flash-attention
whisper_kv_cache kv_pad;
whisper_mel mel;
whisper_batch batch;
@ -903,14 +906,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
}
static bool kv_cache_init(
const struct whisper_hparams & hparams,
struct whisper_kv_cache & cache,
ggml_backend_t backend,
ggml_type wtype,
int64_t n_text_state,
int64_t n_text_layer,
int n_ctx) {
const int64_t n_text_state = hparams.n_text_state;
const int64_t n_text_layer = hparams.n_text_layer;
const int64_t n_mem = n_text_layer*n_ctx;
const int64_t n_elements = n_text_state*n_mem;
@ -942,6 +943,8 @@ static bool kv_cache_init(
return false;
}
ggml_backend_buffer_clear(cache.buffer, 0);
return true;
}
@ -1873,6 +1876,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer;
const int n_state_head = n_state/n_head;
auto & kv_pad = wstate.kv_pad;
WHISPER_ASSERT(!!kv_pad.ctx);
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
@ -1885,7 +1896,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
const float KQscale = 1.0f/sqrtf(float(n_state_head));
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
@ -1935,14 +1946,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
layer.attn_k_w,
cur);
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
layer.attn_v_w,
@ -1956,31 +1967,31 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
0, 2, 1, 3);
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Kpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
struct ggml_tensor * Kpad = ggml_reshape_3d(ctx0, kv_pad.k, n_state_head, n_ctx_pad, n_head);
struct ggml_tensor * K =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Kcur, n_state/n_head, n_head, n_ctx),
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_view_3d(ctx0,
Kpad,
n_state/n_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
n_state_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
struct ggml_tensor * Vpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
struct ggml_tensor * Vpad = ggml_reshape_3d(ctx0, kv_pad.v, n_state_head, n_ctx_pad, n_head);
struct ggml_tensor * V =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx),
ggml_reshape_3d(ctx0, Vcur, n_state_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_view_3d(ctx0,
Vpad,
n_state/n_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
n_state_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
ggml_build_forward_expand(gf, K);
ggml_build_forward_expand(gf, V);
@ -1993,7 +2004,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_permute(ctx0,
ggml_cpy(ctx0,
Kcur,
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
@ -2006,9 +2017,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
Vcur,
n_state/n_head, n_head, n_ctx),
n_state_head, n_head, n_ctx),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
@ -2117,6 +2128,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
const int n_state_head = n_state/n_head;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
@ -2129,7 +2142,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
const float Kscale = pow(float(n_state) / n_head, -0.25);
const float Kscale = pow(float(n_state_head), -0.25);
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
@ -2295,6 +2308,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer;
const int n_state_head = n_state/n_head;
const int n_tokens = batch.n_tokens;
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
@ -2321,7 +2336,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_set_name(position, "position");
ggml_set_input(position);
const float KQscale = pow(float(n_state)/n_head, -0.25);
const float KQscale = pow(float(n_state_head), -0.25);
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
ggml_set_name(KQ_mask, "KQ_mask");
@ -2397,14 +2412,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
n_state/n_head, n_kv, n_head,
n_state_head, n_kv, n_head,
ggml_element_size(kv_self.k)*n_state,
ggml_element_size(kv_self.k)*n_state/n_head,
ggml_element_size(kv_self.k)*n_state_head,
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
// K * Q
@ -2414,9 +2429,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_kv, n_state/n_head, n_head,
n_kv, n_state_head, n_head,
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
n_ctx*ggml_element_size(kv_self.v)*n_state_head,
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
@ -2469,33 +2484,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// Kcross is already scaled
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state/n_head, n_audio_ctx, n_head,
n_state_head, n_audio_ctx, n_head,
ggml_element_size(wstate.kv_cross.k)*n_state,
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
ggml_element_size(wstate.kv_cross.k)*n_state_head,
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
//struct ggml_tensor * Vcross =
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
// n_state/n_head, n_head, n_audio_ctx);
// n_state_head, n_head, n_audio_ctx);
//struct ggml_tensor * V_trans =
// ggml_cpy(ctx0,
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state_head, n_head));
struct ggml_tensor * V =
ggml_view_3d(ctx0, wstate.kv_cross.v,
n_audio_ctx, n_state/n_head, n_head,
n_audio_ctx, n_state_head, n_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
// ------
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);
// K * Q
@ -2504,7 +2519,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctx0,
// KQ,
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state_head)))
// );
// no masking for cross-attention
@ -3187,7 +3202,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// in theory, there can be a case where this is not enough, but in practice it should always be enough
const int factor = 3;
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer,
ctx->model.hparams.n_text_ctx*factor)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
@ -3198,7 +3216,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
}
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer,
ctx->model.hparams.n_audio_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
@ -3209,6 +3230,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
}
if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
ctx->model.hparams.n_audio_state,
1,
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
}
{
const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
}
// [EXPERIMENTAL] Token-level timestamps with DTW
if (ctx->params.dtw_token_timestamps) {
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
@ -3565,6 +3600,7 @@ void whisper_free_state(struct whisper_state * state) {
if (state) {
kv_cache_free(state->kv_self);
kv_cache_free(state->kv_cross);
kv_cache_free(state->kv_pad);
#ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) {