diff --git a/include/sled/futures/detail/base_future.h b/include/sled/futures/detail/base_future.h index bbaff54..55285cb 100644 --- a/include/sled/futures/detail/base_future.h +++ b/include/sled/futures/detail/base_future.h @@ -1,63 +1,208 @@ #ifndef SLED_FUTURES_DETAIL_BASE_FUTURE_H #define SLED_FUTURES_DETAIL_BASE_FUTURE_H -#include "sled/any.h" +#include "sled/log/log.h" +#include "sled/optional.h" #include "sled/synchronization/mutex.h" #include #include namespace sled { -template -class Promise; +namespace { +enum class State { + kCancel, + kPending, + kTimeout, + kError, + kValue, +}; +}; template -class FutureState { -public: - T Get() +struct FutureState { + mutable sled::Mutex mutex; + mutable sled::ConditionVariable cond_var; + + sled::optional value; + std::exception_ptr error; + + State state = State::kPending; + + void AssertHasValue() const { ASSERT(state == State::kValue, "can't find value"); } + + void AssertHasError() const { ASSERT(state == State::kError, "can't find error"); } + + void AssertHasTimeout() const { ASSERT(state == State::kTimeout, "can't find timeout"); } + + void Wait(sled::MutexLock *lock_ptr = nullptr) const { - sled::MutexLock lock(&mutex_); - cv_.Wait(&mutex_, [this]() { return done_; }); - return sled::any_cast(value_); + if (lock_ptr) { + if (state != State::kPending) { return; } + cond_var.Wait(*lock_ptr, [this] { return state != State::kPending; }); + } else { + sled::MutexLock lock(&mutex); + if (state != State::kPending) { return; } + cond_var.Wait(lock, [this] { return state != State::kPending; }); + } } void SetError(std::exception_ptr e) { - sled::MutexLock lock(&mutex_); - value_ = e; - done_ = true; - }; + sled::MutexLock lock(&mutex); + if (state == State::kCancel) { return; } + ASSERT(state == State::kPending, "state must be kPending"); + error = std::move(e); + state = State::kError; + cond_var.NotifyAll(); + } - template - typename std::enable_if::value && std::is_convertible::value>::type SetValue(U &&value) + void SetTimeout() { - sled::MutexLock lock(&mutex_); - value_ = static_cast(std::forward(value)); - done_ = true; + sled::MutexLock lock(&mutex); + if (state == State::kCancel) { return; } + ASSERT(state == State::kPending, "state must be kPending"); + state = State::kTimeout; + cond_var.NotifyAll(); } template - typename std::enable_if::value>::type SetValue(U &&value) + typename std::enable_if::value>::type SetValue(U &&val) { - sled::MutexLock lock(&mutex_); - done_ = true; + sled::MutexLock lock(&mutex); + if (state == State::kCancel) { return; } + ASSERT(state == State::kPending, "state must be kPending"); + value = std::forward(val); + state = State::kValue; + cond_var.NotifyAll(); + } +}; + +template<> +struct FutureState { + mutable sled::Mutex mutex; + mutable sled::ConditionVariable cond_var; + + std::exception_ptr error; + + State state = State::kPending; + + void AssertHasValue() const { ASSERT(state == State::kValue, "can't find value"); } + + void AssertHasError() const { ASSERT(state == State::kError, "can't find error"); } + + void AssertHasTimeout() const { ASSERT(state == State::kTimeout, "can't find timeout"); } + + void Wait(sled::MutexLock *lock_ptr = nullptr) const + { + if (lock_ptr) { + if (state != State::kPending) { return; } + cond_var.Wait(*lock_ptr, [this] { return state != State::kPending; }); + } else { + sled::MutexLock lock(&mutex); + if (state != State::kPending) { return; } + cond_var.Wait(lock, [this] { return state != State::kPending; }); + } } -private: - sled::Mutex mutex_; - sled::ConditionVariable cv_; - sled::any value_; - bool done_{false}; + void SetTimeout() + { + sled::MutexLock lock(&mutex); + if (state == State::kCancel) { return; } + ASSERT(state == State::kPending, "state must be kPending"); + state = State::kTimeout; + cond_var.NotifyAll(); + } + + void SetError(std::exception_ptr e) + { + sled::MutexLock lock(&mutex); + if (state == State::kCancel) { return; } + ASSERT(state == State::kPending, "state must be kPending"); + error = std::move(e); + state = State::kError; + cond_var.NotifyAll(); + } + + void SetValue() + { + sled::MutexLock lock(&mutex); + if (state == State::kCancel) { return; } + ASSERT(state == State::kPending, "state must be kPending"); + state = State::kValue; + cond_var.NotifyAll(); + } }; template -class BaseFuture { +class Future { public: - template - T Get() const + // using ValueType = typename std::remove_reference::type; + Future(std::shared_ptr> state) : state_(std::move(state)) {} + + T Get() const & { - return state_->Get(); + sled::MutexLock lock(&state_->mutex); + state_->Wait(&lock); + state_->AssertHasValue(); + return state_->value.value(); } + T &Get() & + { + sled::MutexLock lock(&state_->mutex); + state_->Wait(&lock); + state_->AssertHasValue(); + return state_->value.value(); + } + + T &&Get() && + { + sled::MutexLock lock(&state_->mutex); + state_->Wait(&lock); + state_->AssertHasValue(); + return std::move(state_->value.value()); + } + +private: + std::shared_ptr> state_; +}; + +template<> +class Future { +public: + Future(std::shared_ptr> state) : state_(std::move(state)) {} + + void Wait() const { state_->Wait(); } + + void Get() const { Wait(); } + +protected: + std::shared_ptr> state_; +}; + +template +class Promise { +public: + Promise() : state_(new FutureState()) {} + + Future GetFuture() const { return Future(state_); } + + template + typename std::enable_if::value && std::is_convertible::value>::type SetValue(U &&val) + { + state_->SetValue(val); + } + + template + typename std::enable_if::value>::type SetValue() + { + state_->SetValue(); + } + + void SetError(std::exception_ptr e) { state_->SetError(e); } + + void SetTimeout() { state_->SetTimeout(); } + private: std::shared_ptr> state_; }; diff --git a/include/sled/futures/detail/just.h b/include/sled/futures/detail/just.h new file mode 100644 index 0000000..58ac377 --- /dev/null +++ b/include/sled/futures/detail/just.h @@ -0,0 +1,39 @@ +#ifndef SLED_FUTURES_DETAIL_JUST_H +#define SLED_FUTURES_DETAIL_JUST_H + +#include + +namespace sled { +namespace detail { + +template +struct JustOperation { + T value; + R receiver; + + void Start() { receiver.SetValue(std::move(value)); } + + void Stop() { receiver.SetStopped(); } +}; + +template +struct JustSender { + T value; + + template + JustOperation Connect(R receiver) + { + return {value, receiver}; + } +}; + +template +JustSender +Just(T value) +{ + return {value}; +} + +}// namespace detail +}// namespace sled +#endif// SLED_FUTURES_DETAIL_JUST_H diff --git a/include/sled/futures/detail/retry.h b/include/sled/futures/detail/retry.h new file mode 100644 index 0000000..d82736c --- /dev/null +++ b/include/sled/futures/detail/retry.h @@ -0,0 +1,120 @@ +#ifndef SLED_FUTURES_DETAIL_RETRY_H +#define SLED_FUTURES_DETAIL_RETRY_H + +#include "sled/log/log.h" +#include "sled/synchronization/mutex.h" +#include "traits.h" +#include + +namespace sled { +namespace detail { + +namespace { +struct RetryState { + enum State { kPending, kDone, kRetry }; + + sled::Mutex mutex; + sled::ConditionVariable cv; + int retry_count = 0; + State state = kPending; +}; +}// namespace + +template +struct RetryReceiver { + std::shared_ptr state; + R receiver; + bool stopped = false; + + template + void SetValue(U &&val) + { + { + sled::MutexLock lock(&state->mutex); + if (stopped) { return; } + state->state = RetryState::kDone; + state->cv.NotifyAll(); + } + receiver.SetValue(std::forward(val)); + } + + void SetError(std::exception_ptr e) + { + // notify + { + sled::MutexLock lock(&state->mutex); + if (stopped) { return; } + if (state->retry_count > 0) { + --state->retry_count; + state->state = RetryState::kRetry; + return; + } else { + state->state = RetryState::kDone; + state->cv.NotifyAll(); + } + } + receiver.SetError(e); + } + + void SetStopped() + { + { + sled::MutexLock lock(&state->mutex); + if (stopped) { return; } + stopped = true; + state->state = RetryState::kDone; + state->cv.NotifyAll(); + } + receiver.SetStopped(); + } +}; + +template +struct RetryOperation { + int retry_count; + std::shared_ptr state; + ConnectResultT op; + + void Start() + { + { + sled::MutexLock lock(&state->mutex); + state->retry_count = retry_count; + state->state = RetryState::kPending; + } + do { + op.Start(); + sled::MutexLock lock(&state->mutex); + state->cv.Wait(lock, [this] { return state->state != RetryState::kPending; }); + if (state->state == RetryState::kDone) { break; } + state->state = RetryState::kPending; + } while (true); + } + + void Stop() { op.Stop(); } +}; + +template +struct RetrySender { + S sender; + int retry_count; + + template + RetryOperation> Connect(R receiver) + { + auto state = std::make_shared(); + auto op = sender.Connect(RetryReceiver{state, receiver}); + return {retry_count, state, op}; + } +}; + +template +RetrySender +Retry(S sender, int retry_count) +{ + return {sender, retry_count}; +} + +}// namespace detail +}// namespace sled +#endif// SLED_FUTURES_DETAIL_RETRY_H diff --git a/include/sled/futures/detail/then.h b/include/sled/futures/detail/then.h new file mode 100644 index 0000000..811941b --- /dev/null +++ b/include/sled/futures/detail/then.h @@ -0,0 +1,71 @@ +#ifndef SLED_FUTURES_DETAIL_THEN_H +#define SLED_FUTURES_DETAIL_THEN_H + +#include "traits.h" +#include + +namespace sled { +namespace detail { + +template +struct ThenReceiver { + R receiver; + F func; + bool stopped = false; + + template + void SetValue(U &&val) + { + if (stopped) { return; } + try { + receiver.SetValue(func(std::forward(val))); + } catch (...) { + SetError(std::current_exception()); + } + } + + void SetError(std::exception_ptr e) + { + if (stopped) { return; } + receiver.SetError(e); + } + + void SetStopped() + { + if (stopped) { return; } + stopped = true; + receiver.SetStopped(); + } +}; + +template +struct ThenOperation { + ConnectResultT op; + + void Start() { op.Start(); } + + void Stop() { op.Stop(); } +}; + +template +struct ThenSender { + S sender; + F func; + + template + ThenOperation> Connect(R receiver) + { + return {sender.Connect(ThenReceiver{receiver, func})}; + } +}; + +template +ThenSender +Then(S sender, F &&func) +{ + return {sender, std::forward(func)}; +} + +}// namespace detail +}// namespace sled +#endif// SLED_FUTURES_DETAIL_THEN_H diff --git a/include/sled/futures/detail/traits.h b/include/sled/futures/detail/traits.h new file mode 100644 index 0000000..332f930 --- /dev/null +++ b/include/sled/futures/detail/traits.h @@ -0,0 +1,18 @@ +#ifndef SLED_FUTURES_DETAIL_TRAITS_H +#define SLED_FUTURES_DETAIL_TRAITS_H + +#include + +namespace sled { +namespace detail { +template +struct ConnectResult { + typedef decltype(std::declval().Connect(std::declval())) type; +}; + +template +using ConnectResultT = typename ConnectResult::type; + +}// namespace detail +}// namespace sled +#endif// SLED_FUTURES_DETAIL_TRAITS_H diff --git a/include/sled/futures/detail/via.h b/include/sled/futures/detail/via.h new file mode 100644 index 0000000..7aa8990 --- /dev/null +++ b/include/sled/futures/detail/via.h @@ -0,0 +1,72 @@ +#ifndef SLED_FUTURES_DETAIL_VIA_H +#define SLED_FUTURES_DETAIL_VIA_H +#include "traits.h" +#include +#include +#include + +namespace sled { +namespace detail { +template +struct ViaReceiver { + R receiver; + F schedule; + bool stopped = false; + + template + void SetValue(U &&val) + { + if (stopped) { return; } + try { + auto func = std::bind(&R::SetValue, &receiver, std::forward(val)); + schedule(std::move(func)); + } catch (...) { + SetError(std::current_exception()); + } + } + + void SetError(std::exception_ptr e) + { + if (stopped) { return; } + receiver.SetError(e); + } + + void SetStopped() + { + if (stopped) { return; } + stopped = true; + receiver.SetStopped(); + } +}; + +template +struct ViaOperation { + ConnectResultT op; + + void Start() { op.Start(); } + + void Stop() { op.Stop(); } +}; + +template +struct ViaSender { + S sender; + F schedule; + + template + ViaOperation> Connect(R receiver) + { + return {sender.Connect(ViaReceiver{receiver, schedule})}; + } +}; + +template +ViaSender +Via(S sender, F &&schedule) +{ + return {sender, std::forward(schedule)}; +} + +}// namespace detail +}// namespace sled +#endif// SLED_FUTURES_DETAIL_VIA_H diff --git a/include/sled/sled.h b/include/sled/sled.h index 5c8df03..d3a06d0 100644 --- a/include/sled/sled.h +++ b/include/sled/sled.h @@ -11,7 +11,7 @@ #include "sled/filesystem/temporary_file.h" // futures -#include "sled/futures/promise.h" +// #include "sled/futures/promise.h" // lang #include "lang/attributes.h"