diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c6bdd4fb3..885c56492 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2246,6 +2246,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; + case GGML_UNARY_OP_REGLU: + ggml_cuda_op_reglu(ctx, dst); + break; + case GGML_UNARY_OP_GEGLU: + ggml_cuda_op_geglu(ctx, dst); + break; + case GGML_UNARY_OP_SWIGLU: + ggml_cuda_op_swiglu(ctx, dst); + break; default: return false; } @@ -3039,6 +3048,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]); + case GGML_UNARY_OP_REGLU: + case GGML_UNARY_OP_GEGLU: + case GGML_UNARY_OP_SWIGLU: + return ggml_is_contiguous_1(op->src[0]); default: return false; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2c0375fbe..c98564a31 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -196,6 +196,62 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +/* gated ops */ + +template +static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int k, const int n, const int o) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op on first half of row and multiply with gate in second half + const int j = (i / n) * o + (i % n); + dst[i] = (T)(op((float)x[j]) * (float)x[j + n]); +} + +template +static void unary_gated_cuda(const T * x, T * dst, const int k, const int n, const int o, cudaStream_t stream) { + const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; + unary_gated_op_kernel<<>>(x, dst, k, n, o); +} + +template +void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + const int nc = src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] >= nc); + GGML_ASSERT(ggml_nrows(dst) >= ggml_nrows(src0)); + + if (src0->type == GGML_TYPE_F16) { + unary_gated_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream); + } else { + unary_gated_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream); + } +} + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 6686fc17e..d4533d24e 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -57,3 +57,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);