mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-27 03:33:46 -04:00
cuda : add ELU support (#14657)
This commit is contained in:
@ -2303,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_UNARY_OP_EXP:
|
case GGML_UNARY_OP_EXP:
|
||||||
ggml_cuda_op_exp(ctx, dst);
|
ggml_cuda_op_exp(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_ELU:
|
||||||
|
ggml_cuda_op_elu(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -3116,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_EXP:
|
case GGML_UNARY_OP_EXP:
|
||||||
|
case GGML_UNARY_OP_ELU:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
|
|||||||
return logf(x);
|
return logf(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_elu(float x) {
|
||||||
|
return (x > 0.f) ? x : expm1f(x);
|
||||||
|
}
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
ggml_cuda_op_unary<op_log>(ctx, dst);
|
ggml_cuda_op_unary<op_log>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_elu>(ctx, dst);
|
||||||
|
}
|
||||||
/* gated ops */
|
/* gated ops */
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
|
@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||||||
|
|
||||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
Reference in New Issue
Block a user