CANN: Add the basic supports of Flash Attention kernel (#13627)

* cann: add the basic FA support

* cann: update the readme

* cann: update the FlashAttention with PSEShift

* cann: update the input parameters in FA

* cann: update the alibi with max_bias

* cann: add the constrints of softcap

* cann: update the docs CANN.md

* cann: update the docs CANN.md

* cann: fix typo of CANN.md

* cann: add some comments and update the CANN.md

* cann: update the CANN.md

* cann: update the inner precise for fusedInferAttention

* cann: update the constraints of flash_attn_ext on ggml-cann.cpp

* cann: clean the whitespace

* cann: clean the whitespace

* cann: add a new endline
This commit is contained in:
Bizhao Shi
2025-05-26 10:20:18 +08:00
committed by GitHub
parent e121edc432
commit 2d38b6e400
9 changed files with 392 additions and 0 deletions

9
docs/backend/CANN.md Normal file → Executable file
View File

@ -280,6 +280,15 @@ cmake --build build --config release
### **GitHub contribution**:
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.
## Updates
### Basic Flash Attention Support
The basic FA kernel with aclnnops has been added in aclnn_ops.cpp.
Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap.
Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future.
Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn).
We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.
## TODO
- Support more models and data types.

0
ggml/src/ggml-cann/CMakeLists.txt Normal file → Executable file
View File

0
ggml/src/ggml-cann/Doxyfile Normal file → Executable file
View File

2
ggml/src/ggml-cann/acl_tensor.cpp Normal file → Executable file
View File

@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
return ACL_FLOAT;
case GGML_TYPE_F16:
return ACL_FLOAT16;
case GGML_TYPE_BF16:
return ACL_BF16;
case GGML_TYPE_I8:
return ACL_INT8;
case GGML_TYPE_I16:

0
ggml/src/ggml-cann/acl_tensor.h Normal file → Executable file
View File

330
ggml/src/ggml-cann/aclnn_ops.cpp Normal file → Executable file
View File

@ -66,6 +66,7 @@
#include <aclnnop/aclnn_gt_scalar.h>
#include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_grouped_matmul_v2.h>
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <float.h>
#include <cmath>
@ -74,11 +75,13 @@
#include <vector>
#include "ggml-impl.h"
#include "ggml.h"
#define GGML_COMMON_DECL_C
#include "../ggml-common.h"
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
aclTensor ** acl_src1, aclTensor ** acl_dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
@ -2861,3 +2864,330 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
break;
}
}
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
ggml_tensor* src0 = dst->src[0]; // q, fp32
ggml_tensor* src1 = dst->src[1]; // k, fp16
ggml_tensor* src2 = dst->src[2]; // v, fp16
ggml_tensor* src3 = dst->src[3]; // mask, fp16
float maxBias = 0.0f;
float scaleValue = 1.0f;
float logitSoftcap = 0.0f;
memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float));
memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float));
memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float));
if(logitSoftcap == 0.0f){
size_t faElemSize = sizeof(uint16_t);
auto faDataType = ACL_FLOAT16; //ACL_BF16;
aclTensor* acl_src0_f16_tensor = nullptr;
aclTensor* acl_src1_f16_tensor = nullptr;
aclTensor* acl_src2_f16_tensor = nullptr;
aclTensor* acl_dst_f16_tensor = nullptr;
// Step 1: cast the src0 (Query) to fp16 if needed
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
void* src0_f16_buffer = nullptr;
if(ggml_cann_type_mapping(src0->type) != faDataType){
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
src0_f16_buffer = src0_f16_allocator.alloc(
ggml_nelements(src0) * faElemSize);
int64_t* src0_f16_ne = src0->ne;
size_t src0_f16_nb[GGML_MAX_DIMS];
src0_f16_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
}
acl_src0_f16_tensor = ggml_cann_create_tensor(
src0_f16_buffer, faDataType, faElemSize,
src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
);
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
}else{
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
}
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
// and the direct output from FusedInferAttention
acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0->ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
acl_dst_f16_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
aclTensor* bcast_pse_tensor = nullptr;
int64_t bcast_pse_ne[GGML_MAX_DIMS];
size_t bcast_pse_nb[GGML_MAX_DIMS];
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
void* bcast_pse_buffer = nullptr;
if(src3 != nullptr){
bcast_pse_buffer = bcast_pse_allocator.alloc(
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
if(src0->ne[1] > 1){
// Case 1: broadcast pse for prefill stage with multiple head
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
bcast_pse_ne[0] = src3->ne[0];
bcast_pse_ne[1] = src3->ne[1];
bcast_pse_ne[2] = src0->ne[2];
bcast_pse_ne[3] = src3->ne[3];
bcast_pse_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
}
bcast_pse_tensor = ggml_cann_create_tensor(
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
int64_t repeats[] = {1, src0->ne[2], 1, 1};
aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
}else{
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
size_t* trunc_pse_nb = src3->nb;
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
src3->data, ACL_FLOAT16, sizeof(uint16_t),
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
bcast_pse_ne[0] = src3->ne[0];
bcast_pse_ne[1] = src0->ne[1];
bcast_pse_ne[2] = src0->ne[2];
bcast_pse_ne[3] = src3->ne[3];
bcast_pse_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
}
bcast_pse_tensor = ggml_cann_create_tensor(
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
int64_t repeats[] = {1, src0->ne[2], 1, 1};
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
}
// Compute the slope if needed. Derived from ggml_cann_softmax().
if(maxBias != 0.0f){
// alibi
const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
const int64_t n_head = src0->ne[2];
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
// init arange
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
ne2_ne3 * faElemSize);
void* tmp_arange_buffer = arange_allocator.get();
// arange1: [1, ..., n_heads_log2_floor+1)
float start = 1;
float stop = n_heads_log2_floor + 1;
float step = 1;
int64_t n_elements_arange = n_heads_log2_floor;
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
size_t tmp_arange1_nb[] = {faElemSize};
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, faDataType, faElemSize,
tmp_arange1_ne, tmp_arange1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
aclTensor* tmp_arange2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
start = 1;
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
step = 2;
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_arange2_nb[] = {faElemSize};
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
(char*)tmp_arange_buffer +
n_heads_log2_floor * faElemSize,
faDataType, faElemSize,
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
n_elements_arange);
}
// init mk_base
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
ne2_ne3 * faElemSize);
void* tmp_mk_base_buffer = mk_base_allocator.get();
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
size_t tmp_mk_base1_nb[] = {faElemSize};
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_base1_ne, tmp_mk_base1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
aclTensor* tmp_mk_base2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_mk_base2_nb[] = {faElemSize};
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
(char*)tmp_mk_base_buffer +
n_heads_log2_floor * faElemSize,
faDataType, faElemSize,
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
}
// init mk
int64_t tmp_mk_base_ne[] = {ne2_ne3};
size_t tmp_mk_base_nb[] = {faElemSize};
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, faDataType, faElemSize,
tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
// reshape mk
int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
size_t tmp_mk_nb[GGML_MAX_DIMS];
tmp_mk_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
}
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
ACL_FORMAT_ND);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
tmp_arange_tensor, tmp_mk_tensor);
}
}
// Step 4: set the inputs for FusedInferAttention.
int kvTensorNum = 1;
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
int64_t numHeads = src0->ne[2]; // N
int64_t numKeyValueHeads = src1->ne[2];
// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
int64_t preTokens = 65535;
int64_t nextTokens = 65535;
char layout[5] = {'B', 'N', 'S', 'D', 0};
int64_t sparseMode = 0;
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
int64_t blockSize = 0;
int64_t antiquantMode = 0;
bool softmaxLseFlag = false;
int64_t keyAntiquantMode = 0;
int64_t valueAntiquantMode = 0;
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
bcast_pse_tensor, nullptr, // pse, mask
nullptr, nullptr, // actSeqLen, actSeqLenkv
nullptr, nullptr, // deqScale1, quantScale1
nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
nullptr, nullptr, // antiquantScale, antiquantOffset
nullptr, // blockTable
nullptr, nullptr, // qPadSize, kvPadSize
nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
numHeads, scaleValue, // heads, scaleValue
preTokens, nextTokens, // preTokens, nextTokens
layout, // inputLayout
numKeyValueHeads, // numKVHeads
sparseMode, innerPrecise, // sparseMode, innerPrecise
blockSize, antiquantMode, // blockSize, antiquantMode
softmaxLseFlag, // softmaxLseFlag
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
acl_dst_f16_tensor, // attentionOut
nullptr // softmaxLse
);
// Step 6: post-processing, permute and cast to f32
int64_t new_dim[] = {0, 2, 1, 3};
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
if(ggml_cann_type_mapping(dst->type) != faDataType){
ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
void* perm_out_f16_buffer = perm_out_f16_allocator.get();
int64_t* perm_out_f16_ne = dst->ne;
size_t perm_out_f16_nb[GGML_MAX_DIMS];
perm_out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
}
aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
perm_out_f16_buffer, faDataType, faElemSize,
perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
aclnn_cast(ctx,
acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
}else{
// only need to permute
aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
}
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
acl_dst_f16_tensor,
acl_dst_tensor);
if(src3 != nullptr){
ggml_cann_release_resources(ctx, bcast_pse_tensor);
}
}else{
GGML_ABORT("Function is not implemented.");
}
}

15
ggml/src/ggml-cann/aclnn_ops.h Normal file → Executable file
View File

@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*/
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
/**
* @brief Performs the Flash Attention extended operator using the CANN backend.
*
* @details This function implements the memory-efficient Flash Attention algorithm
* for computing scaled dot-product attention with hardware acceleration.
* The result is stored in the destination tensor `dst`.
*
* This operation is accelerated using the CANN backend to improve runtime performance.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the result will be stored.
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
*/
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
/*
* @brief A generic wrapper for ACL resources with custom deleter support.
*/

0
ggml/src/ggml-cann/common.h Normal file → Executable file
View File

36
ggml/src/ggml-cann/ggml-cann.cpp Normal file → Executable file
View File

@ -36,6 +36,7 @@
#include "ggml-backend-impl.h"
#include "ggml-cann/aclnn_ops.h"
#include "ggml-cann/common.h"
#include "ggml.h"
#define GGML_COMMON_DECL_C
@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_COUNT_EQUAL:
ggml_cann_count_equal(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
default:
return false;
}
@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_FLASH_ATTN_EXT:{
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
return false;
}
if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
return false;
}
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}
float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
if(logitSoftcap != 0.0f) {
return false;
}
return true;
}
default:
return false;
}