CANN: Implement GLU ops (#14884)

Implement REGLU, GEGLU, SWIGLU ops according to #14158
This commit is contained in:
hipudding
2025-07-26 17:56:18 +08:00
committed by GitHub
parent 9b8f3c6c77
commit 11dd5a44eb
4 changed files with 194 additions and 40 deletions

View File

@ -77,6 +77,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
for (int i = 0; i < final_dims; i++) {
acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];
}
size_t elem_offset = offset / ggml_element_size(tensor);
acl_storage_len += elem_offset;
// Reverse ne and stride.
std::reverse(acl_ne, acl_ne + final_dims);
@ -84,7 +86,7 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
aclTensor* acl_tensor = aclCreateTensor(
acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
elem_offset, format, &acl_storage_len, 1,
tensor->data);
return acl_tensor;

View File

@ -99,7 +99,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT
}
}
void ggml_cann_unary_op(
void ggml_cann_op_unary(
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];
@ -111,6 +111,42 @@ void ggml_cann_unary_op(
ggml_cann_release_resources(ctx, acl_src, acl_dst);
}
void ggml_cann_op_unary_gated(
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr;
if(src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
acl_src0 = ggml_cann_create_tensor(src0);
acl_src1 = ggml_cann_create_tensor(src1);
} else {
int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]};
size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]};
acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0);
acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0));
if (swapped) {
std::swap(acl_src0, acl_src1);
}
}
unary_op(ctx, acl_src0, acl_dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
if(src1)
ggml_cann_release_resources(ctx, acl_src1);
}
/**
* @brief Repeats elements of a tensor along each dimension according to the
* specified repeat array.

View File

@ -1098,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
*/
template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];
aclTensor* acl_src = ggml_cann_create_tensor(src);
@ -1109,49 +1109,125 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
}
/**
* @brief Applies a unary operation to a ggml tensor using the CANN backend.
* @brief Applies a unary operation to a ggml tensor using the CANN backend.
*
* @details This function performs a unary operation on the input tensor using
* a user-provided lambda or callable object `unary_op`, which accepts the CANN
* context and two ACL tensors (source and destination). Internally, this function
* creates ACL representations of the ggml tensors and invokes the unary operation.
* The result is stored in the destination tensor `dst`. This utility abstracts the
* common boilerplate of tensor conversion and cleanup when implementing unary ops.
* @details This function applies a unary operation to the input tensor using
* a user-provided lambda or callable `unary_op`. The lambda receives the
* CANN backend context and two ACL tensors: the source and the destination.
*
* @param unary_op A callable that performs the unary operation using CANN APIs.
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the result will be stored.
* The source tensor is retrieved from `dst->src[0]`.
* Internally, this function handles the conversion from GGML tensors to ACL tensors,
* calls the provided unary op, and manages resource cleanup. The input is assumed
* to be `dst->src[0]`, and the result is written to `dst`.
*
* This utility simplifies writing unary op wrappers by abstracting tensor preparation.
*
* @param unary_op A callable that performs the unary operation using CANN ACL APIs.
* @param ctx The CANN context for operation execution.
* @param dst The destination ggml_tensor where the result will be stored.
* The input tensor is assumed to be `dst->src[0]`.
*
* @see GGML_CANN_CALL_OP_UNARY
*/
void ggml_cann_unary_op(
void ggml_cann_op_unary(
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
ggml_backend_cann_context& ctx, ggml_tensor* dst);
/**
* @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
*
* This macro defines an inline lambda wrapping a specific ACL operation name,
* and passes it to the templated ggml_cann_unary_op function. It simplifies
* calling unary ops by hiding the lambda boilerplate.
* @details This function performs a gated activation such as GEGLU or ReGLU.
* It supports two input modes:
*
* 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
* These are used directly as the value and gate tensors.
*
* 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
* contain a concatenation of value and gate along the first dimension. This tensor
* will be split into two equal halves to form the value and gate inputs.
*
* The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
* then multiplies the result in-place with the gate tensor:
*
* Internally, the lambda will call:
* @code
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
* dst = unary_op(value) * gate;
* @endcode
*
* The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
* order of value/gate in the packed input case.
*
* @param unary_op A callable that performs the unary operation using CANN ACL APIs.
* It receives (ctx, acl_value_tensor, acl_output_tensor).
* @param ctx The CANN context used for execution.
* @param dst The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
*
* @see GGML_CANN_CALL_OP_UNARY_GATED
*/
void ggml_cann_op_unary_gated(
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
ggml_backend_cann_context& ctx, ggml_tensor* dst);
/**
* @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
*
* This macro wraps the specified ACLNN unary operator name into a lambda expression,
* and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
* unary ops in the CANN backend.
*
* Internally, this macro expands to a lambda like:
* @code
* [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
* };
* @endcode
*
* This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
*
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
*
* @see ggml_cann_unary_op
* @see ggml_cann_op_unary
* @see GGML_CANN_CALL_ACLNN_OP
*/
#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
do { \
auto lambda = [](ggml_backend_cann_context& ctx, \
aclTensor* acl_src, \
aclTensor* acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
}; \
ggml_cann_unary_op(lambda, ctx, dst); \
ggml_cann_op_unary(lambda, ctx, dst); \
} \
while (0)
/**
* @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
*
* This macro wraps the specified ACLNN unary operator name into a lambda expression,
* and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
* executing gated unary ops in the CANN backend.
*
* Internally, this macro expands to a lambda like:
* @code
* [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
* };
* @endcode
*
* This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
*
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
*
* @see ggml_cann_op_unary_gated
* @see GGML_CANN_CALL_ACLNN_OP
*/
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
do { \
auto lambda = [](ggml_backend_cann_context& ctx, \
aclTensor* acl_src, \
aclTensor* acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
}; \
ggml_cann_op_unary_gated(lambda, ctx, dst); \
} \
while (0)
#endif // CANN_ACLNN_OPS

View File

@ -1681,16 +1681,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_ABS:
GGML_CANN_CALL_UNARY_OP(Abs);
GGML_CANN_CALL_OP_UNARY(Abs);
break;
case GGML_UNARY_OP_NEG:
GGML_CANN_CALL_UNARY_OP(Neg);
GGML_CANN_CALL_OP_UNARY(Neg);
break;
case GGML_UNARY_OP_GELU:
GGML_CANN_CALL_UNARY_OP(Gelu);
case GGML_UNARY_OP_GELU_ERF:
// aclnnGelu internally uses the erf-based approximation.
GGML_CANN_CALL_OP_UNARY(Gelu);
break;
case GGML_UNARY_OP_SILU:
GGML_CANN_CALL_UNARY_OP(Silu);
GGML_CANN_CALL_OP_UNARY(Silu);
break;
case GGML_UNARY_OP_GELU_QUICK: {
auto lambda = [](ggml_backend_cann_context& ctx,
@ -1698,31 +1700,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
aclTensor* acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
};
ggml_cann_unary_op(lambda, ctx, dst);
ggml_cann_op_unary(lambda, ctx, dst);
} break;
case GGML_UNARY_OP_TANH:
GGML_CANN_CALL_UNARY_OP(Tanh);
GGML_CANN_CALL_OP_UNARY(Tanh);
break;
case GGML_UNARY_OP_RELU:
GGML_CANN_CALL_UNARY_OP(Relu);
GGML_CANN_CALL_OP_UNARY(Relu);
break;
case GGML_UNARY_OP_SIGMOID:
GGML_CANN_CALL_UNARY_OP(Sigmoid);
GGML_CANN_CALL_OP_UNARY(Sigmoid);
break;
case GGML_UNARY_OP_HARDSIGMOID:
GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
break;
case GGML_UNARY_OP_HARDSWISH:
GGML_CANN_CALL_UNARY_OP(Hardswish);
GGML_CANN_CALL_OP_UNARY(Hardswish);
break;
case GGML_UNARY_OP_EXP:
GGML_CANN_CALL_UNARY_OP(Exp);
GGML_CANN_CALL_OP_UNARY(Exp);
break;
case GGML_UNARY_OP_ELU:
ggml_cann_elu(ctx, dst);
break;
case GGML_UNARY_OP_SGN:
GGML_CANN_CALL_UNARY_OP(Sign);
GGML_CANN_CALL_OP_UNARY(Sign);
break;
case GGML_UNARY_OP_STEP:
ggml_cann_step(ctx, dst);
@ -1731,6 +1733,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
return false;
}
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(dst)) {
case GGML_GLU_OP_REGLU:
GGML_CANN_CALL_OP_UNARY_GATED(Relu);
break;
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_GEGLU_ERF:
// aclnnGelu internally uses the erf-based approximation.
GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
break;
case GGML_GLU_OP_SWIGLU:
GGML_CANN_CALL_OP_UNARY_GATED(Silu);
break;
case GGML_GLU_OP_GEGLU_QUICK: {
auto lambda = [](ggml_backend_cann_context& ctx,
aclTensor* acl_src,
aclTensor* acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
};
ggml_cann_op_unary_gated(lambda, ctx, dst);
} break;
default:
return false;
}
break;
case GGML_OP_NORM:
ggml_cann_norm(ctx, dst);
break;
@ -1773,7 +1800,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_binary_op<aclnn_mul>(ctx, dst);
break;
case GGML_OP_SQRT:
GGML_CANN_CALL_UNARY_OP(Sqrt);
GGML_CANN_CALL_OP_UNARY(Sqrt);
break;
case GGML_OP_CLAMP:
ggml_cann_clamp(ctx, dst);
@ -1818,16 +1845,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_argmax(ctx, dst);
break;
case GGML_OP_COS:
ggml_cann_unary_op<aclnn_cos>(ctx, dst);
ggml_cann_op_unary<aclnn_cos>(ctx, dst);
break;
case GGML_OP_SIN:
ggml_cann_unary_op<aclnn_sin>(ctx, dst);
ggml_cann_op_unary<aclnn_sin>(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cann_conv_transpose_1d(ctx, dst);
break;
case GGML_OP_LOG:
GGML_CANN_CALL_UNARY_OP(Log);
GGML_CANN_CALL_OP_UNARY(Log);
break;
case GGML_OP_MEAN:
ggml_cann_mean(ctx, dst);
@ -2101,10 +2128,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_GELU_ERF:
return true;
default:
return false;
}
case GGML_OP_GLU:
switch (ggml_get_glu_op(op)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return true;
default:
return false;
}
break;
case GGML_OP_MUL_MAT: {
switch (op->src[0]->type) {
case GGML_TYPE_F16: