mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
CUDA: use mma PTX instructions for FlashAttention (#11583)
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p
|
||||
DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-wmma-f16.cuh"
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
|
||||
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
@ -57,20 +57,12 @@ for vkq_size in [16, 32]:
|
||||
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
||||
|
||||
for kq_acc_t in ["half", "float"]:
|
||||
for cols_per_block in [8, 16, 32]:
|
||||
if kq_acc_t == "float" and cols_per_block == 8:
|
||||
continue
|
||||
for cols_per_block in [8, 16, 32, 64]:
|
||||
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_WMMA_START)
|
||||
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
|
||||
continue
|
||||
if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
|
||||
continue
|
||||
f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
|
Reference in New Issue
Block a user