mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-14 12:19:48 -04:00
CANN: Add broadcast for softmax and FA (#15208)
* refactor softmax * fix fa * fix mask shape * format * add comments * Remove whitespace
This commit is contained in:
@@ -812,7 +812,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_cann_release_resources(ctx, src_trans_tensor);
|
||||
return;
|
||||
} else {
|
||||
GGML_ABORT("Unsupport dst is not tontiguous.");
|
||||
GGML_ABORT("Unsupport dst is not contiguous.");
|
||||
}
|
||||
}
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
@@ -1330,160 +1330,196 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the Alibi (Attention with Linear Biases) mechanism to the
|
||||
* @details This function implements the Alibi mechanism, which introduces
|
||||
* learnable biases into the attention scores to simulate relative
|
||||
* position encoding without the need for explicit positional
|
||||
* embeddings.
|
||||
* @brief Generate a range of values and apply a scalar base exponentiation.
|
||||
*
|
||||
* @param ctx The backend CANN context for executing operations.
|
||||
* @param acl_src The source tensor representing the query or key.
|
||||
* @param acl_position The position tensor containing relative positions.
|
||||
* @param acl_dst The destination tensor where the result will be stored.
|
||||
* @param n_head The number of attention heads.
|
||||
* @param src_ne The dimensions of the source tensor.
|
||||
* @param src_nb0 The byte size of the first dimension of the source
|
||||
tensor.
|
||||
* @param max_bias The maximum bias value used in the Alibi mechanism.
|
||||
* @param dst The destination tensor object for additional metadata.
|
||||
* This function creates an evenly spaced sequence from `start` to `stop` (exclusive),
|
||||
* with step size `step`, stores it in a temporary buffer, and then computes:
|
||||
*
|
||||
* The function performs the following steps:
|
||||
* 1. Calculates the logarithm floor of the number of heads to determine the
|
||||
base for bias calculation.
|
||||
* 2. Initializes arrays with arithmetic sequences and fills them with bias
|
||||
values.
|
||||
* 3. Computes the bias tensor based on the calculated biases and arithmetic
|
||||
sequences.
|
||||
* 4. Reshapes the bias tensor to match the dimensions of the input tensors.
|
||||
* 5. Multiplies the position tensor by the bias tensor.
|
||||
* 6. Adds the result of the multiplication to the source tensor to produce the
|
||||
final output.
|
||||
* @f[
|
||||
* slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size
|
||||
* @f]
|
||||
*
|
||||
* The results are written to the provided @p slope_buffer.
|
||||
*
|
||||
* @param ctx CANN backend context for memory allocation and operator execution.
|
||||
* @param slope_buffer Pointer to the output buffer (float array) for the computed slope values.
|
||||
* @param m Scalar base for the exponentiation.
|
||||
* @param size Number of elements in the generated sequence.
|
||||
* @param start Starting exponent offset.
|
||||
* @param stop Stopping exponent offset (exclusive).
|
||||
* @param step Step size for the exponent increment.
|
||||
*/
|
||||
static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_position, aclTensor* acl_dst,
|
||||
const int n_head, int64_t* src_ne, const size_t src_nb0,
|
||||
float max_bias, ggml_tensor* dst) {
|
||||
const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
|
||||
GGML_ASSERT(src_nb0 == sizeof(float));
|
||||
GGML_ASSERT(n_head == src_ne[2]);
|
||||
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
|
||||
float m, int64_t size, float start, float stop, float step){
|
||||
int64_t ne[] = {size};
|
||||
size_t nb[] = {sizeof(float)};
|
||||
|
||||
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
|
||||
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
|
||||
void* arange_buffer = arange_allocator.get();
|
||||
|
||||
float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||
aclTensor* arange_tensor = ggml_cann_create_tensor(
|
||||
arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
|
||||
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
|
||||
|
||||
// init arange
|
||||
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
|
||||
ne2_ne3 * ggml_type_size(dst->type));
|
||||
void* tmp_arange_buffer = arange_allocator.get();
|
||||
aclTensor* slope_tensor = ggml_cann_create_tensor(
|
||||
slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
|
||||
|
||||
// arange1: [1, ..., n_heads_log2_floor+1)
|
||||
float start = 1;
|
||||
float stop = n_heads_log2_floor + 1;
|
||||
float step = 1;
|
||||
int64_t n_elements_arange = n_heads_log2_floor;
|
||||
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
|
||||
|
||||
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
|
||||
size_t tmp_arange1_nb[] = {sizeof(dst->type)};
|
||||
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
|
||||
tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
|
||||
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
|
||||
|
||||
aclTensor* tmp_arange2_tensor = nullptr;
|
||||
if (n_heads_log2_floor < ne2_ne3) {
|
||||
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
|
||||
start = 1;
|
||||
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
|
||||
step = 2;
|
||||
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
|
||||
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
|
||||
size_t tmp_arange2_nb[] = {sizeof(dst->type)};
|
||||
|
||||
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
|
||||
(char*)tmp_arange_buffer +
|
||||
n_heads_log2_floor * ggml_type_size(dst->type),
|
||||
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
|
||||
n_elements_arange);
|
||||
}
|
||||
|
||||
// init mk_base
|
||||
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
|
||||
ne2_ne3 * ggml_type_size(dst->type));
|
||||
void* tmp_mk_base_buffer = mk_base_allocator.get();
|
||||
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
|
||||
size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
|
||||
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
|
||||
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
|
||||
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
|
||||
|
||||
aclTensor* tmp_mk_base2_tensor = nullptr;
|
||||
if (n_heads_log2_floor < ne2_ne3) {
|
||||
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
|
||||
size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
|
||||
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
|
||||
(char*)tmp_mk_base_buffer +
|
||||
n_heads_log2_floor * ggml_type_size(dst->type),
|
||||
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
|
||||
}
|
||||
|
||||
// init mk
|
||||
int64_t tmp_mk_base_ne[] = {ne2_ne3};
|
||||
size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
|
||||
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
|
||||
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
|
||||
tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
|
||||
|
||||
// reshape mk
|
||||
int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
|
||||
size_t tmp_mk_nb[GGML_MAX_DIMS];
|
||||
tmp_mk_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
|
||||
}
|
||||
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
|
||||
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
|
||||
ACL_FORMAT_ND);
|
||||
|
||||
// acl_position * mk
|
||||
int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
|
||||
size_t tmp_output_nb[GGML_MAX_DIMS];
|
||||
tmp_output_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
|
||||
}
|
||||
ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
|
||||
void* tmp_output_buffer = output_allocator.get();
|
||||
aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
|
||||
tmp_output_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
|
||||
ACL_FORMAT_ND);
|
||||
aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
|
||||
|
||||
// add
|
||||
aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
|
||||
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
|
||||
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
|
||||
tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor);
|
||||
ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor);
|
||||
}
|
||||
|
||||
void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
/**
|
||||
* @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.
|
||||
*
|
||||
* This function generates slope values for each attention head according to the ALiBi
|
||||
* (Attention with Linear Biases) method. It splits the computation into two ranges depending
|
||||
* on whether the head index is less than @p n_head_log2 or not, and uses different base values
|
||||
* (`m0` and `m1`) for the exponentiation.
|
||||
*
|
||||
* @f[
|
||||
* slope[h] =
|
||||
* \begin{cases}
|
||||
* m_0^{(h + 1)}, & h < n\_head\_log2 \\
|
||||
* m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2
|
||||
* \end{cases}
|
||||
* \quad , \quad \text{if } max\_bias > 0
|
||||
* @f]
|
||||
*
|
||||
* If @p max_bias <= 0, all slope values are set to 1.0.
|
||||
*
|
||||
* @param ctx CANN backend context for memory allocation and operator execution.
|
||||
* @param n_head Total number of attention heads.
|
||||
* @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
|
||||
* @param max_bias Maximum bias value for slope computation.
|
||||
*
|
||||
*/
|
||||
static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
|
||||
void* slope_buffer, float max_bias) {
|
||||
const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
// const float slope = (max_bias > 0.0f) ?
|
||||
// h < n_head_log2 ?
|
||||
// powf(m0, h + 1) :
|
||||
// powf(m1, 2*(h - n_head_log2) + 1) :
|
||||
// 1.0f;
|
||||
// arange1
|
||||
float start = 0 + 1;
|
||||
float end = (n_head_log2 - 1) + 1;
|
||||
float step = 1;
|
||||
float count = n_head_log2;
|
||||
// end needs to be +1 because aclnn uses a left-closed, right-open interval.
|
||||
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
|
||||
if (n_head_log2 < n_head) {
|
||||
// arange2
|
||||
start = 2 * (n_head_log2 - n_head_log2) + 1;
|
||||
end = 2 * ((n_head - 1) - n_head_log2) + 1;
|
||||
step = 2;
|
||||
count = n_head - n_head_log2;
|
||||
aclnn_get_slope_inner(
|
||||
ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
|
||||
m1, count, start, end + 1, step);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.
|
||||
*
|
||||
* This function computes the ALiBi slopes for each attention head (if max_bias > 0),
|
||||
* multiplies them with the attention mask to produce bias tensors, and adds these biases
|
||||
* to the destination tensor (@p dst).
|
||||
*
|
||||
* The function performs necessary broadcasting of the mask and slope tensors to match
|
||||
* the shape of the destination tensor, then applies element-wise multiplication and addition
|
||||
* using CANN operators.
|
||||
*
|
||||
* @param ctx CANN backend context for memory management and operator execution.
|
||||
* @param mask Input attention mask tensor, assumed to be contiguous.
|
||||
* @param dst Destination tensor to which ALiBi biases will be added.
|
||||
* @param dst_ptr Pointer to the memory of the destination tensor.
|
||||
* @param max_bias Maximum bias value controlling the slope scaling.
|
||||
*
|
||||
* @note
|
||||
* - Write data into dst_ptr using only the shape information of the dst tensor.
|
||||
* - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.
|
||||
*/
|
||||
static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
|
||||
ggml_tensor* dst, void* dst_ptr, float max_bias) {
|
||||
void* slope_buffer = nullptr;
|
||||
void* bias_buffer = nullptr;
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
int64_t n_heads = dst->ne[2];
|
||||
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
|
||||
slope_buffer = slope_allocator.get();
|
||||
ggml_cann_pool_alloc bias_allocator(
|
||||
ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
|
||||
bias_buffer = bias_allocator.get();
|
||||
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
|
||||
}
|
||||
|
||||
// broadcast for mask, slop and dst;
|
||||
int64_t nr2 = dst->ne[2] / mask->ne[2];
|
||||
int64_t nr3 = dst->ne[3] / mask->ne[3];
|
||||
|
||||
// broadcast the mask across rows
|
||||
int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };
|
||||
size_t mask_nb[] = {
|
||||
mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],
|
||||
mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3]
|
||||
};
|
||||
|
||||
int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };
|
||||
size_t dst_nb[] = {
|
||||
dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],
|
||||
dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3]
|
||||
};
|
||||
|
||||
// slope is a 1 dim tensor, slope.ne2 == dst.ne2
|
||||
int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };
|
||||
size_t slope_nb[GGML_MAX_DIMS + 2];
|
||||
slope_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
|
||||
slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];
|
||||
}
|
||||
|
||||
aclTensor* acl_slope = ggml_cann_create_tensor(
|
||||
slope_buffer, ACL_FLOAT, sizeof(float),
|
||||
slope_ne, slope_nb, GGML_MAX_DIMS + 2);
|
||||
aclTensor* acl_mask = ggml_cann_create_tensor(
|
||||
mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);
|
||||
|
||||
// write data into dst_ptr using only the shape information of the dst tensor.
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(
|
||||
dst_ptr, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), dst_ne, dst_nb,
|
||||
GGML_MAX_DIMS + 2);
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };
|
||||
size_t bias_nb[GGML_MAX_DIMS + 2];
|
||||
bias_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
|
||||
bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];
|
||||
}
|
||||
aclTensor* bias_tensor = ggml_cann_create_tensor(
|
||||
bias_buffer, ACL_FLOAT, sizeof(float),
|
||||
bias_ne, bias_nb, GGML_MAX_DIMS + 2);
|
||||
|
||||
aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor);
|
||||
aclnn_add(ctx, acl_dst, bias_tensor);
|
||||
ggml_cann_release_resources(ctx, bias_tensor);
|
||||
} else {
|
||||
aclnn_add(ctx, acl_dst, acl_mask);
|
||||
}
|
||||
ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst);
|
||||
}
|
||||
|
||||
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cann_dup(ctx, dst);
|
||||
}
|
||||
|
||||
@@ -1501,118 +1537,41 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
* @param acl_dst The destination tensor where the softmax results will be
|
||||
* stored.
|
||||
*/
|
||||
static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
int64_t dim, aclTensor* acl_dst) {
|
||||
static void aclnn_softmax(ggml_backend_cann_context & ctx,
|
||||
aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
|
||||
}
|
||||
|
||||
void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor* src0 = dst->src[0];
|
||||
ggml_tensor* src1 = dst->src[1]; // mask
|
||||
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
float scale = 1.0f;
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
// input mul scale
|
||||
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
|
||||
ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));
|
||||
void* src_tensor_buffer = src_tensor_allocator.get();
|
||||
aclTensor* softmax_tensor = ggml_cann_create_tensor(
|
||||
src_tensor_buffer, ggml_cann_type_mapping(src0->type),
|
||||
ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS);
|
||||
|
||||
size_t n_bytes = ggml_nbytes(src0);
|
||||
ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
|
||||
void* input_mul_scale_buffer = mul_scale_allocator.get();
|
||||
aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
|
||||
input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
|
||||
src0->nb, GGML_MAX_DIMS);
|
||||
|
||||
bool inplace = false;
|
||||
aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
|
||||
aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false);
|
||||
|
||||
// mask
|
||||
aclTensor* acl_src1_fp32_tensor = nullptr;
|
||||
aclTensor* tmp_mask_tensor = nullptr;
|
||||
ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
|
||||
if (src1) {
|
||||
const bool use_f16 = src1->type == GGML_TYPE_F16;
|
||||
if (use_f16) {
|
||||
// cast to fp32
|
||||
size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
|
||||
size_t src1_fp32_nb[GGML_MAX_DIMS];
|
||||
src1_fp32_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
|
||||
}
|
||||
src1_fp32_allocator.alloc(n_bytes);
|
||||
void* src1_fp32_buffer = src1_fp32_allocator.get();
|
||||
acl_src1_fp32_tensor = ggml_cann_create_tensor(
|
||||
src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
|
||||
src1_fp32_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
|
||||
aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
|
||||
ggml_cann_release_resources(ctx, acl_src1);
|
||||
} else {
|
||||
acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
|
||||
}
|
||||
|
||||
// broadcast the mask across rows, only use ne11 of ne01 in mask
|
||||
if (src1->ne[1] != src0->ne[1]) {
|
||||
// mask shape: [1,1,ne11,ne10]
|
||||
int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
|
||||
size_t tmp_mask_nb[GGML_MAX_DIMS];
|
||||
tmp_mask_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
|
||||
}
|
||||
tmp_mask_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
}
|
||||
|
||||
// alibi
|
||||
const int n_head = src0->ne[2];
|
||||
const size_t src_nb0 = src0->nb[0];
|
||||
|
||||
n_bytes = ggml_nbytes(dst);
|
||||
ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
|
||||
void* output_buffer = output_allocator.get();
|
||||
aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
|
||||
output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
|
||||
dst->nb, GGML_MAX_DIMS);
|
||||
if (max_bias <= 0.0f) {
|
||||
// slope = 1.0
|
||||
if (tmp_mask_tensor) {
|
||||
aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
|
||||
alibi_output_tensor);
|
||||
} else {
|
||||
aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
|
||||
alibi_output_tensor);
|
||||
}
|
||||
} else {
|
||||
// slope != 1.0
|
||||
if (tmp_mask_tensor) {
|
||||
aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
|
||||
alibi_output_tensor, n_head, src0->ne, src_nb0,
|
||||
max_bias, dst);
|
||||
} else {
|
||||
aclnn_alibi(ctx, acl_input_mul_scale_tensor,
|
||||
acl_src1_fp32_tensor, alibi_output_tensor, n_head,
|
||||
src0->ne, src_nb0, max_bias, dst);
|
||||
}
|
||||
}
|
||||
|
||||
// softmax
|
||||
aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
|
||||
ggml_cann_release_resources(ctx, alibi_output_tensor);
|
||||
} else {
|
||||
aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
|
||||
aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);
|
||||
}
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
|
||||
acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
|
||||
// softmax
|
||||
aclnn_softmax(ctx, softmax_tensor, 3, acl_dst);
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -3208,104 +3167,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
// Compute the slope if needed. Derived from ggml_cann_softmax().
|
||||
if(maxBias != 0.0f){
|
||||
// alibi
|
||||
const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
|
||||
const int64_t n_head = src0->ne[2];
|
||||
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
|
||||
float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
|
||||
float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
|
||||
// init arange
|
||||
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
|
||||
ne2_ne3 * faElemSize);
|
||||
void* tmp_arange_buffer = arange_allocator.get();
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
|
||||
void* slope_buffer = slope_allocator.get();
|
||||
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
|
||||
|
||||
// arange1: [1, ..., n_heads_log2_floor+1)
|
||||
float start = 1;
|
||||
float stop = n_heads_log2_floor + 1;
|
||||
float step = 1;
|
||||
int64_t n_elements_arange = n_heads_log2_floor;
|
||||
|
||||
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
|
||||
size_t tmp_arange1_nb[] = {faElemSize};
|
||||
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
|
||||
tmp_arange_buffer, faDataType, faElemSize,
|
||||
tmp_arange1_ne, tmp_arange1_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
|
||||
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
|
||||
|
||||
aclTensor* tmp_arange2_tensor = nullptr;
|
||||
if (n_heads_log2_floor < ne2_ne3) {
|
||||
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
|
||||
start = 1;
|
||||
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
|
||||
step = 2;
|
||||
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
|
||||
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
|
||||
size_t tmp_arange2_nb[] = {faElemSize};
|
||||
|
||||
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
|
||||
(char*)tmp_arange_buffer +
|
||||
n_heads_log2_floor * faElemSize,
|
||||
faDataType, faElemSize,
|
||||
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
|
||||
n_elements_arange);
|
||||
int64_t slope_ne[] = {1, 1, n_heads, 1};
|
||||
size_t slope_nb[GGML_MAX_DIMS];
|
||||
slope_nb[0] = sizeof(float);
|
||||
for(int i = 1;i<GGML_MAX_DIMS;i++) {
|
||||
slope_nb[i] = slope_nb[i-1] * slope_ne[0];
|
||||
}
|
||||
|
||||
// init mk_base
|
||||
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
|
||||
ne2_ne3 * faElemSize);
|
||||
void* tmp_mk_base_buffer = mk_base_allocator.get();
|
||||
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
|
||||
size_t tmp_mk_base1_nb[] = {faElemSize};
|
||||
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
|
||||
tmp_mk_base_buffer, faDataType, faElemSize,
|
||||
tmp_mk_base1_ne, tmp_mk_base1_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclTensor* slope_tensor = ggml_cann_create_tensor(
|
||||
slope_buffer, ACL_FLOAT, sizeof(float),
|
||||
slope_ne, slope_nb, GGML_MAX_DIMS);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
|
||||
|
||||
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
|
||||
|
||||
aclTensor* tmp_mk_base2_tensor = nullptr;
|
||||
if (n_heads_log2_floor < ne2_ne3) {
|
||||
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
|
||||
size_t tmp_mk_base2_nb[] = {faElemSize};
|
||||
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
|
||||
(char*)tmp_mk_base_buffer +
|
||||
n_heads_log2_floor * faElemSize,
|
||||
faDataType, faElemSize,
|
||||
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
|
||||
}
|
||||
|
||||
// init mk
|
||||
int64_t tmp_mk_base_ne[] = {ne2_ne3};
|
||||
size_t tmp_mk_base_nb[] = {faElemSize};
|
||||
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
|
||||
tmp_mk_base_buffer, faDataType, faElemSize,
|
||||
tmp_mk_base_ne, tmp_mk_base_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
|
||||
tmp_arange_buffer, faDataType, faElemSize,
|
||||
tmp_mk_base_ne, tmp_mk_base_nb,
|
||||
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
|
||||
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
|
||||
|
||||
// reshape mk
|
||||
int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
|
||||
size_t tmp_mk_nb[GGML_MAX_DIMS];
|
||||
tmp_mk_nb[0] = faElemSize;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
|
||||
}
|
||||
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
|
||||
tmp_mk_base_buffer, faDataType, faElemSize,
|
||||
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
|
||||
ACL_FORMAT_ND);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
|
||||
|
||||
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
|
||||
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
|
||||
tmp_arange_tensor, tmp_mk_tensor);
|
||||
ggml_cann_release_resources(ctx, slope_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
// only support F32 and F16.
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return ggml_is_contiguous(op);
|
||||
} break;
|
||||
case GGML_OP_CONT: {
|
||||
// TODO: support GGML_TYPE_BF16
|
||||
@@ -2456,8 +2456,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
// value of paddingW should be at most half of kernelW
|
||||
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
||||
}
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_DUP:
|
||||
return ggml_is_contiguous(op);
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -2503,9 +2504,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
if (op->src[2]) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:{
|
||||
// derived from [ggml-cuda.cu]
|
||||
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
|
||||
@@ -2532,11 +2531,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
// DeepSeek MLA
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
float logitSoftcap = 0.0f;
|
||||
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
|
||||
if(logitSoftcap != 0.0f) {
|
||||
|
Reference in New Issue
Block a user