From 946b1f685909c8c9c044f145bce819c02f327eaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 28 Jul 2025 14:30:22 +0200 Subject: [PATCH] CUDA: fix pointer incrementation in FA (#14916) --- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 9 ++++----- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e9b5c3063..109253838 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 6a4bdc0ff..2cf2e408e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); }