diff --git a/CMakeLists.txt b/CMakeLists.txt index e0f61d2..2adefa4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,12 +77,25 @@ target_link_libraries(sled PUBLIC rpc_core fmt marl) if(SLED_BUILD_BENCHMARK) find_package(benchmark REQUIRED) - add_executable(sled_benchmark benchmark/strings/base64_benchmark.cc) + add_executable(sled_benchmark "src/system/fiber/fiber_bench.cc") target_link_libraries(sled_benchmark PRIVATE sled benchmark::benchmark benchmark::benchmark_main) endif(SLED_BUILD_BENCHMARK) if(SLED_BUILD_TESTS) - find_package(gtest REQUIRED) - add_executable(sled_tests "") + include(FetchContent) + FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip + ) + FetchContent_MakeAvailable(googletest) + add_executable(sled_tests + src/filesystem/path_test.cc + src/strings/base64_test.cc + src/cleanup_test.cc + src/status_or_test.cc + src/system/fiber/fiber_test.cc + ) + target_link_libraries(sled_tests PRIVATE sled GTest::gtest GTest::gtest_main) + add_test(NAME sled_tests COMMAND sled_tests) endif(SLED_BUILD_TESTS) diff --git a/benchmark/strings/base64_benchmark.cc b/benchmark/strings/base64_benchmark.cc deleted file mode 100644 index c9e2740..0000000 --- a/benchmark/strings/base64_benchmark.cc +++ /dev/null @@ -1,59 +0,0 @@ -#include -#include -#include -#include - -struct strings {}; - -static std::string -AllocRandomString(sled::Random &random, int len) -{ - std::stringstream ss; - for (int i = len; i > 0; i--) { ss << random.Rand(); } - return ss.str(); -} - -void -BenchmarkBase64Encode(benchmark::State &state) -{ - state.PauseTiming(); - sled::Random random(2393); - std::vector test_data; - for (int i = 0; i < state.range(0); i++) { - test_data.emplace_back(AllocRandomString(random, state.range(1))); - } - - state.ResumeTiming(); - for (int i = 0; i < state.range(2); i++) { - for (const auto &str : test_data) { - auto base64 = sled::Base64::Encode(str); - } - } -} - -std::string -uint2str(unsigned int num) -{ - std::ostringstream oss; - oss << num; - return oss.str(); -} - -void -test(benchmark::State &state) -{ - for (int i = 0; i < 1000000; i++) (void) uint2str(i); - state.end(); -} - -BENCHMARK(test); -/* -BENCHMARK(BenchmarkBase64Encode) - ->ArgsProduct({ - // generate the num of strings - benchmark::CreateRange(10, 1000, 10), - // generate the length of each string - benchmark::CreateRange(10, 1000, 10), - benchmark::CreateRange(10, 1000, 10), - }); -*/ diff --git a/include/sled/synchronization/mutex.h b/include/sled/synchronization/mutex.h index f27e39e..dc2cccf 100644 --- a/include/sled/synchronization/mutex.h +++ b/include/sled/synchronization/mutex.h @@ -5,12 +5,14 @@ **/ #pragma once +#include "marl/conditionvariable.h" #ifndef SLED_SYNCHRONIZATION_MUTEX_H #define SLED_SYNCHRONIZATION_MUTEX_H #include "sled/units/time_delta.h" #include #include +#include #include #include @@ -31,24 +33,26 @@ struct HasLockAndUnlock { }; }// namespace internal -class Mutex final { -public: - Mutex() = default; - Mutex(const Mutex &) = delete; - Mutex &operator=(const Mutex &) = delete; +using Mutex = marl::mutex; - inline void Lock() { impl_.lock(); }; - - inline bool TryLock() { return impl_.try_lock(); } - - inline void AssertHeld() {} - - inline void Unlock() { impl_.unlock(); } - -private: - std::mutex impl_; - friend class ConditionVariable; -}; +// class Mutex final { +// public: +// Mutex() = default; +// Mutex(const Mutex &) = delete; +// Mutex &operator=(const Mutex &) = delete; +// +// inline void Lock() { impl_.lock(); }; +// +// inline bool TryLock() { return impl_.try_lock(); } +// +// inline void AssertHeld() {} +// +// inline void Unlock() { impl_.unlock(); } +// +// private: +// std::mutex impl_; +// friend class ConditionVariable; +// }; class RecursiveMutex final { public: @@ -85,8 +89,22 @@ private: friend class ConditionVariable; }; -using MutexLock = LockGuard; -using MutexGuard = LockGuard; +class MutexGuard final { +public: + MutexGuard(Mutex *mutex) : lock_(*mutex) {} + + MutexGuard(const MutexGuard &) = delete; + MutexGuard &operator=(const MutexGuard &) = delete; + +private: + friend class ConditionVariable; + marl::lock lock_; +}; + +using MutexLock = MutexGuard; +// using MutexGuard = marl::lock; +// using MutexLock = LockGuard; +// using MutexGuard = LockGuard; using RecursiveMutexLock = LockGuard; // class MutexLock final { @@ -121,44 +139,77 @@ using RecursiveMutexLock = LockGuard; class ConditionVariable final { public: static constexpr TimeDelta kForever = TimeDelta::PlusInfinity(); - ConditionVariable() = default; - ConditionVariable(const ConditionVariable &) = delete; - ConditionVariable &operator=(const ConditionVariable &) = delete; - template - inline bool Wait(LockGuard &guard, Predicate pred) - { - std::unique_lock lock(guard.mutex_->impl_, std::adopt_lock); - cv_.wait(lock, pred); - return true; - } - - template - inline bool - WaitFor(LockGuard &guard, TimeDelta timeout, Predicate pred) - { - std::unique_lock lock(guard.mutex_->impl_, std::adopt_lock); - if (timeout == kForever) { - cv_.wait(lock, pred); - return true; - } else { - return cv_.wait_for(lock, std::chrono::milliseconds(timeout.ms()), - pred); - } - } - - // template - // bool WaitUntil(Mutex *mutex, TimeDelta timeout, Predicate pred) - // {} + // inline ConditionVariable(); inline void NotifyOne() { cv_.notify_one(); } inline void NotifyAll() { cv_.notify_all(); } + template + inline void Wait(MutexLock &lock, Predicate &&pred) + { + cv_.wait(lock, std::forward(pred)); + } + + template + inline bool WaitFor(MutexLock &lock, TimeDelta timeout, Predicate &&pred) + { + if (timeout == TimeDelta::PlusInfinity()) { + cv_.wait(lock.lock_, std::forward(pred)); + return true; + } else { + return cv_.wait_for(lock.lock_, + std::chrono::milliseconds(timeout.ms()), + std::forward(pred)); + } + } + private: - std::condition_variable cv_; + marl::ConditionVariable cv_; }; +// class ConditionVariable final { +// public: +// static constexpr TimeDelta kForever = TimeDelta::PlusInfinity(); +// ConditionVariable() = default; +// ConditionVariable(const ConditionVariable &) = delete; +// ConditionVariable &operator=(const ConditionVariable &) = delete; +// +// template +// inline bool Wait(LockGuard &guard, Predicate pred) +// { +// std::unique_lock lock(guard.mutex_->impl_, std::adopt_lock); +// cv_.wait(lock, pred); +// return true; +// } +// +// template +// inline bool +// WaitFor(LockGuard &guard, TimeDelta timeout, Predicate pred) +// { +// std::unique_lock lock(guard.mutex_->impl_, std::adopt_lock); +// if (timeout == kForever) { +// cv_.wait(lock, pred); +// return true; +// } else { +// return cv_.wait_for(lock, std::chrono::milliseconds(timeout.ms()), +// pred); +// } +// } +// +// // template +// // bool WaitUntil(Mutex *mutex, TimeDelta timeout, Predicate pred) +// // {} +// +// inline void NotifyOne() { cv_.notify_one(); } +// +// inline void NotifyAll() { cv_.notify_all(); } +// +// private: +// std::condition_variable cv_; +// }; + }// namespace sled #endif// SLED_SYNCHRONIZATION_MUTEX_H diff --git a/include/sled/system/fiber/fiber_wait_group.h b/include/sled/system/fiber/fiber_wait_group.h deleted file mode 100644 index b747a31..0000000 --- a/include/sled/system/fiber/fiber_wait_group.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once -#ifndef SLED_SYSTEM_FIBER_FIBER_WAIT_GROUP_H -#define SLED_SYSTEM_FIBER_FIBER_WAIT_GROUP_H -#include - -namespace sled { -using FiberWaitGroup = marl::WaitGroup; -} - -#endif// SLED_SYSTEM_FIBER_FIBER_WAIT_GROUP_H diff --git a/include/sled/system/fiber/fiber_scheduler.h b/include/sled/system/fiber/scheduler.h similarity index 87% rename from include/sled/system/fiber/fiber_scheduler.h rename to include/sled/system/fiber/scheduler.h index 24f611d..9b807a5 100644 --- a/include/sled/system/fiber/fiber_scheduler.h +++ b/include/sled/system/fiber/scheduler.h @@ -1,12 +1,12 @@ #pragma once -#ifndef SLED_SYSTEM_FIBER_FIBER_SCHEDULER_H -#define SLED_SYSTEM_FIBER_FIBER_SCHEDULER_H +#ifndef SLED_SYSTEM_FIBER_SCHEDULER_H +#define SLED_SYSTEM_FIBER_SCHEDULER_H #include #include #include namespace sled { -using FiberScheduler = marl::Scheduler; +using Scheduler = marl::Scheduler; // schedule() schedules the task T to be asynchronously called using the // currently bound scheduler. @@ -42,4 +42,4 @@ Schedule(Function &&f) } }// namespace sled -#endif// SLED_SYSTEM_FIBER_FIBER_SCHEDULER_H +#endif// SLED_SYSTEM_FIBER_SCHEDULER_H diff --git a/include/sled/system/fiber/wait_group.h b/include/sled/system/fiber/wait_group.h new file mode 100644 index 0000000..9dc5ad2 --- /dev/null +++ b/include/sled/system/fiber/wait_group.h @@ -0,0 +1,26 @@ +#pragma once +#ifndef SLED_SYSTEM_FIBER_WAIT_GROUP_H +#define SLED_SYSTEM_FIBER_WAIT_GROUP_H +#include + +namespace sled { + +class WaitGroup final { +public: + inline WaitGroup(unsigned int count = 0, + marl::Allocator *allocator = marl::Allocator::Default) + : wg_(new marl::WaitGroup(count, allocator)) + {} + + inline void Add(unsigned int count = 1) const { wg_->add(count); }; + + inline bool Done() const { return wg_->done(); } + + inline void Wait() const { wg_->wait(); } + +private: + mutable std::shared_ptr wg_; +}; +}// namespace sled + +#endif// SLED_SYSTEM_FIBER_WAIT_GROUP_H diff --git a/src/cleanup_test.cc b/src/cleanup_test.cc new file mode 100644 index 0000000..76ba443 --- /dev/null +++ b/src/cleanup_test.cc @@ -0,0 +1,16 @@ +#include +#include +#include + +TEST(Cleanup, TestCleanup) +{ + sled::Random rand(1314); + for (int i = 0; i < 100; ++i) { + int a = rand.Rand(10000); + int b = rand.Rand(10000, 20000); + { + sled::Cleanup<> c([=, &a]() { a = b; }); + } + ASSERT_EQ(a, b); + } +} diff --git a/src/filesystem/path_test.cc b/src/filesystem/path_test.cc new file mode 100644 index 0000000..e54bc7f --- /dev/null +++ b/src/filesystem/path_test.cc @@ -0,0 +1,23 @@ +#include +#include + +TEST(Path, TestCurrent) +{ + sled::Path path = sled::Path::Current(); + std::string str = path.ToString(); + EXPECT_FALSE(str.empty()); +} + +TEST(Path, TestHome) +{ + sled::Path path = sled::Path::Home(); + std::string str = path.ToString(); + EXPECT_FALSE(str.empty()); +} + +TEST(Path, TestTempDir) +{ + sled::Path path = sled::Path::TempDir(); + std::string str = path.ToString(); + EXPECT_FALSE(str.empty()); +} diff --git a/src/network/async_resolver.cc b/src/network/async_resolver.cc index eb2b7d9..4ab7b75 100644 --- a/src/network/async_resolver.cc +++ b/src/network/async_resolver.cc @@ -45,7 +45,6 @@ AsyncResolver::AsyncResolver() : error_(-1), state_(new State) {} AsyncResolver::~AsyncResolver() { MutexLock lock(&state_->mutex); - ; state_->status = State::Status::kDead; } diff --git a/src/network/physical_socket_server.cc b/src/network/physical_socket_server.cc index d107b73..d52b1d1 100644 --- a/src/network/physical_socket_server.cc +++ b/src/network/physical_socket_server.cc @@ -400,7 +400,7 @@ void PhysicalSocket::SetError(int error) { // MutexLock lock(&mutex_); - LockGuard lock(&mutex_); + MutexGuard lock(&mutex_); error_ = error; } diff --git a/src/status_or_test.cc b/src/status_or_test.cc new file mode 100644 index 0000000..e117c1e --- /dev/null +++ b/src/status_or_test.cc @@ -0,0 +1,22 @@ +#include +#include + +TEST(StatusOr, TestStatusOr) +{ + sled::StatusOr so; + EXPECT_FALSE(so.ok()); + so = sled::StatusOr(1); + EXPECT_TRUE(so.ok()); + EXPECT_EQ(so.value(), 1); + EXPECT_EQ(so.status().code(), sled::StatusCode::kOk); +} + +TEST(StatusOr, make_status_or) +{ + auto from_raw_str = sled::make_status_or("hello"); + auto from_string = sled::make_status_or(std::string("world")); + EXPECT_TRUE(from_raw_str.ok()); + EXPECT_TRUE(from_string.ok()); + EXPECT_EQ(from_raw_str.value(), "hello"); + EXPECT_EQ(from_string.value(), "world"); +} diff --git a/src/strings/base64.cc b/src/strings/base64.cc index ded6149..5b81f20 100644 --- a/src/strings/base64.cc +++ b/src/strings/base64.cc @@ -45,6 +45,11 @@ Base64::Encode(const std::string &input) } } + /** + * value_bits + * 2 -> 4 -> (8 - value_bits - 2) + * 4 -> 2 -> (8 - value_bits - 2) + **/ if (value_bits > 0) { ss << kBase64Chars[((value << 8) >> (value_bits + 2)) & 0x3F]; } @@ -63,10 +68,10 @@ Base64::Decode(const std::string &input) for (unsigned char c : input) { if (-1 != kInvBase64Chars[c]) { // valid base64 character - value = (value << 6) + kInvBase64Chars[c]; + value = (value << 6) | kInvBase64Chars[c]; value_bits += 6; - if (value_bits >= 0) { - ss << char((value >> value_bits) & 0xFF); + if (value_bits >= 8) { + ss << char((value >> (value_bits - 8)) & 0xFF); value_bits -= 8; } } else if (c == '=') { @@ -81,6 +86,7 @@ Base64::Decode(const std::string &input) } ++index; } + return make_status_or(ss.str()); } }// namespace sled diff --git a/src/strings/base64_test.cc b/src/strings/base64_test.cc new file mode 100644 index 0000000..d8c2963 --- /dev/null +++ b/src/strings/base64_test.cc @@ -0,0 +1,23 @@ +#include +#include + +#define CONCAT_IMPL(A, B) A##B +#define CONCAT(A, B) CONCAT_IMPL(A, B) + +#define TEST_ENCODE_DECODE(base64, text) \ + do { \ + ASSERT_EQ(sled::Base64::Encode(text), base64); \ + auto CONCAT(res, __LINE__) = sled::Base64::Decode(base64); \ + ASSERT_TRUE(CONCAT(res, __LINE__).ok()); \ + ASSERT_EQ(CONCAT(res, __LINE__).value(), text); \ + } while (0) + +TEST(Base64, EncodeAndDecode) +{ + TEST_ENCODE_DECODE("aGVsbG8gd29ybGQK", "hello world\n"); + TEST_ENCODE_DECODE("U2VuZCByZWluZm9yY2VtZW50cwo=", "Send reinforcements\n"); + TEST_ENCODE_DECODE("", ""); + TEST_ENCODE_DECODE("AA==", std::string("\0", 1)); + TEST_ENCODE_DECODE("AAA=", std::string("\0\0", 2)); + TEST_ENCODE_DECODE("AAAA", std::string("\0\0\0", 3)); +} diff --git a/src/system/fiber/fiber_bench.cc b/src/system/fiber/fiber_bench.cc new file mode 100644 index 0000000..cc36772 --- /dev/null +++ b/src/system/fiber/fiber_bench.cc @@ -0,0 +1,25 @@ +#include +#include +#include + +static void +MultiQueue(benchmark::State &state) +{ + sled::Scheduler scheduler(sled::Scheduler::Config::allCores()); + scheduler.bind(); + defer(scheduler.unbind()); + const int num_tasks = state.range(0); + sled::WaitGroup wg(num_tasks); + sled::WaitGroup start_flag(1); + + for (int i = 0; i < num_tasks; i++) { + sled::Schedule([=] { + start_flag.Wait(); + wg.Done(); + }); + } + start_flag.Done(); + wg.Wait(); +} + +BENCHMARK(MultiQueue)->RangeMultiplier(10)->Range(10, 10000); diff --git a/src/system/fiber/fiber_test.cc b/src/system/fiber/fiber_test.cc new file mode 100644 index 0000000..75f8f36 --- /dev/null +++ b/src/system/fiber/fiber_test.cc @@ -0,0 +1,25 @@ +#include +#include +#include + +TEST(FiberScheduler, TestFiberScheduler) +{ + sled::Scheduler scheduler(sled::Scheduler::Config::allCores()); + scheduler.bind(); + defer(scheduler.unbind()); + + std::atomic counter = {0}; + sled::WaitGroup wg(1); + sled::WaitGroup wg2(1000); + for (int i = 0; i < 1000; i++) { + sled::Schedule([&] { + wg.Wait(); + wg2.Done(); + counter++; + }); + } + + sled::Schedule([=] { wg.Done(); }); + wg2.Wait(); + ASSERT_EQ(counter.load(), 1000); +}