Merge branch 'master' of https://code.uocat.com/tqcq/sled
Some checks failed
linux-arm-gcc / linux-gcc-armhf (push) Successful in 1m54s
linux-x64-gcc / linux-gcc (Debug) (push) Successful in 1m50s
linux-mips64-gcc / linux-gcc-mips64el (Release) (push) Successful in 2m8s
linux-mips64-gcc / linux-gcc-mips64el (Debug) (push) Successful in 3m0s
linux-x64-gcc / linux-gcc (Release) (push) Successful in 3m17s
linux-aarch64-cpu-gcc / linux-gcc-aarch64 (push) Has been cancelled

This commit is contained in:
tqcq 2024-05-01 05:09:46 +00:00
commit 28009f189a
5 changed files with 242 additions and 3 deletions

View File

@ -197,7 +197,6 @@ if(SLED_BUILD_TESTS)
sled_add_test(NAME sled_thread_pool_test SRCS
src/sled/system/thread_pool_test.cc)
sled_add_test(NAME sled_event_bus_test SRCS
src/sled/event_bus/event_bus_test.cc)
sled_add_test(NAME sled_lua_test SRCS tests/lua_test.cc)
@ -215,7 +214,8 @@ if(SLED_BUILD_TESTS)
sled_add_test(NAME sled_inja_test SRCS src/sled/nonstd/inja_test.cc)
sled_add_test(NAME sled_fsm_test SRCS src/sled/nonstd/fsm_test.cc)
sled_add_test(NAME sled_timestamp_test SRCS src/sled/units/timestamp_test.cc)
sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc)
sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc
src/sled/futures/when_all_test.cc)
sled_add_test(
NAME sled_cache_test SRCS src/sled/cache/lru_cache_test.cc
src/sled/cache/fifo_cache_test.cc src/sled/cache/expire_cache_test.cc)

View File

@ -33,6 +33,7 @@ enum FutureState {
kNotCompletedFuture = 0,
kSuccessFuture = 1,
kFailedFuture = 2,
kCancelled = 3,
};
SLED_EXPORT void IncrementFuturesUsage();
@ -129,6 +130,12 @@ public:
return data_->state.load(std::memory_order_acquire) == future_detail::kSuccessFuture;
}
bool IsCancelled() const noexcept
{
SLED_ASSERT(data_ != nullptr, "Future is not valid");
return data_->state.load(std::memory_order_acquire) == future_detail::kCancelled;
}
bool IsValid() const noexcept { return static_cast<bool>(data_); }
bool Wait(int64_t timeout_ms) const noexcept { return Wait(sled::TimeDelta::Millis(timeout_ms)); }

View File

@ -18,7 +18,8 @@ class Promise final {
"Promise<_, void> is not allowed. Use Promise<_, bool> instead");
public:
using Value = T;
using ValueType = T;
using FailureType = FailureT;
Promise() = default;
Promise(const Promise &) = default;
Promise(Promise &&) noexcept = default;

124
src/sled/futures/when_all.h Normal file
View File

@ -0,0 +1,124 @@
#ifndef SLED_FUTURES_WHEN_ALL_H
#define SLED_FUTURES_WHEN_ALL_H
#pragma once
#include "future.h"
#include "sled/meta/type_traits.h"
#include "sled/synchronization/mutex.h"
namespace sled {
namespace futures {
struct WhenAllState {
sled::Mutex mutex;
sled::ConditionVariable cv;
std::atomic<int> count{0};
std::atomic<bool> has_failed{false};
};
template<typename T, typename FailureT>
void
WhenAllImpl(std::weak_ptr<WhenAllState> &weak_state, Future<T, FailureT> &future)
{
{
auto state = weak_state.lock();
if (!state) { return; }
state->count++;
}
future.OnSuccess([weak_state](const T &) {
auto state = weak_state.lock();
if (!state) { return; }
if (state->count.fetch_sub(1) == 0) {
sled::MutexLock lock(&state->mutex);
state->cv.NotifyAll();
}
});
future.OnFailure([weak_state](const FailureT &) {
auto state = weak_state.lock();
if (!state) { return; }
state->has_failed = true;
sled::MutexLock lock(&state->mutex);
state->cv.NotifyAll();
});
}
template<typename T, typename FailureT, typename... Args>
void
WhenAllImpl(std::weak_ptr<WhenAllState> &weak_state, Future<T, FailureT> &future, Args &&...futures)
{
{
auto state = weak_state.lock();
if (!state) { return; }
state->count++;
}
future.OnSuccess([weak_state](const T &) {
auto state = weak_state.lock();
if (!state) { return; }
if (state->count.fetch_sub(1) == 1) {
sled::MutexLock lock(&state->mutex);
state->cv.NotifyAll();
}
});
future.OnFailure([weak_state](const FailureT &) {
auto state = weak_state.lock();
if (!state) { return; }
state->has_failed = true;
sled::MutexLock lock(&state->mutex);
state->cv.NotifyAll();
});
WhenAllImpl(weak_state, std::forward<Args>(futures)...);
}
}// namespace futures
template<typename T, typename FailureT, typename... Args>
bool
WhenAll(Future<T, FailureT> &future, Args &&...futures)
{
auto state = std::make_shared<futures::WhenAllState>();
std::weak_ptr<futures::WhenAllState> weak_state = state;
futures::WhenAllImpl(weak_state, future, std::forward<Args>(futures)...);
{
sled::MutexLock lock(&state->mutex);
state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
}
return !state->has_failed.load();
}
template<typename ContainerT>
bool
WhenAll(ContainerT container)
{
auto state = std::make_shared<futures::WhenAllState>();
std::weak_ptr<futures::WhenAllState> weak_state = state;
for (auto &f : container) { futures::WhenAllImpl(weak_state, f); }
{
sled::MutexLock lock(&state->mutex);
state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
}
return !state->has_failed.load();
}
template<typename IteratorT>
bool
WhenAll(IteratorT begin, IteratorT end)
{
auto state = std::make_shared<futures::WhenAllState>();
std::weak_ptr<futures::WhenAllState> weak_state = state;
for (auto it = begin; it != end; ++it) { futures::WhenAllImpl(weak_state, *it); }
{
sled::MutexLock lock(&state->mutex);
state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
}
return !state->has_failed.load();
}
}// namespace sled
#endif// SLED_FUTURES_WHEN_ALL_H

View File

@ -0,0 +1,107 @@
#include <sled/futures/when_all.h>
#include <sled/random.h>
#include <sled/system/thread_pool.h>
#include <sled/time_utils.h>
TEST_SUITE("futures when_all")
{
TEST_CASE("single thread")
{
sled::Promise<int> p1;
sled::Promise<int> p2;
sled::Promise<int> p3;
auto f1 = p1.GetFuture();
auto f2 = p2.GetFuture();
auto f3 = p3.GetFuture();
SUBCASE("all success")
{
p1.Success(1);
p2.Success(2);
p3.Success(3);
CHECK(WhenAll(f1, f2, f3));
CHECK_EQ(f1.Result(), 1);
CHECK_EQ(f2.Result(), 2);
CHECK_EQ(f3.Result(), 3);
}
SUBCASE("all failed")
{
p1.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("1"));
p2.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("2"));
p3.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("3"));
CHECK_FALSE(WhenAll(f1, f2, f3));
CHECK(f1.IsFailed());
CHECK(f2.IsFailed());
CHECK(f3.IsFailed());
}
}
TEST_CASE("multi thread")
{
sled::ThreadPool pool(10);
std::vector<sled::Future<int>> futures;
SUBCASE("all success")
{
for (int i = 0; i < 1000; ++i) {
sled::Promise<int> p;
auto f = p.GetFuture();
pool.PostTask([p, i] { p.Success(i); });
futures.push_back(f);
}
CHECK(WhenAll(futures));
for (int i = 0; i < 1000; ++i) { CHECK_EQ(futures[i].Result(), i); }
}
SUBCASE("all failed")
{
for (int i = 0; i < 1000; ++i) {
sled::Promise<int> p;
auto f = p.GetFuture();
pool.PostTask([p, i] {
p.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>(std::to_string(i)));
});
futures.push_back(f);
}
CHECK_FALSE(WhenAll(futures));
bool has_failed = false;
for (int i = 0; i < 1000; ++i) {
if (futures[i].IsFailed()) {
has_failed = true;
break;
}
}
CHECK(has_failed);
}
SUBCASE("some failed")
{
for (int i = 0; i < 1000; ++i) {
sled::Promise<int> p;
auto f = p.GetFuture();
pool.PostTask([p, i] {
sled::Random random(sled::TimeUTCNanos());
sled::Thread::SleepMs(random.Rand(1, 15));
if (i % 2 == 0) {
p.Success(i);
} else {
p.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>(std::to_string(i)));
}
});
futures.push_back(f);
}
CHECK_FALSE(WhenAll(futures));
bool has_failed = false;
for (int i = 0; i < 1000; ++i) {
if (futures[i].IsFailed()) {
has_failed = true;
break;
}
}
CHECK(has_failed);
}
}
}