vulkan: Add fusion support for RMS_NORM+MUL (#14366)

* vulkan: Add fusion support for RMS_NORM+MUL

- Add a use_count to ggml_tensor, so we can detect if an output is used more than once.
- Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor.
- Add detection logic and basic fusion logic in ggml-vulkan.
- Add some testing support for fusion. Rather than computing one node at a time, allow
for computing the whole graph and just testing one node's results. Add rms_norm_mul tests
and enable a llama test.

* extract some common fusion logic

* fix -Winconsistent-missing-override

* move ggml_can_fuse to a common function

* build fix

* C and C++ versions of can_fuse

* move use count to the graph to avoid data races and double increments when used in multiple threads

* use hash table lookup to find node index

* change use_counts to be indexed by hash table slot

* minimize hash lookups

style fixes

* last node doesn't need single use.
fix type.
handle mul operands being swapped.

* remove redundant parameter

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Jeff Bolz
2025-06-29 02:43:36 -05:00
committed by GitHub
parent 27208bf657
commit bd9c981d72
8 changed files with 263 additions and 56 deletions

View File

@ -382,6 +382,8 @@ struct test_case {
return 0;
}
virtual bool run_whole_graph() { return false; }
ggml_cgraph * gf = nullptr;
ggml_cgraph * gb = nullptr;
@ -574,7 +576,7 @@ struct test_case {
GGML_UNUSED(index);
};
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr);
if (!cmp_ok) {
printf("compare failed ");
@ -1896,6 +1898,63 @@ struct test_rms_norm_back : public test_case {
}
};
// GGML_OP_RMS_NORM + GGML_OP_MUL
struct test_rms_norm_mul : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float eps;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "RMS_NORM_MUL";
}
bool run_whole_graph() override { return true; }
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
}
test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_set_param(b);
ggml_set_name(b, "b");
// Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
a = ggml_add(ctx, a, b);
ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.f, 10.f);
}
}
double max_nmse_err() override {
return 1e-6;
}
float grad_eps() override {
return 1.0f;
}
bool grad_precise() override {
return true;
}
};
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
@ -3736,6 +3795,7 @@ struct test_llama : public test_llm {
static constexpr float attn_factor = 1.0f;
static constexpr float beta_fast = 32.0f;
static constexpr float beta_slow = 1.0f;
bool fused;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@ -3751,7 +3811,9 @@ struct test_llama : public test_llm {
return 2e-3;
}
test_llama(int n_tokens = 1)
bool run_whole_graph() override { return fused; }
test_llama(int n_tokens = 1, bool fused = false)
: test_llm({
/*n_vocab =*/ 32000,
/*n_embd =*/ 3200,
@ -3763,7 +3825,9 @@ struct test_llama : public test_llm {
/*f_norm_eps =*/ 0.f,
/*f_norm_rms_eps =*/ 1e-5f,
/*n_tokens =*/ n_tokens,
}) {
})
, fused(fused)
{
}
ggml_tensor * build_graph(ggml_context * ctx) override {
@ -4306,6 +4370,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
@ -4677,6 +4744,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
test_cases.emplace_back(new test_llama(2, true));
// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
test_cases.emplace_back(new test_llama(1));