CUDA: optimize FA for GQA + large batches (#12014)

This commit is contained in:
Johannes Gäßler
2025-02-22 12:20:17 +01:00
committed by GitHub
parent 335eb04a91
commit 5fa07c2f93
32 changed files with 940 additions and 411 deletions

View File

@ -3119,6 +3119,7 @@ struct test_leaky_relu : public test_case {
struct test_flash_attn_ext : public test_case {
const int64_t hs; // head size
const int64_t nh; // num heads
const int64_t nr; // repeat in Q, tests for grouped-query attention
const int64_t kv; // kv size
const int64_t nb; // batch size
@ -3131,7 +3132,7 @@ struct test_flash_attn_ext : public test_case {
std::array<int32_t, 4> permute;
std::string vars() override {
return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
}
double max_nmse_err() override {
@ -3142,13 +3143,13 @@ struct test_flash_attn_ext : public test_case {
GGML_UNUSED(t);
// Just counting matmul costs:
// Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
return 2 * 2 * nh * nb * hs * kv;
return 2 * 2 * nh*nr * nb * hs * kv;
}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@ -3166,13 +3167,13 @@ struct test_flash_attn_ext : public test_case {
return t;
};
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
ggml_set_name(q, "q");
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(k, "k");
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;
@ -4278,14 +4279,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
if (!mask && max_bias > 0.0f) continue;
for (float logit_softcap : {0.0f, 10.0f}) {
if (hs != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
for (int nh : { 4, }) {
for (int nr : { 1, 4, 16 }) {
if (nr == 16 && hs != 128) continue;
for (int kv : { 512, 1024, }) {
if (nr != 1 && kv != 512) continue;
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
}
}
}
}