From 24d3524bfd147dcd3d3f440da4966cf7f1c85d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 28 Jul 2025 14:30:22 +0200 Subject: [PATCH] CUDA: fix pointer incrementation in FA (llama/14916) --- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 9 ++++----- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e9b5c306..10925383 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 6a4bdc0f..2cf2e408 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); }