diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e9b5c3063..109253838 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 6a4bdc0ff..2cf2e408e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); }