From 926fe234e92b5627b01e613fb6e9495316276fb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 20 May 2025 14:45:07 +0200 Subject: [PATCH] CUDA: skip fully masked-out KV in FA vec kernel (llama/13584) * CUDA: skip fully masked-out KV in FA vec kernel --- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 52 +++++++++++++++++++++++++--- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 51 ++++++++++++++++++++++++--- 2 files changed, 95 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index d96e3921..798a59b2 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#ifndef GGML_USE_HIP __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16( NO_DEVICE_CODE; return; } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ half maskh_shared[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = 0.0f; + } + __syncthreads(); // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16( for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; + } + + __syncthreads(); + + // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. + // In such cases, skip the KV slice. + // On AMD __all_sync would not work correctly because it assumes a warp size of 64. +#ifndef GGML_USE_HIP + bool skip = true; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]); + skip = skip && isinf(tmp.x) && isinf(tmp.y); + } + } + if (__all_sync(0xFFFFFFFF, skip)) { + continue; + } +#endif // GGML_USE_HIP + } + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, // see https://github.com/ggerganov/llama.cpp/pull/7061 . // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). @@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16( sum = logit_softcap*tanhf(sum); } - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += maskh_shared[j*D + i_KQ]; if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 7064675d..49c592ea 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#ifndef GGML_USE_HIP __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, const char * __restrict__ K, @@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32( NO_DEVICE_CODE; return; } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ float maskf_shared[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*D + tid] = 0.0f; + } + __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -181,6 +194,34 @@ static __global__ void flash_attn_vec_ext_f32( for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); + } + + __syncthreads(); + + // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. + // In such cases, skip the KV slice. + // On AMD __all_sync would not work correctly because it assumes a warp size of 64. +#ifndef GGML_USE_HIP + bool skip = true; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + skip = skip && isinf(maskf_shared[j*D + i]); + } + } + if (__all_sync(0xFFFFFFFF, skip)) { + continue; + } +#endif // GGML_USE_HIP + } + float kqmax_new_arr[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -204,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f32( sum = logit_softcap*tanhf(sum); } - sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + sum += maskf_shared[j*D + i_KQ]; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -326,7 +367,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false;