mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-20 01:38:04 +02:00
ggml : add ggml_gelu_erf() CUDA kernel (llama/13719)
* ggml : add ggml_gelu_erf() CUDA kernel * missing semicolon
This commit is contained in:
parent
f1576b2659
commit
2d6c6862f7
@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_SILU:
|
||||
ggml_cuda_op_silu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
ggml_cuda_op_gelu_erf(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
ggml_cuda_op_gelu_quick(ctx, dst);
|
||||
break;
|
||||
@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
|
@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
||||
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_quick(float x) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
|
||||
@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
|
||||
}
|
||||
|
@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
Loading…
x
Reference in New Issue
Block a user