Loading CMakeLists.txt +0 −1 Original line number Diff line number Diff line Loading @@ -184,7 +184,6 @@ if(SLED_BUILD_TESTS) sled_all_tests SRCS src/sled/debugging/demangle_test.cc src/sled/async/async_test.cc src/sled/filesystem/path_test.cc src/sled/log/fmt_test.cc src/sled/synchronization/sequence_checker_test.cc Loading src/sled/async/async.cc +4 −3 Original line number Diff line number Diff line Loading @@ -20,18 +20,19 @@ FiberScheduler::FiberScheduler() { } static ThreadPool thread_pool; void FiberScheduler::schedule(async::task_run_handle t) { static ThreadPool thread_pool; auto move_on_copy = sled::MakeMoveOnCopy(t); thread_pool.submit([move_on_copy] { move_on_copy.value.run_with_wait_handler(SleepWaitHandler); }); // thread_pool.submit([move_on_copy] { move_on_copy.value.run(); }); thread_pool.PostTask([move_on_copy] { move_on_copy.value.run(); }); } }// namespace sled // clang-format on namespace async { sled::FiberScheduler & default_scheduler() Loading src/sled/system/thread_pool.cc +78 −12 Original line number Diff line number Diff line #include "sled/system/thread_pool.h" #include "sled/log/log.h" #include "sled/system/location.h" #include "sled/task_queue/task_queue_base.h" namespace sled { constexpr char ThreadPool::kTag[]; ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create()) { if (num_threads == -1) { num_threads = std::thread::hardware_concurrency(); } scheduler_ = new sled::Scheduler(sled::Scheduler::Config().setWorkerThreadCount(num_threads)); if (num_threads <= 0) { num_threads = std::thread::hardware_concurrency(); } auto state = std::make_shared<State>(); for (int i = 0; i < num_threads; i++) { threads_.emplace_back(std::thread([state] { state->idle++; while (state->is_running) { std::function<void()> task; sled::Location loc = SLED_FROM_HERE; // fetch task { sled::MutexLock lock(&state->mutex); state->cv.Wait(lock, [state] { return !state->task_queue.empty() || !state->is_running; }); if (!state->task_queue.empty()) { task = std::move(state->task_queue.front().first); loc = state->task_queue.front().second; state->task_queue.pop(); } if (!state->task_queue.empty()) { state->cv.NotifyOne(); } } if (task) { state->idle--; try { task(); } catch (const std::exception &e) { LOGE(kTag, "ThreadPool::ThreadPool() task exception: {}, from={}", e.what(), loc.ToString()); } catch (...) { LOGE(kTag, "ThreadPool::ThreadPool() task unknown exception, from={}", loc.ToString()); } state->idle++; } } })); } state_ = state; delayed_thread_->Start(); } ThreadPool::~ThreadPool() { delete scheduler_; } ThreadPool::~ThreadPool() { Delete(); } void ThreadPool::Delete() {} { if (state_) { sled::MutexLock lock(&state_->mutex); state_->is_running = false; state_->cv.NotifyAll(); state_ = nullptr; } delayed_thread_.reset(); for (auto &thread : threads_) { thread.join(); } } int ThreadPool::idle() const { auto state = state_; SLED_ASSERT(state != nullptr, "ThreadPool::idle() state_ is nullptr"); if (state->is_running) { return state->idle; } return 0; } void ThreadPool::PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) { scheduler_->enqueue(marl::Task([task] { task(); })); auto state = state_; SLED_ASSERT(state != nullptr, "ThreadPool::PostTaskImpl() state_ is nullptr"); if (!state->is_running) { LOGW(kTag, "ThreadPool::PostTaskImpl() state is not running"); return; } sled::MutexLock lock(&state->mutex); state->task_queue.emplace(std::move(task), location); state->cv.NotifyOne(); } void Loading @@ -28,19 +90,23 @@ ThreadPool::PostDelayedTaskImpl(std::function<void()> &&task, const PostDelayedTaskTraits &traits, const Location &location) { auto move_task_to_fiber = [task]() { task(); }; auto weak_state = std::weak_ptr<State>(state_); auto delay_post = [task, location, weak_state]() { auto state = weak_state.lock(); if (!state || !state->is_running) { return; } sled::MutexLock lock(&state->mutex); state->task_queue.emplace(std::move(task), location); state->cv.NotifyOne(); }; if (traits.high_precision) { delayed_thread_->PostDelayedTaskWithPrecision( TaskQueueBase::DelayPrecision::kHigh, std::move(move_task_to_fiber), std::move(delay_post), delay, location); } else { delayed_thread_->PostDelayedTaskWithPrecision( TaskQueueBase::DelayPrecision::kLow, std::move(move_task_to_fiber), delay, location); delayed_thread_ ->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kLow, std::move(delay_post), delay, location); } } Loading src/sled/system/thread_pool.h +25 −11 Original line number Diff line number Diff line #pragma once #ifndef SLED_SYSTEM_THREAD_POOL_H #define SLED_SYSTEM_THREAD_POOL_H #include "sled/synchronization/mutex.h" #include "sled/system/fiber/scheduler.h" #include "sled/system/thread.h" #include <functional> Loading @@ -9,6 +10,7 @@ namespace sled { class ThreadPool final : public TaskQueueBase { public: static constexpr char kTag[] = "ThreadPool"; /** * @param num_threads The number of threads to create in the thread pool. If * -1, the number of threads will be equal to the number of hardware threads Loading @@ -16,18 +18,20 @@ public: ThreadPool(int num_threads = -1); ~ThreadPool(); template<typename F, typename... Args> auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> { std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...); auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func); auto future = task_ptr->get_future(); scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); return future; } // template<typename F, typename... Args> // auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> // { // std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...); // auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func); // auto future = task_ptr->get_future(); // // scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); // return future; // } void Delete() override; int idle() const; protected: void PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) override; Loading @@ -37,8 +41,18 @@ protected: const Location &location) override; private: sled::Scheduler *scheduler_; std::unique_ptr<sled::Thread> delayed_thread_; struct State { std::atomic<bool> is_running{true}; std::atomic<int> idle{0}; sled::Mutex mutex; sled::ConditionVariable cv; std::queue<std::pair<std::function<void()>, sled::Location>> task_queue; }; std::shared_ptr<State> state_; std::vector<std::thread> threads_; // sled::Scheduler *scheduler_; std::unique_ptr<sled::Thread> delayed_thread_ = nullptr; }; }// namespace sled Loading src/sled/system/thread_pool_bench.cc +2 −2 Original line number Diff line number Diff line Loading @@ -7,8 +7,8 @@ ThreadPoolBench(picobench::state &state) { sled::ThreadPool pool(-1); for (auto _ : state) { std::future<int> f = pool.submit([]() { return 1; }); (void) f.get(); auto res = pool.BlockingCall([]() { return 1; }); (void) res; } } Loading Loading
CMakeLists.txt +0 −1 Original line number Diff line number Diff line Loading @@ -184,7 +184,6 @@ if(SLED_BUILD_TESTS) sled_all_tests SRCS src/sled/debugging/demangle_test.cc src/sled/async/async_test.cc src/sled/filesystem/path_test.cc src/sled/log/fmt_test.cc src/sled/synchronization/sequence_checker_test.cc Loading
src/sled/async/async.cc +4 −3 Original line number Diff line number Diff line Loading @@ -20,18 +20,19 @@ FiberScheduler::FiberScheduler() { } static ThreadPool thread_pool; void FiberScheduler::schedule(async::task_run_handle t) { static ThreadPool thread_pool; auto move_on_copy = sled::MakeMoveOnCopy(t); thread_pool.submit([move_on_copy] { move_on_copy.value.run_with_wait_handler(SleepWaitHandler); }); // thread_pool.submit([move_on_copy] { move_on_copy.value.run(); }); thread_pool.PostTask([move_on_copy] { move_on_copy.value.run(); }); } }// namespace sled // clang-format on namespace async { sled::FiberScheduler & default_scheduler() Loading
src/sled/system/thread_pool.cc +78 −12 Original line number Diff line number Diff line #include "sled/system/thread_pool.h" #include "sled/log/log.h" #include "sled/system/location.h" #include "sled/task_queue/task_queue_base.h" namespace sled { constexpr char ThreadPool::kTag[]; ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create()) { if (num_threads == -1) { num_threads = std::thread::hardware_concurrency(); } scheduler_ = new sled::Scheduler(sled::Scheduler::Config().setWorkerThreadCount(num_threads)); if (num_threads <= 0) { num_threads = std::thread::hardware_concurrency(); } auto state = std::make_shared<State>(); for (int i = 0; i < num_threads; i++) { threads_.emplace_back(std::thread([state] { state->idle++; while (state->is_running) { std::function<void()> task; sled::Location loc = SLED_FROM_HERE; // fetch task { sled::MutexLock lock(&state->mutex); state->cv.Wait(lock, [state] { return !state->task_queue.empty() || !state->is_running; }); if (!state->task_queue.empty()) { task = std::move(state->task_queue.front().first); loc = state->task_queue.front().second; state->task_queue.pop(); } if (!state->task_queue.empty()) { state->cv.NotifyOne(); } } if (task) { state->idle--; try { task(); } catch (const std::exception &e) { LOGE(kTag, "ThreadPool::ThreadPool() task exception: {}, from={}", e.what(), loc.ToString()); } catch (...) { LOGE(kTag, "ThreadPool::ThreadPool() task unknown exception, from={}", loc.ToString()); } state->idle++; } } })); } state_ = state; delayed_thread_->Start(); } ThreadPool::~ThreadPool() { delete scheduler_; } ThreadPool::~ThreadPool() { Delete(); } void ThreadPool::Delete() {} { if (state_) { sled::MutexLock lock(&state_->mutex); state_->is_running = false; state_->cv.NotifyAll(); state_ = nullptr; } delayed_thread_.reset(); for (auto &thread : threads_) { thread.join(); } } int ThreadPool::idle() const { auto state = state_; SLED_ASSERT(state != nullptr, "ThreadPool::idle() state_ is nullptr"); if (state->is_running) { return state->idle; } return 0; } void ThreadPool::PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) { scheduler_->enqueue(marl::Task([task] { task(); })); auto state = state_; SLED_ASSERT(state != nullptr, "ThreadPool::PostTaskImpl() state_ is nullptr"); if (!state->is_running) { LOGW(kTag, "ThreadPool::PostTaskImpl() state is not running"); return; } sled::MutexLock lock(&state->mutex); state->task_queue.emplace(std::move(task), location); state->cv.NotifyOne(); } void Loading @@ -28,19 +90,23 @@ ThreadPool::PostDelayedTaskImpl(std::function<void()> &&task, const PostDelayedTaskTraits &traits, const Location &location) { auto move_task_to_fiber = [task]() { task(); }; auto weak_state = std::weak_ptr<State>(state_); auto delay_post = [task, location, weak_state]() { auto state = weak_state.lock(); if (!state || !state->is_running) { return; } sled::MutexLock lock(&state->mutex); state->task_queue.emplace(std::move(task), location); state->cv.NotifyOne(); }; if (traits.high_precision) { delayed_thread_->PostDelayedTaskWithPrecision( TaskQueueBase::DelayPrecision::kHigh, std::move(move_task_to_fiber), std::move(delay_post), delay, location); } else { delayed_thread_->PostDelayedTaskWithPrecision( TaskQueueBase::DelayPrecision::kLow, std::move(move_task_to_fiber), delay, location); delayed_thread_ ->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kLow, std::move(delay_post), delay, location); } } Loading
src/sled/system/thread_pool.h +25 −11 Original line number Diff line number Diff line #pragma once #ifndef SLED_SYSTEM_THREAD_POOL_H #define SLED_SYSTEM_THREAD_POOL_H #include "sled/synchronization/mutex.h" #include "sled/system/fiber/scheduler.h" #include "sled/system/thread.h" #include <functional> Loading @@ -9,6 +10,7 @@ namespace sled { class ThreadPool final : public TaskQueueBase { public: static constexpr char kTag[] = "ThreadPool"; /** * @param num_threads The number of threads to create in the thread pool. If * -1, the number of threads will be equal to the number of hardware threads Loading @@ -16,18 +18,20 @@ public: ThreadPool(int num_threads = -1); ~ThreadPool(); template<typename F, typename... Args> auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> { std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...); auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func); auto future = task_ptr->get_future(); scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); return future; } // template<typename F, typename... Args> // auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> // { // std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...); // auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func); // auto future = task_ptr->get_future(); // // scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); // return future; // } void Delete() override; int idle() const; protected: void PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) override; Loading @@ -37,8 +41,18 @@ protected: const Location &location) override; private: sled::Scheduler *scheduler_; std::unique_ptr<sled::Thread> delayed_thread_; struct State { std::atomic<bool> is_running{true}; std::atomic<int> idle{0}; sled::Mutex mutex; sled::ConditionVariable cv; std::queue<std::pair<std::function<void()>, sled::Location>> task_queue; }; std::shared_ptr<State> state_; std::vector<std::thread> threads_; // sled::Scheduler *scheduler_; std::unique_ptr<sled::Thread> delayed_thread_ = nullptr; }; }// namespace sled Loading
src/sled/system/thread_pool_bench.cc +2 −2 Original line number Diff line number Diff line Loading @@ -7,8 +7,8 @@ ThreadPoolBench(picobench::state &state) { sled::ThreadPool pool(-1); for (auto _ : state) { std::future<int> f = pool.submit([]() { return 1; }); (void) f.get(); auto res = pool.BlockingCall([]() { return 1; }); (void) res; } } Loading