CUDA: fix overflow in FA, tune performance (#14840)

This commit is contained in:
Johannes Gäßler
2025-07-23 21:43:25 +02:00
committed by Aaron Teo
parent 9db975e327
commit 5ad021f924
8 changed files with 98 additions and 246 deletions

View File

@@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3);
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33);
typedef half (*vec_dot_KQ_f16_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -892,14 +872,11 @@ void launch_fattn(
mask ? ((const char *) mask->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
nb21, nb22, nb23,
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
);
CUDA_CHECK(cudaGetLastError());

View File

@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int stride_K,
const int stride_V,
const int stride_mask,
const int jt,
half2 * const __restrict__ tile_Q,
half2 * const __restrict__ tile_K,
half2 * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
if (nstages <= 1) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
if (use_cp_async) {
cp_async_wait_all();
}
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
(K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
}
}
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
if (use_cp_async) {
cp_async_wait_all();
}
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
}
// Iterate over ne11 == previous tokens:
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
// With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
GGML_UNUSED(nb22); GGML_UNUSED(nb23);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
}

View File

@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
@@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
}
}
@@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
}
}
@@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
GGML_UNUSED(nb23);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}

View File

@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
@@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
GGML_UNUSED(nb23);
NO_DEVICE_CODE;
return;
}
@@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
}
@@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
}
}
@@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32(
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}

View File

@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
@@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16(
half2 VKQ[ncols] = {{0.0f, 0.0f}};
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
}
__syncthreads();
@@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum((float)sum);
if (use_logit_softcap) {
@@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16(
}
half2 V_k;
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
}
}
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads();
}
@@ -351,8 +338,7 @@ static __global__ void flash_attn_vec_ext_f16(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
GGML_UNUSED(nb23);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}

View File

@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
@@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
GGML_UNUSED(nb23);
NO_DEVICE_CODE;
return;
}
@@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f};
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
}
__syncthreads();
@@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
if (use_logit_softcap) {
@@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32(
break;
}
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
const float V_ki = dequantize_1_v(V + k*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*D + k];
}
}
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads();
}
@@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}

View File

@@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16(
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}

View File

@@ -280,23 +280,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
if (GGML_CUDA_CC_IS_AMD(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (fp16_mma_available(cc)) {
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
}
return;
}
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);