CPU/CUDA: Gemma 2 FlashAttention support (llama/8542)

* CPU/CUDA: Gemma 2 FlashAttention support

* apply logit_softcap to scale in kernel

* disable logit softcapping tests on Metal

* remove metal check
This commit is contained in:
Johannes Gäßler
2024-08-24 21:34:59 +02:00
committed by Georgi Gerganov
parent 9b16ddd3a5
commit 24d8534bd8
10 changed files with 304 additions and 63 deletions

View File

@@ -1807,7 +1807,8 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
float max_bias);
float max_bias,
float logit_softcap);
GGML_API void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,