From 2099a9d5dbdcd2841d02dd396a71d0de70bf4490 Mon Sep 17 00:00:00 2001 From: Piotr Date: Thu, 27 Mar 2025 23:41:04 +0100 Subject: [PATCH] server : Support listening on a unix socket (#12613) * server : Bump cpp-httplib to include AF_UNIX windows support Signed-off-by: Piotr Stankiewicz * server : Allow running the server example on a unix socket Signed-off-by: Piotr Stankiewicz --------- Signed-off-by: Piotr Stankiewicz --- common/arg.cpp | 2 +- examples/server/httplib.h | 562 +++++++++++++++++++++---------------- examples/server/server.cpp | 23 +- 3 files changed, 331 insertions(+), 256 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b6bfe6f89..8292adaac 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1979,7 +1979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--host"}, "HOST", - string_format("ip address to listen (default: %s)", params.hostname.c_str()), + string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()), [](common_params & params, const std::string & value) { params.hostname = value; } diff --git a/examples/server/httplib.h b/examples/server/httplib.h index 593beb501..0f981dc89 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -8,7 +8,7 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.19.0" +#define CPPHTTPLIB_VERSION "0.20.0" /* * Configuration @@ -188,15 +188,16 @@ using ssize_t = long; #include #include +// afunix.h uses types declared in winsock2.h, so has to be included after it. +#include + #ifndef WSA_FLAG_NO_HANDLE_INHERIT #define WSA_FLAG_NO_HANDLE_INHERIT 0x80 #endif +using nfds_t = unsigned long; using socket_t = SOCKET; using socklen_t = int; -#ifdef CPPHTTPLIB_USE_POLL -#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) -#endif #else // not _WIN32 @@ -216,16 +217,11 @@ using socklen_t = int; #ifdef __linux__ #include #endif -#include -#ifdef CPPHTTPLIB_USE_POLL -#include -#endif #include +#include +#include #include #include -#ifndef __VMS -#include -#endif #include #include #include @@ -247,7 +243,6 @@ using socket_t = int; #include #include #include -#include #include #include #include @@ -320,6 +315,10 @@ using socket_t = int; #include #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +#include +#endif + /* * Declaration */ @@ -435,6 +434,15 @@ private: } // namespace detail +enum SSLVerifierResponse { + // no decision has been made, use the built-in certificate verifier + NoDecisionMade, + // connection certificate is verified and accepted + CertificateAccepted, + // connection certificate was processed but is rejected + CertificateRejected +}; + enum StatusCode { // Information responses Continue_100 = 100, @@ -670,7 +678,7 @@ struct Request { bool is_chunked_content_provider_ = false; size_t authorization_count_ = 0; std::chrono::time_point start_time_ = - std::chrono::steady_clock::time_point::min(); + (std::chrono::steady_clock::time_point::min)(); }; struct Response { @@ -736,7 +744,8 @@ public: virtual ~Stream() = default; virtual bool is_readable() const = 0; - virtual bool is_writable() const = 0; + virtual bool wait_readable() const = 0; + virtual bool wait_writable() const = 0; virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -879,7 +888,7 @@ public: * Captures parameters in request path and stores them in Request::path_params * * Capture name is a substring of a pattern from : to /. - * The rest of the pattern is matched agains the request path directly + * The rest of the pattern is matched against the request path directly * Parameters are captured starting from the next character after * the end of the last matched static pattern fragment until the next /. * @@ -1109,7 +1118,7 @@ private: virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_{false}; - std::atomic is_decommisioned{false}; + std::atomic is_decommissioned{false}; struct MountPointEntry { std::string mount_point; @@ -1483,7 +1492,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier(std::function verifier); + void set_server_certificate_verifier( + std::function verifier); #endif void set_logger(Logger logger); @@ -1600,7 +1610,7 @@ protected: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT bool server_certificate_verification_ = true; bool server_hostname_verification_ = true; - std::function server_certificate_verifier_; + std::function server_certificate_verifier_; #endif Logger logger_; @@ -1913,7 +1923,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier(std::function verifier); + void set_server_certificate_verifier( + std::function verifier); #endif void set_logger(Logger logger); @@ -2046,6 +2057,10 @@ inline void duration_to_sec_and_usec(const T &duration, U callback) { callback(static_cast(sec), static_cast(usec)); } +template inline constexpr size_t str_len(const char (&)[N]) { + return N - 1; +} + inline bool is_numeric(const std::string &str) { return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); } @@ -2205,9 +2220,9 @@ inline const char *status_message(int status) { inline std::string get_bearer_token_auth(const Request &req) { if (req.has_header("Authorization")) { - static std::string BearerHeaderPrefix = "Bearer "; + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); return req.get_header_value("Authorization") - .substr(BearerHeaderPrefix.length()); + .substr(bearer_header_prefix_len); } return ""; } @@ -2382,8 +2397,6 @@ std::string encode_query_param(const std::string &value); std::string decode_url(const std::string &s, bool convert_plus_to_space); -void read_file(const std::string &path, std::string &out); - std::string trim_copy(const std::string &s); void divide( @@ -2439,7 +2452,7 @@ ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); -enum class EncodingType { None = 0, Gzip, Brotli }; +enum class EncodingType { None = 0, Gzip, Brotli, Zstd }; EncodingType encoding_type(const Request &req, const Response &res); @@ -2449,7 +2462,8 @@ public: ~BufferStream() override = default; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -2551,6 +2565,34 @@ private: }; #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +class zstd_compressor : public compressor { +public: + zstd_compressor(); + ~zstd_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + ZSTD_CCtx *ctx_ = nullptr; +}; + +class zstd_decompressor : public decompressor { +public: + zstd_decompressor(); + ~zstd_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + ZSTD_DCtx *ctx_ = nullptr; +}; +#endif + // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { @@ -2569,7 +2611,7 @@ private: char *fixed_buffer_; const size_t fixed_buffer_size_; size_t fixed_buffer_used_size_ = 0; - std::string glowable_buffer_; + std::string growable_buffer_; }; class mmap { @@ -2910,18 +2952,9 @@ inline std::string decode_url(const std::string &s, return result; } -inline void read_file(const std::string &path, std::string &out) { - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); -} - inline std::string file_extension(const std::string &path) { std::smatch m; - static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + thread_local auto re = std::regex("\\.([a-zA-Z0-9]+)$"); if (std::regex_search(path, m, re)) { return m[1].str(); } return std::string(); } @@ -3005,18 +3038,18 @@ inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, fixed_buffer_size_(fixed_buffer_size) {} inline const char *stream_line_reader::ptr() const { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { return fixed_buffer_; } else { - return glowable_buffer_.data(); + return growable_buffer_.data(); } } inline size_t stream_line_reader::size() const { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { return fixed_buffer_used_size_; } else { - return glowable_buffer_.size(); + return growable_buffer_.size(); } } @@ -3027,7 +3060,7 @@ inline bool stream_line_reader::end_with_crlf() const { inline bool stream_line_reader::getline() { fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); + growable_buffer_.clear(); #ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR char prev_byte = 0; @@ -3065,11 +3098,11 @@ inline void stream_line_reader::append(char c) { fixed_buffer_[fixed_buffer_used_size_++] = c; fixed_buffer_[fixed_buffer_used_size_] = '\0'; } else { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + growable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); } - glowable_buffer_ += c; + growable_buffer_ += c; } } @@ -3246,35 +3279,23 @@ inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, }); } +inline int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) { +#ifdef _WIN32 + return ::WSAPoll(fds, nfds, timeout); +#else + return ::poll(fds, nfds, timeout); +#endif +} + template inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { -#ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd; pfd.fd = sock; pfd.events = (Read ? POLLIN : POLLOUT); auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd, 1, timeout); }); -#else -#ifndef _WIN32 - if (sock >= FD_SETSIZE) { return -1; } -#endif - - fd_set fds, *rfds, *wfds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - rfds = (Read ? &fds : nullptr); - wfds = (Read ? nullptr : &fds); - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - return handle_EINTR([&]() { - return select(static_cast(sock + 1), rfds, wfds, nullptr, &tv); - }); -#endif + return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); }); } inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { @@ -3287,14 +3308,14 @@ inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { -#ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd_read; pfd_read.fd = sock; pfd_read.events = POLLIN | POLLOUT; auto timeout = static_cast(sec * 1000 + usec / 1000); - auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + auto poll_res = + handle_EINTR([&]() { return poll_wrapper(&pfd_read, 1, timeout); }); if (poll_res == 0) { return Error::ConnectionTimeout; } @@ -3308,38 +3329,6 @@ inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, } return Error::Connection; -#else -#ifndef _WIN32 - if (sock >= FD_SETSIZE) { return Error::Connection; } -#endif - - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - auto ret = handle_EINTR([&]() { - return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); - }); - - if (ret == 0) { return Error::ConnectionTimeout; } - - if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { - auto error = 0; - socklen_t len = sizeof(error); - auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len); - auto successful = res >= 0 && !error; - return successful ? Error::Success : Error::Connection; - } - return Error::Connection; -#endif } inline bool is_socket_alive(socket_t sock) { @@ -3359,11 +3348,12 @@ public: time_t write_timeout_sec, time_t write_timeout_usec, time_t max_timeout_msec = 0, std::chrono::time_point start_time = - std::chrono::steady_clock::time_point::min()); + (std::chrono::steady_clock::time_point::min)()); ~SocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3378,7 +3368,7 @@ private: time_t write_timeout_sec_; time_t write_timeout_usec_; time_t max_timeout_msec_; - const std::chrono::time_point start_time; + const std::chrono::time_point start_time_; std::vector read_buff_; size_t read_buff_off_ = 0; @@ -3395,11 +3385,12 @@ public: time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, time_t max_timeout_msec = 0, std::chrono::time_point start_time = - std::chrono::steady_clock::time_point::min()); + (std::chrono::steady_clock::time_point::min)()); ~SSLSocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3415,7 +3406,7 @@ private: time_t write_timeout_sec_; time_t write_timeout_usec_; time_t max_timeout_msec_; - const std::chrono::time_point start_time; + const std::chrono::time_point start_time_; }; #endif @@ -3550,7 +3541,6 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, hints.ai_flags = socket_flags; } -#ifndef _WIN32 if (hints.ai_family == AF_UNIX) { const auto addrlen = host.length(); if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } @@ -3574,11 +3564,19 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, sizeof(addr) - sizeof(addr.sun_path) + addrlen); #ifndef SOCK_CLOEXEC +#ifndef _WIN32 fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif #endif if (socket_options) { socket_options(sock); } +#ifdef _WIN32 + // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so + // remove the option. + detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); +#endif + bool dummy; if (!bind_or_connect(sock, hints, dummy)) { close_socket(sock); @@ -3587,7 +3585,6 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, } return sock; } -#endif auto service = std::to_string(port); @@ -3993,6 +3990,12 @@ inline EncodingType encoding_type(const Request &req, const Response &res) { if (ret) { return EncodingType::Gzip; } #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + // TODO: 'Accept-Encoding' has zstd, not zstd;q=0 + ret = s.find("zstd") != std::string::npos; + if (ret) { return EncodingType::Zstd; } +#endif + return EncodingType::None; } @@ -4201,6 +4204,61 @@ inline bool brotli_decompressor::decompress(const char *data, } #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +inline zstd_compressor::zstd_compressor() { + ctx_ = ZSTD_createCCtx(); + ZSTD_CCtx_setParameter(ctx_, ZSTD_c_compressionLevel, ZSTD_fast); +} + +inline zstd_compressor::~zstd_compressor() { ZSTD_freeCCtx(ctx_); } + +inline bool zstd_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + ZSTD_EndDirective mode = last ? ZSTD_e_end : ZSTD_e_continue; + ZSTD_inBuffer input = {data, data_length, 0}; + + bool finished; + do { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_compressStream2(ctx_, &output, &input, mode); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + + finished = last ? (remaining == 0) : (input.pos == input.size); + + } while (!finished); + + return true; +} + +inline zstd_decompressor::zstd_decompressor() { ctx_ = ZSTD_createDCtx(); } + +inline zstd_decompressor::~zstd_decompressor() { ZSTD_freeDCtx(ctx_); } + +inline bool zstd_decompressor::is_valid() const { return ctx_ != nullptr; } + +inline bool zstd_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + std::array buff{}; + ZSTD_inBuffer input = {data, data_length, 0}; + + while (input.pos < input.size) { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_decompressStream(ctx_, &output, &input); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + } + + return true; +} +#endif + inline bool has_header(const Headers &headers, const std::string &key) { return headers.find(key) != headers.end(); } @@ -4227,6 +4285,9 @@ inline bool parse_header(const char *beg, const char *end, T fn) { p++; } + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + if (p == end) { return false; } auto key_end = p; @@ -4242,10 +4303,6 @@ inline bool parse_header(const char *beg, const char *end, T fn) { if (!key_len) { return false; } auto key = std::string(beg, key_end); - // auto val = (case_ignore::equal(key, "Location") || - // case_ignore::equal(key, "Referer")) - // ? std::string(p, end) - // : decode_url(std::string(p, end), false); auto val = std::string(p, end); if (!detail::fields::is_field_value(val)) { return false; } @@ -4341,7 +4398,8 @@ inline bool read_content_without_length(Stream &strm, uint64_t r = 0; for (;;) { auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n <= 0) { return false; } + if (n == 0) { return true; } + if (n < 0) { return false; } if (!out(buf, static_cast(n), r, 0)) { return false; } r += static_cast(n); @@ -4384,7 +4442,7 @@ inline bool read_content_chunked(Stream &strm, T &x, assert(chunk_len == 0); - // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked // transfer coding is complete when a chunk with a chunk-size of zero is // received, possibly followed by a trailer section, and finally terminated by // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 @@ -4394,8 +4452,8 @@ inline bool read_content_chunked(Stream &strm, T &x, // to be ok whether the final CRLF exists or not in the chunked data. // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 // - // According to the reference code in RFC 9112, cpp-htpplib now allows - // chuncked transfer coding data without the final CRLF. + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. if (!line_reader.getline()) { return true; } while (strcmp(line_reader.ptr(), "\r\n") != 0) { @@ -4442,6 +4500,13 @@ bool prepare_content_receiver(T &x, int &status, #else status = StatusCode::UnsupportedMediaType_415; return false; +#endif + } else if (encoding == "zstd") { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; #endif } @@ -4565,7 +4630,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { - if (strm.is_writable() && write_data(strm, d, l)) { + if (write_data(strm, d, l)) { offset += l; } else { ok = false; @@ -4574,10 +4639,10 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; while (offset < end_offset && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, end_offset - offset, data_sink)) { @@ -4615,17 +4680,17 @@ write_content_without_length(Stream &strm, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { offset += l; - if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; } + if (!write_data(strm, d, l)) { ok = false; } } return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; data_sink.done = [&](void) { data_available = false; }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { return false; } else if (!content_provider(offset, 0, data_sink)) { return false; @@ -4660,10 +4725,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { - ok = false; - } + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; } } } else { ok = false; @@ -4672,7 +4734,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -4692,17 +4754,14 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, if (!payload.empty()) { // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; return; } } - static const std::string done_marker("0\r\n"); - if (!write_data(strm, done_marker.data(), done_marker.size())) { - ok = false; - } + constexpr const char done_marker[] = "0\r\n"; + if (!write_data(strm, done_marker, str_len(done_marker))) { ok = false; } // Trailer if (trailer) { @@ -4714,8 +4773,8 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, } } - static const std::string crlf("\r\n"); - if (!write_data(strm, crlf.data(), crlf.size())) { ok = false; } + constexpr const char crlf[] = "\r\n"; + if (!write_data(strm, crlf, str_len(crlf))) { ok = false; } }; data_sink.done = [&](void) { done_with_trailer(nullptr); }; @@ -4725,7 +4784,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, 0, data_sink)) { @@ -4957,13 +5016,13 @@ public: return false; } - static const std::string header_content_type = "Content-Type:"; + constexpr const char header_content_type[] = "Content-Type:"; if (start_with_case_ignore(header, header_content_type)) { file_.content_type = - trim_copy(header.substr(header_content_type.size())); + trim_copy(header.substr(str_len(header_content_type))); } else { - static const std::regex re_content_disposition( + thread_local const std::regex re_content_disposition( R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", std::regex_constants::icase); @@ -4985,8 +5044,8 @@ public: it = params.find("filename*"); if (it != params.end()) { - // Only allow UTF-8 enconnding... - static const std::regex re_rfc5987_encoding( + // Only allow UTF-8 encoding... + thread_local const std::regex re_rfc5987_encoding( R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); std::smatch m2; @@ -5058,10 +5117,10 @@ private: file_.content_type.clear(); } - bool start_with_case_ignore(const std::string &a, - const std::string &b) const { - if (a.size() < b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { + bool start_with_case_ignore(const std::string &a, const char *b) const { + const auto b_len = strlen(b); + if (a.size() < b_len) { return false; } + for (size_t i = 0; i < b_len; i++) { if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { return false; } @@ -5148,19 +5207,18 @@ private: }; inline std::string random_string(size_t length) { - static const char data[] = + constexpr const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - // std::random_device might actually be deterministic on some - // platforms, but due to lack of support in the c++ standard library, - // doing better requires either some ugly hacks or breaking portability. - static std::random_device seed_gen; - - // Request 128 bits of entropy for initialization - static std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), - seed_gen()}; - - static std::mt19937 engine(seed_sequence); + thread_local auto engine([]() { + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + std::random_device seed_gen; + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + return std::mt19937(seed_sequence); + }()); std::string result; for (size_t i = 0; i < length; i++) { @@ -5232,7 +5290,7 @@ serialize_multipart_formdata(const MultipartFormDataItems &items, inline bool range_error(Request &req, Response &res) { if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { - ssize_t contant_len = static_cast( + ssize_t content_len = static_cast( res.content_length_ ? res.content_length_ : res.body.size()); ssize_t prev_first_pos = -1; @@ -5252,12 +5310,12 @@ inline bool range_error(Request &req, Response &res) { if (first_pos == -1 && last_pos == -1) { first_pos = 0; - last_pos = contant_len; + last_pos = content_len; } if (first_pos == -1) { - first_pos = contant_len - last_pos; - last_pos = contant_len - 1; + first_pos = content_len - last_pos; + last_pos = content_len - 1; } // NOTE: RFC-9110 '14.1.2. Byte Ranges': @@ -5269,13 +5327,13 @@ inline bool range_error(Request &req, Response &res) { // with a value that is one less than the current length of the selected // representation). // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 - if (last_pos == -1 || last_pos >= contant_len) { - last_pos = contant_len - 1; + if (last_pos == -1 || last_pos >= content_len) { + last_pos = content_len - 1; } // Range must be within content length if (!(0 <= first_pos && first_pos <= last_pos && - last_pos <= contant_len - 1)) { + last_pos <= content_len - 1)) { return true; } @@ -5674,7 +5732,8 @@ inline bool parse_www_authenticate(const Response &res, bool is_proxy) { auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; if (res.has_header(auth_key)) { - static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + thread_local auto re = + std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); auto s = res.get_header_value(auth_key); auto pos = s.find(' '); if (pos != std::string::npos) { @@ -5758,7 +5817,7 @@ inline void hosted_at(const std::string &hostname, inline std::string append_query_params(const std::string &path, const Params ¶ms) { std::string path_with_query = path; - const static std::regex re("[^?]+\\?.*"); + thread_local const std::regex re("[^?]+\\?.*"); auto delm = std::regex_match(path, re) ? '&' : '?'; path_with_query += delm + detail::params_to_query_str(params); return path_with_query; @@ -5987,14 +6046,14 @@ inline ssize_t Stream::write(const std::string &s) { namespace detail { -inline void calc_actual_timeout(time_t max_timeout_msec, - time_t duration_msec, time_t timeout_sec, - time_t timeout_usec, time_t &actual_timeout_sec, +inline void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, + time_t timeout_sec, time_t timeout_usec, + time_t &actual_timeout_sec, time_t &actual_timeout_usec) { auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); auto actual_timeout_msec = - std::min(max_timeout_msec - duration_msec, timeout_msec); + (std::min)(max_timeout_msec - duration_msec, timeout_msec); actual_timeout_sec = actual_timeout_msec / 1000; actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; @@ -6010,12 +6069,16 @@ inline SocketStream::SocketStream( read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time(start_time), + max_timeout_msec_(max_timeout_msec), start_time_(start_time), read_buff_(read_buff_size_, 0) {} inline SocketStream::~SocketStream() = default; inline bool SocketStream::is_readable() const { + return read_buff_off_ < read_buff_content_size_; +} + +inline bool SocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -6028,7 +6091,7 @@ inline bool SocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SocketStream::is_writable() const { +inline bool SocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); } @@ -6055,7 +6118,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!is_readable()) { return -1; } + if (!wait_readable()) { return -1; } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -6080,7 +6143,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } inline ssize_t SocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } + if (!wait_writable()) { return -1; } #if defined(_WIN32) && !defined(_WIN64) size = @@ -6104,14 +6167,16 @@ inline socket_t SocketStream::socket() const { return sock_; } inline time_t SocketStream::duration() const { return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time) + std::chrono::steady_clock::now() - start_time_) .count(); } // Buffer stream implementation inline bool BufferStream::is_readable() const { return true; } -inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::wait_readable() const { return true; } + +inline bool BufferStream::wait_writable() const { return true; } inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 @@ -6141,7 +6206,7 @@ inline time_t BufferStream::duration() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { - static constexpr char marker[] = "/:"; + constexpr const char marker[] = "/:"; // One past the last ending position of a path param substring std::size_t last_param_end = 0; @@ -6162,7 +6227,7 @@ inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { static_fragments_.push_back( pattern.substr(last_param_end, marker_pos - last_param_end + 1)); - const auto param_name_start = marker_pos + 2; + const auto param_name_start = marker_pos + str_len(marker); auto sep_pos = pattern.find(separator, param_name_start); if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } @@ -6469,12 +6534,12 @@ inline Server &Server::set_payload_max_length(size_t length) { inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { auto ret = bind_internal(host, port, socket_flags); - if (ret == -1) { is_decommisioned = true; } + if (ret == -1) { is_decommissioned = true; } return ret >= 0; } inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { auto ret = bind_internal(host, 0, socket_flags); - if (ret == -1) { is_decommisioned = true; } + if (ret == -1) { is_decommissioned = true; } return ret; } @@ -6488,7 +6553,7 @@ inline bool Server::listen(const std::string &host, int port, inline bool Server::is_running() const { return is_running_; } inline void Server::wait_until_ready() const { - while (!is_running_ && !is_decommisioned) { + while (!is_running_ && !is_decommissioned) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } } @@ -6500,10 +6565,10 @@ inline void Server::stop() { detail::shutdown_socket(sock); detail::close_socket(sock); } - is_decommisioned = false; + is_decommissioned = false; } -inline void Server::decommission() { is_decommisioned = true; } +inline void Server::decommission() { is_decommissioned = true; } inline bool Server::parse_request_line(const char *s, Request &req) const { auto len = strlen(s); @@ -6526,7 +6591,7 @@ inline bool Server::parse_request_line(const char *s, Request &req) const { if (count != 3) { return false; } } - static const std::set methods{ + thread_local const std::set methods{ "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; @@ -6680,6 +6745,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req, } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); #endif } else { compressor = detail::make_unique(); @@ -6862,7 +6931,7 @@ Server::create_server_socket(const std::string &host, int port, inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { - if (is_decommisioned) { return -1; } + if (is_decommissioned) { return -1; } if (!is_valid()) { return -1; } @@ -6889,7 +6958,7 @@ inline int Server::bind_internal(const std::string &host, int port, } inline bool Server::listen_internal() { - if (is_decommisioned) { return false; } + if (is_decommissioned) { return false; } auto ret = true; is_running_ = true; @@ -6913,7 +6982,7 @@ inline bool Server::listen_internal() { #endif #if defined _WIN32 - // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // sockets connected via WASAccept inherit flags NO_HANDLE_INHERIT, // OVERLAPPED socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); #elif defined SOCK_CLOEXEC @@ -6955,7 +7024,7 @@ inline bool Server::listen_internal() { task_queue->shutdown(); } - is_decommisioned = !ret; + is_decommissioned = !ret; return ret; } @@ -7095,6 +7164,8 @@ inline void Server::apply_ranges(const Request &req, Response &res, res.set_header("Content-Encoding", "gzip"); } else if (type == detail::EncodingType::Brotli) { res.set_header("Content-Encoding", "br"); + } else if (type == detail::EncodingType::Zstd) { + res.set_header("Content-Encoding", "zstd"); } } } @@ -7134,6 +7205,11 @@ inline void Server::apply_ranges(const Request &req, Response &res, #ifdef CPPHTTPLIB_BROTLI_SUPPORT compressor = detail::make_unique(); content_encoding = "br"; +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); + content_encoding = "zstd"; #endif } @@ -7189,20 +7265,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, res.version = "HTTP/1.1"; res.headers = default_headers_; -#ifdef _WIN32 - // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). -#else -#ifndef CPPHTTPLIB_USE_POLL - // Socket file descriptor exceeded FD_SETSIZE... - if (strm.socket() >= FD_SETSIZE) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = StatusCode::InternalServerError_500; - return write_response(strm, close_connection, req, res); - } -#endif -#endif - // Request line and headers if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { @@ -7394,6 +7456,16 @@ inline ClientImpl::ClientImpl(const std::string &host, int port, client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} inline ClientImpl::~ClientImpl() { + // Wait until all the requests in flight are handled. + size_t retry_count = 10; + while (retry_count-- > 0) { + { + std::lock_guard guard(socket_mutex_); + if (socket_requests_in_flight_ == 0) { break; } + } + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + std::lock_guard guard(socket_mutex_); shutdown_socket(socket_); close_socket(socket_); @@ -7519,9 +7591,9 @@ inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, if (!line_reader.getline()) { return false; } #ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR - const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); #else - const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); #endif std::cmatch m; @@ -7577,7 +7649,7 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { #endif if (!is_alive) { - // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // Attempt to avoid sigpipe by shutting down non-gracefully if it seems // like the other side has already closed the connection Also, there // cannot be any requests in flight from other threads since we locked // request_mutex_, so safe to close everything immediately @@ -7753,7 +7825,7 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { auto location = res.get_header_value("location"); if (location.empty()) { return false; } - const static std::regex re( + thread_local const std::regex re( R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); std::smatch m; @@ -7862,6 +7934,10 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, #ifdef CPPHTTPLIB_ZLIB_SUPPORT if (!accept_encoding.empty()) { accept_encoding += ", "; } accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; #endif req.set_header("Accept-Encoding", accept_encoding); } @@ -8213,8 +8289,7 @@ inline bool ClientImpl::process_socket( std::function callback) { return detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, max_timeout_msec_, start_time, - std::move(callback)); + write_timeout_usec_, max_timeout_msec_, start_time, std::move(callback)); } inline bool ClientImpl::is_ssl() const { return false; } @@ -9009,7 +9084,7 @@ inline void ClientImpl::enable_server_hostname_verification(bool enabled) { } inline void ClientImpl::set_server_certificate_verifier( - std::function verifier) { + std::function verifier) { server_certificate_verifier_ = verifier; } #endif @@ -9062,18 +9137,13 @@ inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, // Note that it is not always possible to avoid SIGPIPE, this is merely a // best-efforts. if (shutdown_gracefully) { -#ifdef _WIN32 (void)(sock); - SSL_shutdown(ssl); -#else - detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, 1, 0); - - auto ret = SSL_shutdown(ssl); - while (ret == 0) { - std::this_thread::sleep_for(std::chrono::milliseconds{100}); - ret = SSL_shutdown(ssl); + // SSL_shutdown() returns 0 on first call (indicating close_notify alert + // sent) and 1 on subsequent call (indicating close_notify alert received) + if (SSL_shutdown(ssl) == 0) { + // Expected to return 1, but even if it doesn't, we free ssl + SSL_shutdown(ssl); } -#endif } std::lock_guard guard(ctx_mutex); @@ -9124,19 +9194,11 @@ inline bool process_client_socket_ssl( time_t max_timeout_msec, std::chrono::time_point start_time, T callback) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec, - max_timeout_msec, start_time); + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); return callback(strm); } -class SSLInit { -public: - SSLInit() { - OPENSSL_init_ssl( - OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); - } -}; - // SSL socket stream implementation inline SSLSocketStream::SSLSocketStream( socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, @@ -9147,13 +9209,17 @@ inline SSLSocketStream::SSLSocketStream( read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time(start_time) { + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); } inline SSLSocketStream::~SSLSocketStream() = default; inline bool SSLSocketStream::is_readable() const { + return SSL_pending(ssl_) > 0; +} + +inline bool SSLSocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -9166,7 +9232,7 @@ inline bool SSLSocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SSLSocketStream::is_writable() const { +inline bool SSLSocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); } @@ -9174,7 +9240,7 @@ inline bool SSLSocketStream::is_writable() const { inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { auto ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret < 0) { auto err = SSL_get_error(ssl_, ret); @@ -9188,7 +9254,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { #endif if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret >= 0) { return ret; } @@ -9205,7 +9271,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { + if (wait_writable()) { auto handle_size = static_cast( std::min(size, (std::numeric_limits::max)())); @@ -9220,7 +9286,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { #else while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif - if (is_writable()) { + if (wait_writable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_write(ssl_, ptr, static_cast(handle_size)); if (ret >= 0) { return ret; } @@ -9249,12 +9315,10 @@ inline socket_t SSLSocketStream::socket() const { return sock_; } inline time_t SSLSocketStream::duration() const { return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time) + std::chrono::steady_clock::now() - start_time_) .count(); } -static SSLInit sslinit_; - } // namespace detail // SSL HTTP server implementation @@ -9623,12 +9687,18 @@ inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { } if (server_certificate_verification_) { + auto verification_status = SSLVerifierResponse::NoDecisionMade; + if (server_certificate_verifier_) { - if (!server_certificate_verifier_(ssl2)) { - error = Error::SSLServerVerification; - return false; - } - } else { + verification_status = server_certificate_verifier_(ssl2); + } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + error = Error::SSLServerVerification; + return false; + } + + if (verification_status == SSLVerifierResponse::NoDecisionMade) { verify_result_ = SSL_get_verify_result(ssl2); if (verify_result_ != X509_V_OK) { @@ -9740,8 +9810,8 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { auto type = GEN_DNS; - struct in6_addr addr6{}; - struct in_addr addr{}; + struct in6_addr addr6 = {}; + struct in_addr addr = {}; size_t addr_len = 0; #ifndef __MINGW32__ @@ -10389,7 +10459,7 @@ inline void Client::enable_server_hostname_verification(bool enabled) { } inline void Client::set_server_certificate_verifier( - std::function verifier) { + std::function verifier) { cli_->set_server_certificate_verifier(verifier); } #endif @@ -10433,8 +10503,4 @@ inline SSL_CTX *Client::ssl_context() const { } // namespace httplib -#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) -#undef poll -#endif - #endif // CPPHTTPLIB_HTTPLIB_H diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 18caa9127..77dd316d9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4459,15 +4459,24 @@ int main(int argc, char ** argv) { llama_backend_free(); }; - // bind HTTP listen port bool was_bound = false; - if (params.port == 0) { - int bound_port = svr->bind_to_any_port(params.hostname); - if ((was_bound = (bound_port >= 0))) { - params.port = bound_port; - } + if (string_ends_with(std::string(params.hostname), ".sock")) { + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + svr->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = svr->bind_to_port(params.hostname, 8080); } else { - was_bound = svr->bind_to_port(params.hostname, params.port); + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (params.port == 0) { + int bound_port = svr->bind_to_any_port(params.hostname); + if ((was_bound = (bound_port >= 0))) { + params.port = bound_port; + } + } else { + was_bound = svr->bind_to_port(params.hostname, params.port); + } } if (!was_bound) {