mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-02 07:55:57 +02:00
CUDA: fix race condition in FA vector kernels (llama/13742)
This commit is contained in:
parent
994b4f86ab
commit
f1576b2659
@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||||
|
__syncthreads();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_HIP
|
#endif // GGML_USE_HIP
|
||||||
|
@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||||
|
__syncthreads();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_HIP
|
#endif // GGML_USE_HIP
|
||||||
|
Loading…
x
Reference in New Issue
Block a user