mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-30 06:03:37 -04:00
CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case (#12183)
- Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value. - Prefer vector flash attention kernels over MMA kernel for BS=1 Fixes Issue: #12182 --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
template<int D> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
||||
dst += D * gridDim.y*blockIdx.x;
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||
dst += D * gridDim.z*blockIdx.x;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
__shared__ float2 meta[parallel_blocks];
|
||||
extern __shared__ float2 meta[];
|
||||
if (tid < 2*parallel_blocks) {
|
||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float kqmax = meta[0].x;
|
||||
#pragma unroll
|
||||
for (int l = 1; l < parallel_blocks; ++l) {
|
||||
kqmax = max(kqmax, meta[l].x);
|
||||
}
|
||||
|
||||
float VKQ_numerator = 0.0f;
|
||||
float VKQ_denominator = 0.0f;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const float diff = meta[l].x - kqmax;
|
||||
const float KQ_max_scale = expf(diff);
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
static void on_no_fattn_vec_case(const int D) {
|
||||
@@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
}
|
||||
}
|
||||
|
||||
// parallel_blocks == 0 is stream-k decomposition
|
||||
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
||||
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
|
||||
const int warp_size = WARP_SIZE
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
@@ -748,12 +745,14 @@ void launch_fattn(
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
dim3 blocks_num;
|
||||
if (parallel_blocks == 0) {
|
||||
if (stream_k) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int max_blocks = 2*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
@@ -769,9 +768,43 @@ void launch_fattn(
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||
} else {
|
||||
blocks_num.x = parallel_blocks*ntiles_x;
|
||||
blocks_num.y = Q->ne[2];
|
||||
blocks_num.z = Q->ne[3];
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
|
||||
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
|
||||
// Test whether parallel_blocks can be set to a higher value for better efficiency.
|
||||
const int blocks_per_wave = nsm * max_blocks_per_sm;
|
||||
int nwaves_best = 0;
|
||||
int efficiency_percent_best = 0;
|
||||
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
|
||||
const int nblocks_total = ntiles_total * parallel_blocks_test;
|
||||
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
|
||||
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
||||
|
||||
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
||||
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (efficiency_percent > efficiency_percent_best) {
|
||||
nwaves_best = nwaves;
|
||||
efficiency_percent_best = efficiency_percent;
|
||||
parallel_blocks = parallel_blocks_test;
|
||||
}
|
||||
}
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
@@ -803,7 +836,7 @@ void launch_fattn(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
@@ -815,7 +848,7 @@ void launch_fattn(
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if constexpr (parallel_blocks == 0) {
|
||||
if (stream_k) {
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
@@ -824,13 +857,14 @@ void launch_fattn(
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
} else if constexpr (parallel_blocks > 1) {
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
flash_attn_combine_results<D>
|
||||
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
Reference in New Issue
Block a user