From 2baea5e4b328f74de657a53da2f8e7ddd39e7c37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 8 Aug 2025 08:19:58 +0200 Subject: [PATCH] CUDA: attention sinks for mma FlashAttention (llama/15157) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 87 ++++++++++++++++++++++------ ggml/src/ggml-cuda/fattn.cu | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +- 3 files changed, 73 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 37125384..39731baa 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half2 * const __restrict__ mask_h2, + const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. @@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16( while (kbc < kbc_stop && kb0_stop == iter_k) { const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); + const int head0 = zt * ncols2; + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16( } const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); + const int head0 = zt * ncols2; + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8ddd0415..6c1185de 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS] - if (sinks) { + if (sinks && !fp16_mma_available(cc)) { if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ec7ab255..19e9c405 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; } // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS] - if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported + if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) + && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { return false; } if (op->src[0]->ne[0] == 192) {