feat add thread_pool
This commit is contained in:
parent
e1cb39690b
commit
3c1b92dedb
@ -51,6 +51,7 @@ target_sources(
|
|||||||
src/synchronization/thread_local.cc
|
src/synchronization/thread_local.cc
|
||||||
src/system/location.cc
|
src/system/location.cc
|
||||||
src/system/thread.cc
|
src/system/thread.cc
|
||||||
|
src/system/thread_pool.cc
|
||||||
src/task_queue/pending_task_safety_flag.cc
|
src/task_queue/pending_task_safety_flag.cc
|
||||||
src/task_queue/task_queue_base.cc
|
src/task_queue/task_queue_base.cc
|
||||||
src/timer/task_queue_timeout.cc
|
src/timer/task_queue_timeout.cc
|
||||||
@ -81,6 +82,7 @@ if(SLED_BUILD_BENCHMARK)
|
|||||||
src/random_bench.cc
|
src/random_bench.cc
|
||||||
src/strings/base64_bench.cc
|
src/strings/base64_bench.cc
|
||||||
src/system/fiber/fiber_bench.cc
|
src/system/fiber/fiber_bench.cc
|
||||||
|
src/system/thread_pool_bench.cc
|
||||||
src/system_time_bench.cc)
|
src/system_time_bench.cc)
|
||||||
target_link_libraries(sled_benchmark PRIVATE sled benchmark::benchmark
|
target_link_libraries(sled_benchmark PRIVATE sled benchmark::benchmark
|
||||||
benchmark::benchmark_main)
|
benchmark::benchmark_main)
|
||||||
@ -99,6 +101,7 @@ if(SLED_BUILD_TESTS)
|
|||||||
src/cleanup_test.cc
|
src/cleanup_test.cc
|
||||||
src/status_or_test.cc
|
src/status_or_test.cc
|
||||||
src/system/fiber/fiber_test.cc
|
src/system/fiber/fiber_test.cc
|
||||||
|
src/system/thread_pool_test.cc
|
||||||
)
|
)
|
||||||
target_link_libraries(sled_tests PRIVATE sled GTest::gtest GTest::gtest_main)
|
target_link_libraries(sled_tests PRIVATE sled GTest::gtest GTest::gtest_main)
|
||||||
add_test(NAME sled_tests COMMAND sled_tests)
|
add_test(NAME sled_tests COMMAND sled_tests)
|
||||||
|
34
include/sled/system/thread_pool.h
Normal file
34
include/sled/system/thread_pool.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#pragma once
|
||||||
|
#ifndef SLED_SYSTEM_THREAD_POOL_H
|
||||||
|
#define SLED_SYSTEM_THREAD_POOL_H
|
||||||
|
#include "sled/system/fiber/scheduler.h"
|
||||||
|
#include <functional>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
|
namespace sled {
|
||||||
|
class ThreadPool final {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @param num_threads The number of threads to create in the thread pool. If
|
||||||
|
* -1, the number of threads will be equal to the number of hardware threads
|
||||||
|
**/
|
||||||
|
ThreadPool(int num_threads = -1);
|
||||||
|
~ThreadPool();
|
||||||
|
|
||||||
|
template<typename F, typename... Args>
|
||||||
|
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
|
||||||
|
{
|
||||||
|
std::function<decltype(f(args...))()> func =
|
||||||
|
std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||||
|
auto task_ptr =
|
||||||
|
std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
|
||||||
|
scheduler->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
|
||||||
|
return task_ptr->get_future();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
sled::Scheduler *scheduler;
|
||||||
|
};
|
||||||
|
|
||||||
|
}// namespace sled
|
||||||
|
#endif// SLED_SYSTEM_THREAD_POOL_H
|
15
src/system/thread_pool.cc
Normal file
15
src/system/thread_pool.cc
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#include "sled/system/thread_pool.h"
|
||||||
|
|
||||||
|
namespace sled {
|
||||||
|
ThreadPool::ThreadPool(int num_threads)
|
||||||
|
{
|
||||||
|
if (num_threads == -1) {
|
||||||
|
num_threads = std::thread::hardware_concurrency();
|
||||||
|
}
|
||||||
|
scheduler = new sled::Scheduler(
|
||||||
|
sled::Scheduler::Config().setWorkerThreadCount(num_threads));
|
||||||
|
}
|
||||||
|
|
||||||
|
ThreadPool::~ThreadPool() { delete scheduler; }
|
||||||
|
|
||||||
|
}// namespace sled
|
20
src/system/thread_pool_bench.cc
Normal file
20
src/system/thread_pool_bench.cc
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#include "sled/system/fiber/wait_group.h"
|
||||||
|
#include <benchmark/benchmark.h>
|
||||||
|
#include <future>
|
||||||
|
#include <sled/system/thread_pool.h>
|
||||||
|
|
||||||
|
static void
|
||||||
|
ThreadPoolBench(benchmark::State &state)
|
||||||
|
{
|
||||||
|
sled::ThreadPool pool(-1);
|
||||||
|
for (auto _ : state) {
|
||||||
|
std::vector<std::future<int>> futures;
|
||||||
|
for (int i = 0; i < state.range(0); i++) {
|
||||||
|
std::future<int> f = pool.submit([]() { return 1; });
|
||||||
|
futures.push_back(std::move(f));
|
||||||
|
}
|
||||||
|
for (auto &f : futures) { f.get(); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(ThreadPoolBench)->RangeMultiplier(10)->Range(10, 10000);
|
65
src/system/thread_pool_test.cc
Normal file
65
src/system/thread_pool_test.cc
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <random>
|
||||||
|
#include <sled/system/thread_pool.h>
|
||||||
|
|
||||||
|
std::random_device rd;
|
||||||
|
std::mt19937 mt(rd());
|
||||||
|
std::uniform_int_distribution<int> dist(-10, 10);
|
||||||
|
auto rnd = std::bind(dist, mt);
|
||||||
|
|
||||||
|
void
|
||||||
|
simulate_hard_computation()
|
||||||
|
{
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(20 + rnd()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple function that adds multiplies two numbers and prints the result
|
||||||
|
void
|
||||||
|
multiply(const int a, const int b)
|
||||||
|
{
|
||||||
|
simulate_hard_computation();
|
||||||
|
const int res = a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same as before but now we have an output parameter
|
||||||
|
void
|
||||||
|
multiply_output(int &out, const int a, const int b)
|
||||||
|
{
|
||||||
|
simulate_hard_computation();
|
||||||
|
out = a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same as before but now we have an output parameter
|
||||||
|
int
|
||||||
|
multiply_return(const int a, const int b)
|
||||||
|
{
|
||||||
|
simulate_hard_computation();
|
||||||
|
const int res = a * b;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
class ThreadPoolTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
void SetUp() override { tp = new sled::ThreadPool(); }
|
||||||
|
|
||||||
|
void TearDown() override { delete tp; }
|
||||||
|
|
||||||
|
sled::ThreadPool *tp;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ThreadPoolTest, Output)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < 100; ++i) {
|
||||||
|
int out;
|
||||||
|
tp->submit(multiply_output, std::ref(out), i, i).get();
|
||||||
|
EXPECT_EQ(out, i * i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ThreadPoolTest, Return)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < 100; ++i) {
|
||||||
|
auto f = tp->submit(multiply_return, i, i);
|
||||||
|
EXPECT_EQ(f.get(), i * i);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user