From 92b8810ec7aa6d778bc287cc918443cf67b962e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 30 Jul 2025 15:46:13 +0200 Subject: [PATCH] CUDA: skip masked KV slices for all FA kernels (#14924) --- ggml/src/ggml-cuda/common.cuh | 14 ++++++ ggml/src/ggml-cuda/fattn-common.cuh | 75 +++++++++++++++++++++++++++- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 21 ++++++-- ggml/src/ggml-cuda/fattn-tile-f16.cu | 4 +- ggml/src/ggml-cuda/fattn-tile-f32.cu | 4 +- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 26 ++-------- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 25 ++-------- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 4 +- ggml/src/ggml-cuda/fattn.cu | 3 +- 9 files changed, 120 insertions(+), 56 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 44daafbf7..4560d85c4 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -432,6 +432,20 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n dst[row] = norm ? sum / ncols : sum; } +template +static __device__ __forceinline__ int warp_reduce_all(int x) { +#ifdef GGML_USE_HIP +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = x && __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +#else + static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); + return __all_sync(0xffffffff, x); +#endif // GGML_USE_HIP +} + template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 0cc74f284..b6db446c6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -500,6 +501,55 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } +template +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) +static __global__ void flash_attn_mask_to_KV_max( + const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) { + const int ne31 = gridDim.x; + const int tid = threadIdx.x; + const int sequence = blockIdx.y; + const int jt = blockIdx.x; + + mask += sequence*s33 + jt*ncols1*s31; + + __shared__ int buf_iw[WARP_SIZE]; + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + KV_max_sj += FATTN_KQ_STRIDE; + break; + } + } + + if (threadIdx.x != 0) { + return; + } + + KV_max[sequence*ne31 + jt] = KV_max_sj; +} + template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( @@ -711,6 +761,7 @@ void launch_fattn( ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); @@ -779,11 +830,30 @@ void launch_fattn( V_data = (char *) V_f16.ptr; } - int parallel_blocks = 1; - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or + // multiple sequences of possibly different lengths. + if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + const int s31 = mask->nb[1] / sizeof(half2); + const int s33 = mask->nb[3] / sizeof(half2); + + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); + + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + + KV_max.alloc(ne_KV_max); + flash_attn_mask_to_KV_max<<>> + ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33); + CUDA_CHECK(cudaGetLastError()); + } + + int parallel_blocks = 1; + const dim3 block_dim(warp_size, nwarps, 1); int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); @@ -870,6 +940,7 @@ void launch_fattn( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 8e847d361..a86b95428 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } // Iterate over ne11 == previous tokens: - for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + int kb0 = kb0_start; + for (; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, @@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } // With multi-stage loading there is no __syncthreads at the end of the iter, @@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + + if (KV_max) { + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + } constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { @@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + + if (KV_max) { + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 678288c13..9d0b24ae7 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -90,7 +91,8 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new[ncols/nwarps]; diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index bc283d9a7..be72f76fb 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { // Calculate KQ tile and keep track of new maximum KQ values: float kqmax_new[ncols/nwarps]; diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index afef815ce..a2df2f66b 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -177,10 +178,11 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ[ncols] = {{0.0f, 0.0f}}; + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, // Increment pointers after each loop: K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { @@ -191,29 +193,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; } - __syncthreads(); - - // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. - // In such cases, skip the KV slice. - // On AMD __all_sync would not work correctly because it assumes a warp size of 64. -#ifndef GGML_USE_HIP - bool skip = true; -#pragma unroll - for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]); - skip = skip && isinf(tmp.x) && isinf(tmp.y); - } - } - if (__all_sync(0xFFFFFFFF, skip)) { - __syncthreads(); - continue; - } -#endif // GGML_USE_HIP } // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 3595e2969..9ab0fc133 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -183,10 +184,11 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, // Increment pointers after each loop: K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { @@ -197,28 +199,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); } - __syncthreads(); - - // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. - // In such cases, skip the KV slice. - // On AMD __all_sync would not work correctly because it assumes a warp size of 64. -#ifndef GGML_USE_HIP - bool skip = true; -#pragma unroll - for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - skip = skip && isinf(maskf_shared[j*D + i]); - } - } - if (__all_sync(0xFFFFFFFF, skip)) { - __syncthreads(); - continue; - } -#endif // GGML_USE_HIP } float kqmax_new_arr[ncols]; diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f4393ec57..40554204b 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -165,7 +166,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index d9f161305..a51136f6b 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -315,7 +315,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && + (Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion; const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) {