CUDA: fix negative KV_max values in FA (#15321)

This commit is contained in:
Johannes Gäßler
2025-08-14 23:21:24 +02:00
committed by GitHub
parent df36bce667
commit 4227c9be42

View File

@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
all_inf = warp_reduce_all(all_inf); all_inf = warp_reduce_all(all_inf);
if (!all_inf) { if (!all_inf) {
KV_max_sj += FATTN_KQ_STRIDE;
break; break;
} }
} }
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
KV_max_sj += FATTN_KQ_STRIDE;
if (threadIdx.x != 0) { if (threadIdx.x != 0) {
return; return;
} }