diff --git a/src/curve_mechanism_base.cpp b/src/curve_mechanism_base.cpp index b37cb9da..4636ea58 100644 --- a/src/curve_mechanism_base.cpp +++ b/src/curve_mechanism_base.cpp @@ -35,7 +35,6 @@ #include "session_base.hpp" #ifdef ZMQ_HAVE_CURVE - zmq::curve_mechanism_base_t::curve_mechanism_base_t ( session_base_t *session_, const options_t &options_, @@ -76,35 +75,99 @@ zmq::curve_encoding_t::curve_encoding_t (const char *encode_nonce_prefix_, { } +// Right now, we only transport the lower two bit flags of zmq::msg_t, so they +// are binary identical, and we can just use a bitmask to select them. If we +// happened to add more flags, this might change. +static const uint8_t flag_mask = zmq::msg_t::more | zmq::msg_t::command; +static const size_t flags_len = 1; +static const size_t nonce_prefix_len = 16; +static const char message_command[] = "\x07MESSAGE"; +static const size_t message_command_len = sizeof (message_command) - 1; +static const size_t message_header_len = + message_command_len + sizeof (zmq::curve_encoding_t::nonce_t); + +#ifndef ZMQ_USE_LIBSODIUM +static const size_t crypto_box_MACBYTES = 16; +#endif + +int zmq::curve_encoding_t::check_validity (msg_t *msg_, int *error_event_code_) +{ + const size_t size = msg_->size (); + const uint8_t *const message = static_cast (msg_->data ()); + + if (size < message_command_len + || 0 != memcmp (message, message_command, message_command_len)) { + *error_event_code_ = ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND; + errno = EPROTO; + return -1; + } + + if (size < message_header_len + crypto_box_MACBYTES + flags_len) { + *error_event_code_ = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE; + errno = EPROTO; + return -1; + } + + { + const uint64_t nonce = get_uint64 (message + message_command_len); + if (nonce <= _cn_peer_nonce) { + *error_event_code_ = ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE; + errno = EPROTO; + return -1; + } + set_peer_nonce (nonce); + } + + return 0; +} + int zmq::curve_encoding_t::encode (msg_t *msg_) { - const size_t mlen = crypto_box_ZEROBYTES + 1 + msg_->size (); - uint8_t message_nonce[crypto_box_NONCEBYTES]; - memcpy (message_nonce, _encode_nonce_prefix, 16); - put_uint64 (message_nonce + 16, _cn_nonce); - - uint8_t flags = 0; - if (msg_->flags () & msg_t::more) - flags |= 0x01; - if (msg_->flags () & msg_t::command) - flags |= 0x02; + memcpy (message_nonce, _encode_nonce_prefix, nonce_prefix_len); + put_uint64 (message_nonce + nonce_prefix_len, get_and_inc_nonce ()); +#ifdef ZMQ_USE_LIBSODIUM + const size_t mlen = flags_len + msg_->size (); std::vector message_plaintext (mlen); +#else + const size_t mlen = crypto_box_ZEROBYTES + flags_len + msg_->size (); + std::vector message_plaintext_with_zerobytes (mlen); + uint8_t *const message_plaintext = + &message_plaintext_with_zerobytes[crypto_box_ZEROBYTES]; - std::fill (message_plaintext.begin (), - message_plaintext.begin () + crypto_box_ZEROBYTES, 0); - message_plaintext[crypto_box_ZEROBYTES] = flags; + std::fill (message_plaintext_with_zerobytes.begin (), + message_plaintext_with_zerobytes.begin () + crypto_box_ZEROBYTES, + 0); +#endif + + const uint8_t flags = msg_->flags () & flag_mask; + message_plaintext[0] = flags; // this is copying the data from insecure memory, so there is no point in // using secure_allocator_t for message_plaintext if (msg_->size () > 0) - memcpy (&message_plaintext[crypto_box_ZEROBYTES + 1], msg_->data (), - msg_->size ()); + memcpy (&message_plaintext[flags_len], msg_->data (), msg_->size ()); +#ifdef ZMQ_USE_LIBSODIUM + msg_t msg_box; + int rc = + msg_box.init_size (message_header_len + mlen + crypto_box_MACBYTES); + zmq_assert (rc == 0); + + rc = crypto_box_easy_afternm ( + static_cast (msg_box.data ()) + message_header_len, + &message_plaintext[0], mlen, message_nonce, _cn_precom); + zmq_assert (rc == 0); + + msg_->move (msg_box); + + uint8_t *const message = static_cast (msg_->data ()); +#else std::vector message_box (mlen); - int rc = crypto_box_afternm (&message_box[0], &message_plaintext[0], mlen, - message_nonce, _cn_precom); + int rc = + crypto_box_afternm (&message_box[0], &message_plaintext_with_zerobytes[0], + mlen, message_nonce, _cn_precom); zmq_assert (rc == 0); rc = msg_->close (); @@ -113,76 +176,89 @@ int zmq::curve_encoding_t::encode (msg_t *msg_) rc = msg_->init_size (16 + mlen - crypto_box_BOXZEROBYTES); zmq_assert (rc == 0); - uint8_t *message = static_cast (msg_->data ()); + uint8_t *const message = static_cast (msg_->data ()); - memcpy (message, "\x07MESSAGE", 8); - memcpy (message + 8, message_nonce + 16, 8); - memcpy (message + 16, &message_box[crypto_box_BOXZEROBYTES], + memcpy (message + message_header_len, &message_box[crypto_box_BOXZEROBYTES], mlen - crypto_box_BOXZEROBYTES); +#endif - _cn_nonce++; + memcpy (message, message_command, message_command_len); + memcpy (message + message_command_len, message_nonce + nonce_prefix_len, + sizeof (nonce_t)); return 0; } int zmq::curve_encoding_t::decode (msg_t *msg_, int *error_event_code_) { - const size_t size = msg_->size (); - const uint8_t *message = static_cast (msg_->data ()); - - if (size < 8 || 0 != memcmp (message, "\x07MESSAGE", 8)) { - *error_event_code_ = ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND; - errno = EPROTO; - return -1; + int rc = check_validity (msg_, error_event_code_); + if (0 != rc) { + return rc; } - if (size < 33) { - *error_event_code_ = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE; - errno = EPROTO; - return -1; - } + uint8_t *const message = static_cast (msg_->data ()); uint8_t message_nonce[crypto_box_NONCEBYTES]; - memcpy (message_nonce, _decode_nonce_prefix, 16); - memcpy (message_nonce + 16, message + 8, 8); - const uint64_t nonce = get_uint64 (message + 8); - if (nonce <= _cn_peer_nonce) { - *error_event_code_ = ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE; - errno = EPROTO; - return -1; - } - _cn_peer_nonce = nonce; + memcpy (message_nonce, _decode_nonce_prefix, nonce_prefix_len); + memcpy (message_nonce + nonce_prefix_len, message + message_command_len, + sizeof (nonce_t)); - const size_t clen = crypto_box_BOXZEROBYTES + msg_->size () - 16; +#ifdef ZMQ_USE_LIBSODIUM + const size_t clen = msg_->size () - message_header_len; - std::vector message_plaintext (clen); + uint8_t *const message_plaintext = message + message_header_len; + + rc = crypto_box_open_easy_afternm (message_plaintext, + message + message_header_len, clen, + message_nonce, _cn_precom); +#else + const size_t clen = + crypto_box_BOXZEROBYTES + msg_->size () - message_header_len; + + std::vector message_plaintext_with_zerobytes (clen); std::vector message_box (clen); std::fill (message_box.begin (), message_box.begin () + crypto_box_BOXZEROBYTES, 0); - memcpy (&message_box[crypto_box_BOXZEROBYTES], message + 16, - msg_->size () - 16); + memcpy (&message_box[crypto_box_BOXZEROBYTES], message + message_header_len, + msg_->size () - message_header_len); + + rc = crypto_box_open_afternm (&message_plaintext_with_zerobytes[0], + &message_box[0], clen, message_nonce, + _cn_precom); + + const uint8_t *const message_plaintext = + &message_plaintext_with_zerobytes[crypto_box_ZEROBYTES]; +#endif - int rc = crypto_box_open_afternm (&message_plaintext[0], &message_box[0], - clen, message_nonce, _cn_precom); if (rc == 0) { + const uint8_t flags = message_plaintext[0]; + +#ifdef ZMQ_USE_LIBSODIUM + const size_t plaintext_size = clen - flags_len - crypto_box_MACBYTES; + + if (plaintext_size > 0) { + memmove (msg_->data (), &message_plaintext[flags_len], + plaintext_size); + } + + msg_->shrink (plaintext_size); +#else rc = msg_->close (); zmq_assert (rc == 0); - rc = msg_->init_size (clen - 1 - crypto_box_ZEROBYTES); + rc = msg_->init_size (clen - flags_len - crypto_box_ZEROBYTES); zmq_assert (rc == 0); - const uint8_t flags = message_plaintext[crypto_box_ZEROBYTES]; - if (flags & 0x01) - msg_->set_flags (msg_t::more); - if (flags & 0x02) - msg_->set_flags (msg_t::command); - // this is copying the data to insecure memory, so there is no point in // using secure_allocator_t for message_plaintext - if (msg_->size () > 0) - memcpy (msg_->data (), &message_plaintext[crypto_box_ZEROBYTES + 1], + if (msg_->size () > 0) { + memcpy (msg_->data (), &message_plaintext[flags_len], msg_->size ()); + } +#endif + + msg_->set_flags (flags & flag_mask); } else { // CURVE I : connection key used for MESSAGE is wrong *error_event_code_ = ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC; diff --git a/src/curve_mechanism_base.hpp b/src/curve_mechanism_base.hpp index 1bb65ec1..a341e744 100644 --- a/src/curve_mechanism_base.hpp +++ b/src/curve_mechanism_base.hpp @@ -64,18 +64,19 @@ class curve_encoding_t uint8_t *get_writable_precom_buffer () { return _cn_precom; } const uint8_t *get_precom_buffer () const { return _cn_precom; } - uint64_t get_and_inc_nonce () { return _cn_nonce++; } - void set_peer_nonce (uint64_t peer_nonce_) - { - _cn_peer_nonce = peer_nonce_; - }; + typedef uint64_t nonce_t; + + nonce_t get_and_inc_nonce () { return _cn_nonce++; } + void set_peer_nonce (nonce_t peer_nonce_) { _cn_peer_nonce = peer_nonce_; }; private: + int check_validity (msg_t *msg_, int *error_event_code_); + const char *_encode_nonce_prefix; const char *_decode_nonce_prefix; - uint64_t _cn_nonce; - uint64_t _cn_peer_nonce; + nonce_t _cn_nonce; + nonce_t _cn_peer_nonce; // Intermediary buffer used to speed up boxing and unboxing. uint8_t _cn_precom[crypto_box_BEFORENMBYTES]; diff --git a/src/msg.cpp b/src/msg.cpp index 7822a548..612a18b2 100644 --- a/src/msg.cpp +++ b/src/msg.cpp @@ -361,6 +361,30 @@ size_t zmq::msg_t::size () const } } +void zmq::msg_t::shrink (size_t new_size_) +{ + // Check the validity of the message. + zmq_assert (check ()); + zmq_assert (new_size_ <= size ()); + + switch (_u.base.type) { + case type_vsm: + _u.vsm.size = static_cast (new_size_); + break; + case type_lmsg: + _u.lmsg.content->size = new_size_; + break; + case type_zclmsg: + _u.zclmsg.content->size = new_size_; + break; + case type_cmsg: + _u.cmsg.size = new_size_; + break; + default: + zmq_assert (false); + } +} + unsigned char zmq::msg_t::flags () const { return _u.base.flags; diff --git a/src/msg.hpp b/src/msg.hpp index a75be068..636ab3e5 100644 --- a/src/msg.hpp +++ b/src/msg.hpp @@ -161,6 +161,8 @@ class msg_t // references drops to 0, the message is closed and false is returned. bool rm_refs (int refs_); + void shrink (size_t new_size_); + // Size in bytes of the largest message that is still copied around // rather than being reference-counted. enum