diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d8b01a..f926135 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -197,7 +197,6 @@ if(SLED_BUILD_TESTS) sled_add_test(NAME sled_thread_pool_test SRCS src/sled/system/thread_pool_test.cc) - sled_add_test(NAME sled_event_bus_test SRCS src/sled/event_bus/event_bus_test.cc) sled_add_test(NAME sled_lua_test SRCS tests/lua_test.cc) @@ -215,7 +214,8 @@ if(SLED_BUILD_TESTS) sled_add_test(NAME sled_inja_test SRCS src/sled/nonstd/inja_test.cc) sled_add_test(NAME sled_fsm_test SRCS src/sled/nonstd/fsm_test.cc) sled_add_test(NAME sled_timestamp_test SRCS src/sled/units/timestamp_test.cc) - sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc) + sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc + src/sled/futures/when_all_test.cc) sled_add_test( NAME sled_cache_test SRCS src/sled/cache/lru_cache_test.cc src/sled/cache/fifo_cache_test.cc src/sled/cache/expire_cache_test.cc) diff --git a/src/sled/futures/future.h b/src/sled/futures/future.h index cece94b..138a8a7 100644 --- a/src/sled/futures/future.h +++ b/src/sled/futures/future.h @@ -33,6 +33,7 @@ enum FutureState { kNotCompletedFuture = 0, kSuccessFuture = 1, kFailedFuture = 2, + kCancelled = 3, }; SLED_EXPORT void IncrementFuturesUsage(); @@ -129,6 +130,12 @@ public: return data_->state.load(std::memory_order_acquire) == future_detail::kSuccessFuture; } + bool IsCancelled() const noexcept + { + SLED_ASSERT(data_ != nullptr, "Future is not valid"); + return data_->state.load(std::memory_order_acquire) == future_detail::kCancelled; + } + bool IsValid() const noexcept { return static_cast(data_); } bool Wait(int64_t timeout_ms) const noexcept { return Wait(sled::TimeDelta::Millis(timeout_ms)); } diff --git a/src/sled/futures/internal/promise.h b/src/sled/futures/internal/promise.h index 7252f85..e02d8e1 100644 --- a/src/sled/futures/internal/promise.h +++ b/src/sled/futures/internal/promise.h @@ -18,7 +18,8 @@ class Promise final { "Promise<_, void> is not allowed. Use Promise<_, bool> instead"); public: - using Value = T; + using ValueType = T; + using FailureType = FailureT; Promise() = default; Promise(const Promise &) = default; Promise(Promise &&) noexcept = default; diff --git a/src/sled/futures/when_all.h b/src/sled/futures/when_all.h new file mode 100644 index 0000000..4e5953a --- /dev/null +++ b/src/sled/futures/when_all.h @@ -0,0 +1,124 @@ +#ifndef SLED_FUTURES_WHEN_ALL_H +#define SLED_FUTURES_WHEN_ALL_H + +#pragma once +#include "future.h" +#include "sled/meta/type_traits.h" +#include "sled/synchronization/mutex.h" + +namespace sled { + +namespace futures { + +struct WhenAllState { + sled::Mutex mutex; + sled::ConditionVariable cv; + std::atomic count{0}; + std::atomic has_failed{false}; +}; + +template +void +WhenAllImpl(std::weak_ptr &weak_state, Future &future) +{ + { + auto state = weak_state.lock(); + if (!state) { return; } + state->count++; + } + future.OnSuccess([weak_state](const T &) { + auto state = weak_state.lock(); + if (!state) { return; } + if (state->count.fetch_sub(1) == 0) { + sled::MutexLock lock(&state->mutex); + state->cv.NotifyAll(); + } + }); + future.OnFailure([weak_state](const FailureT &) { + auto state = weak_state.lock(); + if (!state) { return; } + state->has_failed = true; + sled::MutexLock lock(&state->mutex); + state->cv.NotifyAll(); + }); +} + +template +void +WhenAllImpl(std::weak_ptr &weak_state, Future &future, Args &&...futures) +{ + { + auto state = weak_state.lock(); + if (!state) { return; } + state->count++; + } + + future.OnSuccess([weak_state](const T &) { + auto state = weak_state.lock(); + if (!state) { return; } + if (state->count.fetch_sub(1) == 1) { + sled::MutexLock lock(&state->mutex); + state->cv.NotifyAll(); + } + }); + + future.OnFailure([weak_state](const FailureT &) { + auto state = weak_state.lock(); + if (!state) { return; } + state->has_failed = true; + sled::MutexLock lock(&state->mutex); + state->cv.NotifyAll(); + }); + + WhenAllImpl(weak_state, std::forward(futures)...); +} + +}// namespace futures + +template +bool +WhenAll(Future &future, Args &&...futures) +{ + auto state = std::make_shared(); + std::weak_ptr weak_state = state; + futures::WhenAllImpl(weak_state, future, std::forward(futures)...); + + { + sled::MutexLock lock(&state->mutex); + state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); }); + } + + return !state->has_failed.load(); +} + +template +bool +WhenAll(ContainerT container) +{ + auto state = std::make_shared(); + std::weak_ptr weak_state = state; + for (auto &f : container) { futures::WhenAllImpl(weak_state, f); } + { + sled::MutexLock lock(&state->mutex); + state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); }); + } + return !state->has_failed.load(); +} + +template +bool +WhenAll(IteratorT begin, IteratorT end) +{ + auto state = std::make_shared(); + std::weak_ptr weak_state = state; + for (auto it = begin; it != end; ++it) { futures::WhenAllImpl(weak_state, *it); } + { + sled::MutexLock lock(&state->mutex); + state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); }); + } + return !state->has_failed.load(); +} + +}// namespace sled + +#endif// SLED_FUTURES_WHEN_ALL_H diff --git a/src/sled/futures/when_all_test.cc b/src/sled/futures/when_all_test.cc new file mode 100644 index 0000000..2402c37 --- /dev/null +++ b/src/sled/futures/when_all_test.cc @@ -0,0 +1,107 @@ +#include +#include +#include +#include + +TEST_SUITE("futures when_all") +{ + TEST_CASE("single thread") + { + sled::Promise p1; + sled::Promise p2; + sled::Promise p3; + + auto f1 = p1.GetFuture(); + auto f2 = p2.GetFuture(); + auto f3 = p3.GetFuture(); + + SUBCASE("all success") + { + p1.Success(1); + p2.Success(2); + p3.Success(3); + + CHECK(WhenAll(f1, f2, f3)); + CHECK_EQ(f1.Result(), 1); + CHECK_EQ(f2.Result(), 2); + CHECK_EQ(f3.Result(), 3); + } + SUBCASE("all failed") + { + p1.Failure(sled::failure::FailureFromString::FailureType>("1")); + p2.Failure(sled::failure::FailureFromString::FailureType>("2")); + p3.Failure(sled::failure::FailureFromString::FailureType>("3")); + + CHECK_FALSE(WhenAll(f1, f2, f3)); + CHECK(f1.IsFailed()); + CHECK(f2.IsFailed()); + CHECK(f3.IsFailed()); + } + } + + TEST_CASE("multi thread") + { + sled::ThreadPool pool(10); + std::vector> futures; + + SUBCASE("all success") + { + for (int i = 0; i < 1000; ++i) { + sled::Promise p; + auto f = p.GetFuture(); + pool.PostTask([p, i] { p.Success(i); }); + futures.push_back(f); + } + CHECK(WhenAll(futures)); + for (int i = 0; i < 1000; ++i) { CHECK_EQ(futures[i].Result(), i); } + } + + SUBCASE("all failed") + { + for (int i = 0; i < 1000; ++i) { + sled::Promise p; + auto f = p.GetFuture(); + pool.PostTask([p, i] { + p.Failure(sled::failure::FailureFromString::FailureType>(std::to_string(i))); + }); + futures.push_back(f); + } + CHECK_FALSE(WhenAll(futures)); + bool has_failed = false; + for (int i = 0; i < 1000; ++i) { + if (futures[i].IsFailed()) { + has_failed = true; + break; + } + } + CHECK(has_failed); + } + + SUBCASE("some failed") + { + for (int i = 0; i < 1000; ++i) { + sled::Promise p; + auto f = p.GetFuture(); + pool.PostTask([p, i] { + sled::Random random(sled::TimeUTCNanos()); + sled::Thread::SleepMs(random.Rand(1, 15)); + if (i % 2 == 0) { + p.Success(i); + } else { + p.Failure(sled::failure::FailureFromString::FailureType>(std::to_string(i))); + } + }); + futures.push_back(f); + } + CHECK_FALSE(WhenAll(futures)); + bool has_failed = false; + for (int i = 0; i < 1000; ++i) { + if (futures[i].IsFailed()) { + has_failed = true; + break; + } + } + CHECK(has_failed); + } + } +}