0
0
mirror of https://github.com/zeromq/libzmq.git synced 2024-12-27 15:41:05 +08:00

Merge pull request #3763 from sigiesec/replace-strcpy

Avoid possible buffers overruns in ws_engine
This commit is contained in:
Luca Boccassi 2019-12-25 16:13:13 +01:00 committed by GitHub
commit 246cc77efc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 73 additions and 37 deletions

View File

@ -460,6 +460,7 @@ if(NOT MSVC)
check_cxx_symbol_exists(mkdtemp stdlib.h HAVE_MKDTEMP) check_cxx_symbol_exists(mkdtemp stdlib.h HAVE_MKDTEMP)
check_cxx_symbol_exists(accept4 sys/socket.h HAVE_ACCEPT4) check_cxx_symbol_exists(accept4 sys/socket.h HAVE_ACCEPT4)
check_cxx_symbol_exists(strnlen string.h HAVE_STRNLEN) check_cxx_symbol_exists(strnlen string.h HAVE_STRNLEN)
check_cxx_symbol_exists(strlcpy string.h ZMQ_HAVE_STRLCPY)
else() else()
set(HAVE_STRNLEN 1) set(HAVE_STRNLEN 1)
endif() endif()

View File

@ -51,6 +51,7 @@
#cmakedefine ZMQ_HAVE_PTHREAD_SET_AFFINITY #cmakedefine ZMQ_HAVE_PTHREAD_SET_AFFINITY
#cmakedefine HAVE_ACCEPT4 #cmakedefine HAVE_ACCEPT4
#cmakedefine HAVE_STRNLEN #cmakedefine HAVE_STRNLEN
#cmakedefine ZMQ_HAVE_STRLCPY
#cmakedefine ZMQ_HAVE_IPC #cmakedefine ZMQ_HAVE_IPC

View File

@ -751,6 +751,20 @@ AC_COMPILE_IFELSE(
AC_MSG_RESULT([no]) AC_MSG_RESULT([no])
]) ])
# string.h doesn't seem to be included by default in Fedora 30
AC_MSG_CHECKING([whether strlcpy is available])
AC_COMPILE_IFELSE(
[AC_LANG_PROGRAM(
[[#include <string.h>]],
[[char buf [100]; size_t bar = strlcpy (buf, "foo", 100); (void)bar; return 0;]])
],[
AC_MSG_RESULT([yes])
AC_DEFINE(ZMQ_HAVE_STRLCPY, [1],
[strlcpy is available])
],[
AC_MSG_RESULT([no])
])
# pthread_setname is non-posix, and there are at least 4 different implementations # pthread_setname is non-posix, and there are at least 4 different implementations
AC_MSG_CHECKING([whether signature of pthread_setname_np() has 1 argument]) AC_MSG_CHECKING([whether signature of pthread_setname_np() has 1 argument])
AC_COMPILE_IFELSE( AC_COMPILE_IFELSE(

View File

@ -411,6 +411,7 @@ void zmq::print_backtrace (void)
while (unw_step (&cursor) > 0) { while (unw_step (&cursor) > 0) {
unw_word_t offset; unw_word_t offset;
unw_proc_info_t p_info; unw_proc_info_t p_info;
static const char unknown[] = "?";
const char *file_name; const char *file_name;
char *demangled_name; char *demangled_name;
char func_name[256] = ""; char func_name[256] = "";
@ -422,14 +423,14 @@ void zmq::print_backtrace (void)
rc = unw_get_proc_name (&cursor, func_name, 256, &offset); rc = unw_get_proc_name (&cursor, func_name, 256, &offset);
if (rc == -UNW_ENOINFO) if (rc == -UNW_ENOINFO)
strcpy (func_name, "?"); memcpy (func_name, unknown, sizeof unknown);
addr = (void *) (p_info.start_ip + offset); addr = (void *) (p_info.start_ip + offset);
if (dladdr (addr, &dl_info) && dl_info.dli_fname) if (dladdr (addr, &dl_info) && dl_info.dli_fname)
file_name = dl_info.dli_fname; file_name = dl_info.dli_fname;
else else
file_name = "?"; file_name = unknown;
demangled_name = abi::__cxa_demangle (func_name, NULL, NULL, &rc); demangled_name = abi::__cxa_demangle (func_name, NULL, NULL, &rc);

View File

@ -879,7 +879,7 @@ int zmq::create_ipc_wildcard_address (std::string &path_, std::string &file_)
// We need room for tmp_path + trailing NUL // We need room for tmp_path + trailing NUL
std::vector<char> buffer (tmp_path.length () + 1); std::vector<char> buffer (tmp_path.length () + 1);
strcpy (&buffer[0], tmp_path.c_str ()); memcpy (&buffer[0], tmp_path.c_str (), tmp_path.length () + 1);
#if defined HAVE_MKDTEMP #if defined HAVE_MKDTEMP
// Create the directory. POSIX requires that mkdtemp() creates the // Create the directory. POSIX requires that mkdtemp() creates the

View File

@ -785,6 +785,7 @@ int zmq::options_t::setsockopt (int option_,
} }
break; break;
#ifdef ZMQ_HAVE_WSS
case ZMQ_WSS_KEY_PEM: case ZMQ_WSS_KEY_PEM:
// TODO: check if valid certificate // TODO: check if valid certificate
wss_key_pem = std::string ((char *) optval_, optvallen_); wss_key_pem = std::string ((char *) optval_, optvallen_);
@ -803,7 +804,7 @@ int zmq::options_t::setsockopt (int option_,
case ZMQ_WSS_TRUST_SYSTEM: case ZMQ_WSS_TRUST_SYSTEM:
return do_setsockopt_int_as_bool_strict (optval_, optvallen_, return do_setsockopt_int_as_bool_strict (optval_, optvallen_,
&wss_trust_system); &wss_trust_system);
#endif
#endif #endif
default: default:

View File

@ -114,15 +114,12 @@ zmq::session_base_t::session_base_t (class io_thread_t *io_thread_,
_socket (socket_), _socket (socket_),
_io_thread (io_thread_), _io_thread (io_thread_),
_has_linger_timer (false), _has_linger_timer (false),
_addr (addr_), _addr (addr_)
_wss_hostname (NULL) #ifdef ZMQ_HAVE_WSS
,
_wss_hostname (options_.wss_hostname)
#endif
{ {
if (options_.wss_hostname.length () > 0) {
_wss_hostname =
static_cast<char *> (malloc (options_.wss_hostname.length () + 1));
assert (_wss_hostname);
strcpy (_wss_hostname, options_.wss_hostname.c_str ());
}
} }
const zmq::endpoint_uri_pair_t &zmq::session_base_t::get_endpoint () const const zmq::endpoint_uri_pair_t &zmq::session_base_t::get_endpoint () const
@ -145,9 +142,6 @@ zmq::session_base_t::~session_base_t ()
if (_engine) if (_engine)
_engine->terminate (); _engine->terminate ();
if (_wss_hostname)
free (_wss_hostname);
LIBZMQ_DELETE (_addr); LIBZMQ_DELETE (_addr);
} }
@ -701,8 +695,8 @@ zmq::own_t *zmq::session_base_t::create_connecter_tcp (io_thread_t *io_thread_,
zmq::own_t *zmq::session_base_t::create_connecter_ws (io_thread_t *io_thread_, zmq::own_t *zmq::session_base_t::create_connecter_ws (io_thread_t *io_thread_,
bool wait_) bool wait_)
{ {
return new (std::nothrow) return new (std::nothrow) ws_connecter_t (io_thread_, this, options, _addr,
ws_connecter_t (io_thread_, this, options, _addr, wait_, false, NULL); wait_, false, std::string ());
} }
#endif #endif

View File

@ -192,9 +192,11 @@ class session_base_t : public own_t, public io_object_t, public i_pipe_events
// Protocol and address to use when connecting. // Protocol and address to use when connecting.
address_t *_addr; address_t *_addr;
#ifdef ZMQ_HAVE_WSS
// TLS handshake, we need to take a copy when the session is created, // TLS handshake, we need to take a copy when the session is created,
// in order to maintain the value at the creation time // in order to maintain the value at the creation time
char *_wss_hostname; const std::string _wss_hostname;
#endif
ZMQ_NON_COPYABLE_NOR_MOVABLE (session_base_t) ZMQ_NON_COPYABLE_NOR_MOVABLE (session_base_t)
}; };

View File

@ -74,7 +74,7 @@ zmq::ws_connecter_t::ws_connecter_t (class io_thread_t *io_thread_,
address_t *addr_, address_t *addr_,
bool delayed_start_, bool delayed_start_,
bool wss_, bool wss_,
const char *tls_hostname_) : const std::string &tls_hostname_) :
stream_connecter_base_t ( stream_connecter_base_t (
io_thread_, session_, options_, addr_, delayed_start_), io_thread_, session_, options_, addr_, delayed_start_),
_connect_timer_started (false), _connect_timer_started (false),

View File

@ -47,7 +47,7 @@ class ws_connecter_t : public stream_connecter_base_t
address_t *addr_, address_t *addr_,
bool delayed_start_, bool delayed_start_,
bool wss_, bool wss_,
const char *tls_hostname_); const std::string &tls_hostname_);
~ws_connecter_t (); ~ws_connecter_t ();
protected: protected:
@ -89,7 +89,7 @@ class ws_connecter_t : public stream_connecter_base_t
bool _connect_timer_started; bool _connect_timer_started;
bool _wss; bool _wss;
const char *_hostname; const std::string &_hostname;
ZMQ_NON_COPYABLE_NOR_MOVABLE (ws_connecter_t) ZMQ_NON_COPYABLE_NOR_MOVABLE (ws_connecter_t)
}; };

View File

@ -52,6 +52,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
#endif #endif
#endif #endif
#include <cstring>
#include "tcp.hpp" #include "tcp.hpp"
#include "ws_engine.hpp" #include "ws_engine.hpp"
#include "session_base.hpp" #include "session_base.hpp"
@ -71,6 +73,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
#ifdef ZMQ_HAVE_WINDOWS #ifdef ZMQ_HAVE_WINDOWS
#define strcasecmp _stricmp #define strcasecmp _stricmp
#else
#ifndef ZMQ_HAVE_STRLCPY
static size_t strlcpy (char *dest_, const char *src_, const size_t dest_size_)
{
size_t remain = dest_size_;
for (; remain && *src_; --remain, ++src_, ++dest_) {
*dest_ = *src_;
}
return dest_size_ - remain;
}
#endif
template <size_t size>
static int strcpy_s (char (&dest_)[size], const char *const src_)
{
const size_t res = strlcpy (dest_, src_, size);
return res >= size ? ERANGE : 0;
}
#endif #endif
// OSX uses a different name for this socket option // OSX uses a different name for this socket option
@ -118,14 +137,14 @@ zmq::ws_engine_t::~ws_engine_t ()
void zmq::ws_engine_t::start_ws_handshake () void zmq::ws_engine_t::start_ws_handshake ()
{ {
if (_client) { if (_client) {
char protocol[21]; const char *protocol;
if (_options.mechanism == ZMQ_NULL) if (_options.mechanism == ZMQ_NULL)
strcpy (protocol, "ZWS2.0/NULL,ZWS2.0"); protocol = "ZWS2.0/NULL,ZWS2.0";
else if (_options.mechanism == ZMQ_PLAIN) else if (_options.mechanism == ZMQ_PLAIN)
strcpy (protocol, "ZWS2.0/PLAIN"); protocol = "ZWS2.0/PLAIN";
#ifdef ZMQ_HAVE_CURVE #ifdef ZMQ_HAVE_CURVE
else if (_options.mechanism == ZMQ_CURVE) else if (_options.mechanism == ZMQ_CURVE)
strcpy (protocol, "ZWS2.0/CURVE"); protocol = "ZWS2.0/CURVE";
#endif #endif
else else
assert (false); assert (false);
@ -440,7 +459,7 @@ bool zmq::ws_engine_t::server_handshake ()
strcasecmp ("upgrade", _header_value) == 0; strcasecmp ("upgrade", _header_value) == 0;
else if (strcasecmp ("Sec-WebSocket-Key", _header_name) else if (strcasecmp ("Sec-WebSocket-Key", _header_name)
== 0) == 0)
strcpy (_websocket_key, _header_value); strcpy_s (_websocket_key, _header_value);
else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name) else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
== 0) { == 0) {
// Currently only the ZWS2.0 is supported // Currently only the ZWS2.0 is supported
@ -453,7 +472,7 @@ bool zmq::ws_engine_t::server_handshake ()
p++; p++;
if (select_protocol (p)) { if (select_protocol (p)) {
strcpy (_websocket_protocol, p); strcpy_s (_websocket_protocol, p);
break; break;
} }
@ -820,11 +839,11 @@ bool zmq::ws_engine_t::client_handshake ()
strcasecmp ("upgrade", _header_value) == 0; strcasecmp ("upgrade", _header_value) == 0;
else if (strcasecmp ("Sec-WebSocket-Accept", _header_name) else if (strcasecmp ("Sec-WebSocket-Accept", _header_name)
== 0) == 0)
strcpy (_websocket_accept, _header_value); strcpy_s (_websocket_accept, _header_value);
else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name) else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
== 0) { == 0) {
if (select_protocol (_header_value)) if (select_protocol (_header_value))
strcpy (_websocket_protocol, _header_value); strcpy_s (_websocket_protocol, _header_value);
} }
_client_handshake_state = client_header_field_cr; _client_handshake_state = client_header_field_cr;
} else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH) } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)

View File

@ -294,8 +294,9 @@ void zmq::ws_listener_t::create_engine (fd_t fd_)
i_engine *engine = NULL; i_engine *engine = NULL;
if (_wss) if (_wss)
#ifdef ZMQ_HAVE_WSS #ifdef ZMQ_HAVE_WSS
engine = new (std::nothrow) wss_engine_t ( engine = new (std::nothrow)
fd_, options, endpoint_pair, _address, false, _tls_cred, NULL); wss_engine_t (fd_, options, endpoint_pair, _address, false, _tls_cred,
std::string ());
#else #else
assert (false); assert (false);
#endif #endif

View File

@ -58,7 +58,7 @@ zmq::wss_engine_t::wss_engine_t (fd_t fd_,
ws_address_t &address_, ws_address_t &address_,
bool client_, bool client_,
void *tls_server_cred_, void *tls_server_cred_,
const char *hostname_) : const std::string &hostname_) :
ws_engine_t (fd_, options_, endpoint_uri_pair_, address_, client_), ws_engine_t (fd_, options_, endpoint_uri_pair_, address_, client_),
_established (false), _established (false),
_tls_client_cred (NULL) _tls_client_cred (NULL)
@ -88,11 +88,13 @@ zmq::wss_engine_t::wss_engine_t (fd_t fd_,
rc = gnutls_init (&_tls_session, GNUTLS_CLIENT | GNUTLS_NONBLOCK); rc = gnutls_init (&_tls_session, GNUTLS_CLIENT | GNUTLS_NONBLOCK);
assert (rc == GNUTLS_E_SUCCESS); assert (rc == GNUTLS_E_SUCCESS);
if (hostname_) if (!hostname_.empty ())
gnutls_server_name_set (_tls_session, GNUTLS_NAME_DNS, hostname_, gnutls_server_name_set (_tls_session, GNUTLS_NAME_DNS,
strlen (hostname_)); hostname_.c_str (), hostname_.size ());
gnutls_session_set_ptr (_tls_session, (void *) hostname_); gnutls_session_set_ptr (
_tls_session,
hostname_.empty () ? NULL : const_cast<char *> (hostname_.c_str ()));
rc = gnutls_credentials_set (_tls_session, GNUTLS_CRD_CERTIFICATE, rc = gnutls_credentials_set (_tls_session, GNUTLS_CRD_CERTIFICATE,
_tls_client_cred); _tls_client_cred);

View File

@ -46,7 +46,7 @@ class wss_engine_t : public ws_engine_t
ws_address_t &address_, ws_address_t &address_,
bool client_, bool client_,
void *tls_server_cred_, void *tls_server_cred_,
const char *hostname_); const std::string &hostname_);
~wss_engine_t (); ~wss_engine_t ();
void out_event (); void out_event ();