mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
CUDA: optimize FA for GQA + large batches (#12014)
This commit is contained in:
@ -3119,6 +3119,7 @@ struct test_leaky_relu : public test_case {
|
||||
struct test_flash_attn_ext : public test_case {
|
||||
const int64_t hs; // head size
|
||||
const int64_t nh; // num heads
|
||||
const int64_t nr; // repeat in Q, tests for grouped-query attention
|
||||
const int64_t kv; // kv size
|
||||
const int64_t nb; // batch size
|
||||
|
||||
@ -3131,7 +3132,7 @@ struct test_flash_attn_ext : public test_case {
|
||||
std::array<int32_t, 4> permute;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
|
||||
return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
@ -3142,13 +3143,13 @@ struct test_flash_attn_ext : public test_case {
|
||||
GGML_UNUSED(t);
|
||||
// Just counting matmul costs:
|
||||
// Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
|
||||
return 2 * 2 * nh * nb * hs * kv;
|
||||
return 2 * 2 * nh*nr * nb * hs * kv;
|
||||
}
|
||||
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
||||
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
|
||||
std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
||||
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||
@ -3166,13 +3167,13 @@ struct test_flash_attn_ext : public test_case {
|
||||
return t;
|
||||
};
|
||||
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
|
||||
ggml_set_name(q, "q");
|
||||
|
||||
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
@ -4278,14 +4279,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||
if (hs != 128 && logit_softcap != 0.0f) continue;
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, }) {
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
|
||||
for (int nh : { 4, }) {
|
||||
for (int nr : { 1, 4, 16 }) {
|
||||
if (nr == 16 && hs != 128) continue;
|
||||
for (int kv : { 512, 1024, }) {
|
||||
if (nr != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user