forked from extern/whisper.cpp
ggml : fix bug in new soft max computation
This commit is contained in:
parent
44efbf7ff1
commit
f30b5d322c
29
ggml.c
29
ggml.c
@ -82,9 +82,16 @@ typedef void* thread_ret_t;
|
|||||||
/*#define GGML_PERF*/
|
/*#define GGML_PERF*/
|
||||||
#define GGML_DEBUG 0
|
#define GGML_DEBUG 0
|
||||||
#define GGML_GELU_FP16
|
#define GGML_GELU_FP16
|
||||||
|
|
||||||
#define GGML_SOFT_MAX_UNROLL 4
|
#define GGML_SOFT_MAX_UNROLL 4
|
||||||
#define GGML_VEC_DOT_UNROLL 4
|
#define GGML_VEC_DOT_UNROLL 4
|
||||||
|
|
||||||
|
#ifdef GGML_USE_ACCELERATE
|
||||||
|
// uncomment to use vDSP for soft max computation
|
||||||
|
// note: not sure if it is actually faster
|
||||||
|
//#define GGML_SOFT_MAX_ACCELERATE
|
||||||
|
#endif
|
||||||
|
|
||||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||||
#define GGML_MEM_ALIGN 4
|
#define GGML_MEM_ALIGN 4
|
||||||
#else
|
#else
|
||||||
@ -5975,7 +5982,12 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||||||
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
{
|
{
|
||||||
#ifndef GGML_USE_ACCELERATE
|
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||||
|
max = -max;
|
||||||
|
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
||||||
|
vvexpf(S, S, &Mup);
|
||||||
|
ggml_vec_sum_f32(Mup, &sum, S);
|
||||||
|
#else
|
||||||
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
||||||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||||
|
|
||||||
@ -5998,9 +6010,6 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||||||
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
||||||
sum += sump[i];
|
sum += sump[i];
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
vvexpf(S, S, &Mup);
|
|
||||||
ggml_vec_sum_f32(Mup, &sum, S);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6202,7 +6211,12 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||||||
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
{
|
{
|
||||||
#ifndef GGML_USE_ACCELERATE
|
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||||
|
max = -max;
|
||||||
|
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
||||||
|
vvexpf(S, S, &Mup);
|
||||||
|
ggml_vec_sum_f32(Mup, &sum, S);
|
||||||
|
#else
|
||||||
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
||||||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||||
|
|
||||||
@ -6225,9 +6239,6 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||||||
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
||||||
sum += sump[i];
|
sum += sump[i];
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
vvexpf(S, S, &Mup);
|
|
||||||
ggml_vec_sum_f32(Mup, &sum, S);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6244,7 +6255,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
||||||
|
|
||||||
for (int i = 0; i < M; i++) {
|
for (int i = 0; i < M; i++) {
|
||||||
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
||||||
|
Loading…
Reference in New Issue
Block a user