mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-31 23:15:38 +02:00
whisper : use flash attention in the encoder
This commit is contained in:
parent
f56b8305c4
commit
7c94a11162
32
whisper.cpp
32
whisper.cpp
@ -148,6 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
//#define WHISPER_USE_FLASH_ATTN
|
||||
//#define WHISPER_USE_FLASH_FF
|
||||
#define WHISPER_MAX_DECODERS 8
|
||||
#define WHISPER_MAX_NODES 4096
|
||||
@ -1958,6 +1959,36 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
#ifdef WHISPER_USE_FLASH_ATTN
|
||||
struct ggml_tensor * Kpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0, Kcur, n_state/n_head, n_head, n_ctx),
|
||||
0, 2, 1, 3),
|
||||
ggml_view_3d(ctx0,
|
||||
Kpad,
|
||||
n_state/n_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
|
||||
|
||||
struct ggml_tensor * Vpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx),
|
||||
0, 2, 1, 3),
|
||||
ggml_view_3d(ctx0,
|
||||
Vpad,
|
||||
n_state/n_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
|
||||
|
||||
ggml_build_forward_expand(gf, K);
|
||||
ggml_build_forward_expand(gf, V);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, Q, Kpad, Vpad, nullptr, KQscale, 0.0f);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
||||
#else
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
@ -1987,6 +2018,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
||||
#endif
|
||||
}
|
||||
|
||||
// projection
|
||||
|
Loading…
x
Reference in New Issue
Block a user