From f30b5d322c5786afc75a8a059e0b42dd6b1162b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Jan 2023 21:00:07 +0200 Subject: [PATCH] ggml : fix bug in new soft max computation --- ggml.c | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/ggml.c b/ggml.c index f4c96eb4..eefdcdd7 100644 --- a/ggml.c +++ b/ggml.c @@ -82,8 +82,15 @@ typedef void* thread_ret_t; /*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 + #define GGML_SOFT_MAX_UNROLL 4 -#define GGML_VEC_DOT_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 4 + +#ifdef GGML_USE_ACCELERATE +// uncomment to use vDSP for soft max computation +// note: not sure if it is actually faster +//#define GGML_SOFT_MAX_ACCELERATE +#endif #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 @@ -5975,7 +5982,12 @@ static void ggml_compute_forward_flash_attn_f32( float sum = 0.0f; { -#ifndef GGML_USE_ACCELERATE +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else uint16_t scvt[GGML_SOFT_MAX_UNROLL]; ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; @@ -5998,9 +6010,6 @@ static void ggml_compute_forward_flash_attn_f32( for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { sum += sump[i]; } -#else - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); #endif } @@ -6202,7 +6211,12 @@ static void ggml_compute_forward_flash_attn_f16( float sum = 0.0f; { -#ifndef GGML_USE_ACCELERATE +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else uint16_t scvt[GGML_SOFT_MAX_UNROLL]; ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; @@ -6225,9 +6239,6 @@ static void ggml_compute_forward_flash_attn_f16( for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { sum += sump[i]; } -#else - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); #endif } @@ -6244,7 +6255,7 @@ static void ggml_compute_forward_flash_attn_f16( #endif } - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); for (int i = 0; i < M; i++) { S16[i] = GGML_FP32_TO_FP16(S[i]);