feat add thread_pool
This commit is contained in:
parent
e1cb39690b
commit
3c1b92dedb
@ -51,6 +51,7 @@ target_sources(
|
||||
src/synchronization/thread_local.cc
|
||||
src/system/location.cc
|
||||
src/system/thread.cc
|
||||
src/system/thread_pool.cc
|
||||
src/task_queue/pending_task_safety_flag.cc
|
||||
src/task_queue/task_queue_base.cc
|
||||
src/timer/task_queue_timeout.cc
|
||||
@ -81,6 +82,7 @@ if(SLED_BUILD_BENCHMARK)
|
||||
src/random_bench.cc
|
||||
src/strings/base64_bench.cc
|
||||
src/system/fiber/fiber_bench.cc
|
||||
src/system/thread_pool_bench.cc
|
||||
src/system_time_bench.cc)
|
||||
target_link_libraries(sled_benchmark PRIVATE sled benchmark::benchmark
|
||||
benchmark::benchmark_main)
|
||||
@ -99,6 +101,7 @@ if(SLED_BUILD_TESTS)
|
||||
src/cleanup_test.cc
|
||||
src/status_or_test.cc
|
||||
src/system/fiber/fiber_test.cc
|
||||
src/system/thread_pool_test.cc
|
||||
)
|
||||
target_link_libraries(sled_tests PRIVATE sled GTest::gtest GTest::gtest_main)
|
||||
add_test(NAME sled_tests COMMAND sled_tests)
|
||||
|
34
include/sled/system/thread_pool.h
Normal file
34
include/sled/system/thread_pool.h
Normal file
@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
#ifndef SLED_SYSTEM_THREAD_POOL_H
|
||||
#define SLED_SYSTEM_THREAD_POOL_H
|
||||
#include "sled/system/fiber/scheduler.h"
|
||||
#include <functional>
|
||||
#include <future>
|
||||
|
||||
namespace sled {
|
||||
class ThreadPool final {
|
||||
public:
|
||||
/**
|
||||
* @param num_threads The number of threads to create in the thread pool. If
|
||||
* -1, the number of threads will be equal to the number of hardware threads
|
||||
**/
|
||||
ThreadPool(int num_threads = -1);
|
||||
~ThreadPool();
|
||||
|
||||
template<typename F, typename... Args>
|
||||
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
|
||||
{
|
||||
std::function<decltype(f(args...))()> func =
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||
auto task_ptr =
|
||||
std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
|
||||
scheduler->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
|
||||
return task_ptr->get_future();
|
||||
}
|
||||
|
||||
private:
|
||||
sled::Scheduler *scheduler;
|
||||
};
|
||||
|
||||
}// namespace sled
|
||||
#endif// SLED_SYSTEM_THREAD_POOL_H
|
15
src/system/thread_pool.cc
Normal file
15
src/system/thread_pool.cc
Normal file
@ -0,0 +1,15 @@
|
||||
#include "sled/system/thread_pool.h"
|
||||
|
||||
namespace sled {
|
||||
ThreadPool::ThreadPool(int num_threads)
|
||||
{
|
||||
if (num_threads == -1) {
|
||||
num_threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
scheduler = new sled::Scheduler(
|
||||
sled::Scheduler::Config().setWorkerThreadCount(num_threads));
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool() { delete scheduler; }
|
||||
|
||||
}// namespace sled
|
20
src/system/thread_pool_bench.cc
Normal file
20
src/system/thread_pool_bench.cc
Normal file
@ -0,0 +1,20 @@
|
||||
#include "sled/system/fiber/wait_group.h"
|
||||
#include <benchmark/benchmark.h>
|
||||
#include <future>
|
||||
#include <sled/system/thread_pool.h>
|
||||
|
||||
static void
|
||||
ThreadPoolBench(benchmark::State &state)
|
||||
{
|
||||
sled::ThreadPool pool(-1);
|
||||
for (auto _ : state) {
|
||||
std::vector<std::future<int>> futures;
|
||||
for (int i = 0; i < state.range(0); i++) {
|
||||
std::future<int> f = pool.submit([]() { return 1; });
|
||||
futures.push_back(std::move(f));
|
||||
}
|
||||
for (auto &f : futures) { f.get(); }
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(ThreadPoolBench)->RangeMultiplier(10)->Range(10, 10000);
|
65
src/system/thread_pool_test.cc
Normal file
65
src/system/thread_pool_test.cc
Normal file
@ -0,0 +1,65 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <random>
|
||||
#include <sled/system/thread_pool.h>
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
std::uniform_int_distribution<int> dist(-10, 10);
|
||||
auto rnd = std::bind(dist, mt);
|
||||
|
||||
void
|
||||
simulate_hard_computation()
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(20 + rnd()));
|
||||
}
|
||||
|
||||
// Simple function that adds multiplies two numbers and prints the result
|
||||
void
|
||||
multiply(const int a, const int b)
|
||||
{
|
||||
simulate_hard_computation();
|
||||
const int res = a * b;
|
||||
}
|
||||
|
||||
// Same as before but now we have an output parameter
|
||||
void
|
||||
multiply_output(int &out, const int a, const int b)
|
||||
{
|
||||
simulate_hard_computation();
|
||||
out = a * b;
|
||||
}
|
||||
|
||||
// Same as before but now we have an output parameter
|
||||
int
|
||||
multiply_return(const int a, const int b)
|
||||
{
|
||||
simulate_hard_computation();
|
||||
const int res = a * b;
|
||||
return res;
|
||||
}
|
||||
|
||||
class ThreadPoolTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override { tp = new sled::ThreadPool(); }
|
||||
|
||||
void TearDown() override { delete tp; }
|
||||
|
||||
sled::ThreadPool *tp;
|
||||
};
|
||||
|
||||
TEST_F(ThreadPoolTest, Output)
|
||||
{
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
int out;
|
||||
tp->submit(multiply_output, std::ref(out), i, i).get();
|
||||
EXPECT_EQ(out, i * i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ThreadPoolTest, Return)
|
||||
{
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
auto f = tp->submit(multiply_return, i, i);
|
||||
EXPECT_EQ(f.get(), i * i);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user