CUDA: fix FA tg at long context for CC >= 8.9 (llama/13852)

This commit is contained in:
Johannes Gäßler 2025-05-28 13:33:37 +02:00 committed by Georgi Gerganov
parent 0035b8527c
commit 9a500394ad

View File

@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
__builtin_assume(tid < D);
extern __shared__ float2 meta[];
if (tid < 2*parallel_blocks) {
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
for (int i = tid; i < 2*parallel_blocks; i += D) {
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
}
__syncthreads();