mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-24 06:02:04 +01:00
Add AVX,AVX2 support for ggml_vec_scale_f32
This commit is contained in:
parent
1eb81f863f
commit
419b8a6402
40
ggml.c
40
ggml.c
@ -1118,7 +1118,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
||||||
|
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||||
|
#if defined(__AVX__) || defined(__AVX2__)
|
||||||
|
// AVX 256-bit
|
||||||
|
const int n32 = (n & ~31);
|
||||||
|
|
||||||
|
const __m256 v4 = _mm256_set1_ps(v);
|
||||||
|
|
||||||
|
__m256 y0, y1, y2, y3;
|
||||||
|
|
||||||
|
for (int i = 0; i < n32; i += 32) {
|
||||||
|
y0 = _mm256_loadu_ps(y + i + 0);
|
||||||
|
y1 = _mm256_loadu_ps(y + i + 8);
|
||||||
|
y2 = _mm256_loadu_ps(y + i + 16);
|
||||||
|
y3 = _mm256_loadu_ps(y + i + 24);
|
||||||
|
|
||||||
|
y0 = _mm256_mul_ps(y0, v4);
|
||||||
|
y1 = _mm256_mul_ps(y1, v4);
|
||||||
|
y2 = _mm256_mul_ps(y2, v4);
|
||||||
|
y3 = _mm256_mul_ps(y3, v4);
|
||||||
|
|
||||||
|
_mm256_storeu_ps(y + i + 0, y0);
|
||||||
|
_mm256_storeu_ps(y + i + 8, y1);
|
||||||
|
_mm256_storeu_ps(y + i + 16, y2);
|
||||||
|
_mm256_storeu_ps(y + i + 24, y3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (int i = n32; i < n; ++i) {
|
||||||
|
y[i] *= v;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// scalar
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] *= v;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
|
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
|
||||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
|
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
|
||||||
|
Loading…
Reference in New Issue
Block a user