From 12a81af45f0dbbab24bd819a15f57c03ceb1be90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 2 Jul 2025 13:42:12 +0200 Subject: [PATCH] CUDA: broadcasting for FlashAttention mask (#14500) --- ggml/src/ggml-cuda/fattn-common.cuh | 5 ++++- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 12 ++++++++---- ggml/src/ggml-cuda/fattn-tile-f16.cu | 10 ++++++---- ggml/src/ggml-cuda/fattn-tile-f32.cu | 10 ++++++---- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 8 +++++--- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 9 ++++++--- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 14 ++++++++------ 7 files changed, 43 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index cfab2b5eb..075f14a49 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -851,7 +853,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e230f6d49..709589854 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 9283560d5..0c967f178 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); @@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 32673adb5..124d5d3e8 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 35e649cb3..e78fb1819 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16( K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 953967917..c22baf417 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32( Q += nb02* blockIdx.z + nb01*ic0; K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f3b794c36..c95ca7b1f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16( constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);