From d304f6da0219d05f7fd31c267d425c700ebbe780 Mon Sep 17 00:00:00 2001 From: tqcq <99722391+tqcq@users.noreply.github.com> Date: Fri, 22 Mar 2024 17:44:43 +0800 Subject: [PATCH] feat update --- CMakeLists.txt | 1 + include/sled/exec/detail/just.h | 2 +- include/sled/exec/detail/retry.h | 67 +++++++++++++++++++++++ include/sled/exec/detail/sync_wait.h | 6 +- include/sled/exec/detail/then.h | 13 +++-- include/sled/futures/base_cell.h | 15 +++++ include/sled/futures/future.h | 51 +++++++++++++++++ include/sled/futures/just.h | 28 ++++++++++ include/sled/futures/then.h | 40 ++++++++++++++ include/sled/system/thread.h | 35 ++---------- include/sled/system/thread_pool.h | 24 +++++--- include/sled/task_queue/task_queue_base.h | 59 ++++++++++---------- src/futures/future_test.cc | 8 +++ src/system/thread.cc | 39 ++++--------- src/system/thread_pool.cc | 36 ++++++++++-- src/task_queue/task_queue_base.cc | 15 ++++- 16 files changed, 330 insertions(+), 109 deletions(-) create mode 100644 include/sled/exec/detail/retry.h create mode 100644 include/sled/futures/base_cell.h create mode 100644 include/sled/futures/future.h create mode 100644 include/sled/futures/just.h create mode 100644 include/sled/futures/then.h create mode 100644 src/futures/future_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index ec49ef8..45260b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,6 +120,7 @@ if(SLED_BUILD_TESTS) src/exec/just_test.cc src/any_test.cc src/filesystem/path_test.cc + src/futures/future_test.cc # src/profiling/profiling_test.cc src/strings/base64_test.cc src/cleanup_test.cc diff --git a/include/sled/exec/detail/just.h b/include/sled/exec/detail/just.h index df38285..a75c370 100644 --- a/include/sled/exec/detail/just.h +++ b/include/sled/exec/detail/just.h @@ -12,7 +12,7 @@ struct JustOperation { TReceiver receiver; T value; - void Start() { receiver.SetValue(std::move(value)); } + void Start() { receiver.SetValue(value); } }; template diff --git a/include/sled/exec/detail/retry.h b/include/sled/exec/detail/retry.h new file mode 100644 index 0000000..ec16a18 --- /dev/null +++ b/include/sled/exec/detail/retry.h @@ -0,0 +1,67 @@ +#ifndef SLED_EXEC_DETAIL_RETRY_H +#define SLED_EXEC_DETAIL_RETRY_H +#include +#include +#include +#pragma once + +#include "traits.h" + +namespace sled { + +struct RetryState { + int retry_count; + bool need_retry; +}; + +template +struct RetryReceiver { + TReceiver receiver; + std::shared_ptr state; + + template + void SetValue(T &&value) + { + receiver.SetValue(value); + } + + void SetError(std::exception_ptr err) + { + if (state->retry_count < 0) {} + } + + void SetStopped() { receiver.SetStopped(); } +}; + +template +struct RetryOperation { + ConnectResultT> op; + std::shared_ptr state; + + void Start() {} +}; + +template +struct RetrySender { + using S = typename std::remove_cv::type>::type; + using result_t = SenderResultT; + S sender; + int retry_count; + + template + RetryOperation Connect(TReceiver &&receiver) + { + auto retry_state = std::make_shared(new RetryState{retry_count, false}); + return {sender.Connect(RetryReceiver{receiver, retry_state}), retry_state}; + } +}; + +template +RetrySender +Retry(TSender &&sender, int retry_count) +{ + return {std::forward(sender), retry_count}; +} + +}// namespace sled +#endif// SLED_EXEC_DETAIL_RETRY_H diff --git a/include/sled/exec/detail/sync_wait.h b/include/sled/exec/detail/sync_wait.h index 9b380c3..63dbace 100644 --- a/include/sled/exec/detail/sync_wait.h +++ b/include/sled/exec/detail/sync_wait.h @@ -8,7 +8,7 @@ #include namespace sled { -struct SyncWaitData { +struct SyncWaitState { sled::Mutex lock; sled::ConditionVariable cv; std::exception_ptr err; @@ -17,7 +17,7 @@ struct SyncWaitData { template struct SyncWaitReceiver { - SyncWaitData &data; + SyncWaitState &data; sled::optional &value; void SetValue(T &&val) @@ -49,7 +49,7 @@ sled::optional> SyncWait(TSender sender) { using T = SenderResultT; - SyncWaitData data; + SyncWaitState data; sled::optional value; auto op = sender.Connect(SyncWaitReceiver{data, value}); diff --git a/include/sled/exec/detail/then.h b/include/sled/exec/detail/then.h index 2df6db5..77b403b 100644 --- a/include/sled/exec/detail/then.h +++ b/include/sled/exec/detail/then.h @@ -15,7 +15,11 @@ struct ThenReceiver { template void SetValue(T &&value) { - receiver.SetValue(func(std::forward(value))); + try { + receiver.SetValue(func(std::forward(value))); + } catch (...) { + receiver.SetError(std::current_exception()); + } } void SetError(std::exception_ptr err) { receiver.SetError(err); } @@ -32,8 +36,9 @@ struct ThenOperation { template struct ThenSender { - using result_t = typename eggs::invoke_result_t>; - TSender sender; + using S = typename std::remove_cv::type>::type; + using result_t = typename eggs::invoke_result_t>; + S sender; F func; template @@ -45,7 +50,7 @@ struct ThenSender { template ThenSender -Then(TSender sender, F &&func) +Then(TSender &&sender, F &&func) { return {std::forward(sender), std::forward(func)}; } diff --git a/include/sled/futures/base_cell.h b/include/sled/futures/base_cell.h new file mode 100644 index 0000000..e13fee0 --- /dev/null +++ b/include/sled/futures/base_cell.h @@ -0,0 +1,15 @@ +#ifndef SLED_FUTURES_BASE_CELL_H +#define SLED_FUTURES_BASE_CELL_H +#include +#pragma once + +namespace sled { +namespace futures { + +struct BaseCell { + void *scheduler; +}; +}// namespace futures + +}// namespace sled +#endif// SLED_FUTURES_BASE_CELL_H diff --git a/include/sled/futures/future.h b/include/sled/futures/future.h new file mode 100644 index 0000000..8d06c56 --- /dev/null +++ b/include/sled/futures/future.h @@ -0,0 +1,51 @@ +#ifndef SLED_FUTURES_FUTHRE_H +#define SLED_FUTURES_FUTHRE_H + +#include "sled/any.h" +#include "sled/exec/detail/invoke_result.h" +#include "sled/optional.h" +#include "sled/synchronization/mutex.h" +#include +#include + +namespace sled { + +template +class Future; +template +class Promise; + +template +struct FPState : std::enable_shared_from_this> { + sled::Mutex lock; + sled::optional data; + std::exception_ptr err; + bool done; + sled::any priv; +}; + +template +class Future { +public: + using result_t = T; + + Future(std::shared_ptr> state) : state_(state) {} + +private: + std::shared_ptr> state_; +}; + +template +class Promise { +public: + using result_t = T; + + void SetValue(T &&value) {} + + void SetError(std::exception_ptr err) {} + + Future GetFuture() {} +}; + +}// namespace sled +#endif// SLED_FUTURES_FUTHRE_H diff --git a/include/sled/futures/just.h b/include/sled/futures/just.h new file mode 100644 index 0000000..6265c33 --- /dev/null +++ b/include/sled/futures/just.h @@ -0,0 +1,28 @@ +#ifndef SLED_FUTURES_JUST_H +#define SLED_FUTURES_JUST_H +#include +#pragma once + +namespace sled { +namespace futures { +template +struct JustCell { + T value; + + template + void Start(R receiver) + { + receiver.SetValue(value); + } +}; + +template +JustCell +Just(T &&t) +{ + return {std::forward(t)}; +} + +}// namespace futures +}// namespace sled +#endif// SLED_FUTURES_JUST_H diff --git a/include/sled/futures/then.h b/include/sled/futures/then.h new file mode 100644 index 0000000..a709054 --- /dev/null +++ b/include/sled/futures/then.h @@ -0,0 +1,40 @@ + +#ifndef SLED_FUTURES_THEN_H +#define SLED_FUTURES_THEN_H +#include +#include +#pragma once + +namespace sled { +namespace futures { + +template +struct ThenCell { + S sender; + F func; + + // T value; + + template + void Start(R receiver) + { + sender.Start(); + } + + template + void SetValue(U &&value) + {} + + void SetError(std::exception_ptr err) {} +}; + +template +ThenCell +Then(S sender, F &&func) +{ + return {std::forward(sender), std::forward(func)}; +} + +}// namespace futures +}// namespace sled +#endif// SLED_FUTURES_THEN_H diff --git a/include/sled/system/thread.h b/include/sled/system/thread.h index 084f1ae..339cae2 100644 --- a/include/sled/system/thread.h +++ b/include/sled/system/thread.h @@ -59,25 +59,6 @@ public: Thread(const Thread &) = delete; Thread &operator=(const Thread &) = delete; - void BlockingCall(std::function functor, - const Location &location = Location::Current()) - { - BlockingCallImpl(functor, location); - } - - template::type, - typename = typename std::enable_if::value, - ReturnT>::type> - ReturnT BlockingCall(Functor &&functor, - const Location &location = Location::Current()) - { - ReturnT result; - BlockingCall([&] { result = std::forward(functor)(); }, - location); - return result; - } - static std::unique_ptr CreateWithSocketServer(); static std::unique_ptr Create(); static Thread *Current(); @@ -122,8 +103,7 @@ protected: bool operator<(const DelayedMessage &dmsg) const { return (dmsg.run_time_ms < run_time_ms) - || ((dmsg.run_time_ms == run_time_ms) - && (dmsg.message_number < message_number)); + || ((dmsg.run_time_ms == run_time_ms) && (dmsg.message_number < message_number)); } int64_t delay_ms; @@ -132,15 +112,12 @@ protected: mutable std::function functor; }; - void PostTaskImpl(std::function &&task, - const PostTaskTraits &traits, - const Location &location) override; + void PostTaskImpl(std::function &&task, const PostTaskTraits &traits, const Location &location) override; void PostDelayedTaskImpl(std::function &&task, TimeDelta delay, const PostDelayedTaskTraits &traits, const Location &location) override; - virtual void BlockingCallImpl(std::function functor, - const Location &location); + void BlockingCallImpl(std::function &&functor, const Location &location) override; void DoInit(); void DoDestroy(); @@ -150,8 +127,7 @@ private: std::function Get(int cmsWait); void Dispatch(std::function &&task); static void *PreRun(void *pv); - bool WrapCurrentWithThreadManager(ThreadManager *thread_manager, - bool need_synchronize_access); + bool WrapCurrentWithThreadManager(ThreadManager *thread_manager, bool need_synchronize_access); bool IsRunning(); // for ThreadManager @@ -171,8 +147,7 @@ private: std::unique_ptr thread_; bool owned_; - std::unique_ptr - task_queue_registration_; + std::unique_ptr task_queue_registration_; friend class ThreadManager; }; diff --git a/include/sled/system/thread_pool.h b/include/sled/system/thread_pool.h index 98e8945..05fc6b5 100644 --- a/include/sled/system/thread_pool.h +++ b/include/sled/system/thread_pool.h @@ -2,11 +2,12 @@ #ifndef SLED_SYSTEM_THREAD_POOL_H #define SLED_SYSTEM_THREAD_POOL_H #include "sled/system/fiber/scheduler.h" +#include "sled/system/thread.h" #include #include namespace sled { -class ThreadPool final { +class ThreadPool final : public TaskQueueBase { public: /** * @param num_threads The number of threads to create in the thread pool. If @@ -18,16 +19,25 @@ public: template auto submit(F &&f, Args &&...args) -> std::future { - std::function func = - std::bind(std::forward(f), std::forward(args)...); - auto task_ptr = - std::make_shared>(func); - scheduler->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); + std::function func = std::bind(std::forward(f), std::forward(args)...); + auto task_ptr = std::make_shared>(func); + scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); return task_ptr->get_future(); } + void Delete() override; + +protected: + void PostTaskImpl(std::function &&task, const PostTaskTraits &traits, const Location &location) override; + + void PostDelayedTaskImpl(std::function &&task, + TimeDelta delay, + const PostDelayedTaskTraits &traits, + const Location &location) override; + private: - sled::Scheduler *scheduler; + sled::Scheduler *scheduler_; + std::unique_ptr delayed_thread_; }; }// namespace sled diff --git a/include/sled/task_queue/task_queue_base.h b/include/sled/task_queue/task_queue_base.h index 4c17666..10463ce 100644 --- a/include/sled/task_queue/task_queue_base.h +++ b/include/sled/task_queue/task_queue_base.h @@ -22,42 +22,34 @@ public: }; struct Deleter { - void operator()(TaskQueueBase *task_queue) const - { - task_queue->Delete(); - } + void operator()(TaskQueueBase *task_queue) const { task_queue->Delete(); } }; virtual void Delete() = 0; - inline void PostTask(std::function &&task, - const Location &location = Location::Current()) + inline void PostTask(std::function &&task, const Location &location = Location::Current()) { PostTaskImpl(std::move(task), PostTaskTraits{}, location); } - inline void PostDelayedTask(std::function &&task, - TimeDelta delay, - const Location &location = Location::Current()) + inline void + PostDelayedTask(std::function &&task, TimeDelta delay, const Location &location = Location::Current()) { - PostDelayedTaskImpl(std::move(task), delay, PostDelayedTaskTraits{}, - location); + PostDelayedTaskImpl(std::move(task), delay, PostDelayedTaskTraits{}, location); } - inline void - PostDelayedHighPrecisionTask(std::function &&task, - TimeDelta delay, - const Location &location = Location::Current()) + inline void PostDelayedHighPrecisionTask(std::function &&task, + TimeDelta delay, + const Location &location = Location::Current()) { static PostDelayedTaskTraits traits(true); PostDelayedTaskImpl(std::move(task), delay, traits, location); } - inline void - PostDelayedTaskWithPrecision(DelayPrecision precision, - std::function &&task, - TimeDelta delay, - const Location &location = Location::Current()) + inline void PostDelayedTaskWithPrecision(DelayPrecision precision, + std::function &&task, + TimeDelta delay, + const Location &location = Location::Current()) { switch (precision) { case DelayPrecision::kLow: @@ -69,6 +61,21 @@ public: } } + void BlockingCall(std::function functor, const Location &location = Location::Current()) + { + BlockingCallImpl(std::move(functor), location); + } + + template::type, + typename = typename std::enable_if::value, ReturnT>::type> + ReturnT BlockingCall(Functor &&functor, const Location &location = Location::Current()) + { + ReturnT result; + BlockingCall([&] { result = std::forward(functor)(); }, location); + return result; + } + static TaskQueueBase *Current(); bool IsCurrent() const { return Current() == this; }; @@ -77,20 +84,17 @@ protected: struct PostTaskTraits {}; struct PostDelayedTaskTraits { - PostDelayedTaskTraits(bool high_precision = false) - : high_precision(high_precision) - {} + PostDelayedTaskTraits(bool high_precision = false) : high_precision(high_precision) {} bool high_precision = false; }; - virtual void PostTaskImpl(std::function &&task, - const PostTaskTraits &traits, - const Location &location) = 0; + virtual void PostTaskImpl(std::function &&task, const PostTaskTraits &traits, const Location &location) = 0; virtual void PostDelayedTaskImpl(std::function &&task, TimeDelta delay, const PostDelayedTaskTraits &traits, const Location &location) = 0; + virtual void BlockingCallImpl(std::function &&task, const Location &location); virtual ~TaskQueueBase() = default; class CurrentTaskQueueSetter { @@ -98,8 +102,7 @@ protected: explicit CurrentTaskQueueSetter(TaskQueueBase *task_queue); ~CurrentTaskQueueSetter(); CurrentTaskQueueSetter(const CurrentTaskQueueSetter &) = delete; - CurrentTaskQueueSetter & - operator=(const CurrentTaskQueueSetter &) = delete; + CurrentTaskQueueSetter &operator=(const CurrentTaskQueueSetter &) = delete; private: TaskQueueBase *const previous_; diff --git a/src/futures/future_test.cc b/src/futures/future_test.cc new file mode 100644 index 0000000..bf568cb --- /dev/null +++ b/src/futures/future_test.cc @@ -0,0 +1,8 @@ +#include +#include + +TEST(Future, basic) +{ + // sled::Future x; + // auto res = x.Then([](int) {}); +} diff --git a/src/system/thread.cc b/src/system/thread.cc index 1ea853b..1ead5cc 100644 --- a/src/system/thread.cc +++ b/src/system/thread.cc @@ -45,8 +45,7 @@ void ThreadManager::RemoveInternal(Thread *message_queue) { MutexLock lock(&cirt_); - auto iter = std::find(message_queues_.begin(), message_queues_.end(), - message_queue); + auto iter = std::find(message_queues_.begin(), message_queues_.end(), message_queue); if (iter != message_queues_.end()) { message_queues_.erase(iter); } } @@ -96,8 +95,7 @@ ThreadManager::ProcessAllMessageQueueInternal() MutexLock lock(&cirt_); for (Thread *queue : message_queues_) { queues_not_done.fetch_add(1); - auto sub = - MakeCleanup([&queues_not_done] { queues_not_done.fetch_sub(1); }); + auto sub = MakeCleanup([&queues_not_done] { queues_not_done.fetch_sub(1); }); queue->PostDelayedTask([&sub] {}, TimeDelta::Zero()); } @@ -115,9 +113,7 @@ ThreadManager::SetCurrentThreadInternal(Thread *message_queue) Thread::Thread(SocketServer *ss) : Thread(ss, /*do_init=*/true) {} -Thread::Thread(std::unique_ptr ss) - : Thread(std::move(ss), /*do_init=*/true) -{} +Thread::Thread(std::unique_ptr ss) : Thread(std::move(ss), /*do_init=*/true) {} Thread::Thread(SocketServer *ss, bool do_init) : delayed_next_num_(0), @@ -131,11 +127,7 @@ Thread::Thread(SocketServer *ss, bool do_init) if (do_init) { DoInit(); } } -Thread::Thread(std::unique_ptr ss, bool do_init) - : Thread(ss.get(), do_init) -{ - own_ss_ = std::move(ss); -} +Thread::Thread(std::unique_ptr ss, bool do_init) : Thread(ss.get(), do_init) { own_ss_ = std::move(ss); } Thread::~Thread() { @@ -244,13 +236,10 @@ Thread::Get(int cmsWait) cmsNext = cmsDelayNext; } else { cmsNext = std::max(0, cmsTotal - cmsElapsed); - if ((cmsDelayNext != kForever) && (cmsDelayNext < cmsNext)) { - cmsNext = cmsDelayNext; - } + if ((cmsDelayNext != kForever) && (cmsDelayNext < cmsNext)) { cmsNext = cmsDelayNext; } } { - if (!ss_->Wait(cmsNext == kForever ? SocketServer::kForever - : TimeDelta::Millis(cmsNext), + if (!ss_->Wait(cmsNext == kForever ? SocketServer::kForever : TimeDelta::Millis(cmsNext), /*process_io=*/true)) { return nullptr; } @@ -266,9 +255,7 @@ Thread::Get(int cmsWait) } void -Thread::PostTaskImpl(std::function &&task, - const PostTaskTraits &traits, - const Location &location) +Thread::PostTaskImpl(std::function &&task, const PostTaskTraits &traits, const Location &location) { if (IsQuitting()) { return; } { @@ -303,8 +290,7 @@ Thread::PostDelayedTaskImpl(std::function &&task, } void -Thread::BlockingCallImpl(std::function functor, - const Location &location) +Thread::BlockingCallImpl(std::function &&functor, const Location &location) { if (IsQuitting()) { return; } if (IsCurrent()) { @@ -373,8 +359,7 @@ Thread::SetName(const std::string &name, const void *obj) void Thread::EnsureIsCurrentTaskQueue() { - task_queue_registration_.reset( - new TaskQueueBase::CurrentTaskQueueSetter(this)); + task_queue_registration_.reset(new TaskQueueBase::CurrentTaskQueueSetter(this)); } void @@ -426,8 +411,7 @@ Thread::PreRun(void *pv) } bool -Thread::WrapCurrentWithThreadManager(ThreadManager *thread_manager, - bool need_synchronize_access) +Thread::WrapCurrentWithThreadManager(ThreadManager *thread_manager, bool need_synchronize_access) { // assert(!IsRunning()); owned_ = false; @@ -498,8 +482,7 @@ Thread::Current() return thread; } -AutoSocketServerThread::AutoSocketServerThread(SocketServer *ss) - : Thread(ss, /*do_init=*/false) +AutoSocketServerThread::AutoSocketServerThread(SocketServer *ss) : Thread(ss, /*do_init=*/false) { DoInit(); old_thread_ = ThreadManager::Instance()->CurrentThread(); diff --git a/src/system/thread_pool.cc b/src/system/thread_pool.cc index 43ec03f..51773f5 100644 --- a/src/system/thread_pool.cc +++ b/src/system/thread_pool.cc @@ -1,15 +1,39 @@ #include "sled/system/thread_pool.h" +#include "sled/system/location.h" +#include "sled/task_queue/task_queue_base.h" namespace sled { ThreadPool::ThreadPool(int num_threads) { - if (num_threads == -1) { - num_threads = std::thread::hardware_concurrency(); - } - scheduler = new sled::Scheduler( - sled::Scheduler::Config().setWorkerThreadCount(num_threads)); + if (num_threads == -1) { num_threads = std::thread::hardware_concurrency(); } + scheduler_ = new sled::Scheduler(sled::Scheduler::Config().setWorkerThreadCount(num_threads)); } -ThreadPool::~ThreadPool() { delete scheduler; } +ThreadPool::~ThreadPool() { delete scheduler_; } + +void +ThreadPool::Delete() +{} + +void +ThreadPool::PostTaskImpl(std::function &&task, const PostTaskTraits &traits, const Location &location) +{ + scheduler_->enqueue(marl::Task([task] { task(); })); +} + +void +ThreadPool::PostDelayedTaskImpl(std::function &&task, + TimeDelta delay, + const PostDelayedTaskTraits &traits, + const Location &location) +{ + if (traits.high_precision) { + delayed_thread_->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kHigh, std::move(task), delay, + location); + } else { + delayed_thread_->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kLow, std::move(task), delay, + location); + } +} }// namespace sled diff --git a/src/task_queue/task_queue_base.cc b/src/task_queue/task_queue_base.cc index 877979f..47b2a78 100644 --- a/src/task_queue/task_queue_base.cc +++ b/src/task_queue/task_queue_base.cc @@ -1,4 +1,5 @@ #include "sled/task_queue/task_queue_base.h" +#include "sled/synchronization/event.h" namespace sled { namespace { @@ -11,12 +12,22 @@ TaskQueueBase::Current() return current; } -TaskQueueBase::CurrentTaskQueueSetter::CurrentTaskQueueSetter(TaskQueueBase *task_queue) - : previous_(current) +TaskQueueBase::CurrentTaskQueueSetter::CurrentTaskQueueSetter(TaskQueueBase *task_queue) : previous_(current) { current = task_queue; } TaskQueueBase::CurrentTaskQueueSetter::~CurrentTaskQueueSetter() { current = previous_; } +void +TaskQueueBase::BlockingCallImpl(std::function &&functor, const sled::Location &from) +{ + Event done; + PostTask([functor, &done] { + functor(); + done.Set(); + }); + done.Wait(Event::kForever); +} + }// namespace sled