whisper : use flash attention in the encoder

This commit is contained in:
Georgi Gerganov 2024-05-14 14:49:58 +03:00
parent f56b8305c4
commit 7c94a11162
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -148,6 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
} \
} while (0)
//#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
@ -1958,6 +1959,36 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Kpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
struct ggml_tensor * K =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Kcur, n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_view_3d(ctx0,
Kpad,
n_state/n_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
struct ggml_tensor * Vpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
struct ggml_tensor * V =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_view_3d(ctx0,
Vpad,
n_state/n_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
ggml_build_forward_expand(gf, K);
ggml_build_forward_expand(gf, V);
cur = ggml_flash_attn_ext(ctx0, Q, Kpad, Vpad, nullptr, KQscale, 0.0f);
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
#else
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_cpy(ctx0,
@ -1987,6 +2018,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
#endif
}
// projection