vulkan: enable coopmat2 FA gqa and split_k optimizations more often (#12931)

The grouped query attention optmization doesn't require a power of two ratio,
the only thing relying on it was the modulo operation written as bitwise &.

split_k need not depend on gqa_ratio - enable it any time there's only one
workgroup in the X dimension. The shader gets the split index from the x coord,
and multiple workgroups in the X dimension (pre-split) indicates a larger
FA operation that wouldn't need splitting.
This commit is contained in:
Jeff Bolz
2025-04-16 13:37:25 -05:00
committed by GitHub
parent b43d89e311
commit 015022bb53
3 changed files with 7 additions and 5 deletions

View File

@ -5531,7 +5531,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
@ -5544,8 +5544,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t split_kv = KV;
uint32_t split_k = 1;
if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
GGML_ASSERT(workgroups_x == 1);
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
// Try to run two workgroups per SM.
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
if (split_k > 1) {