From 0bdccf4aa2f5c67af967193caf31d42d5c49bde2 Mon Sep 17 00:00:00 2001 From: Zhanyong Wan Date: Fri, 7 Mar 2025 09:53:19 -0800 Subject: [PATCH] Add a `DistanceFrom()` matcher for general distance comparison. We have a bunch of matchers for asserting that a value is near the target value, e.g. `DoubleNear()` and `FloatNear()`. These matchers only work for specific types (`double` and `float`). They are not flexible enough to support other types that have the notion of a "distance" (e.g. N-dimensional points and vectors, which are commonly used in ML). In this diff, we generalize the idea to a `DistanceFrom(target, get_distance, m)` matcher that works on arbitrary types that have the "distance" concept (the `get_distance` argument is optional and can be omitted for types that support `-`, and `std::abs()`). What it does: 1. compute the distance between the value and the target using `get_distance(value, target)`; if `get_distance` is omitted, compute the distance as `std::abs(value - target)`. 2. match the distance against matcher `m`; if the match succeeds, the `DistanceFrom()` match succeeds. Examples: ``` // 0.5's distance from 0.6 should be <= 0.2. EXPECT_THAT(0.5, DistanceFrom(0.6, Le(0.2))); Vector2D v1(3.0, 4.0), v2(3.2, 6.0); // v1's distance from v2, as computed by EuclideanDistance(v1, v2), // should be >= 1.0. EXPECT_THAT(v1, DistanceFrom(v2, EuclideanDistance, Ge(1.0))); ``` PiperOrigin-RevId: 734593292 Change-Id: Id6bb7074dc4aa4d8abd78b57ad2426637e590de5 --- docs/reference/matchers.md | 2 + googlemock/include/gmock/gmock-matchers.h | 128 +++++++++++++++ .../test/gmock-matchers-arithmetic_test.cc | 146 ++++++++++++++++++ 3 files changed, 276 insertions(+) diff --git a/docs/reference/matchers.md b/docs/reference/matchers.md index 16397ef2..1fc45f2c 100644 --- a/docs/reference/matchers.md +++ b/docs/reference/matchers.md @@ -42,6 +42,8 @@ Matcher | Description | `Lt(value)` | `argument < value` | | `Ne(value)` | `argument != value` | | `IsFalse()` | `argument` evaluates to `false` in a Boolean context. | +| `DistanceFrom(target, m)` | The distance between `argument` and `target` (computed by `std::abs(argument - target)`) matches `m`. | +| `DistanceFrom(target, get_distance, m)` | The distance between `argument` and `target` (computed by `get_distance(argument, target)`) matches `m`. | | `IsTrue()` | `argument` evaluates to `true` in a Boolean context. | | `IsNull()` | `argument` is a `NULL` pointer (raw or smart). | | `NotNull()` | `argument` is a non-null pointer (raw or smart). | diff --git a/googlemock/include/gmock/gmock-matchers.h b/googlemock/include/gmock/gmock-matchers.h index 47311d46..49a6ff2e 100644 --- a/googlemock/include/gmock/gmock-matchers.h +++ b/googlemock/include/gmock/gmock-matchers.h @@ -2855,6 +2855,54 @@ class ContainsMatcherImpl : public QuantifierMatcherImpl { } }; +// Implements DistanceFrom(target, get_distance, distance_matcher) for the given +// argument types: +// * V is the type of the value to be matched. +// * T is the type of the target value. +// * Distance is the type of the distance between V and T. +// * GetDistance is the type of the functor for computing the distance between +// V and T. +template +class DistanceFromMatcherImpl : public MatcherInterface { + public: + // Arguments: + // * target: the target value. + // * get_distance: the functor for computing the distance between the value + // being matched and target. + // * distance_matcher: the matcher for checking the distance. + DistanceFromMatcherImpl(T target, GetDistance get_distance, + Matcher distance_matcher) + : target_(std::move(target)), + get_distance_(std::move(get_distance)), + distance_matcher_(std::move(distance_matcher)) {} + + // Describes what this matcher does. + void DescribeTo(::std::ostream* os) const override { + distance_matcher_.DescribeTo(os); + *os << " away from " << PrintToString(target_); + } + + void DescribeNegationTo(::std::ostream* os) const override { + distance_matcher_.DescribeNegationTo(os); + *os << " away from " << PrintToString(target_); + } + + bool MatchAndExplain(V value, MatchResultListener* listener) const override { + const auto distance = get_distance_(value, target_); + const bool match = distance_matcher_.Matches(distance); + if (!match && listener->IsInterested()) { + *listener << "which is " << PrintToString(distance) << " away from " + << PrintToString(target_); + } + return match; + } + + private: + const T target_; + const GetDistance get_distance_; + const Matcher distance_matcher_; +}; + // Implements Each(element_matcher) for the given argument type Container. // Symmetric to ContainsMatcherImpl. template @@ -2990,6 +3038,50 @@ auto Second(T& x, Rank1) -> decltype((x.second)) { // NOLINT } } // namespace pair_getters +// Default functor for computing the distance between two values. +struct DefaultGetDistance { + template + auto operator()(const T& lhs, const U& rhs) const { + return std::abs(lhs - rhs); + } +}; + +// Implements polymorphic DistanceFrom(target, get_distance, distance_matcher) +// matcher. Template arguments: +// * T is the type of the target value. +// * GetDistance is the type of the functor for computing the distance between +// the value being matched and the target. +// * DistanceMatcher is the type of the matcher for checking the distance. +template +class DistanceFromMatcher { + public: + // Arguments: + // * target: the target value. + // * get_distance: the functor for computing the distance between the value + // being matched and target. + // * distance_matcher: the matcher for checking the distance. + DistanceFromMatcher(T target, GetDistance get_distance, + DistanceMatcher distance_matcher) + : target_(std::move(target)), + get_distance_(std::move(get_distance)), + distance_matcher_(std::move(distance_matcher)) {} + + DistanceFromMatcher(const DistanceFromMatcher& other) = default; + + // Implicitly converts to a monomorphic matcher of the given type. + template + operator Matcher() const { // NOLINT + using Distance = decltype(get_distance_(std::declval(), target_)); + return Matcher(new DistanceFromMatcherImpl( + target_, get_distance_, distance_matcher_)); + } + + private: + const T target_; + const GetDistance get_distance_; + const DistanceMatcher distance_matcher_; +}; + // Implements Key(inner_matcher) for the given argument pair type. // Key(inner_matcher) matches an std::pair whose 'first' field matches // inner_matcher. For example, Contains(Key(Ge(5))) can be used to match an @@ -4372,6 +4464,42 @@ inline internal::FloatingEqMatcher DoubleNear(double rhs, return internal::FloatingEqMatcher(rhs, false, max_abs_error); } +// The DistanceFrom(target, get_distance, m) and DistanceFrom(target, m) +// matchers work on arbitrary types that have the "distance" concept. What they +// do: +// +// 1. compute the distance between the value and the target using +// get_distance(value, target) if get_distance is provided; otherwise compute +// the distance as std::abs(value - target). +// 2. match the distance against the user-provided matcher m; if the match +// succeeds, the DistanceFrom() match succeeds. +// +// Examples: +// +// // 0.5's distance from 0.6 should be <= 0.2. +// EXPECT_THAT(0.5, DistanceFrom(0.6, Le(0.2))); +// +// Vector2D v1(3.0, 4.0), v2(3.2, 6.0); +// // v1's distance from v2, as computed by EuclideanDistance(v1, v2), +// // should be >= 1.0. +// EXPECT_THAT(v1, DistanceFrom(v2, EuclideanDistance, Ge(1.0))); + +template +inline internal::DistanceFromMatcher +DistanceFrom(T target, GetDistance get_distance, + DistanceMatcher distance_matcher) { + return internal::DistanceFromMatcher( + std::move(target), std::move(get_distance), std::move(distance_matcher)); +} + +template +inline internal::DistanceFromMatcher +DistanceFrom(T target, DistanceMatcher distance_matcher) { + return DistanceFrom(std::move(target), internal::DefaultGetDistance(), + std::move(distance_matcher)); +} + // Creates a matcher that matches any double argument approximately equal to // rhs, up to the specified max absolute error bound, including NaN values when // rhs is NaN. The max absolute error bound must be non-negative. diff --git a/googlemock/test/gmock-matchers-arithmetic_test.cc b/googlemock/test/gmock-matchers-arithmetic_test.cc index 06b0b477..6a3fc89c 100644 --- a/googlemock/test/gmock-matchers-arithmetic_test.cc +++ b/googlemock/test/gmock-matchers-arithmetic_test.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include #include "gmock/gmock.h" @@ -398,6 +399,151 @@ TEST(NanSensitiveDoubleNearTest, CanDescribeSelfWithNaNs) { EXPECT_EQ("are an almost-equal pair", Describe(m)); } +// Tests that DistanceFrom() can describe itself properly. +TEST(DistanceFrom, CanDescribeSelf) { + Matcher m = DistanceFrom(1.5, Lt(0.1)); + EXPECT_EQ(Describe(m), "is < 0.1 away from 1.5"); + + m = DistanceFrom(2.5, Gt(0.2)); + EXPECT_EQ(Describe(m), "is > 0.2 away from 2.5"); +} + +// Tests that DistanceFrom() can explain match failure. +TEST(DistanceFrom, CanExplainMatchFailure) { + Matcher m = DistanceFrom(1.5, Lt(0.1)); + EXPECT_EQ(Explain(m, 2.0), "which is 0.5 away from 1.5"); +} + +// Tests that DistanceFrom() matches a double that is within the given range of +// the given value. +TEST(DistanceFrom, MatchesDoubleWithinRange) { + const Matcher m = DistanceFrom(0.5, Le(0.1)); + EXPECT_TRUE(m.Matches(0.45)); + EXPECT_TRUE(m.Matches(0.5)); + EXPECT_TRUE(m.Matches(0.55)); + EXPECT_FALSE(m.Matches(0.39)); + EXPECT_FALSE(m.Matches(0.61)); +} + +// Tests that DistanceFrom() matches a double reference that is within the given +// range of the given value. +TEST(DistanceFrom, MatchesDoubleRefWithinRange) { + const Matcher m = DistanceFrom(0.5, Le(0.1)); + EXPECT_TRUE(m.Matches(0.45)); + EXPECT_TRUE(m.Matches(0.5)); + EXPECT_TRUE(m.Matches(0.55)); + EXPECT_FALSE(m.Matches(0.39)); + EXPECT_FALSE(m.Matches(0.61)); +} + +// Tests that DistanceFrom() can be implicitly converted to a matcher depending +// on the type of the argument. +TEST(DistanceFrom, CanBeImplicitlyConvertedToMatcher) { + EXPECT_THAT(0.58, DistanceFrom(0.5, Le(0.1))); + EXPECT_THAT(0.2, Not(DistanceFrom(0.5, Le(0.1)))); + + EXPECT_THAT(0.58f, DistanceFrom(0.5f, Le(0.1f))); + EXPECT_THAT(0.7f, Not(DistanceFrom(0.5f, Le(0.1f)))); +} + +// Tests that DistanceFrom() can be used on compatible types (i.e. not +// everything has to be of the same type). +TEST(DistanceFrom, CanBeUsedOnCompatibleTypes) { + EXPECT_THAT(0.58, DistanceFrom(0.5, Le(0.1f))); + EXPECT_THAT(0.2, Not(DistanceFrom(0.5, Le(0.1f)))); + + EXPECT_THAT(0.58, DistanceFrom(0.5f, Le(0.1))); + EXPECT_THAT(0.2, Not(DistanceFrom(0.5f, Le(0.1)))); + + EXPECT_THAT(0.58, DistanceFrom(0.5f, Le(0.1f))); + EXPECT_THAT(0.2, Not(DistanceFrom(0.5f, Le(0.1f)))); + + EXPECT_THAT(0.58f, DistanceFrom(0.5, Le(0.1))); + EXPECT_THAT(0.2f, Not(DistanceFrom(0.5, Le(0.1)))); + + EXPECT_THAT(0.58f, DistanceFrom(0.5, Le(0.1f))); + EXPECT_THAT(0.2f, Not(DistanceFrom(0.5, Le(0.1f)))); + + EXPECT_THAT(0.58f, DistanceFrom(0.5f, Le(0.1))); + EXPECT_THAT(0.2f, Not(DistanceFrom(0.5f, Le(0.1)))); +} + +// A 2-dimensional point. For testing using DistanceFrom() with a custom type +// that doesn't have a built-in distance function. +class Point { + public: + Point(double x, double y) : x_(x), y_(y) {} + double x() const { return x_; } + double y() const { return y_; } + + private: + double x_; + double y_; +}; + +// Returns the distance between two points. +double PointDistance(const Point& lhs, const Point& rhs) { + return std::sqrt(std::pow(lhs.x() - rhs.x(), 2) + + std::pow(lhs.y() - rhs.y(), 2)); +} + +// Tests that DistanceFrom() can be used on a type with a custom distance +// function. +TEST(DistanceFrom, CanBeUsedOnTypeWithCustomDistanceFunction) { + const Matcher m = + DistanceFrom(Point(0.5, 0.5), PointDistance, Le(0.1)); + EXPECT_THAT(Point(0.45, 0.45), m); + EXPECT_THAT(Point(0.2, 0.45), Not(m)); +} + +// A wrapper around a double value. For testing using DistanceFrom() with a +// custom type that has neither a built-in distance function nor a built-in +// distance comparator. +class Double { + public: + explicit Double(double value) : value_(value) {} + Double(const Double& other) = default; + double value() const { return value_; } + + // Defines how to print a Double value. We don't use the AbslStringify API + // because googletest doesn't require absl yet. + friend void PrintTo(const Double& value, std::ostream* os) { + *os << "Double(" << value.value() << ")"; + } + + private: + double value_; +}; + +// Returns the distance between two Double values. +Double DoubleDistance(Double lhs, Double rhs) { + return Double(std::abs(lhs.value() - rhs.value())); +} + +MATCHER_P(DoubleLe, rhs, (negation ? "is > " : "is <= ") + PrintToString(rhs)) { + return arg.value() <= rhs.value(); +} + +// Tests that DistanceFrom() can describe itself properly for a type with a +// custom printer. +TEST(DistanceFrom, CanDescribeWithCustomPrinter) { + const Matcher m = + DistanceFrom(Double(0.5), DoubleDistance, DoubleLe(Double(0.1))); + EXPECT_EQ(Describe(m), "is <= Double(0.1) away from Double(0.5)"); + EXPECT_EQ(DescribeNegation(m), "is > Double(0.1) away from Double(0.5)"); +} + +// Tests that DistanceFrom() can be used with a custom distance function and +// comparator. +TEST(DistanceFrom, CanCustomizeDistanceAndComparator) { + const Matcher m = + DistanceFrom(Double(0.5), DoubleDistance, DoubleLe(Double(0.1))); + EXPECT_TRUE(m.Matches(Double(0.45))); + EXPECT_TRUE(m.Matches(Double(0.5))); + EXPECT_FALSE(m.Matches(Double(0.39))); + EXPECT_FALSE(m.Matches(Double(0.61))); +} + // Tests that Not(m) matches any value that doesn't match m. TEST(NotTest, NegatesMatcher) { Matcher m;