diff --git a/CMakeLists.txt b/CMakeLists.txt index 592bc4a..8e50c33 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ target_sources( src/synchronization/thread_local.cc src/system/location.cc src/system/thread.cc + src/system/thread_pool.cc src/task_queue/pending_task_safety_flag.cc src/task_queue/task_queue_base.cc src/timer/task_queue_timeout.cc @@ -81,6 +82,7 @@ if(SLED_BUILD_BENCHMARK) src/random_bench.cc src/strings/base64_bench.cc src/system/fiber/fiber_bench.cc + src/system/thread_pool_bench.cc src/system_time_bench.cc) target_link_libraries(sled_benchmark PRIVATE sled benchmark::benchmark benchmark::benchmark_main) @@ -99,6 +101,7 @@ if(SLED_BUILD_TESTS) src/cleanup_test.cc src/status_or_test.cc src/system/fiber/fiber_test.cc + src/system/thread_pool_test.cc ) target_link_libraries(sled_tests PRIVATE sled GTest::gtest GTest::gtest_main) add_test(NAME sled_tests COMMAND sled_tests) diff --git a/include/sled/system/thread_pool.h b/include/sled/system/thread_pool.h new file mode 100644 index 0000000..98e8945 --- /dev/null +++ b/include/sled/system/thread_pool.h @@ -0,0 +1,34 @@ +#pragma once +#ifndef SLED_SYSTEM_THREAD_POOL_H +#define SLED_SYSTEM_THREAD_POOL_H +#include "sled/system/fiber/scheduler.h" +#include +#include + +namespace sled { +class ThreadPool final { +public: + /** + * @param num_threads The number of threads to create in the thread pool. If + * -1, the number of threads will be equal to the number of hardware threads + **/ + ThreadPool(int num_threads = -1); + ~ThreadPool(); + + template + auto submit(F &&f, Args &&...args) -> std::future + { + std::function func = + std::bind(std::forward(f), std::forward(args)...); + auto task_ptr = + std::make_shared>(func); + scheduler->enqueue(marl::Task([task_ptr]() { (*task_ptr)(); })); + return task_ptr->get_future(); + } + +private: + sled::Scheduler *scheduler; +}; + +}// namespace sled +#endif// SLED_SYSTEM_THREAD_POOL_H diff --git a/src/system/thread_pool.cc b/src/system/thread_pool.cc new file mode 100644 index 0000000..43ec03f --- /dev/null +++ b/src/system/thread_pool.cc @@ -0,0 +1,15 @@ +#include "sled/system/thread_pool.h" + +namespace sled { +ThreadPool::ThreadPool(int num_threads) +{ + if (num_threads == -1) { + num_threads = std::thread::hardware_concurrency(); + } + scheduler = new sled::Scheduler( + sled::Scheduler::Config().setWorkerThreadCount(num_threads)); +} + +ThreadPool::~ThreadPool() { delete scheduler; } + +}// namespace sled diff --git a/src/system/thread_pool_bench.cc b/src/system/thread_pool_bench.cc new file mode 100644 index 0000000..ab0d577 --- /dev/null +++ b/src/system/thread_pool_bench.cc @@ -0,0 +1,20 @@ +#include "sled/system/fiber/wait_group.h" +#include +#include +#include + +static void +ThreadPoolBench(benchmark::State &state) +{ + sled::ThreadPool pool(-1); + for (auto _ : state) { + std::vector> futures; + for (int i = 0; i < state.range(0); i++) { + std::future f = pool.submit([]() { return 1; }); + futures.push_back(std::move(f)); + } + for (auto &f : futures) { f.get(); } + } +} + +BENCHMARK(ThreadPoolBench)->RangeMultiplier(10)->Range(10, 10000); diff --git a/src/system/thread_pool_test.cc b/src/system/thread_pool_test.cc new file mode 100644 index 0000000..b3796a8 --- /dev/null +++ b/src/system/thread_pool_test.cc @@ -0,0 +1,65 @@ +#include +#include +#include + +std::random_device rd; +std::mt19937 mt(rd()); +std::uniform_int_distribution dist(-10, 10); +auto rnd = std::bind(dist, mt); + +void +simulate_hard_computation() +{ + std::this_thread::sleep_for(std::chrono::milliseconds(20 + rnd())); +} + +// Simple function that adds multiplies two numbers and prints the result +void +multiply(const int a, const int b) +{ + simulate_hard_computation(); + const int res = a * b; +} + +// Same as before but now we have an output parameter +void +multiply_output(int &out, const int a, const int b) +{ + simulate_hard_computation(); + out = a * b; +} + +// Same as before but now we have an output parameter +int +multiply_return(const int a, const int b) +{ + simulate_hard_computation(); + const int res = a * b; + return res; +} + +class ThreadPoolTest : public ::testing::Test { +public: + void SetUp() override { tp = new sled::ThreadPool(); } + + void TearDown() override { delete tp; } + + sled::ThreadPool *tp; +}; + +TEST_F(ThreadPoolTest, Output) +{ + for (int i = 0; i < 100; ++i) { + int out; + tp->submit(multiply_output, std::ref(out), i, i).get(); + EXPECT_EQ(out, i * i); + } +} + +TEST_F(ThreadPoolTest, Return) +{ + for (int i = 0; i < 100; ++i) { + auto f = tp->submit(multiply_return, i, i); + EXPECT_EQ(f.get(), i * i); + } +}