mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-05 00:25:26 -04:00
CUDA: fix pointer incrementation in FA (#14916)
This commit is contained in:
@@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
K += blockIdx.y*D * nb11;
|
K += blockIdx.y*D * nb11;
|
||||||
V += blockIdx.y*D * nb21;
|
V += blockIdx.y*D * nb21;
|
||||||
maskh += blockIdx.y*D;
|
maskh += blockIdx.y*D;
|
||||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
||||||
|
// Increment pointers after each loop:
|
||||||
|
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||||
|
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
K += gridDim.y*D * nb11;
|
|
||||||
V += gridDim.y*D * nb21;
|
|
||||||
maskh += gridDim.y*D;
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
K += blockIdx.y*D * nb11;
|
K += blockIdx.y*D * nb11;
|
||||||
V += blockIdx.y*D * nb21;
|
V += blockIdx.y*D * nb21;
|
||||||
maskh += blockIdx.y*D;
|
maskh += blockIdx.y*D;
|
||||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
||||||
|
// Increment pointers after each loop:
|
||||||
|
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||||
|
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
K += gridDim.y*D * nb11;
|
|
||||||
V += gridDim.y*D * nb21;
|
|
||||||
maskh += gridDim.y*D;
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user