diff --git a/src/plain_client.cpp b/src/plain_client.cpp index 8dead179..64233a49 100644 --- a/src/plain_client.cpp +++ b/src/plain_client.cpp @@ -30,6 +30,7 @@ zmq::plain_client_t::plain_client_t (const options_t &options_) : mechanism_t (options_), + error_command_received (false), state (sending_hello) { } @@ -62,38 +63,45 @@ int zmq::plain_client_t::next_handshake_command (msg_t *msg_) int zmq::plain_client_t::process_handshake_command (msg_t *msg_) { - int rc = 0; + const unsigned char *cmd_data = + static_cast (msg_->data ()); + const size_t data_size = msg_->size (); - switch (state) { - case waiting_for_welcome: - rc = process_welcome (msg_); - if (rc == 0) - state = sending_initiate; - break; - case waiting_for_ready: - rc = process_ready (msg_); - if (rc == 0) - state = ready; - break; - default: - // Temporary support for security debugging - puts ("PLAIN I: invalid handshake command"); - errno = EPROTO; - rc = -1; - break; + int rc = 0; + if (data_size >= 8 && !memcmp (cmd_data, "\7WELCOME", 8)) + rc = process_welcome (cmd_data, data_size); + else + if (data_size >= 6 && !memcmp (cmd_data, "\5READY", 6)) + rc = process_ready (cmd_data, data_size); + else + if (data_size >= 6 && !memcmp (cmd_data, "\5ERROR", 6)) + rc = process_error (cmd_data, data_size); + else { + // Temporary support for security debugging + puts ("PLAIN I: invalid handshake command"); + errno = EPROTO; + rc = -1; } + if (rc == 0) { rc = msg_->close (); errno_assert (rc == 0); rc = msg_->init (); errno_assert (rc == 0); } + return rc; } zmq::mechanism_t::status_t zmq::plain_client_t::status () const { - return state == ready? mechanism_t::ready: mechanism_t::handshaking; + if (state == ready) + return mechanism_t::ready; + else + if (error_command_received) + return mechanism_t::error; + else + return mechanism_t::handshaking; } int zmq::plain_client_t::produce_hello (msg_t *msg_) const @@ -125,17 +133,18 @@ int zmq::plain_client_t::produce_hello (msg_t *msg_) const return 0; } -int zmq::plain_client_t::process_welcome (msg_t *msg_) +int zmq::plain_client_t::process_welcome ( + const unsigned char *cmd_data, size_t data_size) { - const unsigned char *ptr = static_cast (msg_->data ()); - const size_t bytes_left = msg_->size (); - - if (bytes_left != 8 || memcmp (ptr, "\x07WELCOME", 8)) { - // Temporary support for security debugging - puts ("PLAIN I: invalid PLAIN client, did not send WELCOME"); + if (state != waiting_for_welcome) { errno = EPROTO; return -1; } + if (data_size != 8) { + errno = EPROTO; + return -1; + } + state = sending_initiate; return 0; } @@ -170,16 +179,35 @@ int zmq::plain_client_t::produce_initiate (msg_t *msg_) const return 0; } -int zmq::plain_client_t::process_ready (msg_t *msg_) +int zmq::plain_client_t::process_ready ( + const unsigned char *cmd_data, size_t data_size) { - const unsigned char *ptr = static_cast (msg_->data ()); - const size_t bytes_left = msg_->size (); - - if (bytes_left < 6 || memcmp (ptr, "\x05READY", 6)) { - // Temporary support for security debugging - puts ("PLAIN I: invalid PLAIN client, did not send READY"); + if (state != waiting_for_ready) { errno = EPROTO; return -1; } - return parse_metadata (ptr + 6, bytes_left - 6); + const int rc = parse_metadata (cmd_data + 6, data_size - 6); + if (rc == 0) + state = ready; + return rc; +} + +int zmq::plain_client_t::process_error ( + const unsigned char *cmd_data, size_t data_size) +{ + if (state != waiting_for_welcome && state != waiting_for_ready) { + errno = EPROTO; + return -1; + } + if (data_size == 6) { + errno = EPROTO; + return -1; + } + const size_t size = static_cast (cmd_data [6]); + if (6 + 1 + size != data_size) { + errno = EPROTO; + return -1; + } + error_command_received = true; + return 0; } diff --git a/src/plain_client.hpp b/src/plain_client.hpp index c5c56af7..758c3bfc 100644 --- a/src/plain_client.hpp +++ b/src/plain_client.hpp @@ -42,6 +42,8 @@ namespace zmq private: + bool error_command_received; + enum state_t { sending_hello, waiting_for_welcome, @@ -55,8 +57,12 @@ namespace zmq int produce_hello (msg_t *msg_) const; int produce_initiate (msg_t *msg_) const; - int process_welcome (msg_t *msg); - int process_ready (msg_t *msg_); + int process_welcome ( + const unsigned char *cmd_data, size_t data_size); + int process_ready ( + const unsigned char *cmd_data, size_t data_size); + int process_error ( + const unsigned char *cmd_data, size_t data_size); }; }