diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 07d6b8b67..0b409ce87 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -812,7 +812,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_release_resources(ctx, src_trans_tensor); return; } else { - GGML_ABORT("Unsupport dst is not tontiguous."); + GGML_ABORT("Unsupport dst is not contiguous."); } } ggml_cann_release_resources(ctx, acl_src, acl_dst); @@ -1330,160 +1330,196 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, } /** - * @brief Applies the Alibi (Attention with Linear Biases) mechanism to the - * @details This function implements the Alibi mechanism, which introduces - * learnable biases into the attention scores to simulate relative - * position encoding without the need for explicit positional - * embeddings. + * @brief Generate a range of values and apply a scalar base exponentiation. * - * @param ctx The backend CANN context for executing operations. - * @param acl_src The source tensor representing the query or key. - * @param acl_position The position tensor containing relative positions. - * @param acl_dst The destination tensor where the result will be stored. - * @param n_head The number of attention heads. - * @param src_ne The dimensions of the source tensor. - * @param src_nb0 The byte size of the first dimension of the source - tensor. - * @param max_bias The maximum bias value used in the Alibi mechanism. - * @param dst The destination tensor object for additional metadata. + * This function creates an evenly spaced sequence from `start` to `stop` (exclusive), + * with step size `step`, stores it in a temporary buffer, and then computes: * - * The function performs the following steps: - * 1. Calculates the logarithm floor of the number of heads to determine the - base for bias calculation. - * 2. Initializes arrays with arithmetic sequences and fills them with bias - values. - * 3. Computes the bias tensor based on the calculated biases and arithmetic - sequences. - * 4. Reshapes the bias tensor to match the dimensions of the input tensors. - * 5. Multiplies the position tensor by the bias tensor. - * 6. Adds the result of the multiplication to the source tensor to produce the - final output. + * @f[ + * slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size + * @f] + * + * The results are written to the provided @p slope_buffer. + * + * @param ctx CANN backend context for memory allocation and operator execution. + * @param slope_buffer Pointer to the output buffer (float array) for the computed slope values. + * @param m Scalar base for the exponentiation. + * @param size Number of elements in the generated sequence. + * @param start Starting exponent offset. + * @param stop Stopping exponent offset (exclusive). + * @param step Step size for the exponent increment. */ -static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_position, aclTensor* acl_dst, - const int n_head, int64_t* src_ne, const size_t src_nb0, - float max_bias, ggml_tensor* dst) { - const int64_t ne2_ne3 = src_ne[2] * src_ne[3]; - GGML_ASSERT(src_nb0 == sizeof(float)); - GGML_ASSERT(n_head == src_ne[2]); +static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, + float m, int64_t size, float start, float stop, float step){ + int64_t ne[] = {size}; + size_t nb[] = {sizeof(float)}; - const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); + ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float)); + void* arange_buffer = arange_allocator.get(); - float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + aclTensor* arange_tensor = ggml_cann_create_tensor( + arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); + aclnn_arange(ctx, arange_tensor, start, stop, step, size); - // init arange - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - ne2_ne3 * ggml_type_size(dst->type)); - void* tmp_arange_buffer = arange_allocator.get(); + aclTensor* slope_tensor = ggml_cann_create_tensor( + slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); - // arange1: [1, ..., n_heads_log2_floor+1) - float start = 1; - float stop = n_heads_log2_floor + 1; - float step = 1; - int64_t n_elements_arange = n_heads_log2_floor; + aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); - int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; - size_t tmp_arange1_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - - aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); - - aclTensor* tmp_arange2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) - start = 1; - stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; - step = 2; - n_elements_arange = ne2_ne3 - n_heads_log2_floor; - int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; - size_t tmp_arange2_nb[] = {sizeof(dst->type)}; - - aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( - (char*)tmp_arange_buffer + - n_heads_log2_floor * ggml_type_size(dst->type), - ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), - tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, - n_elements_arange); - } - - // init mk_base - ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), - ne2_ne3 * ggml_type_size(dst->type)); - void* tmp_mk_base_buffer = mk_base_allocator.get(); - int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; - size_t tmp_mk_base1_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - - aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); - - aclTensor* tmp_mk_base2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; - size_t tmp_mk_base2_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( - (char*)tmp_mk_base_buffer + - n_heads_log2_floor * ggml_type_size(dst->type), - ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), - tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); - } - - // init mk - int64_t tmp_mk_base_ne[] = {ne2_ne3}; - size_t tmp_mk_base_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); - - // reshape mk - int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]}; - size_t tmp_mk_nb[GGML_MAX_DIMS]; - tmp_mk_nb[0] = ggml_type_size(dst->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; - } - aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); - - // acl_position * mk - int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]}; - size_t tmp_output_nb[GGML_MAX_DIMS]; - tmp_output_nb[0] = ggml_type_size(dst->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1]; - } - ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst)); - void* tmp_output_buffer = output_allocator.get(); - aclTensor* tmp_output_tensor = ggml_cann_create_tensor( - tmp_output_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); - aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor); - - // add - aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst); - ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, - tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, - tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor); + GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor); + ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor); } -void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +/** + * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters. + * + * This function generates slope values for each attention head according to the ALiBi + * (Attention with Linear Biases) method. It splits the computation into two ranges depending + * on whether the head index is less than @p n_head_log2 or not, and uses different base values + * (`m0` and `m1`) for the exponentiation. + * + * @f[ + * slope[h] = + * \begin{cases} + * m_0^{(h + 1)}, & h < n\_head\_log2 \\ + * m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2 + * \end{cases} + * \quad , \quad \text{if } max\_bias > 0 + * @f] + * + * If @p max_bias <= 0, all slope values are set to 1.0. + * + * @param ctx CANN backend context for memory allocation and operator execution. + * @param n_head Total number of attention heads. + * @param slope_buffer Pointer to the output buffer (float array) for storing slopes. + * @param max_bias Maximum bias value for slope computation. + * +*/ +static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, + void* slope_buffer, float max_bias) { + const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // const float slope = (max_bias > 0.0f) ? + // h < n_head_log2 ? + // powf(m0, h + 1) : + // powf(m1, 2*(h - n_head_log2) + 1) : + // 1.0f; + // arange1 + float start = 0 + 1; + float end = (n_head_log2 - 1) + 1; + float step = 1; + float count = n_head_log2; + // end needs to be +1 because aclnn uses a left-closed, right-open interval. + aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step); + if (n_head_log2 < n_head) { + // arange2 + start = 2 * (n_head_log2 - n_head_log2) + 1; + end = 2 * ((n_head - 1) - n_head_log2) + 1; + step = 2; + count = n_head - n_head_log2; + aclnn_get_slope_inner( + ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), + m1, count, start, end + 1, step); + } +} + +/** + * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask. + * + * This function computes the ALiBi slopes for each attention head (if max_bias > 0), + * multiplies them with the attention mask to produce bias tensors, and adds these biases + * to the destination tensor (@p dst). + * + * The function performs necessary broadcasting of the mask and slope tensors to match + * the shape of the destination tensor, then applies element-wise multiplication and addition + * using CANN operators. + * + * @param ctx CANN backend context for memory management and operator execution. + * @param mask Input attention mask tensor, assumed to be contiguous. + * @param dst Destination tensor to which ALiBi biases will be added. + * @param dst_ptr Pointer to the memory of the destination tensor. + * @param max_bias Maximum bias value controlling the slope scaling. + * + * @note + * - Write data into dst_ptr using only the shape information of the dst tensor. + * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting. + */ +static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, + ggml_tensor* dst, void* dst_ptr, float max_bias) { + void* slope_buffer = nullptr; + void* bias_buffer = nullptr; + + if (max_bias > 0.0f) { + int64_t n_heads = dst->ne[2]; + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + slope_buffer = slope_allocator.get(); + ggml_cann_pool_alloc bias_allocator( + ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); + bias_buffer = bias_allocator.get(); + aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias); + } + + // broadcast for mask, slop and dst; + int64_t nr2 = dst->ne[2] / mask->ne[2]; + int64_t nr3 = dst->ne[3] / mask->ne[3]; + + // broadcast the mask across rows + int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 }; + size_t mask_nb[] = { + mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2], + mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] + }; + + int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 }; + size_t dst_nb[] = { + dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2], + dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] + }; + + // slope is a 1 dim tensor, slope.ne2 == dst.ne2 + int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 }; + size_t slope_nb[GGML_MAX_DIMS + 2]; + slope_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS + 2; i++) { + slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1]; + } + + aclTensor* acl_slope = ggml_cann_create_tensor( + slope_buffer, ACL_FLOAT, sizeof(float), + slope_ne, slope_nb, GGML_MAX_DIMS + 2); + aclTensor* acl_mask = ggml_cann_create_tensor( + mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); + + // write data into dst_ptr using only the shape information of the dst tensor. + aclTensor* acl_dst = ggml_cann_create_tensor( + dst_ptr, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst_ne, dst_nb, + GGML_MAX_DIMS + 2); + + if (max_bias > 0.0f) { + int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 }; + size_t bias_nb[GGML_MAX_DIMS + 2]; + bias_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS + 2; i++) { + bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1]; + } + aclTensor* bias_tensor = ggml_cann_create_tensor( + bias_buffer, ACL_FLOAT, sizeof(float), + bias_ne, bias_nb, GGML_MAX_DIMS + 2); + + aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor); + aclnn_add(ctx, acl_dst, bias_tensor); + ggml_cann_release_resources(ctx, bias_tensor); + } else { + aclnn_add(ctx, acl_dst, acl_mask); + } + ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst); +} + +void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_dup(ctx, dst); } @@ -1501,118 +1537,41 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param acl_dst The destination tensor where the softmax results will be * stored. */ -static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src, - int64_t dim, aclTensor* acl_dst) { +static void aclnn_softmax(ggml_backend_cann_context & ctx, + aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) { GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst); } -void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor* src0 = dst->src[0]; ggml_tensor* src1 = dst->src[1]; // mask aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); - float scale = 1.0f; + float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, (float*)dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // input mul scale aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); + ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0)); + void* src_tensor_buffer = src_tensor_allocator.get(); + aclTensor* softmax_tensor = ggml_cann_create_tensor( + src_tensor_buffer, ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS); - size_t n_bytes = ggml_nbytes(src0); - ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes); - void* input_mul_scale_buffer = mul_scale_allocator.get(); - aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor( - input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne, - src0->nb, GGML_MAX_DIMS); - - bool inplace = false; - aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace); + aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false); // mask - aclTensor* acl_src1_fp32_tensor = nullptr; - aclTensor* tmp_mask_tensor = nullptr; - ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool()); if (src1) { - const bool use_f16 = src1->type == GGML_TYPE_F16; - if (use_f16) { - // cast to fp32 - size_t n_bytes = ggml_nelements(src1) * sizeof(float_t); - size_t src1_fp32_nb[GGML_MAX_DIMS]; - src1_fp32_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1]; - } - src1_fp32_allocator.alloc(n_bytes); - void* src1_fp32_buffer = src1_fp32_allocator.get(); - acl_src1_fp32_tensor = ggml_cann_create_tensor( - src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne, - src1_fp32_nb, GGML_MAX_DIMS); - aclTensor* acl_src1 = ggml_cann_create_tensor(src1); - aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT); - ggml_cann_release_resources(ctx, acl_src1); - } else { - acl_src1_fp32_tensor = ggml_cann_create_tensor(src1); - } - - // broadcast the mask across rows, only use ne11 of ne01 in mask - if (src1->ne[1] != src0->ne[1]) { - // mask shape: [1,1,ne11,ne10] - int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1}; - size_t tmp_mask_nb[GGML_MAX_DIMS]; - tmp_mask_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1]; - } - tmp_mask_tensor = ggml_cann_create_tensor( - src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - } - - // alibi - const int n_head = src0->ne[2]; - const size_t src_nb0 = src0->nb[0]; - - n_bytes = ggml_nbytes(dst); - ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes); - void* output_buffer = output_allocator.get(); - aclTensor* alibi_output_tensor = ggml_cann_create_tensor( - output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne, - dst->nb, GGML_MAX_DIMS); - if (max_bias <= 0.0f) { - // slope = 1.0 - if (tmp_mask_tensor) { - aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor, - alibi_output_tensor); - } else { - aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor, - alibi_output_tensor); - } - } else { - // slope != 1.0 - if (tmp_mask_tensor) { - aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor, - alibi_output_tensor, n_head, src0->ne, src_nb0, - max_bias, dst); - } else { - aclnn_alibi(ctx, acl_input_mul_scale_tensor, - acl_src1_fp32_tensor, alibi_output_tensor, n_head, - src0->ne, src_nb0, max_bias, dst); - } - } - - // softmax - aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst); - ggml_cann_release_resources(ctx, alibi_output_tensor); - } else { - aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst); + aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias); } - - ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst, - acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor); + // softmax + aclnn_softmax(ctx, softmax_tensor, 3, acl_dst); + ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor); } /** @@ -3208,104 +3167,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ // Compute the slope if needed. Derived from ggml_cann_softmax(). if(maxBias != 0.0f){ // alibi - const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; - const int64_t n_head = src0->ne[2]; - const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); - float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); - float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); - // init arange - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - ne2_ne3 * faElemSize); - void* tmp_arange_buffer = arange_allocator.get(); + const int64_t n_heads = src0->ne[2]; + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + void* slope_buffer = slope_allocator.get(); + aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias); - // arange1: [1, ..., n_heads_log2_floor+1) - float start = 1; - float stop = n_heads_log2_floor + 1; - float step = 1; - int64_t n_elements_arange = n_heads_log2_floor; - - int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; - size_t tmp_arange1_nb[] = {faElemSize}; - aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, faDataType, faElemSize, - tmp_arange1_ne, tmp_arange1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - - aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); - - aclTensor* tmp_arange2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) - start = 1; - stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; - step = 2; - n_elements_arange = ne2_ne3 - n_heads_log2_floor; - int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; - size_t tmp_arange2_nb[] = {faElemSize}; - - aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( - (char*)tmp_arange_buffer + - n_heads_log2_floor * faElemSize, - faDataType, faElemSize, - tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, - n_elements_arange); + int64_t slope_ne[] = {1, 1, n_heads, 1}; + size_t slope_nb[GGML_MAX_DIMS]; + slope_nb[0] = sizeof(float); + for(int i = 1;ine[2], src0->ne[3]}; - size_t tmp_mk_nb[GGML_MAX_DIMS]; - tmp_mk_nb[0] = faElemSize; - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; - } - aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, faDataType, faElemSize, - tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); - - ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, - tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, - tmp_arange_tensor, tmp_mk_tensor); + ggml_cann_release_resources(ctx, slope_tensor); } } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cf575b367..3d3520f19 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // only support F32 and F16. return false; } - return true; + return ggml_is_contiguous(op); } break; case GGML_OP_CONT: { // TODO: support GGML_TYPE_BF16 @@ -2456,8 +2456,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // value of paddingW should be at most half of kernelW return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); } - case GGML_OP_SUM: case GGML_OP_DUP: + return ggml_is_contiguous(op); + case GGML_OP_SUM: case GGML_OP_IM2COL: case GGML_OP_CONCAT: case GGML_OP_REPEAT: @@ -2503,9 +2504,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if (op->src[2]) { return false; } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); + return true; case GGML_OP_FLASH_ATTN_EXT:{ // derived from [ggml-cuda.cu] if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ @@ -2532,11 +2531,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // DeepSeek MLA return false; } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - if (op->src[0]->ne[3] != 1) { - return false; - } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); if(logitSoftcap != 0.0f) {