diff --git a/CMakeLists.txt b/CMakeLists.txt index af601c1..58e3bdc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -184,7 +184,6 @@ if(SLED_BUILD_TESTS) sled_all_tests SRCS src/sled/debugging/demangle_test.cc - src/sled/async/async_test.cc src/sled/filesystem/path_test.cc src/sled/log/fmt_test.cc src/sled/synchronization/sequence_checker_test.cc diff --git a/src/sled/async/async.cc b/src/sled/async/async.cc index 44bed72..f848d26 100644 --- a/src/sled/async/async.cc +++ b/src/sled/async/async.cc @@ -20,18 +20,19 @@ FiberScheduler::FiberScheduler() { } +static ThreadPool thread_pool; + void FiberScheduler::schedule(async::task_run_handle t) { - static ThreadPool thread_pool; auto move_on_copy = sled::MakeMoveOnCopy(t); - thread_pool.submit([move_on_copy] { move_on_copy.value.run_with_wait_handler(SleepWaitHandler); }); - // thread_pool.submit([move_on_copy] { move_on_copy.value.run(); }); + thread_pool.PostTask([move_on_copy] { move_on_copy.value.run(); }); } }// namespace sled // clang-format on + namespace async { sled::FiberScheduler & default_scheduler() diff --git a/src/sled/system/thread_pool.cc b/src/sled/system/thread_pool.cc index cfa35f2..3c54427 100644 --- a/src/sled/system/thread_pool.cc +++ b/src/sled/system/thread_pool.cc @@ -1,25 +1,87 @@ #include "sled/system/thread_pool.h" +#include "sled/log/log.h" #include "sled/system/location.h" #include "sled/task_queue/task_queue_base.h" namespace sled { + +constexpr char ThreadPool::kTag[]; + ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create()) { - if (num_threads == -1) { num_threads = std::thread::hardware_concurrency(); } - scheduler_ = new sled::Scheduler(sled::Scheduler::Config().setWorkerThreadCount(num_threads)); + if (num_threads <= 0) { num_threads = std::thread::hardware_concurrency(); } + auto state = std::make_shared(); + for (int i = 0; i < num_threads; i++) { + threads_.emplace_back(std::thread([state] { + state->idle++; + while (state->is_running) { + std::function task; + sled::Location loc = SLED_FROM_HERE; + // fetch task + { + sled::MutexLock lock(&state->mutex); + state->cv.Wait(lock, [state] { return !state->task_queue.empty() || !state->is_running; }); + if (!state->task_queue.empty()) { + task = std::move(state->task_queue.front().first); + loc = state->task_queue.front().second; + state->task_queue.pop(); + } + if (!state->task_queue.empty()) { state->cv.NotifyOne(); } + } + if (task) { + state->idle--; + try { + task(); + } catch (const std::exception &e) { + LOGE(kTag, "ThreadPool::ThreadPool() task exception: {}, from={}", e.what(), loc.ToString()); + } catch (...) { + LOGE(kTag, "ThreadPool::ThreadPool() task unknown exception, from={}", loc.ToString()); + } + state->idle++; + } + } + })); + } + state_ = state; delayed_thread_->Start(); } -ThreadPool::~ThreadPool() { delete scheduler_; } +ThreadPool::~ThreadPool() { Delete(); } void ThreadPool::Delete() -{} +{ + if (state_) { + sled::MutexLock lock(&state_->mutex); + state_->is_running = false; + state_->cv.NotifyAll(); + state_ = nullptr; + } + delayed_thread_.reset(); + for (auto &thread : threads_) { thread.join(); } +} + +int +ThreadPool::idle() const +{ + auto state = state_; + SLED_ASSERT(state != nullptr, "ThreadPool::idle() state_ is nullptr"); + if (state->is_running) { return state->idle; } + return 0; +} void ThreadPool::PostTaskImpl(std::function &&task, const PostTaskTraits &traits, const Location &location) { - scheduler_->enqueue(marl::Task([task] { task(); })); + auto state = state_; + SLED_ASSERT(state != nullptr, "ThreadPool::PostTaskImpl() state_ is nullptr"); + if (!state->is_running) { + LOGW(kTag, "ThreadPool::PostTaskImpl() state is not running"); + return; + } + sled::MutexLock lock(&state->mutex); + state->task_queue.emplace(std::move(task), location); + state->cv.NotifyOne(); } void @@ -28,19 +90,23 @@ ThreadPool::PostDelayedTaskImpl(std::function &&task, const PostDelayedTaskTraits &traits, const Location &location) { - auto move_task_to_fiber = [task]() { task(); }; + auto weak_state = std::weak_ptr(state_); + auto delay_post = [task, location, weak_state]() { + auto state = weak_state.lock(); + if (!state || !state->is_running) { return; } + sled::MutexLock lock(&state->mutex); + state->task_queue.emplace(std::move(task), location); + state->cv.NotifyOne(); + }; if (traits.high_precision) { delayed_thread_->PostDelayedTaskWithPrecision( TaskQueueBase::DelayPrecision::kHigh, - std::move(move_task_to_fiber), + std::move(delay_post), delay, location); } else { - delayed_thread_->PostDelayedTaskWithPrecision( - TaskQueueBase::DelayPrecision::kLow, - std::move(move_task_to_fiber), - delay, - location); + delayed_thread_ + ->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kLow, std::move(delay_post), delay, location); } } diff --git a/src/sled/system/thread_pool.h b/src/sled/system/thread_pool.h index c2ccdda..758a316 100644 --- a/src/sled/system/thread_pool.h +++ b/src/sled/system/thread_pool.h @@ -1,6 +1,7 @@ #pragma once #ifndef SLED_SYSTEM_THREAD_POOL_H #define SLED_SYSTEM_THREAD_POOL_H +#include "sled/synchronization/mutex.h" #include "sled/system/fiber/scheduler.h" #include "sled/system/thread.h" #include @@ -9,6 +10,7 @@ namespace sled { class ThreadPool final : public TaskQueueBase { public: + static constexpr char kTag[] = "ThreadPool"; /** * @param num_threads The number of threads to create in the thread pool. If * -1, the number of threads will be equal to the number of hardware threads @@ -16,18 +18,20 @@ public: ThreadPool(int num_threads = -1); ~ThreadPool(); - template - auto submit(F &&f, Args &&...args) -> std::future - { - std::function func = std::bind(std::forward(f), std::forward(args)...); - auto task_ptr = std::make_shared>(func); - auto future = task_ptr->get_future(); - scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); - return future; - } + // template + // auto submit(F &&f, Args &&...args) -> std::future + // { + // std::function func = std::bind(std::forward(f), std::forward(args)...); + // auto task_ptr = std::make_shared>(func); + // auto future = task_ptr->get_future(); + // // scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); + // return future; + // } void Delete() override; + int idle() const; + protected: void PostTaskImpl(std::function &&task, const PostTaskTraits &traits, const Location &location) override; @@ -37,8 +41,18 @@ protected: const Location &location) override; private: - sled::Scheduler *scheduler_; - std::unique_ptr delayed_thread_; + struct State { + std::atomic is_running{true}; + std::atomic idle{0}; + sled::Mutex mutex; + sled::ConditionVariable cv; + std::queue, sled::Location>> task_queue; + }; + + std::shared_ptr state_; + std::vector threads_; + // sled::Scheduler *scheduler_; + std::unique_ptr delayed_thread_ = nullptr; }; }// namespace sled diff --git a/src/sled/system/thread_pool_bench.cc b/src/sled/system/thread_pool_bench.cc index 1f769af..028971a 100644 --- a/src/sled/system/thread_pool_bench.cc +++ b/src/sled/system/thread_pool_bench.cc @@ -7,8 +7,8 @@ ThreadPoolBench(picobench::state &state) { sled::ThreadPool pool(-1); for (auto _ : state) { - std::future f = pool.submit([]() { return 1; }); - (void) f.get(); + auto res = pool.BlockingCall([]() { return 1; }); + (void) res; } } diff --git a/src/sled/system/thread_pool_test.cc b/src/sled/system/thread_pool_test.cc index 8a0ce0f..1a51507 100644 --- a/src/sled/system/thread_pool_test.cc +++ b/src/sled/system/thread_pool_test.cc @@ -41,29 +41,6 @@ multiply_return(const int a, const int b) TEST_SUITE("ThreadPool") { - TEST_CASE("submit") - { - sled::ThreadPool *tp = new sled::ThreadPool(); - REQUIRE_NE(tp, nullptr); - - SUBCASE("Output") - { - for (int i = 0; i < 100; ++i) { - int out; - tp->submit(multiply_output, std::ref(out), i, i).get(); - CHECK_EQ(out, i * i); - } - } - SUBCASE("Return") - { - for (int i = 0; i < 100; ++i) { - auto f = tp->submit(multiply_return, i, i); - CHECK_EQ(f.get(), i * i); - } - } - - delete tp; - } TEST_CASE("PostTask") { sled::ThreadPool *tp = new sled::ThreadPool(); @@ -86,7 +63,7 @@ TEST_SUITE("ThreadPool") TEST_CASE("10^6 task test") { sled::ThreadPool *tp = new sled::ThreadPool(); - const int task_num = 1E6; + const int task_num = 1E4; sled::WaitGroup wg(task_num); for (int i = 0; i < task_num; i++) { tp->PostTask([wg] { wg.Done(); });