diff --git a/src/ws_address.cpp b/src/ws_address.cpp index 5298e7e3..9e5224b9 100644 --- a/src/ws_address.cpp +++ b/src/ws_address.cpp @@ -67,13 +67,15 @@ zmq::ws_address_t::ws_address_t (const sockaddr *sa_, socklen_t sa_len_) && sa_len_ >= static_cast (sizeof (_address.ipv6))) memcpy (&_address.ipv6, sa_, sizeof (_address.ipv6)); - _path = std::string ("/"); + _path = std::string (""); char hbuf[NI_MAXHOST]; const int rc = getnameinfo (addr (), addrlen (), hbuf, sizeof (hbuf), NULL, 0, NI_NUMERICHOST); - if (rc != 0) + if (rc != 0) { _host = std::string ("localhost"); + return; + } std::ostringstream os; @@ -128,7 +130,7 @@ int zmq::ws_address_t::to_string (std::string &addr_) const { std::ostringstream os; os << std::string ("ws://") << host () << std::string (":") - << _address.port (); + << _address.port () << _path; addr_ = os.str (); return 0; diff --git a/src/ws_encoder.cpp b/src/ws_encoder.cpp index 3f1cd1fe..dfccb8bb 100644 --- a/src/ws_encoder.cpp +++ b/src/ws_encoder.cpp @@ -55,19 +55,24 @@ void zmq::ws_encoder_t::message_ready () { int offset = 0; + _is_binary = false; + if (in_progress ()->is_ping ()) _tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_ping; else if (in_progress ()->is_pong ()) _tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_pong; else if (in_progress ()->is_close_cmd ()) _tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_close; - else + else { _tmp_buf[offset++] = 0x82; // Final | binary + _is_binary = true; + } _tmp_buf[offset] = _must_mask ? 0x80 : 0x00; size_t size = in_progress ()->size (); - size++; // TODO: check if binary + if (_is_binary) + size++; if (size <= 125) _tmp_buf[offset++] |= static_cast (size & 127); @@ -88,17 +93,17 @@ void zmq::ws_encoder_t::message_ready () offset += 4; } - // TODO: check if binary + if (_is_binary) { + // Encode flags. + unsigned char protocol_flags = 0; + if (in_progress ()->flags () & msg_t::more) + protocol_flags |= ws_protocol_t::more_flag; + if (in_progress ()->flags () & msg_t::command) + protocol_flags |= ws_protocol_t::command_flag; - // Encode flags. - unsigned char protocol_flags = 0; - if (in_progress ()->flags () & msg_t::more) - protocol_flags |= ws_protocol_t::more_flag; - if (in_progress ()->flags () & msg_t::command) - protocol_flags |= ws_protocol_t::command_flag; - - _tmp_buf[offset++] = - _must_mask ? protocol_flags ^ _mask[0] : protocol_flags; + _tmp_buf[offset++] = + _must_mask ? protocol_flags ^ _mask[0] : protocol_flags; + } next_step (_tmp_buf, offset, &ws_encoder_t::size_ready, false); } @@ -109,20 +114,23 @@ void zmq::ws_encoder_t::size_ready () assert (in_progress () != &_masked_msg); const size_t size = in_progress ()->size (); - _masked_msg.close (); - _masked_msg.init_size (size); - - int mask_index = 1; // TODO: check if binary message - unsigned char *dest = - static_cast (_masked_msg.data ()); unsigned char *src = static_cast (in_progress ()->data ()); - for (size_t i = 0, size = in_progress ()->size (); i < size; - ++i, mask_index++) + unsigned char *dest = src; + + // If msg is shared or data is constant we cannot mask in-place, allocate a new msg for it + if (in_progress ()->flags () & msg_t::shared + || in_progress ()->is_cmsg ()) { + _masked_msg.close (); + _masked_msg.init_size (size); + dest = static_cast (_masked_msg.data ()); + } + + int mask_index = _is_binary ? 1 : 0; + for (size_t i = 0; i < size; ++i, mask_index++) dest[i] = src[i] ^ _mask[mask_index % 4]; - next_step (_masked_msg.data (), _masked_msg.size (), - &ws_encoder_t::message_ready, true); + next_step (dest, size, &ws_encoder_t::message_ready, true); } else { next_step (in_progress ()->data (), in_progress ()->size (), &ws_encoder_t::message_ready, true); diff --git a/src/ws_encoder.hpp b/src/ws_encoder.hpp index b145e665..299259a7 100644 --- a/src/ws_encoder.hpp +++ b/src/ws_encoder.hpp @@ -50,6 +50,7 @@ class ws_encoder_t ZMQ_FINAL : public encoder_base_t bool _must_mask; unsigned char _mask[4]; msg_t _masked_msg; + bool _is_binary; ZMQ_NON_COPYABLE_NOR_MOVABLE (ws_encoder_t) }; diff --git a/src/ws_listener.cpp b/src/ws_listener.cpp index c8380818..9151f2c1 100644 --- a/src/ws_listener.cpp +++ b/src/ws_listener.cpp @@ -124,10 +124,14 @@ void zmq::ws_listener_t::in_event () std::string zmq::ws_listener_t::get_socket_name (zmq::fd_t fd_, socket_end_t socket_end_) const { - if (_wss) - return zmq::get_socket_name (fd_, socket_end_); + std::string socket_name; - return zmq::get_socket_name (fd_, socket_end_); + if (_wss) + socket_name = zmq::get_socket_name (fd_, socket_end_); + else + socket_name = zmq::get_socket_name (fd_, socket_end_); + + return socket_name + _address.path (); } int zmq::ws_listener_t::create_socket (const char *addr_) diff --git a/src/wss_address.cpp b/src/wss_address.cpp index 6adf90ad..abfec9eb 100644 --- a/src/wss_address.cpp +++ b/src/wss_address.cpp @@ -46,7 +46,7 @@ int zmq::wss_address_t::to_string (std::string &addr_) const { std::ostringstream os; os << std::string ("wss://") << host () << std::string (":") - << _address.port (); + << _address.port () << path (); addr_ = os.str (); return 0; diff --git a/tests/test_ws_transport.cpp b/tests/test_ws_transport.cpp index a700c853..da496370 100644 --- a/tests/test_ws_transport.cpp +++ b/tests/test_ws_transport.cpp @@ -41,7 +41,6 @@ void test_roundtrip () TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://*:*/roundtrip")); TEST_ASSERT_SUCCESS_ERRNO ( zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); - strcat (connect_address, "/roundtrip"); void *sc = test_context_socket (ZMQ_REQ); TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address)); @@ -79,7 +78,6 @@ void test_heartbeat () TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://*:*/heartbeat")); TEST_ASSERT_SUCCESS_ERRNO ( zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); - strcat (connect_address, "/heartbeat"); void *sc = test_context_socket (ZMQ_REQ); @@ -113,7 +111,6 @@ void test_short_message () TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://*:*/short")); TEST_ASSERT_SUCCESS_ERRNO ( zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); - strcat (connect_address, "/short"); void *sc = test_context_socket (ZMQ_REQ); TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address)); @@ -147,7 +144,6 @@ void test_large_message () TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://*:*/large")); TEST_ASSERT_SUCCESS_ERRNO ( zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); - strcat (connect_address, "/short"); void *sc = test_context_socket (ZMQ_REQ); TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address)); @@ -197,8 +193,6 @@ void test_curve () TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (server, "ws://*:*/roundtrip")); TEST_ASSERT_SUCCESS_ERRNO (zmq_getsockopt (server, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); - strcat (connect_address, "/roundtrip"); - void *client = test_context_socket (ZMQ_REQ); TEST_ASSERT_SUCCESS_ERRNO ( @@ -215,6 +209,60 @@ void test_curve () test_context_socket_close (server); } + +void test_mask_shared_msg () +{ + char connect_address[MAX_SOCKET_STRING + strlen ("/mask-shared")]; + size_t addr_length = sizeof (connect_address); + void *sb = test_context_socket (ZMQ_DEALER); + TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://*:*/mask-shared")); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); + + void *sc = test_context_socket (ZMQ_DEALER); + TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address)); + + zmq_msg_t msg; + zmq_msg_init_size ( + &msg, 255); // Message have to be long enough so it won't fit inside msg + unsigned char *data = (unsigned char *) zmq_msg_data (&msg); + for (int i = 0; i < 255; i++) + data[i] = i; + + // Taking a copy to make the msg shared + zmq_msg_t copy; + zmq_msg_init (©); + zmq_msg_copy (©, &msg); + + // Sending the shared msg + int rc = zmq_msg_send (&msg, sc, 0); + TEST_ASSERT_EQUAL_INT (255, rc); + + // Recv the msg and check that it was masked correctly + rc = zmq_msg_recv (&msg, sb, 0); + TEST_ASSERT_EQUAL_INT (255, rc); + data = (unsigned char *) zmq_msg_data (&msg); + for (int i = 0; i < 255; i++) + TEST_ASSERT_EQUAL_INT (i, data[i]); + + // Testing that copy was not masked + data = (unsigned char *) zmq_msg_data (©); + for (int i = 0; i < 255; i++) + TEST_ASSERT_EQUAL_INT (i, data[i]); + + // Constant msg cannot be masked as well, as it is constant + rc = zmq_send_const (sc, "HELLO", 5, 0); + TEST_ASSERT_EQUAL_INT (5, rc); + recv_string_expect_success (sb, "HELLO", 0); + + zmq_msg_close (©); + zmq_msg_close (&msg); + + test_context_socket_close (sc); + test_context_socket_close (sb); +} + + int main () { setup_test_environment (); @@ -225,6 +273,7 @@ int main () RUN_TEST (test_short_message); RUN_TEST (test_large_message); RUN_TEST (test_heartbeat); + RUN_TEST (test_mask_shared_msg); if (zmq_has ("curve")) RUN_TEST (test_curve);