diff --git a/src/sled/futures/future.h b/src/sled/futures/future.h index b70fc71..4ca0f2e 100644 --- a/src/sled/futures/future.h +++ b/src/sled/futures/future.h @@ -10,8 +10,9 @@ #include "sled/synchronization/event.h" #include "sled/synchronization/mutex.h" #include "sled/task_queue/task_queue_base.h" -#include "sled/variant.h" +#include "sled/utility/forward_on_copy.h" #include +#include #include namespace sled { @@ -48,7 +49,8 @@ struct FutureData { ~FutureData() { DecrementFuturesUsage(); } std::atomic_int state{kNotCompletedFuture}; - sled::variant value; + // sled::variant value; + sled::any value; std::list> success_callbacks; std::list> failure_callbacks; sled::Mutex mutex_; @@ -130,7 +132,8 @@ public: if (!IsCompleted()) Wait(); if (IsSucceeded()) { try { - return sled::get(data_->value); + // return sled::get(data_->value); + return sled::any_cast(data_->value); } catch (...) {} } return T(); @@ -140,7 +143,8 @@ public: { SLED_ASSERT(data_ != nullptr, "Future is not valid"); if (!IsCompleted()) { Wait(); } - return sled::get(data_->value); + // return sled::get(data_->value); + return sled::any_cast(data_->value); } FailureT FailureReason() const @@ -149,7 +153,8 @@ public: if (!IsCompleted()) { Wait(); } if (IsFailed()) { try { - return sled::get(data_->value); + // return sled::get(data_->value); + return sled::any_cast(data_->value); } catch (...) {} } return FailureT(); @@ -176,7 +181,8 @@ public: } if (call_it) { try { - f(sled::get(data_->value)); + // f(sled::get(data_->value)); + f(sled::any_cast(data_->value)); } catch (...) {} } return Future(data_); @@ -203,7 +209,8 @@ public: } if (call_it) { try { - f(sled::get(data_->value)); + // f(sled::get(data_->value)); + f(sled::any_cast(data_->value)); } catch (...) {} } return Future(data_); @@ -294,6 +301,20 @@ public: return result; } + template().Result()>)> + Future AndThen(Func &&f) const noexcept + { + return FlatMap([f](const T &) { return f(); }); + } + + template::type> + Future AndThenValue(T2 &&value) const noexcept + { + Future result = Future::Create(); + auto forward_on_copy = sled::MakeForwardOnCopy(std::forward(value)); + return Map([forward_on_copy](const T &) noexcept { return forward_on_copy.value(); }); + } + Future Via(TaskQueueBase *task_queue) const noexcept { SLED_ASSERT(task_queue != nullptr, "TaskQueue is not valid"); @@ -365,6 +386,7 @@ private: if (IsCompleted()) { return; } try { + // data_->value.template emplace(std::move(value)); // data_->value.template emplace(std::move(value)); data_->value = std::move(value); } catch (...) {} @@ -376,7 +398,8 @@ private: for (const auto &f : callbacks) { try { - f(sled::get(data_->value)); + // f(sled::get(data_->value)); + f(sled::any_cast(data_->value)); } catch (...) {} } } @@ -396,6 +419,7 @@ private: if (IsCompleted()) { return; } try { // data_->value.template emplace(std::move(reason)); + // data_->value = std::move(reason); data_->value = std::move(reason); } catch (...) {} data_->state.store(detail::kFailedFuture, std::memory_order_release); @@ -406,7 +430,8 @@ private: for (const auto &f : callbacks) { try { - f(sled::get(data_->value)); + // f(sled::get(data_->value)); + f(sled::any_cast(data_->value)); } catch (...) {} } } diff --git a/src/sled/futures/future_test.cc b/src/sled/futures/future_test.cc index fb9c47a..82b1290 100644 --- a/src/sled/futures/future_test.cc +++ b/src/sled/futures/future_test.cc @@ -1,9 +1,11 @@ #include #include +#include TEST_SUITE("future") { TEST_CASE("base success") + { sled::Promise p; auto f = p.GetFuture(); @@ -12,6 +14,7 @@ TEST_SUITE("future") CHECK(f.IsValid()); CHECK_EQ(f.Result(), 42); } + TEST_CASE("base failed") { sled::Promise p; @@ -23,31 +26,65 @@ TEST_SUITE("future") CHECK(f.IsValid()); CHECK_EQ(f.FailureReason(), "error"); } + + TEST_CASE("throw") + { + sled::Promise p; + auto f = p.GetFuture().Map([](int x) { + throw std::runtime_error("test"); + return x; + }); + + p.Success(42); + REQUIRE(f.IsCompleted()); + REQUIRE(f.IsFailed()); + CHECK_EQ(f.FailureReason(), "test"); + } + + TEST_CASE("base failed") + { + sled::Promise p; + auto f = p.GetFuture(); + p.Failure("error"); + REQUIRE(p.IsFilled()); + REQUIRE(f.IsCompleted()); + CHECK(f.Wait(-1)); + CHECK(f.IsValid()); + CHECK_EQ(f.FailureReason(), "error"); + } + TEST_CASE("thread success") {} - TEST_CASE("map") + TEST_CASE("Map") { sled::Promise p; auto f = p.GetFuture(); - auto f2 = f.Map([](int i) { return i + 1; }); + auto f2 = f.Map([](int i) { return i + 1; }) + .Map([](int i) { return std::to_string(i) + "00"; }) + .Map([](const std::string &str) { + std::stringstream ss(str); + int x; + ss >> x; + return x; + }); p.Success(42); CHECK(f2.Wait(-1)); - CHECK_EQ(f2.Result(), 43); + CHECK_EQ(f2.Result(), 4300); } TEST_CASE("FlatMap") { - // sled::Promise p; - // auto f = p.GetFuture().FlatMap([](int i) { - // auto str = std::to_string(i); - // sled::Promise p; - // p.Success(str); - // - // return p.GetFuture(); - // }); - // p.Success(42); - // CHECK(f.Wait(-1)); - // CHECK_EQ(f.Result(), "42"); + sled::Promise p; + auto f = p.GetFuture().FlatMap([](int i) { + auto str = std::to_string(i); + sled::Promise p; + p.Success(str); + + return p.GetFuture(); + }); + p.Success(42); + CHECK(f.IsCompleted()); + CHECK_EQ(f.Result(), "42"); } TEST_CASE("Via") diff --git a/src/sled/utility/forward_on_copy.h b/src/sled/utility/forward_on_copy.h new file mode 100644 index 0000000..53bff08 --- /dev/null +++ b/src/sled/utility/forward_on_copy.h @@ -0,0 +1,76 @@ +#ifndef SLED_UTILITY_FORWARD_ON_COPY_H +#define SLED_UTILITY_FORWARD_ON_COPY_H + +#pragma once +#include "sled/any.h" + +namespace sled { +namespace detail { + +struct ForwardOnCopyTag {}; + +};// namespace detail + +template +struct ForwardOnCopy { + ForwardOnCopy(ForwardOnCopy &other) + : is_moved_(other.is_moved_), + value_(is_moved_ ? other.value_ : std::move(other.value_)) + {} + + // ForwardOnCopy(ForwardOnCopy &&other) noexcept; + ForwardOnCopy &operator=(ForwardOnCopy &other) + { + is_moved_ = other.is_moved_; + value_ = is_moved_ ? other.value_ : std::move(other.value_); + return *this; + } + + // ForwardOnCopy &operator=(ForwardOnCopy &&other) noexcept; + T value() const + { + if (is_moved_) { + return std::move(sled::any_cast(value_)); + } else { + return sled::any_cast(value_); + } + } + +private: + bool is_moved_; + sled::any value_; +}; + +template +ForwardOnCopy +MakeForwardOnCopy(T value) +{ + ForwardOnCopy result; + result.is_moved_ = false; + result.value_ = value; + return result; +} + +template +ForwardOnCopy +MakeForwardOnCopy(T &value) +{ + ForwardOnCopy result; + result.is_moved_ = false; + result.value_ = value; + return result; +} + +template +ForwardOnCopy +MakeForwardOnCopy(T &&value) +{ + ForwardOnCopy result; + result.is_moved_ = true; + result.value_ = std::move(value); + return result; +} + +}// namespace sled + +#endif// SLED_UTILITY_FORWARD_ON_COPY_H