vulkan: fix rms_norm+mul fusion (#14545)

The fused operation was grabbing the epsilon value from the wrong place.

Add an env var to disable fusion.

Add some missing checks for supported shapes/types.

Handle fused rms_norm+mul in check_results.
This commit is contained in:
Jeff Bolz
2025-07-06 03:08:16 -05:00
committed by GitHub
parent a0374a67e2
commit e592be1575
2 changed files with 88 additions and 24 deletions

View File

@ -2583,10 +2583,6 @@ struct test_rms_norm_mul : public test_case {
}
}
double max_nmse_err() override {
return 1e-6;
}
float grad_eps() override {
return 1.0f;
}
@ -5058,7 +5054,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}