HIP: implement FlashAttention via rocWMMA for CDNA and RDNA3+ (#12032)

Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check
Adds rocWMMA support to fattn-wmma-f16

---

Signed-off-by: Carl Klemm <carl@uvos.xyz>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Ben Jackson <ben@ben.com>
This commit is contained in:
David Huang
2025-03-04 05:10:54 +08:00
committed by GitHub
parent dfd6b2c0be
commit becade5de7
6 changed files with 145 additions and 95 deletions

View File

@@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
}
}
@@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
} else
@@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
sum += (T) (sumid4d8 + m4s8scaled);
}
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
v |= (vh << 18) & 0x00100000; // 2 -> 20
v |= (vh << 25) & 0x10000000; // 3 -> 28
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
}
}
@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
v |= (vh << 18) & 0x00100000; // 2 -> 20
v |= (vh << 25) & 0x10000000; // 3 -> 28
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
} else
@@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
sum += (T) (sumid5d8 + m5s8scaled);
}
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
T Q_d;
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
} else {
const float2 * Q_ds = (const float2 *) Q_ds_v;
Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
Q_d = Q_ds[k_KQ_0/warp_size].x;
}
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
}
return sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const half2 * K_h2 = (const half2 *) K_c;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
half2 sum2 = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
}
return __low2half(sum2) + __high2half(sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
float sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
}
return sum;
@@ -698,6 +704,8 @@ void launch_fattn(
GGML_ASSERT(Q->ne[3] == 1);
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@@ -750,7 +758,7 @@ void launch_fattn(
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 block_dim(warp_size, nwarps, 1);
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
GGML_ASSERT(block_dim.x % warp_size == 0);
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
(const char *) Q->data,
K_data,