diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index 9dc96696..7a3a7d98 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -490,224 +490,18 @@ bool zmq::stream_engine_t::handshake () zmq_assert (_handshaking); zmq_assert (_greeting_bytes_read < _greeting_size); // Receive the greeting. - while (_greeting_bytes_read < _greeting_size) { - const int n = tcp_read (_s, _greeting_recv + _greeting_bytes_read, - _greeting_size - _greeting_bytes_read); - if (n == 0) { - errno = EPIPE; - error (connection_error); - return false; - } - if (n == -1) { - if (errno != EAGAIN) - error (connection_error); - return false; - } - - _greeting_bytes_read += n; - - // We have received at least one byte from the peer. - // If the first byte is not 0xff, we know that the - // peer is using unversioned protocol. - if (_greeting_recv[0] != 0xff) - break; - - if (_greeting_bytes_read < signature_size) - continue; - - // Inspect the right-most bit of the 10th byte (which coincides - // with the 'flags' field if a regular message was sent). - // Zero indicates this is a header of a routing id message - // (i.e. the peer is using the unversioned protocol). - if (!(_greeting_recv[9] & 0x01)) - break; - - // The peer is using versioned protocol. - // Send the major version number. - if (_outpos + _outsize == _greeting_send + signature_size) { - if (_outsize == 0) - set_pollout (_handle); - _outpos[_outsize++] = 3; // Major version number - } - - if (_greeting_bytes_read > signature_size) { - if (_outpos + _outsize == _greeting_send + signature_size + 1) { - if (_outsize == 0) - set_pollout (_handle); - - // Use ZMTP/2.0 to talk to older peers. - if (_greeting_recv[10] == ZMTP_1_0 - || _greeting_recv[10] == ZMTP_2_0) - _outpos[_outsize++] = _options.type; - else { - _outpos[_outsize++] = 0; // Minor version number - memset (_outpos + _outsize, 0, 20); - - zmq_assert (_options.mechanism == ZMQ_NULL - || _options.mechanism == ZMQ_PLAIN - || _options.mechanism == ZMQ_CURVE - || _options.mechanism == ZMQ_GSSAPI); - - if (_options.mechanism == ZMQ_NULL) - memcpy (_outpos + _outsize, "NULL", 4); - else if (_options.mechanism == ZMQ_PLAIN) - memcpy (_outpos + _outsize, "PLAIN", 5); - else if (_options.mechanism == ZMQ_GSSAPI) - memcpy (_outpos + _outsize, "GSSAPI", 6); - else if (_options.mechanism == ZMQ_CURVE) - memcpy (_outpos + _outsize, "CURVE", 5); - _outsize += 20; - memset (_outpos + _outsize, 0, 32); - _outsize += 32; - _greeting_size = v3_greeting_size; - } - } - } - } + const int rc = receive_greeting (); + if (rc == -1) + return false; + const bool unversioned = rc != 0; // Position of the revision field in the greeting. const size_t revision_pos = 10; - // Is the peer using ZMTP/1.0 with no revision number? - // If so, we send and receive rest of routing id message - if (_greeting_recv[0] != 0xff || !(_greeting_recv[9] & 0x01)) { - if (_session->zap_enabled ()) { - // reject ZMTP 1.0 connections if ZAP is enabled - error (protocol_error); - return false; - } - - _encoder = new (std::nothrow) v1_encoder_t (out_batch_size); - alloc_assert (_encoder); - - _decoder = - new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize); - alloc_assert (_decoder); - - // We have already sent the message header. - // Since there is no way to tell the encoder to - // skip the message header, we simply throw that - // header data away. - const size_t header_size = - _options.routing_id_size + 1 >= UCHAR_MAX ? 10 : 2; - unsigned char tmp[10], *bufferp = tmp; - - // Prepare the routing id message and load it into encoder. - // Then consume bytes we have already sent to the peer. - const int rc = _tx_msg.init_size (_options.routing_id_size); - zmq_assert (rc == 0); - memcpy (_tx_msg.data (), _options.routing_id, _options.routing_id_size); - _encoder->load_msg (&_tx_msg); - size_t buffer_size = _encoder->encode (&bufferp, header_size); - zmq_assert (buffer_size == header_size); - - // Make sure the decoder sees the data we have already received. - _inpos = _greeting_recv; - _insize = _greeting_bytes_read; - - // To allow for interoperability with peers that do not forward - // their subscriptions, we inject a phantom subscription message - // message into the incoming message stream. - if (_options.type == ZMQ_PUB || _options.type == ZMQ_XPUB) - _subscription_required = true; - - // We are sending our routing id now and the next message - // will come from the socket. - _next_msg = &stream_engine_t::pull_msg_from_session; - - // We are expecting routing id message. - _process_msg = &stream_engine_t::process_routing_id_msg; - } else if (_greeting_recv[revision_pos] == ZMTP_1_0) { - if (_session->zap_enabled ()) { - // reject ZMTP 1.0 connections if ZAP is enabled - error (protocol_error); - return false; - } - - _encoder = new (std::nothrow) v1_encoder_t (out_batch_size); - alloc_assert (_encoder); - - _decoder = - new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize); - alloc_assert (_decoder); - } else if (_greeting_recv[revision_pos] == ZMTP_2_0) { - if (_session->zap_enabled ()) { - // reject ZMTP 2.0 connections if ZAP is enabled - error (protocol_error); - return false; - } - - _encoder = new (std::nothrow) v2_encoder_t (out_batch_size); - alloc_assert (_encoder); - - _decoder = new (std::nothrow) - v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy); - alloc_assert (_decoder); - } else { - _encoder = new (std::nothrow) v2_encoder_t (out_batch_size); - alloc_assert (_encoder); - - _decoder = new (std::nothrow) - v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy); - alloc_assert (_decoder); - - if (_options.mechanism == ZMQ_NULL - && memcmp (_greeting_recv + 12, - "NULL\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) - == 0) { - _mechanism = new (std::nothrow) - null_mechanism_t (_session, _peer_address, _options); - alloc_assert (_mechanism); - } else if (_options.mechanism == ZMQ_PLAIN - && memcmp (_greeting_recv + 12, - "PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) - == 0) { - if (_options.as_server) - _mechanism = new (std::nothrow) - plain_server_t (_session, _peer_address, _options); - else - _mechanism = - new (std::nothrow) plain_client_t (_session, _options); - alloc_assert (_mechanism); - } -#ifdef ZMQ_HAVE_CURVE - else if (_options.mechanism == ZMQ_CURVE - && memcmp (_greeting_recv + 12, - "CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) - == 0) { - if (_options.as_server) - _mechanism = new (std::nothrow) - curve_server_t (_session, _peer_address, _options); - else - _mechanism = - new (std::nothrow) curve_client_t (_session, _options); - alloc_assert (_mechanism); - } -#endif -#ifdef HAVE_LIBGSSAPI_KRB5 - else if (_options.mechanism == ZMQ_GSSAPI - && memcmp (_greeting_recv + 12, - "GSSAPI\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) - == 0) { - if (_options.as_server) - _mechanism = new (std::nothrow) - gssapi_server_t (_session, _peer_address, _options); - else - _mechanism = - new (std::nothrow) gssapi_client_t (_session, _options); - alloc_assert (_mechanism); - } -#endif - else { - _session->get_socket ()->event_handshake_failed_protocol ( - _session->get_endpoint (), - ZMQ_PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH); - error (protocol_error); - return false; - } - _next_msg = &stream_engine_t::next_handshake_command; - _process_msg = &stream_engine_t::process_handshake_command; - } + if (!(this + ->*select_handshake_fun (unversioned, + _greeting_recv[revision_pos])) ()) + return false; // Start polling for output if necessary. if (_outsize == 0) @@ -725,6 +519,269 @@ bool zmq::stream_engine_t::handshake () return true; } +int zmq::stream_engine_t::receive_greeting () +{ + bool unversioned = false; + while (_greeting_bytes_read < _greeting_size) { + const int n = tcp_read (_s, _greeting_recv + _greeting_bytes_read, + _greeting_size - _greeting_bytes_read); + if (n == 0) { + errno = EPIPE; + error (connection_error); + return -1; + } + if (n == -1) { + if (errno != EAGAIN) + error (connection_error); + return -1; + } + + _greeting_bytes_read += n; + + // We have received at least one byte from the peer. + // If the first byte is not 0xff, we know that the + // peer is using unversioned protocol. + if (_greeting_recv[0] != 0xff) { + unversioned = true; + break; + } + + if (_greeting_bytes_read < signature_size) + continue; + + // Inspect the right-most bit of the 10th byte (which coincides + // with the 'flags' field if a regular message was sent). + // Zero indicates this is a header of a routing id message + // (i.e. the peer is using the unversioned protocol). + if (!(_greeting_recv[9] & 0x01)) { + unversioned = true; + break; + } + + // The peer is using versioned protocol. + receive_greeting_versioned (); + } + return unversioned ? 1 : 0; +} + +void zmq::stream_engine_t::receive_greeting_versioned () +{ + // Send the major version number. + if (_outpos + _outsize == _greeting_send + signature_size) { + if (_outsize == 0) + set_pollout (_handle); + _outpos[_outsize++] = 3; // Major version number + } + + if (_greeting_bytes_read > signature_size) { + if (_outpos + _outsize == _greeting_send + signature_size + 1) { + if (_outsize == 0) + set_pollout (_handle); + + // Use ZMTP/2.0 to talk to older peers. + if (_greeting_recv[10] == ZMTP_1_0 + || _greeting_recv[10] == ZMTP_2_0) + _outpos[_outsize++] = _options.type; + else { + _outpos[_outsize++] = 0; // Minor version number + memset (_outpos + _outsize, 0, 20); + + zmq_assert (_options.mechanism == ZMQ_NULL + || _options.mechanism == ZMQ_PLAIN + || _options.mechanism == ZMQ_CURVE + || _options.mechanism == ZMQ_GSSAPI); + + if (_options.mechanism == ZMQ_NULL) + memcpy (_outpos + _outsize, "NULL", 4); + else if (_options.mechanism == ZMQ_PLAIN) + memcpy (_outpos + _outsize, "PLAIN", 5); + else if (_options.mechanism == ZMQ_GSSAPI) + memcpy (_outpos + _outsize, "GSSAPI", 6); + else if (_options.mechanism == ZMQ_CURVE) + memcpy (_outpos + _outsize, "CURVE", 5); + _outsize += 20; + memset (_outpos + _outsize, 0, 32); + _outsize += 32; + _greeting_size = v3_greeting_size; + } + } + } +} + +zmq::stream_engine_t::handshake_fun_t +zmq::stream_engine_t::select_handshake_fun (bool unversioned, + unsigned char revision) +{ + // Is the peer using ZMTP/1.0 with no revision number? + if (unversioned) { + return &stream_engine_t::handshake_v1_0_unversioned; + } + switch (revision) { + case ZMTP_1_0: + return &stream_engine_t::handshake_v1_0; + case ZMTP_2_0: + return &stream_engine_t::handshake_v2_0; + default: + return &stream_engine_t::handshake_v3_0; + } +} + +bool zmq::stream_engine_t::handshake_v1_0_unversioned () +{ + // We send and receive rest of routing id message + if (_session->zap_enabled ()) { + // reject ZMTP 1.0 connections if ZAP is enabled + error (protocol_error); + return false; + } + + _encoder = new (std::nothrow) v1_encoder_t (out_batch_size); + alloc_assert (_encoder); + + _decoder = + new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize); + alloc_assert (_decoder); + + // We have already sent the message header. + // Since there is no way to tell the encoder to + // skip the message header, we simply throw that + // header data away. + const size_t header_size = + _options.routing_id_size + 1 >= UCHAR_MAX ? 10 : 2; + unsigned char tmp[10], *bufferp = tmp; + + // Prepare the routing id message and load it into encoder. + // Then consume bytes we have already sent to the peer. + const int rc = _tx_msg.init_size (_options.routing_id_size); + zmq_assert (rc == 0); + memcpy (_tx_msg.data (), _options.routing_id, _options.routing_id_size); + _encoder->load_msg (&_tx_msg); + const size_t buffer_size = _encoder->encode (&bufferp, header_size); + zmq_assert (buffer_size == header_size); + + // Make sure the decoder sees the data we have already received. + _inpos = _greeting_recv; + _insize = _greeting_bytes_read; + + // To allow for interoperability with peers that do not forward + // their subscriptions, we inject a phantom subscription message + // message into the incoming message stream. + if (_options.type == ZMQ_PUB || _options.type == ZMQ_XPUB) + _subscription_required = true; + + // We are sending our routing id now and the next message + // will come from the socket. + _next_msg = &stream_engine_t::pull_msg_from_session; + + // We are expecting routing id message. + _process_msg = &stream_engine_t::process_routing_id_msg; + + return true; +} + +bool zmq::stream_engine_t::handshake_v1_0 () +{ + if (_session->zap_enabled ()) { + // reject ZMTP 1.0 connections if ZAP is enabled + error (protocol_error); + return false; + } + + _encoder = new (std::nothrow) v1_encoder_t (out_batch_size); + alloc_assert (_encoder); + + _decoder = + new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize); + alloc_assert (_decoder); + + return true; +} + +bool zmq::stream_engine_t::handshake_v2_0 () +{ + if (_session->zap_enabled ()) { + // reject ZMTP 2.0 connections if ZAP is enabled + error (protocol_error); + return false; + } + + _encoder = new (std::nothrow) v2_encoder_t (out_batch_size); + alloc_assert (_encoder); + + _decoder = new (std::nothrow) + v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy); + alloc_assert (_decoder); + + return true; +} + +bool zmq::stream_engine_t::handshake_v3_0 () +{ + _encoder = new (std::nothrow) v2_encoder_t (out_batch_size); + alloc_assert (_encoder); + + _decoder = new (std::nothrow) + v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy); + alloc_assert (_decoder); + + if (_options.mechanism == ZMQ_NULL + && memcmp (_greeting_recv + 12, "NULL\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", + 20) + == 0) { + _mechanism = new (std::nothrow) + null_mechanism_t (_session, _peer_address, _options); + alloc_assert (_mechanism); + } else if (_options.mechanism == ZMQ_PLAIN + && memcmp (_greeting_recv + 12, + "PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) + == 0) { + if (_options.as_server) + _mechanism = new (std::nothrow) + plain_server_t (_session, _peer_address, _options); + else + _mechanism = new (std::nothrow) plain_client_t (_session, _options); + alloc_assert (_mechanism); + } +#ifdef ZMQ_HAVE_CURVE + else if (_options.mechanism == ZMQ_CURVE + && memcmp (_greeting_recv + 12, + "CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) + == 0) { + if (_options.as_server) + _mechanism = new (std::nothrow) + curve_server_t (_session, _peer_address, _options); + else + _mechanism = new (std::nothrow) curve_client_t (_session, _options); + alloc_assert (_mechanism); + } +#endif +#ifdef HAVE_LIBGSSAPI_KRB5 + else if (_options.mechanism == ZMQ_GSSAPI + && memcmp (_greeting_recv + 12, + "GSSAPI\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) + == 0) { + if (_options.as_server) + _mechanism = new (std::nothrow) + gssapi_server_t (_session, _peer_address, _options); + else + _mechanism = + new (std::nothrow) gssapi_client_t (_session, _options); + alloc_assert (_mechanism); + } +#endif + else { + _session->get_socket ()->event_handshake_failed_protocol ( + _session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH); + error (protocol_error); + return false; + } + _next_msg = &stream_engine_t::next_handshake_command; + _process_msg = &stream_engine_t::process_handshake_command; + + return true; +} + int zmq::stream_engine_t::routing_id_msg (msg_t *msg_) { int rc = msg_->init_size (_options.routing_id_size); diff --git a/src/stream_engine.hpp b/src/stream_engine.hpp index e714fbd9..2c25ce55 100644 --- a/src/stream_engine.hpp +++ b/src/stream_engine.hpp @@ -93,12 +93,22 @@ class stream_engine_t : public io_object_t, public i_engine // Function to handle network disconnections. void error (error_reason_t reason_); - // Receives the greeting message from the peer. - int receive_greeting (); - // Detects the protocol used by the peer. bool handshake (); + // Receive the greeting from the peer. + int receive_greeting (); + void receive_greeting_versioned (); + + typedef bool (stream_engine_t::*handshake_fun_t) (); + static handshake_fun_t select_handshake_fun (bool unversioned, + unsigned char revision); + + bool handshake_v1_0_unversioned (); + bool handshake_v1_0 (); + bool handshake_v2_0 (); + bool handshake_v3_0 (); + int routing_id_msg (msg_t *msg_); int process_routing_id_msg (msg_t *msg_);