feat thread use std::thread
Some checks failed
linux-arm-gcc / linux-gcc-armhf (push) Successful in 4m37s
linux-mips64-gcc / linux-gcc-mips64el (Release) (push) Has been cancelled
linux-aarch64-cpu-gcc / linux-gcc-aarch64 (push) Has been cancelled
linux-x64-gcc / linux-gcc (Release) (push) Has been cancelled
linux-x64-gcc / linux-gcc (Debug) (push) Has been cancelled
linux-mips64-gcc / linux-gcc-mips64el (Debug) (push) Has been cancelled

This commit is contained in:
tqcq 2024-04-30 22:56:20 +08:00
parent 41e460ffd9
commit bb82578636
6 changed files with 110 additions and 53 deletions

View File

@ -184,7 +184,6 @@ if(SLED_BUILD_TESTS)
sled_all_tests sled_all_tests
SRCS SRCS
src/sled/debugging/demangle_test.cc src/sled/debugging/demangle_test.cc
src/sled/async/async_test.cc
src/sled/filesystem/path_test.cc src/sled/filesystem/path_test.cc
src/sled/log/fmt_test.cc src/sled/log/fmt_test.cc
src/sled/synchronization/sequence_checker_test.cc src/sled/synchronization/sequence_checker_test.cc

View File

@ -20,18 +20,19 @@ FiberScheduler::FiberScheduler()
{ {
} }
static ThreadPool thread_pool;
void void
FiberScheduler::schedule(async::task_run_handle t) FiberScheduler::schedule(async::task_run_handle t)
{ {
static ThreadPool thread_pool;
auto move_on_copy = sled::MakeMoveOnCopy(t); auto move_on_copy = sled::MakeMoveOnCopy(t);
thread_pool.submit([move_on_copy] { move_on_copy.value.run_with_wait_handler(SleepWaitHandler); }); thread_pool.PostTask([move_on_copy] { move_on_copy.value.run(); });
// thread_pool.submit([move_on_copy] { move_on_copy.value.run(); });
} }
}// namespace sled }// namespace sled
// clang-format on // clang-format on
namespace async { namespace async {
sled::FiberScheduler & sled::FiberScheduler &
default_scheduler() default_scheduler()

View File

@ -1,25 +1,87 @@
#include "sled/system/thread_pool.h" #include "sled/system/thread_pool.h"
#include "sled/log/log.h"
#include "sled/system/location.h" #include "sled/system/location.h"
#include "sled/task_queue/task_queue_base.h" #include "sled/task_queue/task_queue_base.h"
namespace sled { namespace sled {
constexpr char ThreadPool::kTag[];
ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create()) ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create())
{ {
if (num_threads == -1) { num_threads = std::thread::hardware_concurrency(); } if (num_threads <= 0) { num_threads = std::thread::hardware_concurrency(); }
scheduler_ = new sled::Scheduler(sled::Scheduler::Config().setWorkerThreadCount(num_threads)); auto state = std::make_shared<State>();
for (int i = 0; i < num_threads; i++) {
threads_.emplace_back(std::thread([state] {
state->idle++;
while (state->is_running) {
std::function<void()> task;
sled::Location loc = SLED_FROM_HERE;
// fetch task
{
sled::MutexLock lock(&state->mutex);
state->cv.Wait(lock, [state] { return !state->task_queue.empty() || !state->is_running; });
if (!state->task_queue.empty()) {
task = std::move(state->task_queue.front().first);
loc = state->task_queue.front().second;
state->task_queue.pop();
}
if (!state->task_queue.empty()) { state->cv.NotifyOne(); }
}
if (task) {
state->idle--;
try {
task();
} catch (const std::exception &e) {
LOGE(kTag, "ThreadPool::ThreadPool() task exception: {}, from={}", e.what(), loc.ToString());
} catch (...) {
LOGE(kTag, "ThreadPool::ThreadPool() task unknown exception, from={}", loc.ToString());
}
state->idle++;
}
}
}));
}
state_ = state;
delayed_thread_->Start(); delayed_thread_->Start();
} }
ThreadPool::~ThreadPool() { delete scheduler_; } ThreadPool::~ThreadPool() { Delete(); }
void void
ThreadPool::Delete() ThreadPool::Delete()
{} {
if (state_) {
sled::MutexLock lock(&state_->mutex);
state_->is_running = false;
state_->cv.NotifyAll();
state_ = nullptr;
}
delayed_thread_.reset();
for (auto &thread : threads_) { thread.join(); }
}
int
ThreadPool::idle() const
{
auto state = state_;
SLED_ASSERT(state != nullptr, "ThreadPool::idle() state_ is nullptr");
if (state->is_running) { return state->idle; }
return 0;
}
void void
ThreadPool::PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) ThreadPool::PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location)
{ {
scheduler_->enqueue(marl::Task([task] { task(); })); auto state = state_;
SLED_ASSERT(state != nullptr, "ThreadPool::PostTaskImpl() state_ is nullptr");
if (!state->is_running) {
LOGW(kTag, "ThreadPool::PostTaskImpl() state is not running");
return;
}
sled::MutexLock lock(&state->mutex);
state->task_queue.emplace(std::move(task), location);
state->cv.NotifyOne();
} }
void void
@ -28,19 +90,23 @@ ThreadPool::PostDelayedTaskImpl(std::function<void()> &&task,
const PostDelayedTaskTraits &traits, const PostDelayedTaskTraits &traits,
const Location &location) const Location &location)
{ {
auto move_task_to_fiber = [task]() { task(); }; auto weak_state = std::weak_ptr<State>(state_);
auto delay_post = [task, location, weak_state]() {
auto state = weak_state.lock();
if (!state || !state->is_running) { return; }
sled::MutexLock lock(&state->mutex);
state->task_queue.emplace(std::move(task), location);
state->cv.NotifyOne();
};
if (traits.high_precision) { if (traits.high_precision) {
delayed_thread_->PostDelayedTaskWithPrecision( delayed_thread_->PostDelayedTaskWithPrecision(
TaskQueueBase::DelayPrecision::kHigh, TaskQueueBase::DelayPrecision::kHigh,
std::move(move_task_to_fiber), std::move(delay_post),
delay, delay,
location); location);
} else { } else {
delayed_thread_->PostDelayedTaskWithPrecision( delayed_thread_
TaskQueueBase::DelayPrecision::kLow, ->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kLow, std::move(delay_post), delay, location);
std::move(move_task_to_fiber),
delay,
location);
} }
} }

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#ifndef SLED_SYSTEM_THREAD_POOL_H #ifndef SLED_SYSTEM_THREAD_POOL_H
#define SLED_SYSTEM_THREAD_POOL_H #define SLED_SYSTEM_THREAD_POOL_H
#include "sled/synchronization/mutex.h"
#include "sled/system/fiber/scheduler.h" #include "sled/system/fiber/scheduler.h"
#include "sled/system/thread.h" #include "sled/system/thread.h"
#include <functional> #include <functional>
@ -9,6 +10,7 @@
namespace sled { namespace sled {
class ThreadPool final : public TaskQueueBase { class ThreadPool final : public TaskQueueBase {
public: public:
static constexpr char kTag[] = "ThreadPool";
/** /**
* @param num_threads The number of threads to create in the thread pool. If * @param num_threads The number of threads to create in the thread pool. If
* -1, the number of threads will be equal to the number of hardware threads * -1, the number of threads will be equal to the number of hardware threads
@ -16,18 +18,20 @@ public:
ThreadPool(int num_threads = -1); ThreadPool(int num_threads = -1);
~ThreadPool(); ~ThreadPool();
template<typename F, typename... Args> // template<typename F, typename... Args>
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> // auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
{ // {
std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...); // std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func); // auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
auto future = task_ptr->get_future(); // auto future = task_ptr->get_future();
scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); // // scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
return future; // return future;
} // }
void Delete() override; void Delete() override;
int idle() const;
protected: protected:
void PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) override; void PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) override;
@ -37,8 +41,18 @@ protected:
const Location &location) override; const Location &location) override;
private: private:
sled::Scheduler *scheduler_; struct State {
std::unique_ptr<sled::Thread> delayed_thread_; std::atomic<bool> is_running{true};
std::atomic<int> idle{0};
sled::Mutex mutex;
sled::ConditionVariable cv;
std::queue<std::pair<std::function<void()>, sled::Location>> task_queue;
};
std::shared_ptr<State> state_;
std::vector<std::thread> threads_;
// sled::Scheduler *scheduler_;
std::unique_ptr<sled::Thread> delayed_thread_ = nullptr;
}; };
}// namespace sled }// namespace sled

View File

@ -7,8 +7,8 @@ ThreadPoolBench(picobench::state &state)
{ {
sled::ThreadPool pool(-1); sled::ThreadPool pool(-1);
for (auto _ : state) { for (auto _ : state) {
std::future<int> f = pool.submit([]() { return 1; }); auto res = pool.BlockingCall([]() { return 1; });
(void) f.get(); (void) res;
} }
} }

View File

@ -41,29 +41,6 @@ multiply_return(const int a, const int b)
TEST_SUITE("ThreadPool") TEST_SUITE("ThreadPool")
{ {
TEST_CASE("submit")
{
sled::ThreadPool *tp = new sled::ThreadPool();
REQUIRE_NE(tp, nullptr);
SUBCASE("Output")
{
for (int i = 0; i < 100; ++i) {
int out;
tp->submit(multiply_output, std::ref(out), i, i).get();
CHECK_EQ(out, i * i);
}
}
SUBCASE("Return")
{
for (int i = 0; i < 100; ++i) {
auto f = tp->submit(multiply_return, i, i);
CHECK_EQ(f.get(), i * i);
}
}
delete tp;
}
TEST_CASE("PostTask") TEST_CASE("PostTask")
{ {
sled::ThreadPool *tp = new sled::ThreadPool(); sled::ThreadPool *tp = new sled::ThreadPool();
@ -86,7 +63,7 @@ TEST_SUITE("ThreadPool")
TEST_CASE("10^6 task test") TEST_CASE("10^6 task test")
{ {
sled::ThreadPool *tp = new sled::ThreadPool(); sled::ThreadPool *tp = new sled::ThreadPool();
const int task_num = 1E6; const int task_num = 1E4;
sled::WaitGroup wg(task_num); sled::WaitGroup wg(task_num);
for (int i = 0; i < task_num; i++) { for (int i = 0; i < task_num; i++) {
tp->PostTask([wg] { wg.Done(); }); tp->PostTask([wg] { wg.Done(); });