diff --git a/CMakeLists.txt b/CMakeLists.txt index d206d10..ea28e61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,6 +193,7 @@ set(TILE_SRCS "tile/io/detail/eintr_safe.cc" "tile/io/native/acceptor.cc" "tile/io/descriptor.cc" + "tile/io/util/rate_limiter.cc" "tile/io/event_loop.cc" "tile/init.cc" "tile/init/on_init.cc" @@ -291,6 +292,7 @@ if(TILE_BUILD_TESTS) add_test(NAME ${test_name} COMMAND ${test_name}) endmacro() + tile_add_test(io_util_rate_limiter_test "tile/io/util/rate_limiter_test.cc") tile_add_test(base_exposed_var_test "tile/base/exposed_var_test.cc") # tile_add_test(fiber_detail_scheduler_test "tile/fiber/detail/scheduler_test.cc") tile_add_test(base_internal_meta_test "tile/base/internal/meta_test.cc") diff --git a/tile/io/util/rate_limiter.cc b/tile/io/util/rate_limiter.cc new file mode 100644 index 0000000..ff2569b --- /dev/null +++ b/tile/io/util/rate_limiter.cc @@ -0,0 +1,84 @@ +#include "tile/io/util/rate_limiter.h" + +#include "tile/base/chrono.h" +#include "tile/base/logging.h" + +namespace tile { + +namespace { +class NullLimiter : public RateLimiter { +public: + std::size_t GetQuota() override { return 0; } + void ConsumeBytes(std::size_t consumed) override {} +}; +} // namespace + +RateLimiter *RateLimiter::GetDefaultRxRateLimiter() { + static NullLimiter null_limiter; + return &null_limiter; +} + +RateLimiter *RateLimiter::GetDefaultTxRateLimiter() { + static NullLimiter null_limiter; + return &null_limiter; +} + +TokenBucketRateLimiter::TokenBucketRateLimiter(std::size_t bucket_quota, + std::size_t quota_per_tick, + std::chrono::nanoseconds tick, + bool over_consumption_allowed) + : max_quota_(bucket_quota), quota_per_tick_(quota_per_tick), tick_(tick), + over_consumption_allowed_(over_consumption_allowed) { + TILE_CHECK_GT(bucket_quota, 0); + TILE_CHECK_GT(quota_per_tick, 0); + last_refill_ = ReadSteadyClock().time_since_epoch() / tick_; + curr_quota_ = max_quota_; +} + +std::size_t TokenBucketRateLimiter::GetQuota() { + auto now = ReadSteadyClock().time_since_epoch() / tick_; + std::uint64_t last_refill = internal::Exchange(last_refill_, now); + + curr_quota_ += quota_per_tick_ * (now - last_refill); + + if (curr_quota_ > 0) { + curr_quota_ = std::min(curr_quota_, max_quota_); + return curr_quota_; + } else { + return 0; + } +} + +void TokenBucketRateLimiter::ConsumeBytes(std::size_t consumed) { + TILE_CHECK(over_consumption_allowed_ || consumed <= curr_quota_); + curr_quota_ -= consumed; +} + +ThreadSafeRateLimiter::ThreadSafeRateLimiter(MaybeOwning limiter, + std::size_t burst_limit) + : burst_limit_(burst_limit), impl_(std::move(limiter)) { + TILE_CHECK_GT(burst_limit_, 0); +} + +std::size_t ThreadSafeRateLimiter::GetQuota() { + std::lock_guard lock(lock_); + return impl_->GetQuota(); +} +void ThreadSafeRateLimiter::ConsumeBytes(std::size_t consumed) { + std::lock_guard lock(lock_); + impl_->ConsumeBytes(consumed); +} + +LayeredRateLimiter::LayeredRateLimiter(RateLimiter *upper, + MaybeOwning ours) + : upper_(upper), ours_(std::move(ours)) {} + +std::size_t LayeredRateLimiter::GetQuota() { + return std::min(upper_->GetQuota(), ours_->GetQuota()); +} + +void LayeredRateLimiter::ConsumeBytes(std::size_t consumed) { + upper_->ConsumeBytes(consumed); + ours_->ConsumeBytes(consumed); +} +} // namespace tile diff --git a/tile/io/util/rate_limiter.h b/tile/io/util/rate_limiter.h new file mode 100644 index 0000000..8ed24c6 --- /dev/null +++ b/tile/io/util/rate_limiter.h @@ -0,0 +1,70 @@ +#ifndef TILE_IO_UTIL_RATE_LIMITER_H +#define TILE_IO_UTIL_RATE_LIMITER_H + +#pragma once + +#include "tile/base/maybe_owning.h" + +#include +#include + +namespace tile { +class RateLimiter { +public: + virtual ~RateLimiter() = default; + + virtual std::size_t GetQuota() = 0; + virtual void ConsumeBytes(std::size_t consumed) = 0; + static RateLimiter *GetDefaultRxRateLimiter(); + static RateLimiter *GetDefaultTxRateLimiter(); +}; + +class TokenBucketRateLimiter : public RateLimiter { +public: + TokenBucketRateLimiter( + std::size_t bucket_quota, std::size_t quota_per_tick, + std::chrono::nanoseconds tick = std::chrono::milliseconds(1), + bool over_consumption_allowed = true); + + std::size_t GetQuota() override; + void ConsumeBytes(std::size_t consumed) override; + +private: + std::size_t max_quota_; + std::size_t quota_per_tick_; + std::chrono::nanoseconds tick_; + bool over_consumption_allowed_; + + std::uint64_t last_refill_; + std::int64_t curr_quota_{0}; +}; + +class ThreadSafeRateLimiter : public RateLimiter { +public: + explicit ThreadSafeRateLimiter( + MaybeOwning limiter, + std::size_t burst_limit = std::numeric_limits::max()); + std::size_t GetQuota() override; + void ConsumeBytes(std::size_t consumed) override; + +private: + std::size_t burst_limit_; + std::mutex lock_; + MaybeOwning impl_; +}; + +class LayeredRateLimiter : public RateLimiter { +public: + LayeredRateLimiter(RateLimiter *upper, MaybeOwning ours); + + std::size_t GetQuota() override; + void ConsumeBytes(std::size_t consumed) override; + +private: + RateLimiter *upper_; + MaybeOwning ours_; +}; + +} // namespace tile + +#endif // TILE_IO_UTIL_RATE_LIMITER_H diff --git a/tile/io/util/rate_limiter_test.cc b/tile/io/util/rate_limiter_test.cc new file mode 100644 index 0000000..753e8f1 --- /dev/null +++ b/tile/io/util/rate_limiter_test.cc @@ -0,0 +1,188 @@ +#include "rate_limiter.h" + +#include "gtest/gtest.h" + +#include "tile/base/chrono.h" +#include "tile/base/make_unique.h" +#include "tile/base/random.h" + +namespace tile { + +TEST(RateLimiter, TokenBucketRateLimiter) { + TokenBucketRateLimiter limiter(1000, 1); + std::size_t total = 0; + auto start = ReadSteadyClock(); + while (ReadSteadyClock() - start < std::chrono::seconds(5)) { + auto current = limiter.GetQuota(); + total += current; + limiter.ConsumeBytes(current); + + // with sleep + std::this_thread::sleep_for(std::chrono::milliseconds(1) * Random(100)); + } + ASSERT_NEAR(6000, total, 200); +} + +TEST(RateLimiter, TokenBucketRateLimiter2) { + TokenBucketRateLimiter limiter(1000, 1); + std::size_t total = 0; + auto start = ReadSteadyClock(); + while (ReadSteadyClock() - start < std::chrono::seconds(5)) { + auto current = limiter.GetQuota(); + total += current; + limiter.ConsumeBytes(current); + } + ASSERT_NEAR(6000, total, 200); +} + +TEST(RateLimiter, TokenBucketRateLimiterCapBurst) { + TokenBucketRateLimiter limiter(25, 500); + for (int i = 0; i != 10; ++i) { + ASSERT_EQ(25, limiter.GetQuota()); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } +} + +TEST(RateLimiter, TokenBucketRateLimiterCapBurst2) { + TokenBucketRateLimiter limiter(1000, 500); + for (int i = 0; i != 10; ++i) { + ASSERT_EQ(1000, limiter.GetQuota()); + std::this_thread::sleep_for( + std::chrono::milliseconds(10)); // Enough to fully fill the bucket. + } +} + +TEST(RateLimiter, MultithreadedRateLimiter) { + ThreadSafeRateLimiter limiter(make_unique(1000, 1)); + std::atomic total{0}; + std::vector ts; + + for (int i = 0; i != 10; ++i) { + ts.emplace_back(std::thread([&] { + auto start = ReadSteadyClock(); + while (ReadSteadyClock() - start < std::chrono::seconds(5)) { + auto current = limiter.GetQuota(); + total += current; + limiter.ConsumeBytes(current); + std::this_thread::sleep_for(std::chrono::milliseconds(10) * Random(10)); + } + })); + } + + for (auto &&t : ts) { + t.join(); + } + + ASSERT_NEAR(6000, total.load(), 500); +} + +TEST(RateLimiter, LayeredRateLimiter) { + ThreadSafeRateLimiter base_limiter( + make_unique(1000, 1)); + auto our_limiter = make_unique( + make_unique(1000, 100)); + LayeredRateLimiter layered_limiter(&base_limiter, std::move(our_limiter)); + std::atomic total{0}; + std::vector ts; + + for (int i = 0; i != 10; ++i) { + ts.emplace_back(std::thread([&] { + auto start = ReadSteadyClock(); + while (ReadSteadyClock() - start < std::chrono::seconds(5)) { + auto current = layered_limiter.GetQuota(); + total += current; + layered_limiter.ConsumeBytes(current); + std::this_thread::sleep_for(std::chrono::milliseconds(10) * Random(10)); + } + })); + } + + for (auto &&t : ts) { + t.join(); + } + + ASSERT_NEAR(6000, total.load(), 500); // `msrl` takes effect. +} + +TEST(RateLimiter, LayeredRateLimiter2) { + ThreadSafeRateLimiter base_limiter( + make_unique(1000, 100)); + auto our_limiter = make_unique( + make_unique(1000, 1)); + LayeredRateLimiter layered_limiter(&base_limiter, std::move(our_limiter)); + std::atomic total{0}; + std::vector ts; + + for (int i = 0; i != 10; ++i) { + ts.emplace_back(std::thread([&] { + auto start = ReadSteadyClock(); + while (ReadSteadyClock() - start < std::chrono::seconds(5)) { + auto current = layered_limiter.GetQuota(); + total += current; + layered_limiter.ConsumeBytes(current); + std::this_thread::sleep_for(std::chrono::milliseconds(1) * Random(10)); + } + })); + } + + for (auto &&t : ts) { + t.join(); + } + + ASSERT_NEAR(6000, total.load(), 500); // `tbsrl` takes effect. +} + +TEST(RateLimiter, LayeredRateLimiter3) { + ThreadSafeRateLimiter base_limiter( + make_unique(1000, 1)); + std::atomic total{0}; + std::vector ts; + + for (int i = 0; i != 10; ++i) { + ts.emplace_back(std::thread([&] { + auto our_limiter = make_unique(1000, 100); + LayeredRateLimiter layered_limiter(&base_limiter, std::move(our_limiter)); + auto start = ReadSteadyClock(); + while (ReadSteadyClock() - start < std::chrono::seconds(5)) { + auto current = layered_limiter.GetQuota(); + total += current; + layered_limiter.ConsumeBytes(current); + } + })); + } + + for (auto &&t : ts) { + t.join(); + } + + ASSERT_NEAR(6000, total.load(), 500); // `msrl` takes effect. +} + +TEST(RateLimiter, LayeredRateLimiter4) { + ThreadSafeRateLimiter base_limiter( + make_unique(1000, 100)); + std::atomic total{0}; + std::vector ts; + + for (int i = 0; i != 10; ++i) { + ts.emplace_back(std::thread([&] { + auto our_limiter = make_unique( + make_unique(1000, 1)); + LayeredRateLimiter layered_limiter(&base_limiter, std::move(our_limiter)); + auto start = ReadSteadyClock(); + while (ReadSteadyClock() - start < std::chrono::seconds(5)) { + auto current = layered_limiter.GetQuota(); + total += current; + layered_limiter.ConsumeBytes(current); + } + })); + } + + for (auto &&t : ts) { + t.join(); + } + + ASSERT_NEAR(60000, total.load(), 5000); // `tbsrl` takes effect. +} + +} // namespace tile