feat thread use std::thread
Some checks failed
linux-arm-gcc / linux-gcc-armhf (push) Successful in 4m37s
linux-mips64-gcc / linux-gcc-mips64el (Release) (push) Has been cancelled
linux-aarch64-cpu-gcc / linux-gcc-aarch64 (push) Has been cancelled
linux-x64-gcc / linux-gcc (Release) (push) Has been cancelled
linux-x64-gcc / linux-gcc (Debug) (push) Has been cancelled
linux-mips64-gcc / linux-gcc-mips64el (Debug) (push) Has been cancelled

This commit is contained in:
tqcq 2024-04-30 22:56:20 +08:00
parent 41e460ffd9
commit bb82578636
6 changed files with 110 additions and 53 deletions

View File

@ -184,7 +184,6 @@ if(SLED_BUILD_TESTS)
sled_all_tests
SRCS
src/sled/debugging/demangle_test.cc
src/sled/async/async_test.cc
src/sled/filesystem/path_test.cc
src/sled/log/fmt_test.cc
src/sled/synchronization/sequence_checker_test.cc

View File

@ -20,18 +20,19 @@ FiberScheduler::FiberScheduler()
{
}
static ThreadPool thread_pool;
void
FiberScheduler::schedule(async::task_run_handle t)
{
static ThreadPool thread_pool;
auto move_on_copy = sled::MakeMoveOnCopy(t);
thread_pool.submit([move_on_copy] { move_on_copy.value.run_with_wait_handler(SleepWaitHandler); });
// thread_pool.submit([move_on_copy] { move_on_copy.value.run(); });
thread_pool.PostTask([move_on_copy] { move_on_copy.value.run(); });
}
}// namespace sled
// clang-format on
namespace async {
sled::FiberScheduler &
default_scheduler()

View File

@ -1,25 +1,87 @@
#include "sled/system/thread_pool.h"
#include "sled/log/log.h"
#include "sled/system/location.h"
#include "sled/task_queue/task_queue_base.h"
namespace sled {
constexpr char ThreadPool::kTag[];
ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create())
{
if (num_threads == -1) { num_threads = std::thread::hardware_concurrency(); }
scheduler_ = new sled::Scheduler(sled::Scheduler::Config().setWorkerThreadCount(num_threads));
if (num_threads <= 0) { num_threads = std::thread::hardware_concurrency(); }
auto state = std::make_shared<State>();
for (int i = 0; i < num_threads; i++) {
threads_.emplace_back(std::thread([state] {
state->idle++;
while (state->is_running) {
std::function<void()> task;
sled::Location loc = SLED_FROM_HERE;
// fetch task
{
sled::MutexLock lock(&state->mutex);
state->cv.Wait(lock, [state] { return !state->task_queue.empty() || !state->is_running; });
if (!state->task_queue.empty()) {
task = std::move(state->task_queue.front().first);
loc = state->task_queue.front().second;
state->task_queue.pop();
}
if (!state->task_queue.empty()) { state->cv.NotifyOne(); }
}
if (task) {
state->idle--;
try {
task();
} catch (const std::exception &e) {
LOGE(kTag, "ThreadPool::ThreadPool() task exception: {}, from={}", e.what(), loc.ToString());
} catch (...) {
LOGE(kTag, "ThreadPool::ThreadPool() task unknown exception, from={}", loc.ToString());
}
state->idle++;
}
}
}));
}
state_ = state;
delayed_thread_->Start();
}
ThreadPool::~ThreadPool() { delete scheduler_; }
ThreadPool::~ThreadPool() { Delete(); }
void
ThreadPool::Delete()
{}
{
if (state_) {
sled::MutexLock lock(&state_->mutex);
state_->is_running = false;
state_->cv.NotifyAll();
state_ = nullptr;
}
delayed_thread_.reset();
for (auto &thread : threads_) { thread.join(); }
}
int
ThreadPool::idle() const
{
auto state = state_;
SLED_ASSERT(state != nullptr, "ThreadPool::idle() state_ is nullptr");
if (state->is_running) { return state->idle; }
return 0;
}
void
ThreadPool::PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location)
{
scheduler_->enqueue(marl::Task([task] { task(); }));
auto state = state_;
SLED_ASSERT(state != nullptr, "ThreadPool::PostTaskImpl() state_ is nullptr");
if (!state->is_running) {
LOGW(kTag, "ThreadPool::PostTaskImpl() state is not running");
return;
}
sled::MutexLock lock(&state->mutex);
state->task_queue.emplace(std::move(task), location);
state->cv.NotifyOne();
}
void
@ -28,19 +90,23 @@ ThreadPool::PostDelayedTaskImpl(std::function<void()> &&task,
const PostDelayedTaskTraits &traits,
const Location &location)
{
auto move_task_to_fiber = [task]() { task(); };
auto weak_state = std::weak_ptr<State>(state_);
auto delay_post = [task, location, weak_state]() {
auto state = weak_state.lock();
if (!state || !state->is_running) { return; }
sled::MutexLock lock(&state->mutex);
state->task_queue.emplace(std::move(task), location);
state->cv.NotifyOne();
};
if (traits.high_precision) {
delayed_thread_->PostDelayedTaskWithPrecision(
TaskQueueBase::DelayPrecision::kHigh,
std::move(move_task_to_fiber),
std::move(delay_post),
delay,
location);
} else {
delayed_thread_->PostDelayedTaskWithPrecision(
TaskQueueBase::DelayPrecision::kLow,
std::move(move_task_to_fiber),
delay,
location);
delayed_thread_
->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kLow, std::move(delay_post), delay, location);
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#ifndef SLED_SYSTEM_THREAD_POOL_H
#define SLED_SYSTEM_THREAD_POOL_H
#include "sled/synchronization/mutex.h"
#include "sled/system/fiber/scheduler.h"
#include "sled/system/thread.h"
#include <functional>
@ -9,6 +10,7 @@
namespace sled {
class ThreadPool final : public TaskQueueBase {
public:
static constexpr char kTag[] = "ThreadPool";
/**
* @param num_threads The number of threads to create in the thread pool. If
* -1, the number of threads will be equal to the number of hardware threads
@ -16,18 +18,20 @@ public:
ThreadPool(int num_threads = -1);
~ThreadPool();
template<typename F, typename... Args>
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
{
std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
auto future = task_ptr->get_future();
scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
return future;
}
// template<typename F, typename... Args>
// auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
// {
// std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
// auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
// auto future = task_ptr->get_future();
// // scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
// return future;
// }
void Delete() override;
int idle() const;
protected:
void PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) override;
@ -37,8 +41,18 @@ protected:
const Location &location) override;
private:
sled::Scheduler *scheduler_;
std::unique_ptr<sled::Thread> delayed_thread_;
struct State {
std::atomic<bool> is_running{true};
std::atomic<int> idle{0};
sled::Mutex mutex;
sled::ConditionVariable cv;
std::queue<std::pair<std::function<void()>, sled::Location>> task_queue;
};
std::shared_ptr<State> state_;
std::vector<std::thread> threads_;
// sled::Scheduler *scheduler_;
std::unique_ptr<sled::Thread> delayed_thread_ = nullptr;
};
}// namespace sled

View File

@ -7,8 +7,8 @@ ThreadPoolBench(picobench::state &state)
{
sled::ThreadPool pool(-1);
for (auto _ : state) {
std::future<int> f = pool.submit([]() { return 1; });
(void) f.get();
auto res = pool.BlockingCall([]() { return 1; });
(void) res;
}
}

View File

@ -41,29 +41,6 @@ multiply_return(const int a, const int b)
TEST_SUITE("ThreadPool")
{
TEST_CASE("submit")
{
sled::ThreadPool *tp = new sled::ThreadPool();
REQUIRE_NE(tp, nullptr);
SUBCASE("Output")
{
for (int i = 0; i < 100; ++i) {
int out;
tp->submit(multiply_output, std::ref(out), i, i).get();
CHECK_EQ(out, i * i);
}
}
SUBCASE("Return")
{
for (int i = 0; i < 100; ++i) {
auto f = tp->submit(multiply_return, i, i);
CHECK_EQ(f.get(), i * i);
}
}
delete tp;
}
TEST_CASE("PostTask")
{
sled::ThreadPool *tp = new sled::ThreadPool();
@ -86,7 +63,7 @@ TEST_SUITE("ThreadPool")
TEST_CASE("10^6 task test")
{
sled::ThreadPool *tp = new sled::ThreadPool();
const int task_num = 1E6;
const int task_num = 1E4;
sled::WaitGroup wg(task_num);
for (int i = 0; i < task_num; i++) {
tp->PostTask([wg] { wg.Done(); });