Commit 28009f18 authored by tqcq's avatar tqcq
Browse files

Merge branch 'master' of https://code.uocat.com/tqcq/sled

parents 12c9c36b 015bb678
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -197,7 +197,6 @@ if(SLED_BUILD_TESTS)

  sled_add_test(NAME sled_thread_pool_test SRCS
                src/sled/system/thread_pool_test.cc)

  sled_add_test(NAME sled_event_bus_test SRCS
                src/sled/event_bus/event_bus_test.cc)
  sled_add_test(NAME sled_lua_test SRCS tests/lua_test.cc)
@@ -215,7 +214,8 @@ if(SLED_BUILD_TESTS)
  sled_add_test(NAME sled_inja_test SRCS src/sled/nonstd/inja_test.cc)
  sled_add_test(NAME sled_fsm_test SRCS src/sled/nonstd/fsm_test.cc)
  sled_add_test(NAME sled_timestamp_test SRCS src/sled/units/timestamp_test.cc)
  sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc)
  sled_add_test(NAME sled_future_test SRCS src/sled/futures/future_test.cc
                src/sled/futures/when_all_test.cc)
  sled_add_test(
    NAME sled_cache_test SRCS src/sled/cache/lru_cache_test.cc
    src/sled/cache/fifo_cache_test.cc src/sled/cache/expire_cache_test.cc)
+7 −0
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@ enum FutureState {
    kNotCompletedFuture = 0,
    kSuccessFuture      = 1,
    kFailedFuture       = 2,
    kCancelled          = 3,
};

SLED_EXPORT void IncrementFuturesUsage();
@@ -129,6 +130,12 @@ public:
        return data_->state.load(std::memory_order_acquire) == future_detail::kSuccessFuture;
    }

    bool IsCancelled() const noexcept
    {
        SLED_ASSERT(data_ != nullptr, "Future is not valid");
        return data_->state.load(std::memory_order_acquire) == future_detail::kCancelled;
    }

    bool IsValid() const noexcept { return static_cast<bool>(data_); }

    bool Wait(int64_t timeout_ms) const noexcept { return Wait(sled::TimeDelta::Millis(timeout_ms)); }
+2 −1
Original line number Diff line number Diff line
@@ -18,7 +18,8 @@ class Promise final {
                  "Promise<_, void> is not allowed. Use Promise<_, bool> instead");

public:
    using Value                             = T;
    using ValueType                         = T;
    using FailureType                       = FailureT;
    Promise()                               = default;
    Promise(const Promise &)                = default;
    Promise(Promise &&) noexcept            = default;
+124 −0
Original line number Diff line number Diff line
#ifndef SLED_FUTURES_WHEN_ALL_H
#define SLED_FUTURES_WHEN_ALL_H

#pragma once
#include "future.h"
#include "sled/meta/type_traits.h"
#include "sled/synchronization/mutex.h"

namespace sled {

namespace futures {

struct WhenAllState {
    sled::Mutex mutex;
    sled::ConditionVariable cv;
    std::atomic<int> count{0};
    std::atomic<bool> has_failed{false};
};

template<typename T, typename FailureT>
void
WhenAllImpl(std::weak_ptr<WhenAllState> &weak_state, Future<T, FailureT> &future)
{
    {
        auto state = weak_state.lock();
        if (!state) { return; }
        state->count++;
    }
    future.OnSuccess([weak_state](const T &) {
        auto state = weak_state.lock();
        if (!state) { return; }
        if (state->count.fetch_sub(1) == 0) {
            sled::MutexLock lock(&state->mutex);
            state->cv.NotifyAll();
        }
    });
    future.OnFailure([weak_state](const FailureT &) {
        auto state = weak_state.lock();
        if (!state) { return; }
        state->has_failed = true;
        sled::MutexLock lock(&state->mutex);
        state->cv.NotifyAll();
    });
}

template<typename T, typename FailureT, typename... Args>
void
WhenAllImpl(std::weak_ptr<WhenAllState> &weak_state, Future<T, FailureT> &future, Args &&...futures)
{
    {
        auto state = weak_state.lock();
        if (!state) { return; }
        state->count++;
    }

    future.OnSuccess([weak_state](const T &) {
        auto state = weak_state.lock();
        if (!state) { return; }
        if (state->count.fetch_sub(1) == 1) {
            sled::MutexLock lock(&state->mutex);
            state->cv.NotifyAll();
        }
    });

    future.OnFailure([weak_state](const FailureT &) {
        auto state = weak_state.lock();
        if (!state) { return; }
        state->has_failed = true;
        sled::MutexLock lock(&state->mutex);
        state->cv.NotifyAll();
    });

    WhenAllImpl(weak_state, std::forward<Args>(futures)...);
}

}// namespace futures

template<typename T, typename FailureT, typename... Args>
bool
WhenAll(Future<T, FailureT> &future, Args &&...futures)
{
    auto state                                      = std::make_shared<futures::WhenAllState>();
    std::weak_ptr<futures::WhenAllState> weak_state = state;
    futures::WhenAllImpl(weak_state, future, std::forward<Args>(futures)...);

    {
        sled::MutexLock lock(&state->mutex);
        state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
    }

    return !state->has_failed.load();
}

template<typename ContainerT>
bool
WhenAll(ContainerT container)
{
    auto state                                      = std::make_shared<futures::WhenAllState>();
    std::weak_ptr<futures::WhenAllState> weak_state = state;
    for (auto &f : container) { futures::WhenAllImpl(weak_state, f); }
    {
        sled::MutexLock lock(&state->mutex);
        state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
    }
    return !state->has_failed.load();
}

template<typename IteratorT>
bool
WhenAll(IteratorT begin, IteratorT end)
{
    auto state                                      = std::make_shared<futures::WhenAllState>();
    std::weak_ptr<futures::WhenAllState> weak_state = state;
    for (auto it = begin; it != end; ++it) { futures::WhenAllImpl(weak_state, *it); }
    {
        sled::MutexLock lock(&state->mutex);
        state->cv.Wait(lock, [state] { return state->count.load() == 0 || state->has_failed.load(); });
    }
    return !state->has_failed.load();
}

}// namespace sled

#endif// SLED_FUTURES_WHEN_ALL_H
+107 −0
Original line number Diff line number Diff line
#include <sled/futures/when_all.h>
#include <sled/random.h>
#include <sled/system/thread_pool.h>
#include <sled/time_utils.h>

TEST_SUITE("futures when_all")
{
    TEST_CASE("single thread")
    {
        sled::Promise<int> p1;
        sled::Promise<int> p2;
        sled::Promise<int> p3;

        auto f1 = p1.GetFuture();
        auto f2 = p2.GetFuture();
        auto f3 = p3.GetFuture();

        SUBCASE("all success")
        {
            p1.Success(1);
            p2.Success(2);
            p3.Success(3);

            CHECK(WhenAll(f1, f2, f3));
            CHECK_EQ(f1.Result(), 1);
            CHECK_EQ(f2.Result(), 2);
            CHECK_EQ(f3.Result(), 3);
        }
        SUBCASE("all failed")
        {
            p1.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("1"));
            p2.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("2"));
            p3.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>("3"));

            CHECK_FALSE(WhenAll(f1, f2, f3));
            CHECK(f1.IsFailed());
            CHECK(f2.IsFailed());
            CHECK(f3.IsFailed());
        }
    }

    TEST_CASE("multi thread")
    {
        sled::ThreadPool pool(10);
        std::vector<sled::Future<int>> futures;

        SUBCASE("all success")
        {
            for (int i = 0; i < 1000; ++i) {
                sled::Promise<int> p;
                auto f = p.GetFuture();
                pool.PostTask([p, i] { p.Success(i); });
                futures.push_back(f);
            }
            CHECK(WhenAll(futures));
            for (int i = 0; i < 1000; ++i) { CHECK_EQ(futures[i].Result(), i); }
        }

        SUBCASE("all failed")
        {
            for (int i = 0; i < 1000; ++i) {
                sled::Promise<int> p;
                auto f = p.GetFuture();
                pool.PostTask([p, i] {
                    p.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>(std::to_string(i)));
                });
                futures.push_back(f);
            }
            CHECK_FALSE(WhenAll(futures));
            bool has_failed = false;
            for (int i = 0; i < 1000; ++i) {
                if (futures[i].IsFailed()) {
                    has_failed = true;
                    break;
                }
            }
            CHECK(has_failed);
        }

        SUBCASE("some failed")
        {
            for (int i = 0; i < 1000; ++i) {
                sled::Promise<int> p;
                auto f = p.GetFuture();
                pool.PostTask([p, i] {
                    sled::Random random(sled::TimeUTCNanos());
                    sled::Thread::SleepMs(random.Rand(1, 15));
                    if (i % 2 == 0) {
                        p.Success(i);
                    } else {
                        p.Failure(sled::failure::FailureFromString<sled::Promise<int>::FailureType>(std::to_string(i)));
                    }
                });
                futures.push_back(f);
            }
            CHECK_FALSE(WhenAll(futures));
            bool has_failed = false;
            for (int i = 0; i < 1000; ++i) {
                if (futures[i].IsFailed()) {
                    has_failed = true;
                    break;
                }
            }
            CHECK(has_failed);
        }
    }
}