mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-18 23:57:09 +02:00
CUDA: fix FA tg at long context for CC >= 8.9 (llama/13852)
This commit is contained in:
parent
0035b8527c
commit
9a500394ad
@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
|
|||||||
__builtin_assume(tid < D);
|
__builtin_assume(tid < D);
|
||||||
|
|
||||||
extern __shared__ float2 meta[];
|
extern __shared__ float2 meta[];
|
||||||
if (tid < 2*parallel_blocks) {
|
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
||||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
|
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user