diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 0fcfaa32e..1e23f8f79 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -49,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16( const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); const int stride_KV2 = nb11 / sizeof(half2); @@ -242,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); } + //Attention sink: adjust running max and sum once per head + if (sinksf && blockIdx.y == 0) { + const half sink = __float2half(sinksf[head]); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j)); + kqmax[j0/nwarps] = kqmax_new_j; + + const half val = hexp(sink - kqmax[j0/nwarps]); + kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; + if (threadIdx.x == 0) { + kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val); + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; + } + } + } + float2 * dst2 = (float2 *) dst; #pragma unroll diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 23550cbbd..c58194937 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -60,10 +60,11 @@ static __global__ void flash_attn_tile_ext_f32( const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); const int stride_KV2 = nb11 / sizeof(half2); @@ -252,6 +253,33 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); } + + //Attention sink: adjust running max and sum once per head + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); + kqmax[j0/nwarps] = kqmax_new_j; + + const float val = expf(sink - kqmax[j0/nwarps]); + kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; + if (threadIdx.x == 0) { + kqsum[j0/nwarps] += val; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; + VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; + } + } + } + float2 * dst2 = (float2 *) dst; #pragma unroll diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 93d4d8102..fdc4d17da 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -82,11 +82,12 @@ static __global__ void flash_attn_ext_f16( const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const half2 * mask2 = (const half2 *) maskh; + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const half2 * mask2 = (const half2 *) maskh; + const float * sinksf = (const float *) sinks; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -381,6 +382,53 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } + // Apply attention sinks + if (sinksf && blockIdx.y == 0) { + const float sinkf = sinksf[head]; + const half sinkh = __float2half(sinkf); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf); + + const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new); + KQ_max_f[j0/nwarps] = kqmax_new; + + KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]); + + const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) break; + VKQ2[j*(D_padded/2) + i] *= scale_h2; + } + } else { + half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]); + half kqmax_new = fmaxf(kqmax_old, sinkh); + KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new); + + const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new); + const half2 KQ_max_scale = __half2half2(KQ_max_scale_h); + + KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale; + const half val = hexp(sinkh - kqmax_new); + KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) break; + VKQ2[j*(D_padded/2) + i] *= KQ_max_scale; + } + } + } + + __syncthreads(); + } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j_VKQ = j0 + threadIdx.y; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 6c1185dea..22e90d0e7 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -274,23 +274,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; - const ggml_tensor * sinks = dst->src[4]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS] - if (sinks && !fp16_mma_available(cc)) { - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } - return; - } - #if defined(GGML_HIP_ROCWMMA_FATTN) if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);