diff --git a/whisper.cpp b/whisper.cpp index 44bef4b6..49664525 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -148,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text } \ } while (0) -//#define WHISPER_USE_FLASH_ATTN +#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 @@ -810,6 +810,9 @@ struct whisper_state { // shared between all decoders whisper_kv_cache kv_cross; + // padded buffer for flash-attention + whisper_kv_cache kv_pad; + whisper_mel mel; whisper_batch batch; @@ -903,14 +906,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) { } static bool kv_cache_init( - const struct whisper_hparams & hparams, struct whisper_kv_cache & cache, ggml_backend_t backend, ggml_type wtype, + int64_t n_text_state, + int64_t n_text_layer, int n_ctx) { - const int64_t n_text_state = hparams.n_text_state; - const int64_t n_text_layer = hparams.n_text_layer; - const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; @@ -942,6 +943,8 @@ static bool kv_cache_init( return false; } + ggml_backend_buffer_clear(cache.buffer, 0); + return true; } @@ -1873,6 +1876,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; + const int n_state_head = n_state/n_head; + + auto & kv_pad = wstate.kv_pad; + + WHISPER_ASSERT(!!kv_pad.ctx); + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_encode.meta.size(), /*.mem_buffer =*/ wstate.alloc_encode.meta.data(), @@ -1885,7 +1896,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); - const float KQscale = 1.0f/sqrtf(float(n_state)/n_head); + const float KQscale = 1.0f/sqrtf(float(n_state_head)); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1935,14 +1946,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); - //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25)); + //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25)); + //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, @@ -1956,31 +1967,31 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), 0, 2, 1, 3); #ifdef WHISPER_USE_FLASH_ATTN - struct ggml_tensor * Kpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head); + struct ggml_tensor * Kpad = ggml_reshape_3d(ctx0, kv_pad.k, n_state_head, n_ctx_pad, n_head); struct ggml_tensor * K = ggml_cpy(ctx0, ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Kcur, n_state/n_head, n_head, n_ctx), + ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx), 0, 2, 1, 3), ggml_view_3d(ctx0, Kpad, - n_state/n_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0)); + n_state_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0)); - struct ggml_tensor * Vpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head); + struct ggml_tensor * Vpad = ggml_reshape_3d(ctx0, kv_pad.v, n_state_head, n_ctx_pad, n_head); struct ggml_tensor * V = ggml_cpy(ctx0, ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx), + ggml_reshape_3d(ctx0, Vcur, n_state_head, n_head, n_ctx), 0, 2, 1, 3), ggml_view_3d(ctx0, Vpad, - n_state/n_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0)); + n_state_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0)); ggml_build_forward_expand(gf, K); ggml_build_forward_expand(gf, V); @@ -1993,7 +2004,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_permute(ctx0, ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q @@ -2006,9 +2017,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_permute(ctx0, ggml_reshape_3d(ctx0, Vcur, - n_state/n_head, n_head, n_ctx), + n_state_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head) ); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); @@ -2117,6 +2128,8 @@ static struct ggml_cgraph * whisper_build_graph_cross( const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; + const int n_state_head = n_state/n_head; + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_cross.meta.size(), /*.mem_buffer =*/ wstate.alloc_cross.meta.data(), @@ -2129,7 +2142,7 @@ static struct ggml_cgraph * whisper_build_graph_cross( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); - const float Kscale = pow(float(n_state) / n_head, -0.25); + const float Kscale = pow(float(n_state_head), -0.25); for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; @@ -2295,11 +2308,13 @@ static struct ggml_cgraph * whisper_build_graph_decoder( const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; + const int n_state_head = n_state/n_head; + const int n_tokens = batch.n_tokens; const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; + const int32_t n_kv = worst_case ? n_ctx : kv_self.n; + const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); @@ -2321,7 +2336,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(position, "position"); ggml_set_input(position); - const float KQscale = pow(float(n_state)/n_head, -0.25); + const float KQscale = pow(float(n_state_head), -0.25); struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); ggml_set_name(KQ_mask, "KQ_mask"); @@ -2397,14 +2412,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_state/n_head, n_kv, n_head, + n_state_head, n_kv, n_head, ggml_element_size(kv_self.k)*n_state, - ggml_element_size(kv_self.k)*n_state/n_head, + ggml_element_size(kv_self.k)*n_state_head, ggml_element_size(kv_self.k)*n_state*n_ctx*il); // K * Q @@ -2414,9 +2429,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_kv, n_state/n_head, n_head, + n_kv, n_state_head, n_head, n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, + n_ctx*ggml_element_size(kv_self.v)*n_state_head, n_ctx*ggml_element_size(kv_self.v)*n_state*il); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); @@ -2469,33 +2484,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state/n_head, n_audio_ctx, n_head, + n_state_head, n_audio_ctx, n_head, ggml_element_size(wstate.kv_cross.k)*n_state, - ggml_element_size(wstate.kv_cross.k)*n_state/n_head, + ggml_element_size(wstate.kv_cross.k)*n_state_head, ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); //struct ggml_tensor * Vcross = // ggml_reshape_3d(ctx0, // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, n_audio_ctx); + // n_state_head, n_head, n_audio_ctx); //struct ggml_tensor * V_trans = // ggml_cpy(ctx0, // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head)); + // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state_head, n_head)); struct ggml_tensor * V = ggml_view_3d(ctx0, wstate.kv_cross.v, - n_audio_ctx, n_state/n_head, n_head, + n_audio_ctx, n_state_head, n_head, n_audio_ctx*ggml_element_size(wstate.kv_cross.v), - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head, n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); // ------ struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); // K * Q @@ -2504,7 +2519,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state_head))) // ); // no masking for cross-attention @@ -3187,7 +3202,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; - if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) { + if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + ctx->model.hparams.n_text_ctx*factor)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3198,7 +3216,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + ctx->model.hparams.n_audio_ctx)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3209,6 +3230,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } + if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, + ctx->model.hparams.n_audio_state, + 1, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v); + WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6); + } + // [EXPERIMENTAL] Token-level timestamps with DTW if (ctx->params.dtw_token_timestamps) { if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { @@ -3565,6 +3600,7 @@ void whisper_free_state(struct whisper_state * state) { if (state) { kv_cache_free(state->kv_self); kv_cache_free(state->kv_cross); + kv_cache_free(state->kv_pad); #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) {