feat add when_all
Some checks failed
linux-mips64-gcc / linux-gcc-mips64el (Debug) (push) Successful in 1m55s
linux-arm-gcc / linux-gcc-armhf (push) Successful in 2m0s
linux-x64-gcc / linux-gcc (Debug) (push) Successful in 2m1s
linux-x64-gcc / linux-gcc (Release) (push) Successful in 2m18s
linux-mips64-gcc / linux-gcc-mips64el (Release) (push) Successful in 4m42s
linux-aarch64-cpu-gcc / linux-gcc-aarch64 (push) Has been cancelled
Some checks failed
linux-mips64-gcc / linux-gcc-mips64el (Debug) (push) Successful in 1m55s
linux-arm-gcc / linux-gcc-armhf (push) Successful in 2m0s
linux-x64-gcc / linux-gcc (Debug) (push) Successful in 2m1s
linux-x64-gcc / linux-gcc (Release) (push) Successful in 2m18s
linux-mips64-gcc / linux-gcc-mips64el (Release) (push) Successful in 4m42s
linux-aarch64-cpu-gcc / linux-gcc-aarch64 (push) Has been cancelled
This commit is contained in:
parent
907d39108f
commit
015bb678cd
@ -197,7 +197,6 @@ if(SLED_BUILD_TESTS)
|
||||
|
||||
sled_add_test(NAME sled_thread_pool_test SRCS
|
||||
src/sled/system/thread_pool_test.cc)
|
||||
|
||||
sled_add_test(NAME sled_event_bus_test SRCS
|
||||
src/sled/event_bus/event_bus_test.cc)
|
||||
sled_add_test(NAME sled_lua_test SRCS tests/lua_test.cc)
|
||||
@ -215,7 +214,8 @@ if(SLED_BUILD_TESTS)
|
||||
sled_add_test(NAME sled_inja_test SRCS src/sled/nonstd/inja_test.cc)
|
||||
sled_add_test(NAME sled_fsm_test SRCS src/sled/nonstd/fsm_test.cc)
|
||||
sled_add_test(NAME sled_timestamp_test SRCS src/sled/units/timestamp_test.cc)
|
||||
sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc)
|
||||
sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc
|
||||
src/sled/futures/when_all_test.cc)
|
||||
sled_add_test(
|
||||
NAME sled_cache_test SRCS src/sled/cache/lru_cache_test.cc
|
||||
src/sled/cache/fifo_cache_test.cc src/sled/cache/expire_cache_test.cc)
|
||||
|
@ -33,6 +33,7 @@ enum FutureState {
|
||||
kNotCompletedFuture = 0,
|
||||
kSuccessFuture = 1,
|
||||
kFailedFuture = 2,
|
||||
kCancelled = 3,
|
||||
};
|
||||
|
||||
SLED_EXPORT void IncrementFuturesUsage();
|
||||
@ -129,6 +130,12 @@ public:
|
||||
return data_->state.load(std::memory_order_acquire) == future_detail::kSuccessFuture;
|
||||
}
|
||||
|
||||
bool IsCancelled() const noexcept
|
||||
{
|
||||
SLED_ASSERT(data_ != nullptr, "Future is not valid");
|
||||
return data_->state.load(std::memory_order_acquire) == future_detail::kCancelled;
|
||||
}
|
||||
|
||||
bool IsValid() const noexcept { return static_cast<bool>(data_); }
|
||||
|
||||
bool Wait(int64_t timeout_ms) const noexcept { return Wait(sled::TimeDelta::Millis(timeout_ms)); }
|
||||
|
@ -18,7 +18,8 @@ class Promise final {
|
||||
"Promise<_, void> is not allowed. Use Promise<_, bool> instead");
|
||||
|
||||
public:
|
||||
using Value = T;
|
||||
using ValueType = T;
|
||||
using FailureType = FailureT;
|
||||
Promise() = default;
|
||||
Promise(const Promise &) = default;
|
||||
Promise(Promise &&) noexcept = default;
|
||||
|
124
src/sled/futures/when_all.h
Normal file
124
src/sled/futures/when_all.h
Normal file
@ -0,0 +1,124 @@
|
||||
#ifndef SLED_FUTURES_WHEN_ALL_H
|
||||
#define SLED_FUTURES_WHEN_ALL_H
|
||||
|
||||
#pragma once
|
||||
#include "future.h"
|
||||
#include "sled/meta/type_traits.h"
|
||||
#include "sled/synchronization/mutex.h"
|
||||
|
||||
namespace sled {
|
||||
|
||||
namespace futures {
|
||||
|
||||
struct WhenAllState {
|
||||
sled::Mutex mutex;
|
||||
sled::ConditionVariable cv;
|
||||
std::atomic<int> count{0};
|
||||
std::atomic<bool> has_failed{false};
|
||||
};
|
||||
|
||||
template<typename T, typename FailureT>
|
||||
void
|
||||
WhenAllImpl(std::weak_ptr<WhenAllState> &weak_state, Future<T, FailureT> &future)
|
||||
{
|
||||
{
|
||||
auto state = weak_state.lock();
|
||||
if (!state) { return; }
|
||||
state->count++;
|
||||
}
|
||||
future.OnSuccess([weak_state](const T &) {
|
||||
auto state = weak_state.lock();
|
||||
if (!state) { return; }
|
||||
if (state->count.fetch_sub(1) == 0) {
|
||||
sled::MutexLock lock(&state->mutex);
|
||||
state->cv.NotifyAll();
|
||||
}
|
||||
});
|
||||
future.OnFailure([weak_state](const FailureT &) {
|
||||
auto state = weak_state.lock();
|
||||
if (!state) { return; }
|
||||
state->has_failed = true;
|
||||
sled::MutexLock lock(&state->mutex);
|
||||
state->cv.NotifyAll();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T, typename FailureT, typename... Args>
|
||||
void
|
||||
WhenAllImpl(std::weak_ptr<WhenAllState> &weak_state, Future<T, FailureT> &future, Args &&...futures)
|
||||
{
|
||||
{
|
||||
auto state = weak_state.lock();
|
||||
if (!state) { return; }
|
||||
state->count++;
|
||||
}
|
||||
|
||||
future.OnSuccess([weak_state](const T &) {
|
||||
auto state = weak_state.lock();
|
||||
if (!state) { return; }
|
||||
if (state->count.fetch_sub(1) == 1) {
|
||||
sled::MutexLock lock(&state->mutex);
|
||||
state->cv.NotifyAll();
|
||||
}
|
||||
});
|
||||
|
||||
future.OnFailure([weak_state](const FailureT &) {
|
||||
auto state = weak_state.lock();
|
||||
if (!state) { return; }
|
||||
state->has_failed = true;
|
||||
sled::MutexLock lock(&state->mutex);
|
||||
state->cv.NotifyAll();
|
||||
});
|
||||
|
||||
WhenAllImpl(weak_state, std::forward<Args>(futures)...);
|
||||
}
|
||||
|
||||
}// namespace futures
|
||||
|
||||
template<typename T, typename FailureT, typename... Args>
|
||||
bool
|
||||
WhenAll(Future<T, FailureT> &future, Args &&...futures)
|
||||
{
|
||||
auto state = std::make_shared<futures::WhenAllState>();
|
||||
std::weak_ptr<futures::WhenAllState> weak_state = state;
|
||||
futures::WhenAllImpl(weak_state, future, std::forward<Args>(futures)...);
|
||||
|
||||
{
|
||||
sled::MutexLock lock(&state->mutex);
|
||||
state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
|
||||
}
|
||||
|
||||
return !state->has_failed.load();
|
||||
}
|
||||
|
||||
template<typename ContainerT>
|
||||
bool
|
||||
WhenAll(ContainerT container)
|
||||
{
|
||||
auto state = std::make_shared<futures::WhenAllState>();
|
||||
std::weak_ptr<futures::WhenAllState> weak_state = state;
|
||||
for (auto &f : container) { futures::WhenAllImpl(weak_state, f); }
|
||||
{
|
||||
sled::MutexLock lock(&state->mutex);
|
||||
state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
|
||||
}
|
||||
return !state->has_failed.load();
|
||||
}
|
||||
|
||||
template<typename IteratorT>
|
||||
bool
|
||||
WhenAll(IteratorT begin, IteratorT end)
|
||||
{
|
||||
auto state = std::make_shared<futures::WhenAllState>();
|
||||
std::weak_ptr<futures::WhenAllState> weak_state = state;
|
||||
for (auto it = begin; it != end; ++it) { futures::WhenAllImpl(weak_state, *it); }
|
||||
{
|
||||
sled::MutexLock lock(&state->mutex);
|
||||
state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
|
||||
}
|
||||
return !state->has_failed.load();
|
||||
}
|
||||
|
||||
}// namespace sled
|
||||
|
||||
#endif// SLED_FUTURES_WHEN_ALL_H
|
107
src/sled/futures/when_all_test.cc
Normal file
107
src/sled/futures/when_all_test.cc
Normal file
@ -0,0 +1,107 @@
|
||||
#include <sled/futures/when_all.h>
|
||||
#include <sled/random.h>
|
||||
#include <sled/system/thread_pool.h>
|
||||
#include <sled/time_utils.h>
|
||||
|
||||
TEST_SUITE("futures when_all")
|
||||
{
|
||||
TEST_CASE("single thread")
|
||||
{
|
||||
sled::Promise<int> p1;
|
||||
sled::Promise<int> p2;
|
||||
sled::Promise<int> p3;
|
||||
|
||||
auto f1 = p1.GetFuture();
|
||||
auto f2 = p2.GetFuture();
|
||||
auto f3 = p3.GetFuture();
|
||||
|
||||
SUBCASE("all success")
|
||||
{
|
||||
p1.Success(1);
|
||||
p2.Success(2);
|
||||
p3.Success(3);
|
||||
|
||||
CHECK(WhenAll(f1, f2, f3));
|
||||
CHECK_EQ(f1.Result(), 1);
|
||||
CHECK_EQ(f2.Result(), 2);
|
||||
CHECK_EQ(f3.Result(), 3);
|
||||
}
|
||||
SUBCASE("all failed")
|
||||
{
|
||||
p1.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("1"));
|
||||
p2.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("2"));
|
||||
p3.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("3"));
|
||||
|
||||
CHECK_FALSE(WhenAll(f1, f2, f3));
|
||||
CHECK(f1.IsFailed());
|
||||
CHECK(f2.IsFailed());
|
||||
CHECK(f3.IsFailed());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("multi thread")
|
||||
{
|
||||
sled::ThreadPool pool(10);
|
||||
std::vector<sled::Future<int>> futures;
|
||||
|
||||
SUBCASE("all success")
|
||||
{
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
sled::Promise<int> p;
|
||||
auto f = p.GetFuture();
|
||||
pool.PostTask([p, i] { p.Success(i); });
|
||||
futures.push_back(f);
|
||||
}
|
||||
CHECK(WhenAll(futures));
|
||||
for (int i = 0; i < 1000; ++i) { CHECK_EQ(futures[i].Result(), i); }
|
||||
}
|
||||
|
||||
SUBCASE("all failed")
|
||||
{
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
sled::Promise<int> p;
|
||||
auto f = p.GetFuture();
|
||||
pool.PostTask([p, i] {
|
||||
p.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>(std::to_string(i)));
|
||||
});
|
||||
futures.push_back(f);
|
||||
}
|
||||
CHECK_FALSE(WhenAll(futures));
|
||||
bool has_failed = false;
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
if (futures[i].IsFailed()) {
|
||||
has_failed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(has_failed);
|
||||
}
|
||||
|
||||
SUBCASE("some failed")
|
||||
{
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
sled::Promise<int> p;
|
||||
auto f = p.GetFuture();
|
||||
pool.PostTask([p, i] {
|
||||
sled::Random random(sled::TimeUTCNanos());
|
||||
sled::Thread::SleepMs(random.Rand(1, 15));
|
||||
if (i % 2 == 0) {
|
||||
p.Success(i);
|
||||
} else {
|
||||
p.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>(std::to_string(i)));
|
||||
}
|
||||
});
|
||||
futures.push_back(f);
|
||||
}
|
||||
CHECK_FALSE(WhenAll(futures));
|
||||
bool has_failed = false;
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
if (futures[i].IsFailed()) {
|
||||
has_failed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(has_failed);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user