mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-30 06:03:37 -04:00
CUDA: use async data loading for FlashAttention (#11894)
* CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
@@ -716,7 +716,9 @@ void launch_fattn(
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
@@ -768,13 +770,14 @@ void launch_fattn(
|
||||
dim3 blocks_num;
|
||||
if (parallel_blocks == 0) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
|
||||
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
|
||||
const bool short_context = K->ne[1] < 4096;
|
||||
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
|
||||
|
||||
const int nblocks_stream_k = 2*nsm;
|
||||
|
||||
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
|
||||
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
|
||||
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
@@ -827,7 +830,7 @@ void launch_fattn(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if constexpr (parallel_blocks == 0) {
|
||||
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = blocks_num;
|
||||
|
||||
|
Reference in New Issue
Block a user