From 419b8a640221a2eb4a8a8e2dbd479ba659433f7d Mon Sep 17 00:00:00 2001 From: katsu560 Date: Sat, 17 Dec 2022 08:42:30 +0900 Subject: [PATCH] Add AVX,AVX2 support for ggml_vec_scale_f32 --- ggml.c | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index c5780ed..f1d2b25 100644 --- a/ggml.c +++ b/ggml.c @@ -1118,7 +1118,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ #endif } -inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(__AVX__) || defined(__AVX2__) + // AVX 256-bit + const int n32 = (n & ~31); + + const __m256 v4 = _mm256_set1_ps(v); + + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + y0 = _mm256_loadu_ps(y + i + 0); + y1 = _mm256_loadu_ps(y + i + 8); + y2 = _mm256_loadu_ps(y + i + 16); + y3 = _mm256_loadu_ps(y + i + 24); + + y0 = _mm256_mul_ps(y0, v4); + y1 = _mm256_mul_ps(y1, v4); + y2 = _mm256_mul_ps(y2, v4); + y3 = _mm256_mul_ps(y3, v4); + + _mm256_storeu_ps(y + i + 0, y0); + _mm256_storeu_ps(y + i + 8, y1); + _mm256_storeu_ps(y + i + 16, y2); + _mm256_storeu_ps(y + i + 24, y3); + } + + // leftovers + for (int i = n32; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }