mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-02 16:05:35 +02:00
whisper : use flash attention in the encoder
This commit is contained in:
parent
f56b8305c4
commit
7c94a11162
32
whisper.cpp
32
whisper.cpp
@ -148,6 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
//#define WHISPER_USE_FLASH_ATTN
|
||||||
//#define WHISPER_USE_FLASH_FF
|
//#define WHISPER_USE_FLASH_FF
|
||||||
#define WHISPER_MAX_DECODERS 8
|
#define WHISPER_MAX_DECODERS 8
|
||||||
#define WHISPER_MAX_NODES 4096
|
#define WHISPER_MAX_NODES 4096
|
||||||
@ -1958,6 +1959,36 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
|
#ifdef WHISPER_USE_FLASH_ATTN
|
||||||
|
struct ggml_tensor * Kpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
|
||||||
|
|
||||||
|
struct ggml_tensor * K =
|
||||||
|
ggml_cpy(ctx0,
|
||||||
|
ggml_permute(ctx0,
|
||||||
|
ggml_reshape_3d(ctx0, Kcur, n_state/n_head, n_head, n_ctx),
|
||||||
|
0, 2, 1, 3),
|
||||||
|
ggml_view_3d(ctx0,
|
||||||
|
Kpad,
|
||||||
|
n_state/n_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
|
||||||
|
|
||||||
|
struct ggml_tensor * Vpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
|
||||||
|
|
||||||
|
struct ggml_tensor * V =
|
||||||
|
ggml_cpy(ctx0,
|
||||||
|
ggml_permute(ctx0,
|
||||||
|
ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx),
|
||||||
|
0, 2, 1, 3),
|
||||||
|
ggml_view_3d(ctx0,
|
||||||
|
Vpad,
|
||||||
|
n_state/n_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, K);
|
||||||
|
ggml_build_forward_expand(gf, V);
|
||||||
|
|
||||||
|
cur = ggml_flash_attn_ext(ctx0, Q, Kpad, Vpad, nullptr, KQscale, 0.0f);
|
||||||
|
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
||||||
|
#else
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
@ -1987,6 +2018,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
cur = ggml_cpy(ctx0,
|
cur = ggml_cpy(ctx0,
|
||||||
KQV_merged,
|
KQV_merged,
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
|
Loading…
x
Reference in New Issue
Block a user