From 7c94a111622b77e29b29a6935a76077f50609d6a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 May 2024 14:49:58 +0300 Subject: [PATCH] whisper : use flash attention in the encoder --- whisper.cpp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index ff4223da..44bef4b6 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -148,6 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text } \ } while (0) +//#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 @@ -1958,6 +1959,36 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); +#ifdef WHISPER_USE_FLASH_ATTN + struct ggml_tensor * Kpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head); + + struct ggml_tensor * K = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Kcur, n_state/n_head, n_head, n_ctx), + 0, 2, 1, 3), + ggml_view_3d(ctx0, + Kpad, + n_state/n_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0)); + + struct ggml_tensor * Vpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx), + 0, 2, 1, 3), + ggml_view_3d(ctx0, + Vpad, + n_state/n_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0)); + + ggml_build_forward_expand(gf, K); + ggml_build_forward_expand(gf, V); + + cur = ggml_flash_attn_ext(ctx0, Q, Kpad, Vpad, nullptr, KQscale, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); +#else struct ggml_tensor * K = ggml_permute(ctx0, ggml_cpy(ctx0, @@ -1987,6 +2018,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); +#endif } // projection