0
0
mirror of https://github.com/google/googletest.git synced 2025-03-20 10:53:47 +00:00

Add a DistanceFrom() matcher for general distance comparison.

We have a bunch of matchers for asserting that a value is near the target value, e.g.
`DoubleNear()` and `FloatNear()`. These matchers only work for specific types (`double` and `float`). They are not flexible enough to support other types that have the notion of a "distance" (e.g. N-dimensional points and vectors, which are commonly used in ML).

In this diff, we generalize the idea to a `DistanceFrom(target, get_distance, m)` matcher that works on arbitrary types that have the "distance" concept (the `get_distance` argument is optional and can be omitted for types that support `-`, and `std::abs()`). What it does:

1. compute the distance between the value and the target using `get_distance(value, target)`; if `get_distance` is omitted, compute the distance as `std::abs(value - target)`.
2. match the distance against matcher `m`; if the match succeeds, the `DistanceFrom()` match succeeds.

Examples:

```
  // 0.5's distance from 0.6 should be <= 0.2.
  EXPECT_THAT(0.5, DistanceFrom(0.6, Le(0.2)));

  Vector2D v1(3.0, 4.0), v2(3.2, 6.0);
  // v1's distance from v2, as computed by EuclideanDistance(v1, v2),
  // should be >= 1.0.
  EXPECT_THAT(v1, DistanceFrom(v2, EuclideanDistance, Ge(1.0)));
```

PiperOrigin-RevId: 734593292
Change-Id: Id6bb7074dc4aa4d8abd78b57ad2426637e590de5
This commit is contained in:
Zhanyong Wan 2025-03-07 09:53:19 -08:00 committed by Copybara-Service
parent e88cb95b92
commit 0bdccf4aa2
3 changed files with 276 additions and 0 deletions

View File

@ -42,6 +42,8 @@ Matcher | Description
| `Lt(value)` | `argument < value` |
| `Ne(value)` | `argument != value` |
| `IsFalse()` | `argument` evaluates to `false` in a Boolean context. |
| `DistanceFrom(target, m)` | The distance between `argument` and `target` (computed by `std::abs(argument - target)`) matches `m`. |
| `DistanceFrom(target, get_distance, m)` | The distance between `argument` and `target` (computed by `get_distance(argument, target)`) matches `m`. |
| `IsTrue()` | `argument` evaluates to `true` in a Boolean context. |
| `IsNull()` | `argument` is a `NULL` pointer (raw or smart). |
| `NotNull()` | `argument` is a non-null pointer (raw or smart). |

View File

@ -2855,6 +2855,54 @@ class ContainsMatcherImpl : public QuantifierMatcherImpl<Container> {
}
};
// Implements DistanceFrom(target, get_distance, distance_matcher) for the given
// argument types:
// * V is the type of the value to be matched.
// * T is the type of the target value.
// * Distance is the type of the distance between V and T.
// * GetDistance is the type of the functor for computing the distance between
// V and T.
template <typename V, typename T, typename Distance, typename GetDistance>
class DistanceFromMatcherImpl : public MatcherInterface<V> {
public:
// Arguments:
// * target: the target value.
// * get_distance: the functor for computing the distance between the value
// being matched and target.
// * distance_matcher: the matcher for checking the distance.
DistanceFromMatcherImpl(T target, GetDistance get_distance,
Matcher<const Distance&> distance_matcher)
: target_(std::move(target)),
get_distance_(std::move(get_distance)),
distance_matcher_(std::move(distance_matcher)) {}
// Describes what this matcher does.
void DescribeTo(::std::ostream* os) const override {
distance_matcher_.DescribeTo(os);
*os << " away from " << PrintToString(target_);
}
void DescribeNegationTo(::std::ostream* os) const override {
distance_matcher_.DescribeNegationTo(os);
*os << " away from " << PrintToString(target_);
}
bool MatchAndExplain(V value, MatchResultListener* listener) const override {
const auto distance = get_distance_(value, target_);
const bool match = distance_matcher_.Matches(distance);
if (!match && listener->IsInterested()) {
*listener << "which is " << PrintToString(distance) << " away from "
<< PrintToString(target_);
}
return match;
}
private:
const T target_;
const GetDistance get_distance_;
const Matcher<const Distance&> distance_matcher_;
};
// Implements Each(element_matcher) for the given argument type Container.
// Symmetric to ContainsMatcherImpl.
template <typename Container>
@ -2990,6 +3038,50 @@ auto Second(T& x, Rank1) -> decltype((x.second)) { // NOLINT
}
} // namespace pair_getters
// Default functor for computing the distance between two values.
struct DefaultGetDistance {
template <typename T, typename U>
auto operator()(const T& lhs, const U& rhs) const {
return std::abs(lhs - rhs);
}
};
// Implements polymorphic DistanceFrom(target, get_distance, distance_matcher)
// matcher. Template arguments:
// * T is the type of the target value.
// * GetDistance is the type of the functor for computing the distance between
// the value being matched and the target.
// * DistanceMatcher is the type of the matcher for checking the distance.
template <typename T, typename GetDistance, typename DistanceMatcher>
class DistanceFromMatcher {
public:
// Arguments:
// * target: the target value.
// * get_distance: the functor for computing the distance between the value
// being matched and target.
// * distance_matcher: the matcher for checking the distance.
DistanceFromMatcher(T target, GetDistance get_distance,
DistanceMatcher distance_matcher)
: target_(std::move(target)),
get_distance_(std::move(get_distance)),
distance_matcher_(std::move(distance_matcher)) {}
DistanceFromMatcher(const DistanceFromMatcher& other) = default;
// Implicitly converts to a monomorphic matcher of the given type.
template <typename V>
operator Matcher<V>() const { // NOLINT
using Distance = decltype(get_distance_(std::declval<V>(), target_));
return Matcher<V>(new DistanceFromMatcherImpl<V, T, Distance, GetDistance>(
target_, get_distance_, distance_matcher_));
}
private:
const T target_;
const GetDistance get_distance_;
const DistanceMatcher distance_matcher_;
};
// Implements Key(inner_matcher) for the given argument pair type.
// Key(inner_matcher) matches an std::pair whose 'first' field matches
// inner_matcher. For example, Contains(Key(Ge(5))) can be used to match an
@ -4372,6 +4464,42 @@ inline internal::FloatingEqMatcher<double> DoubleNear(double rhs,
return internal::FloatingEqMatcher<double>(rhs, false, max_abs_error);
}
// The DistanceFrom(target, get_distance, m) and DistanceFrom(target, m)
// matchers work on arbitrary types that have the "distance" concept. What they
// do:
//
// 1. compute the distance between the value and the target using
// get_distance(value, target) if get_distance is provided; otherwise compute
// the distance as std::abs(value - target).
// 2. match the distance against the user-provided matcher m; if the match
// succeeds, the DistanceFrom() match succeeds.
//
// Examples:
//
// // 0.5's distance from 0.6 should be <= 0.2.
// EXPECT_THAT(0.5, DistanceFrom(0.6, Le(0.2)));
//
// Vector2D v1(3.0, 4.0), v2(3.2, 6.0);
// // v1's distance from v2, as computed by EuclideanDistance(v1, v2),
// // should be >= 1.0.
// EXPECT_THAT(v1, DistanceFrom(v2, EuclideanDistance, Ge(1.0)));
template <typename T, typename GetDistance, typename DistanceMatcher>
inline internal::DistanceFromMatcher<T, GetDistance, DistanceMatcher>
DistanceFrom(T target, GetDistance get_distance,
DistanceMatcher distance_matcher) {
return internal::DistanceFromMatcher<T, GetDistance, DistanceMatcher>(
std::move(target), std::move(get_distance), std::move(distance_matcher));
}
template <typename T, typename DistanceMatcher>
inline internal::DistanceFromMatcher<T, internal::DefaultGetDistance,
DistanceMatcher>
DistanceFrom(T target, DistanceMatcher distance_matcher) {
return DistanceFrom(std::move(target), internal::DefaultGetDistance(),
std::move(distance_matcher));
}
// Creates a matcher that matches any double argument approximately equal to
// rhs, up to the specified max absolute error bound, including NaN values when
// rhs is NaN. The max absolute error bound must be non-negative.

View File

@ -34,6 +34,7 @@
#include <cmath>
#include <limits>
#include <memory>
#include <ostream>
#include <string>
#include "gmock/gmock.h"
@ -398,6 +399,151 @@ TEST(NanSensitiveDoubleNearTest, CanDescribeSelfWithNaNs) {
EXPECT_EQ("are an almost-equal pair", Describe(m));
}
// Tests that DistanceFrom() can describe itself properly.
TEST(DistanceFrom, CanDescribeSelf) {
Matcher<double> m = DistanceFrom(1.5, Lt(0.1));
EXPECT_EQ(Describe(m), "is < 0.1 away from 1.5");
m = DistanceFrom(2.5, Gt(0.2));
EXPECT_EQ(Describe(m), "is > 0.2 away from 2.5");
}
// Tests that DistanceFrom() can explain match failure.
TEST(DistanceFrom, CanExplainMatchFailure) {
Matcher<double> m = DistanceFrom(1.5, Lt(0.1));
EXPECT_EQ(Explain(m, 2.0), "which is 0.5 away from 1.5");
}
// Tests that DistanceFrom() matches a double that is within the given range of
// the given value.
TEST(DistanceFrom, MatchesDoubleWithinRange) {
const Matcher<double> m = DistanceFrom(0.5, Le(0.1));
EXPECT_TRUE(m.Matches(0.45));
EXPECT_TRUE(m.Matches(0.5));
EXPECT_TRUE(m.Matches(0.55));
EXPECT_FALSE(m.Matches(0.39));
EXPECT_FALSE(m.Matches(0.61));
}
// Tests that DistanceFrom() matches a double reference that is within the given
// range of the given value.
TEST(DistanceFrom, MatchesDoubleRefWithinRange) {
const Matcher<const double&> m = DistanceFrom(0.5, Le(0.1));
EXPECT_TRUE(m.Matches(0.45));
EXPECT_TRUE(m.Matches(0.5));
EXPECT_TRUE(m.Matches(0.55));
EXPECT_FALSE(m.Matches(0.39));
EXPECT_FALSE(m.Matches(0.61));
}
// Tests that DistanceFrom() can be implicitly converted to a matcher depending
// on the type of the argument.
TEST(DistanceFrom, CanBeImplicitlyConvertedToMatcher) {
EXPECT_THAT(0.58, DistanceFrom(0.5, Le(0.1)));
EXPECT_THAT(0.2, Not(DistanceFrom(0.5, Le(0.1))));
EXPECT_THAT(0.58f, DistanceFrom(0.5f, Le(0.1f)));
EXPECT_THAT(0.7f, Not(DistanceFrom(0.5f, Le(0.1f))));
}
// Tests that DistanceFrom() can be used on compatible types (i.e. not
// everything has to be of the same type).
TEST(DistanceFrom, CanBeUsedOnCompatibleTypes) {
EXPECT_THAT(0.58, DistanceFrom(0.5, Le(0.1f)));
EXPECT_THAT(0.2, Not(DistanceFrom(0.5, Le(0.1f))));
EXPECT_THAT(0.58, DistanceFrom(0.5f, Le(0.1)));
EXPECT_THAT(0.2, Not(DistanceFrom(0.5f, Le(0.1))));
EXPECT_THAT(0.58, DistanceFrom(0.5f, Le(0.1f)));
EXPECT_THAT(0.2, Not(DistanceFrom(0.5f, Le(0.1f))));
EXPECT_THAT(0.58f, DistanceFrom(0.5, Le(0.1)));
EXPECT_THAT(0.2f, Not(DistanceFrom(0.5, Le(0.1))));
EXPECT_THAT(0.58f, DistanceFrom(0.5, Le(0.1f)));
EXPECT_THAT(0.2f, Not(DistanceFrom(0.5, Le(0.1f))));
EXPECT_THAT(0.58f, DistanceFrom(0.5f, Le(0.1)));
EXPECT_THAT(0.2f, Not(DistanceFrom(0.5f, Le(0.1))));
}
// A 2-dimensional point. For testing using DistanceFrom() with a custom type
// that doesn't have a built-in distance function.
class Point {
public:
Point(double x, double y) : x_(x), y_(y) {}
double x() const { return x_; }
double y() const { return y_; }
private:
double x_;
double y_;
};
// Returns the distance between two points.
double PointDistance(const Point& lhs, const Point& rhs) {
return std::sqrt(std::pow(lhs.x() - rhs.x(), 2) +
std::pow(lhs.y() - rhs.y(), 2));
}
// Tests that DistanceFrom() can be used on a type with a custom distance
// function.
TEST(DistanceFrom, CanBeUsedOnTypeWithCustomDistanceFunction) {
const Matcher<Point> m =
DistanceFrom(Point(0.5, 0.5), PointDistance, Le(0.1));
EXPECT_THAT(Point(0.45, 0.45), m);
EXPECT_THAT(Point(0.2, 0.45), Not(m));
}
// A wrapper around a double value. For testing using DistanceFrom() with a
// custom type that has neither a built-in distance function nor a built-in
// distance comparator.
class Double {
public:
explicit Double(double value) : value_(value) {}
Double(const Double& other) = default;
double value() const { return value_; }
// Defines how to print a Double value. We don't use the AbslStringify API
// because googletest doesn't require absl yet.
friend void PrintTo(const Double& value, std::ostream* os) {
*os << "Double(" << value.value() << ")";
}
private:
double value_;
};
// Returns the distance between two Double values.
Double DoubleDistance(Double lhs, Double rhs) {
return Double(std::abs(lhs.value() - rhs.value()));
}
MATCHER_P(DoubleLe, rhs, (negation ? "is > " : "is <= ") + PrintToString(rhs)) {
return arg.value() <= rhs.value();
}
// Tests that DistanceFrom() can describe itself properly for a type with a
// custom printer.
TEST(DistanceFrom, CanDescribeWithCustomPrinter) {
const Matcher<Double> m =
DistanceFrom(Double(0.5), DoubleDistance, DoubleLe(Double(0.1)));
EXPECT_EQ(Describe(m), "is <= Double(0.1) away from Double(0.5)");
EXPECT_EQ(DescribeNegation(m), "is > Double(0.1) away from Double(0.5)");
}
// Tests that DistanceFrom() can be used with a custom distance function and
// comparator.
TEST(DistanceFrom, CanCustomizeDistanceAndComparator) {
const Matcher<Double> m =
DistanceFrom(Double(0.5), DoubleDistance, DoubleLe(Double(0.1)));
EXPECT_TRUE(m.Matches(Double(0.45)));
EXPECT_TRUE(m.Matches(Double(0.5)));
EXPECT_FALSE(m.Matches(Double(0.39)));
EXPECT_FALSE(m.Matches(Double(0.61)));
}
// Tests that Not(m) matches any value that doesn't match m.
TEST(NotTest, NegatesMatcher) {
Matcher<int> m;