ggml : fix bug in new soft max computation

This commit is contained in:
Georgi Gerganov 2023-01-07 21:00:07 +02:00
parent 44efbf7ff1
commit f30b5d322c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

31
ggml.c
View File

@ -82,8 +82,15 @@ typedef void* thread_ret_t;
/*#define GGML_PERF*/ /*#define GGML_PERF*/
#define GGML_DEBUG 0 #define GGML_DEBUG 0
#define GGML_GELU_FP16 #define GGML_GELU_FP16
#define GGML_SOFT_MAX_UNROLL 4 #define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 4 #define GGML_VEC_DOT_UNROLL 4
#ifdef GGML_USE_ACCELERATE
// uncomment to use vDSP for soft max computation
// note: not sure if it is actually faster
//#define GGML_SOFT_MAX_ACCELERATE
#endif
#if UINTPTR_MAX == 0xFFFFFFFF #if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4 #define GGML_MEM_ALIGN 4
@ -5975,7 +5982,12 @@ static void ggml_compute_forward_flash_attn_f32(
float sum = 0.0f; float sum = 0.0f;
{ {
#ifndef GGML_USE_ACCELERATE #ifdef GGML_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(S, 1, &max, S, 1, Mup);
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#else
uint16_t scvt[GGML_SOFT_MAX_UNROLL]; uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
@ -5998,9 +6010,6 @@ static void ggml_compute_forward_flash_attn_f32(
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
sum += sump[i]; sum += sump[i];
} }
#else
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#endif #endif
} }
@ -6202,7 +6211,12 @@ static void ggml_compute_forward_flash_attn_f16(
float sum = 0.0f; float sum = 0.0f;
{ {
#ifndef GGML_USE_ACCELERATE #ifdef GGML_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(S, 1, &max, S, 1, Mup);
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#else
uint16_t scvt[GGML_SOFT_MAX_UNROLL]; uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
@ -6225,9 +6239,6 @@ static void ggml_compute_forward_flash_attn_f16(
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
sum += sump[i]; sum += sump[i];
} }
#else
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#endif #endif
} }
@ -6244,7 +6255,7 @@ static void ggml_compute_forward_flash_attn_f16(
#endif #endif
} }
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
S16[i] = GGML_FP32_TO_FP16(S[i]); S16[i] = GGML_FP32_TO_FP16(S[i]);