vulkan: scalar flash attention implementation (#13324)

* vulkan: scalar flash attention implementation

* vulkan: always use fp32 for scalar flash attention

* vulkan: use vector loads in scalar flash attention shader

* vulkan: remove PV matrix, helps with register usage

* vulkan: reduce register usage in scalar FA, but perf may be slightly worse

* vulkan: load each Q value once. optimize O reduction. more tuning

* vulkan: support q4_0/q8_0 KV in scalar FA

* CI: increase timeout to accommodate newly-supported tests

* vulkan: for scalar FA, select between 1 and 8 rows

* vulkan: avoid using Float16 capability in scalar FA
This commit is contained in:
Jeff Bolz
2025-05-09 23:07:07 -07:00
committed by GitHub
parent 7c28a74e07
commit dc1d2adfc0
4 changed files with 646 additions and 94 deletions

View File

@ -307,7 +307,7 @@ jobs:
run: |
cd build
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 2700
ctest -L main --verbose --timeout 3600
ubuntu-22-cmake-hip:
runs-on: ubuntu-22.04