Commit bb825786 authored by tqcq's avatar tqcq
Browse files

feat thread use std::thread

parent 41e460ff
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -184,7 +184,6 @@ if(SLED_BUILD_TESTS)
    sled_all_tests
    SRCS
    src/sled/debugging/demangle_test.cc
    src/sled/async/async_test.cc
    src/sled/filesystem/path_test.cc
    src/sled/log/fmt_test.cc
    src/sled/synchronization/sequence_checker_test.cc
+4 −3
Original line number Diff line number Diff line
@@ -20,18 +20,19 @@ FiberScheduler::FiberScheduler()
{
}

static ThreadPool thread_pool;

void
FiberScheduler::schedule(async::task_run_handle t)
{
    static ThreadPool thread_pool;
    auto move_on_copy = sled::MakeMoveOnCopy(t);
    thread_pool.submit([move_on_copy] { move_on_copy.value.run_with_wait_handler(SleepWaitHandler); });
    // thread_pool.submit([move_on_copy] { move_on_copy.value.run(); });
    thread_pool.PostTask([move_on_copy] { move_on_copy.value.run(); });
}

}// namespace sled

// clang-format on

namespace async {
sled::FiberScheduler &
default_scheduler()
+78 −12
Original line number Diff line number Diff line
#include "sled/system/thread_pool.h"
#include "sled/log/log.h"
#include "sled/system/location.h"
#include "sled/task_queue/task_queue_base.h"

namespace sled {

constexpr char ThreadPool::kTag[];

ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create())
{
    if (num_threads == -1) { num_threads = std::thread::hardware_concurrency(); }
    scheduler_ = new sled::Scheduler(sled::Scheduler::Config().setWorkerThreadCount(num_threads));
    if (num_threads <= 0) { num_threads = std::thread::hardware_concurrency(); }
    auto state = std::make_shared<State>();
    for (int i = 0; i < num_threads; i++) {
        threads_.emplace_back(std::thread([state] {
            state->idle++;
            while (state->is_running) {
                std::function<void()> task;
                sled::Location loc = SLED_FROM_HERE;
                // fetch task
                {
                    sled::MutexLock lock(&state->mutex);
                    state->cv.Wait(lock, [state] { return !state->task_queue.empty() || !state->is_running; });
                    if (!state->task_queue.empty()) {
                        task = std::move(state->task_queue.front().first);
                        loc  = state->task_queue.front().second;
                        state->task_queue.pop();
                    }
                    if (!state->task_queue.empty()) { state->cv.NotifyOne(); }
                }
                if (task) {
                    state->idle--;
                    try {
                        task();
                    } catch (const std::exception &e) {
                        LOGE(kTag, "ThreadPool::ThreadPool() task exception: {}, from={}", e.what(), loc.ToString());
                    } catch (...) {
                        LOGE(kTag, "ThreadPool::ThreadPool() task unknown exception, from={}", loc.ToString());
                    }
                    state->idle++;
                }
            }
        }));
    }
    state_ = state;
    delayed_thread_->Start();
}

ThreadPool::~ThreadPool() { delete scheduler_; }
ThreadPool::~ThreadPool() { Delete(); }

void
ThreadPool::Delete()
{}
{
    if (state_) {
        sled::MutexLock lock(&state_->mutex);
        state_->is_running = false;
        state_->cv.NotifyAll();
        state_ = nullptr;
    }
    delayed_thread_.reset();
    for (auto &thread : threads_) { thread.join(); }
}

int
ThreadPool::idle() const
{
    auto state = state_;
    SLED_ASSERT(state != nullptr, "ThreadPool::idle() state_ is nullptr");
    if (state->is_running) { return state->idle; }
    return 0;
}

void
ThreadPool::PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location)
{
    scheduler_->enqueue(marl::Task([task] { task(); }));
    auto state = state_;
    SLED_ASSERT(state != nullptr, "ThreadPool::PostTaskImpl() state_ is nullptr");
    if (!state->is_running) {
        LOGW(kTag, "ThreadPool::PostTaskImpl() state is not running");
        return;
    }
    sled::MutexLock lock(&state->mutex);
    state->task_queue.emplace(std::move(task), location);
    state->cv.NotifyOne();
}

void
@@ -28,19 +90,23 @@ ThreadPool::PostDelayedTaskImpl(std::function<void()> &&task,
                                const PostDelayedTaskTraits &traits,
                                const Location &location)
{
    auto move_task_to_fiber = [task]() { task(); };
    auto weak_state = std::weak_ptr<State>(state_);
    auto delay_post = [task, location, weak_state]() {
        auto state = weak_state.lock();
        if (!state || !state->is_running) { return; }
        sled::MutexLock lock(&state->mutex);
        state->task_queue.emplace(std::move(task), location);
        state->cv.NotifyOne();
    };
    if (traits.high_precision) {
        delayed_thread_->PostDelayedTaskWithPrecision(
            TaskQueueBase::DelayPrecision::kHigh,
            std::move(move_task_to_fiber),
            std::move(delay_post),
            delay,
            location);
    } else {
        delayed_thread_->PostDelayedTaskWithPrecision(
            TaskQueueBase::DelayPrecision::kLow,
            std::move(move_task_to_fiber),
            delay,
            location);
        delayed_thread_
            ->PostDelayedTaskWithPrecision(TaskQueueBase::DelayPrecision::kLow, std::move(delay_post), delay, location);
    }
}

+25 −11
Original line number Diff line number Diff line
#pragma once
#ifndef SLED_SYSTEM_THREAD_POOL_H
#define SLED_SYSTEM_THREAD_POOL_H
#include "sled/synchronization/mutex.h"
#include "sled/system/fiber/scheduler.h"
#include "sled/system/thread.h"
#include <functional>
@@ -9,6 +10,7 @@
namespace sled {
class ThreadPool final : public TaskQueueBase {
public:
    static constexpr char kTag[] = "ThreadPool";
    /**
    * @param num_threads The number of threads to create in the thread pool. If
    * -1, the number of threads will be equal to the number of hardware threads
@@ -16,18 +18,20 @@ public:
    ThreadPool(int num_threads = -1);
    ~ThreadPool();

    template<typename F, typename... Args>
    auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
    {
        std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
        auto task_ptr                              = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
        auto future                                = task_ptr->get_future();
        scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
        return future;
    }
    // template<typename F, typename... Args>
    // auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
    // {
    //     std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
    //     auto task_ptr                              = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
    //     auto future                                = task_ptr->get_future();
    //     // scheduler_->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
    //     return future;
    // }

    void Delete() override;

    int idle() const;

protected:
    void PostTaskImpl(std::function<void()> &&task, const PostTaskTraits &traits, const Location &location) override;

@@ -37,8 +41,18 @@ protected:
                             const Location &location) override;

private:
    sled::Scheduler *scheduler_;
    std::unique_ptr<sled::Thread> delayed_thread_;
    struct State {
        std::atomic<bool> is_running{true};
        std::atomic<int> idle{0};
        sled::Mutex mutex;
        sled::ConditionVariable cv;
        std::queue<std::pair<std::function<void()>, sled::Location>> task_queue;
    };

    std::shared_ptr<State> state_;
    std::vector<std::thread> threads_;
    // sled::Scheduler *scheduler_;
    std::unique_ptr<sled::Thread> delayed_thread_ = nullptr;
};

}// namespace sled
+2 −2
Original line number Diff line number Diff line
@@ -7,8 +7,8 @@ ThreadPoolBench(picobench::state &state)
{
    sled::ThreadPool pool(-1);
    for (auto _ : state) {
        std::future<int> f = pool.submit([]() { return 1; });
        (void) f.get();
        auto res = pool.BlockingCall([]() { return 1; });
        (void) res;
    }
}

Loading