feat add thread_pool
All checks were successful
linux-x64-gcc / linux-gcc (Debug) (push) Successful in 34s
linux-x64-gcc / linux-gcc (Release) (push) Successful in 40s

This commit is contained in:
tqcq 2024-03-11 14:24:15 +08:00
parent e1cb39690b
commit 3c1b92dedb
5 changed files with 137 additions and 0 deletions

View File

@ -51,6 +51,7 @@ target_sources(
src/synchronization/thread_local.cc
src/system/location.cc
src/system/thread.cc
src/system/thread_pool.cc
src/task_queue/pending_task_safety_flag.cc
src/task_queue/task_queue_base.cc
src/timer/task_queue_timeout.cc
@ -81,6 +82,7 @@ if(SLED_BUILD_BENCHMARK)
src/random_bench.cc
src/strings/base64_bench.cc
src/system/fiber/fiber_bench.cc
src/system/thread_pool_bench.cc
src/system_time_bench.cc)
target_link_libraries(sled_benchmark PRIVATE sled benchmark::benchmark
benchmark::benchmark_main)
@ -99,6 +101,7 @@ if(SLED_BUILD_TESTS)
src/cleanup_test.cc
src/status_or_test.cc
src/system/fiber/fiber_test.cc
src/system/thread_pool_test.cc
)
target_link_libraries(sled_tests PRIVATE sled GTest::gtest GTest::gtest_main)
add_test(NAME sled_tests COMMAND sled_tests)

View File

@ -0,0 +1,34 @@
#pragma once
#ifndef SLED_SYSTEM_THREAD_POOL_H
#define SLED_SYSTEM_THREAD_POOL_H
#include "sled/system/fiber/scheduler.h"
#include <functional>
#include <future>
namespace sled {
class ThreadPool final {
public:
/**
* @param num_threads The number of threads to create in the thread pool. If
* -1, the number of threads will be equal to the number of hardware threads
**/
ThreadPool(int num_threads = -1);
~ThreadPool();
template<typename F, typename... Args>
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
{
std::function<decltype(f(args...))()> func =
std::bind(std::forward<F>(f), std::forward<Args>(args)...);
auto task_ptr =
std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
scheduler->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); }));
return task_ptr->get_future();
}
private:
sled::Scheduler *scheduler;
};
}// namespace sled
#endif// SLED_SYSTEM_THREAD_POOL_H

15
src/system/thread_pool.cc Normal file
View File

@ -0,0 +1,15 @@
#include "sled/system/thread_pool.h"
namespace sled {
ThreadPool::ThreadPool(int num_threads)
{
if (num_threads == -1) {
num_threads = std::thread::hardware_concurrency();
}
scheduler = new sled::Scheduler(
sled::Scheduler::Config().setWorkerThreadCount(num_threads));
}
ThreadPool::~ThreadPool() { delete scheduler; }
}// namespace sled

View File

@ -0,0 +1,20 @@
#include "sled/system/fiber/wait_group.h"
#include <benchmark/benchmark.h>
#include <future>
#include <sled/system/thread_pool.h>
static void
ThreadPoolBench(benchmark::State &state)
{
sled::ThreadPool pool(-1);
for (auto _ : state) {
std::vector<std::future<int>> futures;
for (int i = 0; i < state.range(0); i++) {
std::future<int> f = pool.submit([]() { return 1; });
futures.push_back(std::move(f));
}
for (auto &f : futures) { f.get(); }
}
}
BENCHMARK(ThreadPoolBench)->RangeMultiplier(10)->Range(10, 10000);

View File

@ -0,0 +1,65 @@
#include <gtest/gtest.h>
#include <random>
#include <sled/system/thread_pool.h>
std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<int> dist(-10, 10);
auto rnd = std::bind(dist, mt);
void
simulate_hard_computation()
{
std::this_thread::sleep_for(std::chrono::milliseconds(20 + rnd()));
}
// Simple function that adds multiplies two numbers and prints the result
void
multiply(const int a, const int b)
{
simulate_hard_computation();
const int res = a * b;
}
// Same as before but now we have an output parameter
void
multiply_output(int &out, const int a, const int b)
{
simulate_hard_computation();
out = a * b;
}
// Same as before but now we have an output parameter
int
multiply_return(const int a, const int b)
{
simulate_hard_computation();
const int res = a * b;
return res;
}
class ThreadPoolTest : public ::testing::Test {
public:
void SetUp() override { tp = new sled::ThreadPool(); }
void TearDown() override { delete tp; }
sled::ThreadPool *tp;
};
TEST_F(ThreadPoolTest, Output)
{
for (int i = 0; i < 100; ++i) {
int out;
tp->submit(multiply_output, std::ref(out), i, i).get();
EXPECT_EQ(out, i * i);
}
}
TEST_F(ThreadPoolTest, Return)
{
for (int i = 0; i < 100; ++i) {
auto f = tp->submit(multiply_return, i, i);
EXPECT_EQ(f.get(), i * i);
}
}