mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
metal : improve FA + improve MoE (#12612)
* ggml : FA with different K, V head sizes (CPU) ggml-ci * metal : add FA with HS=192 * metal : extend FA to support different K and V head sizes ggml-ci * metal : add FA vector kernels for heads K 192 and V 128 ggml-ci * ggml : restrict op on other backends to equal head sizes ggml-ci * metal : optimize FA-vec kernel ggml-ci * metal : FA remove mq registers * metal : improve MoE mul_mat_id condition ggml-ci * metal : fix comments + remove unnecessary addition ggml-ci * metal : avoid too much shared memory usage with mul_mat_id ggml-ci
This commit is contained in:
@ -3217,7 +3217,8 @@ struct test_leaky_relu : public test_case {
|
||||
|
||||
// GGML_OP_FLASH_ATTN_EXT
|
||||
struct test_flash_attn_ext : public test_case {
|
||||
const int64_t hs; // head size
|
||||
const int64_t hsk; // K head size
|
||||
const int64_t hsv; // V head size
|
||||
const int64_t nh; // num heads
|
||||
const int64_t nr; // repeat in Q, tests for grouped-query attention
|
||||
const int64_t kv; // kv size
|
||||
@ -3233,7 +3234,7 @@ struct test_flash_attn_ext : public test_case {
|
||||
std::array<int32_t, 4> permute;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
|
||||
return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
@ -3243,17 +3244,18 @@ struct test_flash_attn_ext : public test_case {
|
||||
uint64_t op_flops(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
// Just counting matmul costs:
|
||||
// Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
|
||||
return 2 * 2 * nh*nr * nb * hs * kv;
|
||||
// Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
|
||||
return 2 * nh*nr * nb * (hsk + hsv) * kv;
|
||||
}
|
||||
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
||||
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
||||
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
|
||||
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
||||
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
||||
: hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
|
||||
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
|
||||
|
||||
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
|
||||
int64_t ne[4] = {ne0, ne1, ne2, ne3};
|
||||
@ -3268,13 +3270,13 @@ struct test_flash_attn_ext : public test_case {
|
||||
return t;
|
||||
};
|
||||
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
|
||||
ggml_set_name(q, "q");
|
||||
|
||||
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
@ -3283,7 +3285,7 @@ struct test_flash_attn_ext : public test_case {
|
||||
ggml_set_name(m, "m");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
|
||||
ggml_flash_attn_ext_set_prec(out, prec);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@ -4412,27 +4414,32 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
for (int hs : { 64, 80, 128, 256, }) {
|
||||
for (bool mask : { true, false } ) {
|
||||
for (float max_bias : { 0.0f, 8.0f }) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||
if (hs != 128 && logit_softcap != 0.0f) continue;
|
||||
for (int nh : { 4, }) {
|
||||
for (int nr : { 1, 4, 16 }) {
|
||||
if (nr == 16 && hs != 128) continue;
|
||||
for (int kv : { 512, 1024, }) {
|
||||
if (nr != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
for (int hsk : { 64, 80, 128, 192, 256, }) {
|
||||
for (int hsv : { 64, 80, 128, 192, 256, }) {
|
||||
if (hsk != 192 && hsk != hsv) continue;
|
||||
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
||||
|
||||
for (bool mask : { true, false } ) {
|
||||
for (float max_bias : { 0.0f, 8.0f }) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||
if (hsk != 128 && logit_softcap != 0.0f) continue;
|
||||
for (int nh : { 4, }) {
|
||||
for (int nr : { 1, 4, 16 }) {
|
||||
if (nr == 16 && hsk != 128) continue;
|
||||
for (int kv : { 512, 1024, }) {
|
||||
if (nr != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
||||
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user