mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-01 21:15:06 +00:00
vulkan: Add fusion support for RMS_NORM+MUL (#14366)
* vulkan: Add fusion support for RMS_NORM+MUL - Add a use_count to ggml_tensor, so we can detect if an output is used more than once. - Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor. - Add detection logic and basic fusion logic in ggml-vulkan. - Add some testing support for fusion. Rather than computing one node at a time, allow for computing the whole graph and just testing one node's results. Add rms_norm_mul tests and enable a llama test. * extract some common fusion logic * fix -Winconsistent-missing-override * move ggml_can_fuse to a common function * build fix * C and C++ versions of can_fuse * move use count to the graph to avoid data races and double increments when used in multiple threads * use hash table lookup to find node index * change use_counts to be indexed by hash table slot * minimize hash lookups style fixes * last node doesn't need single use. fix type. handle mul operands being swapped. * remove redundant parameter --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
@ -382,6 +382,8 @@ struct test_case {
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual bool run_whole_graph() { return false; }
|
||||
|
||||
ggml_cgraph * gf = nullptr;
|
||||
ggml_cgraph * gb = nullptr;
|
||||
|
||||
@ -574,7 +576,7 @@ struct test_case {
|
||||
GGML_UNUSED(index);
|
||||
};
|
||||
|
||||
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
|
||||
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr);
|
||||
|
||||
if (!cmp_ok) {
|
||||
printf("compare failed ");
|
||||
@ -1896,6 +1898,63 @@ struct test_rms_norm_back : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RMS_NORM + GGML_OP_MUL
|
||||
struct test_rms_norm_mul : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "RMS_NORM_MUL";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, eps);
|
||||
}
|
||||
|
||||
test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||
float eps = 1e-6f)
|
||||
: type(type), ne(ne), eps(eps) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_param(b);
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
// Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
|
||||
a = ggml_add(ctx, a, b);
|
||||
ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -10.f, 10.f);
|
||||
}
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 1e-6;
|
||||
}
|
||||
|
||||
float grad_eps() override {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
bool grad_precise() override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_CONV
|
||||
struct test_ssm_conv : public test_case {
|
||||
const ggml_type type;
|
||||
@ -3736,6 +3795,7 @@ struct test_llama : public test_llm {
|
||||
static constexpr float attn_factor = 1.0f;
|
||||
static constexpr float beta_fast = 32.0f;
|
||||
static constexpr float beta_slow = 1.0f;
|
||||
bool fused;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
@ -3751,7 +3811,9 @@ struct test_llama : public test_llm {
|
||||
return 2e-3;
|
||||
}
|
||||
|
||||
test_llama(int n_tokens = 1)
|
||||
bool run_whole_graph() override { return fused; }
|
||||
|
||||
test_llama(int n_tokens = 1, bool fused = false)
|
||||
: test_llm({
|
||||
/*n_vocab =*/ 32000,
|
||||
/*n_embd =*/ 3200,
|
||||
@ -3763,7 +3825,9 @@ struct test_llama : public test_llm {
|
||||
/*f_norm_eps =*/ 0.f,
|
||||
/*f_norm_rms_eps =*/ 1e-5f,
|
||||
/*n_tokens =*/ n_tokens,
|
||||
}) {
|
||||
})
|
||||
, fused(fused)
|
||||
{
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
@ -4306,6 +4370,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
}
|
||||
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
@ -4677,6 +4744,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
|
||||
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
// these tests are disabled to save execution time, but they can be handy for debugging
|
||||
#if 0
|
||||
test_cases.emplace_back(new test_llama(1));
|
||||
|
Reference in New Issue
Block a user