diff --git a/src/curve_server.cpp b/src/curve_server.cpp index 84426a09..4255fa2f 100644 --- a/src/curve_server.cpp +++ b/src/curve_server.cpp @@ -33,9 +33,11 @@ #include "wire.hpp" zmq::curve_server_t::curve_server_t (session_base_t *session_, + const std::string &peer_address_, const options_t &options_) : mechanism_t (options_), session (session_), + peer_address (peer_address_), state (expect_hello), expecting_zap_reply (false), cn_nonce (1) @@ -512,7 +514,7 @@ void zmq::curve_server_t::send_zap_request (const uint8_t *key) rc = session->write_zap_msg (&msg); errno_assert (rc == 0); - // Sequence frame + // Request ID frame rc = msg.init_size (1); errno_assert (rc == 0); memcpy (msg.data (), "1", 1); @@ -527,6 +529,14 @@ void zmq::curve_server_t::send_zap_request (const uint8_t *key) rc = session->write_zap_msg (&msg); errno_assert (rc == 0); + // Address frame + rc = msg.init_size (peer_address.length ()); + errno_assert (rc == 0); + memcpy (msg.data (), peer_address.c_str (), peer_address.length ()); + msg.set_flags (msg_t::more); + rc = session->write_zap_msg (&msg); + errno_assert (rc == 0); + // Mechanism frame rc = msg.init_size (5); errno_assert (rc == 0); @@ -546,18 +556,19 @@ void zmq::curve_server_t::send_zap_request (const uint8_t *key) int zmq::curve_server_t::receive_and_process_zap_reply () { int rc = 0; - msg_t msg [6]; + msg_t msg [7]; // ZAP reply consists of 7 frames - for (int i = 0; i < 6; i++) { + // Initialize all reply frames + for (int i = 0; i < 7; i++) { rc = msg [i].init (); errno_assert (rc == 0); } - for (int i = 0; i < 6; i++) { + for (int i = 0; i < 7; i++) { rc = session->read_zap_msg (&msg [i]); if (rc == -1) break; - if ((msg [i].flags () & msg_t::more) == (i < 5? 0: msg_t::more)) { + if ((msg [i].flags () & msg_t::more) == (i < 6? 0: msg_t::more)) { errno = EPROTO; rc = -1; break; @@ -579,7 +590,7 @@ int zmq::curve_server_t::receive_and_process_zap_reply () goto error; } - // Sequence number frame + // Request id frame if (msg [2].size () != 1 || memcmp (msg [2].data (), "1", 1)) { errno = EPROTO; goto error; @@ -591,8 +602,12 @@ int zmq::curve_server_t::receive_and_process_zap_reply () goto error; } + // Process metadata frame + rc = parse_metadata (static_cast (msg [6].data ()), + msg [6].size ()); + error: - for (int i = 0; i < 6; i++) { + for (int i = 0; i < 7; i++) { const int rc2 = msg [i].close (); errno_assert (rc2 == 0); } diff --git a/src/curve_server.hpp b/src/curve_server.hpp index cfd0d206..8e9aaf8c 100644 --- a/src/curve_server.hpp +++ b/src/curve_server.hpp @@ -50,6 +50,7 @@ namespace zmq public: curve_server_t (session_base_t *session_, + const std::string &peer_address_, const options_t &options_); virtual ~curve_server_t (); @@ -74,6 +75,8 @@ namespace zmq session_base_t * const session; + const std::string peer_address; + // Current FSM state state_t state; diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index 031c6a04..0b9dad5c 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -84,6 +84,9 @@ zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_, // Put the socket into non-blocking mode. unblock_socket (s); + if (!get_peer_ip_address (s, peer_address)) + peer_address = ""; + #ifdef SO_NOSIGPIPE // Make sure that SIGPIPE signal is not generated when writing to a // connection that was already closed by the peer. @@ -534,7 +537,8 @@ bool zmq::stream_engine_t::handshake () else if (memcmp (greeting_recv + 12, "CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0) { if (options.as_server) - mechanism = new (std::nothrow) curve_server_t (session, options); + mechanism = new (std::nothrow) + curve_server_t (session, peer_address, options); else mechanism = new (std::nothrow) curve_client_t (options); alloc_assert (mechanism); diff --git a/src/stream_engine.hpp b/src/stream_engine.hpp index d4d36b1e..29bd0ace 100644 --- a/src/stream_engine.hpp +++ b/src/stream_engine.hpp @@ -187,6 +187,8 @@ namespace zmq // Socket zmq::socket_base_t *socket; + std::string peer_address; + stream_engine_t (const stream_engine_t&); const stream_engine_t &operator = (const stream_engine_t&); }; diff --git a/tests/test_security_curve.cpp b/tests/test_security_curve.cpp index 59e9fca6..94ab9f93 100644 --- a/tests/test_security_curve.cpp +++ b/tests/test_security_curve.cpp @@ -29,6 +29,7 @@ zap_handler (void *zap) char *version = s_recv (zap); char *sequence = s_recv (zap); char *domain = s_recv (zap); + char *address = s_recv (zap); char *mechanism = s_recv (zap); char *client_key = s_recv (zap); @@ -39,11 +40,13 @@ zap_handler (void *zap) s_sendmore (zap, sequence); s_sendmore (zap, "200"); s_sendmore (zap, "OK"); - s_send (zap, "anonymous"); + s_sendmore (zap, "anonymous"); + s_send (zap, ""); free (version); free (sequence); free (domain); + free (address); free (mechanism); free (client_key);