CUDA: fix race condition in FA vector kernels (llama/13742)

This commit is contained in:
Johannes Gäßler 2025-05-24 11:46:19 +02:00 committed by Georgi Gerganov
parent 994b4f86ab
commit f1576b2659
2 changed files with 2 additions and 0 deletions

View File

@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP

View File

@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP