CANN: Add broadcast for softmax and FA (#15208)

* refactor softmax

* fix fa

* fix mask shape

* format

* add comments

* Remove whitespace
This commit is contained in:
hipudding
2025-08-11 22:50:31 +08:00
committed by GitHub
parent cf9e5648a7
commit be48528b06
2 changed files with 216 additions and 343 deletions

View File

@@ -812,7 +812,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_cann_release_resources(ctx, src_trans_tensor);
return;
} else {
GGML_ABORT("Unsupport dst is not tontiguous.");
GGML_ABORT("Unsupport dst is not contiguous.");
}
}
ggml_cann_release_resources(ctx, acl_src, acl_dst);
@@ -1330,160 +1330,196 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
}
/**
* @brief Applies the Alibi (Attention with Linear Biases) mechanism to the
* @details This function implements the Alibi mechanism, which introduces
* learnable biases into the attention scores to simulate relative
* position encoding without the need for explicit positional
* embeddings.
* @brief Generate a range of values and apply a scalar base exponentiation.
*
* @param ctx The backend CANN context for executing operations.
* @param acl_src The source tensor representing the query or key.
* @param acl_position The position tensor containing relative positions.
* @param acl_dst The destination tensor where the result will be stored.
* @param n_head The number of attention heads.
* @param src_ne The dimensions of the source tensor.
* @param src_nb0 The byte size of the first dimension of the source
tensor.
* @param max_bias The maximum bias value used in the Alibi mechanism.
* @param dst The destination tensor object for additional metadata.
* This function creates an evenly spaced sequence from `start` to `stop` (exclusive),
* with step size `step`, stores it in a temporary buffer, and then computes:
*
* The function performs the following steps:
* 1. Calculates the logarithm floor of the number of heads to determine the
base for bias calculation.
* 2. Initializes arrays with arithmetic sequences and fills them with bias
values.
* 3. Computes the bias tensor based on the calculated biases and arithmetic
sequences.
* 4. Reshapes the bias tensor to match the dimensions of the input tensors.
* 5. Multiplies the position tensor by the bias tensor.
* 6. Adds the result of the multiplication to the source tensor to produce the
final output.
* @f[
* slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size
* @f]
*
* The results are written to the provided @p slope_buffer.
*
* @param ctx CANN backend context for memory allocation and operator execution.
* @param slope_buffer Pointer to the output buffer (float array) for the computed slope values.
* @param m Scalar base for the exponentiation.
* @param size Number of elements in the generated sequence.
* @param start Starting exponent offset.
* @param stop Stopping exponent offset (exclusive).
* @param step Step size for the exponent increment.
*/
static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_position, aclTensor* acl_dst,
const int n_head, int64_t* src_ne, const size_t src_nb0,
float max_bias, ggml_tensor* dst) {
const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
GGML_ASSERT(src_nb0 == sizeof(float));
GGML_ASSERT(n_head == src_ne[2]);
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
float m, int64_t size, float start, float stop, float step){
int64_t ne[] = {size};
size_t nb[] = {sizeof(float)};
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
void* arange_buffer = arange_allocator.get();
float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
aclTensor* arange_tensor = ggml_cann_create_tensor(
arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
// init arange
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
ne2_ne3 * ggml_type_size(dst->type));
void* tmp_arange_buffer = arange_allocator.get();
aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
// arange1: [1, ..., n_heads_log2_floor+1)
float start = 1;
float stop = n_heads_log2_floor + 1;
float step = 1;
int64_t n_elements_arange = n_heads_log2_floor;
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
size_t tmp_arange1_nb[] = {sizeof(dst->type)};
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
aclTensor* tmp_arange2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
start = 1;
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
step = 2;
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_arange2_nb[] = {sizeof(dst->type)};
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
(char*)tmp_arange_buffer +
n_heads_log2_floor * ggml_type_size(dst->type),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
n_elements_arange);
}
// init mk_base
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
ne2_ne3 * ggml_type_size(dst->type));
void* tmp_mk_base_buffer = mk_base_allocator.get();
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
aclTensor* tmp_mk_base2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
(char*)tmp_mk_base_buffer +
n_heads_log2_floor * ggml_type_size(dst->type),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
}
// init mk
int64_t tmp_mk_base_ne[] = {ne2_ne3};
size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
// reshape mk
int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
size_t tmp_mk_nb[GGML_MAX_DIMS];
tmp_mk_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
}
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
ACL_FORMAT_ND);
// acl_position * mk
int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
size_t tmp_output_nb[GGML_MAX_DIMS];
tmp_output_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
}
ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
void* tmp_output_buffer = output_allocator.get();
aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
tmp_output_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
ACL_FORMAT_ND);
aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
// add
aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor);
ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor);
}
void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
/**
* @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.
*
* This function generates slope values for each attention head according to the ALiBi
* (Attention with Linear Biases) method. It splits the computation into two ranges depending
* on whether the head index is less than @p n_head_log2 or not, and uses different base values
* (`m0` and `m1`) for the exponentiation.
*
* @f[
* slope[h] =
* \begin{cases}
* m_0^{(h + 1)}, & h < n\_head\_log2 \\
* m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2
* \end{cases}
* \quad , \quad \text{if } max\_bias > 0
* @f]
*
* If @p max_bias <= 0, all slope values are set to 1.0.
*
* @param ctx CANN backend context for memory allocation and operator execution.
* @param n_head Total number of attention heads.
* @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
* @param max_bias Maximum bias value for slope computation.
*
*/
static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
void* slope_buffer, float max_bias) {
const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// const float slope = (max_bias > 0.0f) ?
// h < n_head_log2 ?
// powf(m0, h + 1) :
// powf(m1, 2*(h - n_head_log2) + 1) :
// 1.0f;
// arange1
float start = 0 + 1;
float end = (n_head_log2 - 1) + 1;
float step = 1;
float count = n_head_log2;
// end needs to be +1 because aclnn uses a left-closed, right-open interval.
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
if (n_head_log2 < n_head) {
// arange2
start = 2 * (n_head_log2 - n_head_log2) + 1;
end = 2 * ((n_head - 1) - n_head_log2) + 1;
step = 2;
count = n_head - n_head_log2;
aclnn_get_slope_inner(
ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
m1, count, start, end + 1, step);
}
}
/**
* @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.
*
* This function computes the ALiBi slopes for each attention head (if max_bias > 0),
* multiplies them with the attention mask to produce bias tensors, and adds these biases
* to the destination tensor (@p dst).
*
* The function performs necessary broadcasting of the mask and slope tensors to match
* the shape of the destination tensor, then applies element-wise multiplication and addition
* using CANN operators.
*
* @param ctx CANN backend context for memory management and operator execution.
* @param mask Input attention mask tensor, assumed to be contiguous.
* @param dst Destination tensor to which ALiBi biases will be added.
* @param dst_ptr Pointer to the memory of the destination tensor.
* @param max_bias Maximum bias value controlling the slope scaling.
*
* @note
* - Write data into dst_ptr using only the shape information of the dst tensor.
* - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.
*/
static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
ggml_tensor* dst, void* dst_ptr, float max_bias) {
void* slope_buffer = nullptr;
void* bias_buffer = nullptr;
if (max_bias > 0.0f) {
int64_t n_heads = dst->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
slope_buffer = slope_allocator.get();
ggml_cann_pool_alloc bias_allocator(
ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
bias_buffer = bias_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
}
// broadcast for mask, slop and dst;
int64_t nr2 = dst->ne[2] / mask->ne[2];
int64_t nr3 = dst->ne[3] / mask->ne[3];
// broadcast the mask across rows
int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };
size_t mask_nb[] = {
mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],
mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3]
};
int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };
size_t dst_nb[] = {
dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],
dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3]
};
// slope is a 1 dim tensor, slope.ne2 == dst.ne2
int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };
size_t slope_nb[GGML_MAX_DIMS + 2];
slope_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];
}
aclTensor* acl_slope = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float),
slope_ne, slope_nb, GGML_MAX_DIMS + 2);
aclTensor* acl_mask = ggml_cann_create_tensor(
mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);
// write data into dst_ptr using only the shape information of the dst tensor.
aclTensor* acl_dst = ggml_cann_create_tensor(
dst_ptr, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst_ne, dst_nb,
GGML_MAX_DIMS + 2);
if (max_bias > 0.0f) {
int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };
size_t bias_nb[GGML_MAX_DIMS + 2];
bias_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];
}
aclTensor* bias_tensor = ggml_cann_create_tensor(
bias_buffer, ACL_FLOAT, sizeof(float),
bias_ne, bias_nb, GGML_MAX_DIMS + 2);
aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor);
aclnn_add(ctx, acl_dst, bias_tensor);
ggml_cann_release_resources(ctx, bias_tensor);
} else {
aclnn_add(ctx, acl_dst, acl_mask);
}
ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst);
}
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_cann_dup(ctx, dst);
}
@@ -1501,118 +1537,41 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
* @param acl_dst The destination tensor where the softmax results will be
* stored.
*/
static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
int64_t dim, aclTensor* acl_dst) {
static void aclnn_softmax(ggml_backend_cann_context & ctx,
aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
}
void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1]; // mask
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
float scale = 1.0f;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// input mul scale
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));
void* src_tensor_buffer = src_tensor_allocator.get();
aclTensor* softmax_tensor = ggml_cann_create_tensor(
src_tensor_buffer, ggml_cann_type_mapping(src0->type),
ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS);
size_t n_bytes = ggml_nbytes(src0);
ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
void* input_mul_scale_buffer = mul_scale_allocator.get();
aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
src0->nb, GGML_MAX_DIMS);
bool inplace = false;
aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false);
// mask
aclTensor* acl_src1_fp32_tensor = nullptr;
aclTensor* tmp_mask_tensor = nullptr;
ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
if (src1) {
const bool use_f16 = src1->type == GGML_TYPE_F16;
if (use_f16) {
// cast to fp32
size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
size_t src1_fp32_nb[GGML_MAX_DIMS];
src1_fp32_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
}
src1_fp32_allocator.alloc(n_bytes);
void* src1_fp32_buffer = src1_fp32_allocator.get();
acl_src1_fp32_tensor = ggml_cann_create_tensor(
src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
src1_fp32_nb, GGML_MAX_DIMS);
aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
ggml_cann_release_resources(ctx, acl_src1);
} else {
acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
}
// broadcast the mask across rows, only use ne11 of ne01 in mask
if (src1->ne[1] != src0->ne[1]) {
// mask shape: [1,1,ne11,ne10]
int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
size_t tmp_mask_nb[GGML_MAX_DIMS];
tmp_mask_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
}
tmp_mask_tensor = ggml_cann_create_tensor(
src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
}
// alibi
const int n_head = src0->ne[2];
const size_t src_nb0 = src0->nb[0];
n_bytes = ggml_nbytes(dst);
ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
void* output_buffer = output_allocator.get();
aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
dst->nb, GGML_MAX_DIMS);
if (max_bias <= 0.0f) {
// slope = 1.0
if (tmp_mask_tensor) {
aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
alibi_output_tensor);
} else {
aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
alibi_output_tensor);
}
} else {
// slope != 1.0
if (tmp_mask_tensor) {
aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
alibi_output_tensor, n_head, src0->ne, src_nb0,
max_bias, dst);
} else {
aclnn_alibi(ctx, acl_input_mul_scale_tensor,
acl_src1_fp32_tensor, alibi_output_tensor, n_head,
src0->ne, src_nb0, max_bias, dst);
}
}
// softmax
aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
ggml_cann_release_resources(ctx, alibi_output_tensor);
} else {
aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);
}
ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
// softmax
aclnn_softmax(ctx, softmax_tensor, 3, acl_dst);
ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor);
}
/**
@@ -3208,104 +3167,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
// Compute the slope if needed. Derived from ggml_cann_softmax().
if(maxBias != 0.0f){
// alibi
const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
const int64_t n_head = src0->ne[2];
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
// init arange
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
ne2_ne3 * faElemSize);
void* tmp_arange_buffer = arange_allocator.get();
const int64_t n_heads = src0->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
void* slope_buffer = slope_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
// arange1: [1, ..., n_heads_log2_floor+1)
float start = 1;
float stop = n_heads_log2_floor + 1;
float step = 1;
int64_t n_elements_arange = n_heads_log2_floor;
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
size_t tmp_arange1_nb[] = {faElemSize};
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, faDataType, faElemSize,
tmp_arange1_ne, tmp_arange1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
aclTensor* tmp_arange2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
start = 1;
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
step = 2;
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_arange2_nb[] = {faElemSize};
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
(char*)tmp_arange_buffer +
n_heads_log2_floor * faElemSize,
faDataType, faElemSize,
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
n_elements_arange);
int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS];
slope_nb[0] = sizeof(float);
for(int i = 1;i<GGML_MAX_DIMS;i++) {
slope_nb[i] = slope_nb[i-1] * slope_ne[0];
}
// init mk_base
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
ne2_ne3 * faElemSize);
void* tmp_mk_base_buffer = mk_base_allocator.get();
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
size_t tmp_mk_base1_nb[] = {faElemSize};
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_base1_ne, tmp_mk_base1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float),
slope_ne, slope_nb, GGML_MAX_DIMS);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
aclTensor* tmp_mk_base2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_mk_base2_nb[] = {faElemSize};
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
(char*)tmp_mk_base_buffer +
n_heads_log2_floor * faElemSize,
faDataType, faElemSize,
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
}
// init mk
int64_t tmp_mk_base_ne[] = {ne2_ne3};
size_t tmp_mk_base_nb[] = {faElemSize};
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, faDataType, faElemSize,
tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
// reshape mk
int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
size_t tmp_mk_nb[GGML_MAX_DIMS];
tmp_mk_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
}
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
ACL_FORMAT_ND);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
tmp_arange_tensor, tmp_mk_tensor);
ggml_cann_release_resources(ctx, slope_tensor);
}
}

View File

@@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// only support F32 and F16.
return false;
}
return true;
return ggml_is_contiguous(op);
} break;
case GGML_OP_CONT: {
// TODO: support GGML_TYPE_BF16
@@ -2456,8 +2456,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_SUM:
case GGML_OP_DUP:
return ggml_is_contiguous(op);
case GGML_OP_SUM:
case GGML_OP_IM2COL:
case GGML_OP_CONCAT:
case GGML_OP_REPEAT:
@@ -2503,9 +2504,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if (op->src[2]) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
return true;
case GGML_OP_FLASH_ATTN_EXT:{
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2532,11 +2531,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// DeepSeek MLA
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
if (op->src[0]->ne[3] != 1) {
return false;
}
float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
if(logitSoftcap != 0.0f) {