mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-18 05:56:00 -04:00
CUDA: fix negative KV_max values in FA (#15321)
This commit is contained in:
@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
|
|||||||
all_inf = warp_reduce_all(all_inf);
|
all_inf = warp_reduce_all(all_inf);
|
||||||
|
|
||||||
if (!all_inf) {
|
if (!all_inf) {
|
||||||
KV_max_sj += FATTN_KQ_STRIDE;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
|
||||||
|
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
|
||||||
|
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
|
||||||
|
KV_max_sj += FATTN_KQ_STRIDE;
|
||||||
|
|
||||||
if (threadIdx.x != 0) {
|
if (threadIdx.x != 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user