mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
CUDA: fix race condition in FA vector kernels (#13742)
This commit is contained in:
@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
}
|
||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||
__syncthreads();
|
||||
continue;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
}
|
||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||
__syncthreads();
|
||||
continue;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
Reference in New Issue
Block a user