diff --git a/include/sled/sequence_checker.h b/include/sled/sequence_checker.h deleted file mode 100644 index d151f11..0000000 --- a/include/sled/sequence_checker.h +++ /dev/null @@ -1,25 +0,0 @@ -/** - * @file : sequence_checker - * @created : Saturday Feb 03, 2024 13:32:22 CST - * @license : MIT - **/ - -#pragma once -#ifndef SLED_SEQUENCE_CHECKER_H -#define SLED_SEQUENCE_CHECKER_H - -namespace sled { - -class SequenceChecker : public internal::SequenceCheckerImpl { -public: - enum InitialState : bool { - kDetached = false, - kAttached = true, - }; - - explicit SequenceChecker(InitialState initial_state = kAttached) : Impl(initial_state) {} -}; - -}// namespace sled - -#endif// SLED_SEQUENCE_CHECKER_H diff --git a/include/sled/sled.h b/include/sled/sled.h index 80c332f..dbe6fc6 100644 --- a/include/sled/sled.h +++ b/include/sled/sled.h @@ -38,6 +38,7 @@ #include "sled/strings/utils.h" // synchorization +#include "seld/synchronization/sequence_checker.h" #include "sled/synchronization/event.h" #include "sled/synchronization/mutex.h" #include "sled/synchronization/one_time_event.h" diff --git a/include/sled/strings/base64.h b/include/sled/strings/base64.h index 400b96d..4ce916f 100644 --- a/include/sled/strings/base64.h +++ b/include/sled/strings/base64.h @@ -9,6 +9,7 @@ #define SLED_STRINGS_BASE64_H #include "sled/status_or.h" +#include #include #include @@ -16,13 +17,44 @@ namespace sled { class Base64 { public: - static std::string Encode(const uint8_t *const ptr, size_t len); - static std::string Encode(const std::vector &data); - static std::string Encode(const std::string &data); - static std::string Encode(const char *const data); - static StatusOr Decode(const std::string &base64); - static StatusOr Decode(const std::vector &base64); - static StatusOr Decode(const uint8_t *const ptr, size_t len); + static size_t DecodedLength(const char *base64_data, size_t base64_len); + static std::string Encode(const uint8_t *ptr, size_t len); + static StatusOr Decode(const uint8_t *ptr, size_t len); + + // EncodedLength + static inline size_t EncodedLength(size_t data_len) { return (data_len + 2) / 3 * 4; } + + static inline size_t DecodedLength(const std::string &str) { return DecodedLength(str.data(), str.size()); } + + static inline size_t DecodedLength(const char *base64_str) { return DecodedLength(base64_str, strlen(base64_str)); } + + // Encode + static inline std::string Encode(const std::vector &data) + { + return Encode(data.data(), data.size()); + } + + static inline std::string Encode(const std::string &data) { return Encode((uint8_t *) data.data(), data.size()); } + + static inline std::string Encode(const char *const data) { return Encode((uint8_t *) data, strlen(data)); } + + // Decode + static inline StatusOr Decode(const char *ptr, size_t len) + { + return Decode((const uint8_t *) ptr, len); + } + + static inline StatusOr Decode(const char *ptr) { return Decode(ptr, strlen(ptr)); } + + static inline StatusOr Decode(const std::string &base64) + { + return Decode(base64.data(), base64.size()); + } + + static inline StatusOr Decode(const std::vector &base64) + { + return Decode(base64.data(), base64.size()); + } }; }// namespace sled diff --git a/src/strings/base64.cc b/src/strings/base64.cc index e13d84f..0c6da52 100644 --- a/src/strings/base64.cc +++ b/src/strings/base64.cc @@ -4,6 +4,7 @@ #include #include #include +#include namespace sled { const char kBase64Chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; @@ -17,38 +18,43 @@ IsBase64(char c) return isalnum(c) || c == '+' || c == '/'; } -std::string -Base64::Encode(const uint8_t *const ptr, size_t len) +size_t +Base64::DecodedLength(const char *base64_data, size_t base64_len) { - auto data = (unsigned char *) (ptr); - return Encode(std::vector(data, data + len)); + if (base64_len <= 0) { return 0; } + /** + * The number of padding characters at the end of the base64 data + * is the number of '=' characters at the end of the base64 data. + **/ + if (base64_data) { + size_t padding = (4 - (base64_len % 4)) % 4; + while (base64_data[--base64_len] == '=') { ++padding; } + return ((base64_len + 3) / 4) * 3 - ((5 + padding * 6) / 8); + } else { + return base64_len / 4 * 3 + ((base64_len % 4) * 6 + 5) / 8; + } } std::string -Base64::Encode(const std::string &input) -{ - return Encode((uint8_t *) input.data(), input.length()); -} +Base64::Encode(const uint8_t *ptr, size_t len) -std::string -Base64::Encode(const char *const data) { - return Encode((uint8_t *) data, strlen(data)); -} + // std::stringstream ss; + std::string result(EncodedLength(len), 0); + int write_idx = 0; -std::string -Base64::Encode(const std::vector &data) -{ - std::stringstream ss; int value = 0; int value_bits = 0; - for (unsigned char c : data) { - value = (value << 8) + c; + while (len > 0) { + value = (value << 8) + *ptr; value_bits += 8; while (value_bits >= 6) { value_bits -= 6; - ss << kBase64Chars[(value >> value_bits) & 0x3F]; + // ss << kBase64Chars[(value >> value_bits) & 0x3F]; + result[write_idx++] = kBase64Chars[(value >> value_bits) & 0x3F]; } + ++ptr; + --len; } /** @@ -56,31 +62,40 @@ Base64::Encode(const std::vector &data) * 2 -> 4 -> (8 - value_bits - 2) * 4 -> 2 -> (8 - value_bits - 2) **/ - if (value_bits > 0) { ss << kBase64Chars[((value << 8) >> (value_bits + 2)) & 0x3F]; } - while (ss.str().size() % 4) { ss << '='; } + if (value_bits > 0) { + result[write_idx++] = kBase64Chars[(value << (6 - value_bits)) & 0x3F]; + // ss << kBase64Chars[((value << 8) >> (value_bits + 2)) & 0x3F]; + } + // while (ss.str().size() % 4) { ss << '='; } + while (write_idx % 4) { result[write_idx++] = '='; } - return ss.str(); + // return ss.str(); + return std::move(result); } StatusOr -Base64::Decode(const std::string &input) +Base64::Decode(const uint8_t *ptr, size_t len) { CallOnce(once_flag, [&] { std::fill(kInvBase64Chars.begin(), kInvBase64Chars.end(), -1); for (int i = 0; kBase64Chars[i]; i++) { kInvBase64Chars[kBase64Chars[i]] = i; } }); - std::stringstream ss; + int write_idx = 0; + std::string data(DecodedLength((char *) ptr, len), 0); + // std::stringstream ss; int value = 0; int value_bits = 0; int index = 0; - for (unsigned char c : input) { + for (int i = 0; i < len; i++) { + char c = ptr[i]; if (-1 != kInvBase64Chars[c]) { // valid base64 character value = (value << 6) | kInvBase64Chars[c]; value_bits += 6; if (value_bits >= 8) { - ss << char((value >> (value_bits - 8)) & 0xFF); + data[write_idx++] = (value >> (value_bits - 8)) & 0xFF; + // ss << char((value >> (value_bits - 8)) & 0xFF); value_bits -= 8; } } else if (c == '=') { @@ -93,19 +108,9 @@ Base64::Decode(const std::string &input) } ++index; } + while (write_idx < data.size()) data.pop_back(); - return make_status_or(ss.str()); + return make_status_or(data); } -StatusOr -Base64::Decode(const std::vector &base64) -{ - return Decode(std::string(base64.begin(), base64.end())); -} - -StatusOr -Base64::Decode(const uint8_t *const ptr, size_t len) -{ - return Decode(std::string((char *) ptr, len)); -} }// namespace sled diff --git a/src/strings/base64_bench.cc b/src/strings/base64_bench.cc index 3c36cdd..f9cf85f 100644 --- a/src/strings/base64_bench.cc +++ b/src/strings/base64_bench.cc @@ -38,5 +38,5 @@ Base64Decode(benchmark::State &state) } } -BENCHMARK(Base64Encode)->RangeMultiplier(10)->Range(10, 1000000); -BENCHMARK(Base64Decode)->RangeMultiplier(10)->Range(10, 1000000); +BENCHMARK(Base64Encode)->RangeMultiplier(100)->Range(10, 100000); +BENCHMARK(Base64Decode)->RangeMultiplier(100)->Range(10, 100000); diff --git a/src/strings/base64_test.cc b/src/strings/base64_test.cc index 32f59e2..497ff42 100644 --- a/src/strings/base64_test.cc +++ b/src/strings/base64_test.cc @@ -1,25 +1,47 @@ #include #include -#define TEST_ENCODE_DECODE(base64, text) \ - do { \ - EXPECT_EQ(sled::Base64::Encode(text), std::string(base64)); \ - auto res = sled::Base64::Decode(base64); \ - EXPECT_TRUE(res.ok()); \ - EXPECT_EQ(res.value(), text); \ - } while (0) - -TEST(Base64, Encode) { EXPECT_EQ("aGVsbG8gd29ybGQK", sled::Base64::Encode("hello world\n")); } - -TEST(Base64, Decode) { EXPECT_EQ("hello world\n", sled::Base64::Decode("aGVsbG8gd29ybGQK").value()); } - -TEST(Base64, EncodeAndDecode) +TEST(Base64, EncodedLength) { - TEST_ENCODE_DECODE("aGVsbG8gd29ybGQK", "hello world\n"); - TEST_ENCODE_DECODE("U2VuZCByZWluZm9yY2VtZW50cwo=", "Send reinforcements\n"); - TEST_ENCODE_DECODE("", ""); - TEST_ENCODE_DECODE("IA==", " "); - TEST_ENCODE_DECODE("AA==", std::string("\0", 1)); - TEST_ENCODE_DECODE("AAA=", std::string("\0\0", 2)); - TEST_ENCODE_DECODE("AAAA", std::string("\0\0\0", 3)); + EXPECT_EQ(0, sled::Base64::EncodedLength(0)); + EXPECT_EQ(4, sled::Base64::EncodedLength(1)); + EXPECT_EQ(4, sled::Base64::EncodedLength(2)); + EXPECT_EQ(4, sled::Base64::EncodedLength(3)); + EXPECT_EQ(8, sled::Base64::EncodedLength(4)); + EXPECT_EQ(8, sled::Base64::EncodedLength(5)); + EXPECT_EQ(8, sled::Base64::EncodedLength(6)); + EXPECT_EQ(12, sled::Base64::EncodedLength(7)); +} + +TEST(Base64, DecodedLength) +{ + EXPECT_EQ(0, sled::Base64::DecodedLength(nullptr, 0)); + EXPECT_EQ(1, sled::Base64::DecodedLength(nullptr, 1)); + EXPECT_EQ(2, sled::Base64::DecodedLength(nullptr, 2)); + EXPECT_EQ(2, sled::Base64::DecodedLength(nullptr, 3)); + EXPECT_EQ(3, sled::Base64::DecodedLength(nullptr, 4)); + + EXPECT_EQ(0, sled::Base64::DecodedLength("", 0)); +} + +TEST(Base64, Encode) +{ + EXPECT_EQ("aGVsbG8gd29ybGQK", sled::Base64::Encode("hello world\n")); + EXPECT_EQ("U2VuZCByZWluZm9yY2VtZW50cwo=", sled::Base64::Encode("Send reinforcements\n")); + EXPECT_EQ("", sled::Base64::Encode("")); + EXPECT_EQ("IA==", sled::Base64::Encode(" ")); + EXPECT_EQ("AA==", sled::Base64::Encode(std::string("\0", 1))); + EXPECT_EQ("AAA=", sled::Base64::Encode(std::string("\0\0", 2))); + EXPECT_EQ("AAAA", sled::Base64::Encode(std::string("\0\0\0", 3))); +} + +TEST(Base64, Decode) +{ + EXPECT_EQ("hello world\n", sled::Base64::Decode("aGVsbG8gd29ybGQK").value()); + EXPECT_EQ("Send reinforcements\n", sled::Base64::Decode("U2VuZCByZWluZm9yY2VtZW50cwo=").value()); + EXPECT_EQ("", sled::Base64::Decode("").value()); + EXPECT_EQ(" ", sled::Base64::Decode("IA==").value()); + EXPECT_EQ(std::string("\0", 1), sled::Base64::Decode("AA==").value()); + EXPECT_EQ(std::string("\0\0", 2), sled::Base64::Decode("AAA=").value()); + EXPECT_EQ(std::string("\0\0\0", 3), sled::Base64::Decode("AAAA").value()); }