diff --git a/src/curve_client.cpp b/src/curve_client.cpp index e1012e7c..8c85d168 100644 --- a/src/curve_client.cpp +++ b/src/curve_client.cpp @@ -275,17 +275,21 @@ int zmq::curve_client_t::process_error ( } if (msg_size < 7) { session->get_socket ()->event_handshake_failed_protocol ( - session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR); + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR); errno = EPROTO; return -1; } const size_t error_reason_len = static_cast (msg_data [6]); if (error_reason_len > msg_size - 7) { session->get_socket ()->event_handshake_failed_protocol ( - session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR); + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR); errno = EPROTO; return -1; } + const char *error_reason = reinterpret_cast (msg_data) + 7; + handle_error_reason (error_reason, error_reason_len); state = error_received; return 0; } diff --git a/src/mechanism_base.cpp b/src/mechanism_base.cpp index e255cb01..bcc9f1b3 100644 --- a/src/mechanism_base.cpp +++ b/src/mechanism_base.cpp @@ -52,3 +52,14 @@ int zmq::mechanism_base_t::check_basic_command_structure (msg_t *msg_) return 0; } +void zmq::mechanism_base_t::handle_error_reason (const char *error_reason, + int error_reason_len) +{ + if (error_reason_len == 3 && error_reason[1] == '0' + && error_reason[2] == '0' && error_reason[0] >= '3' + && error_reason[0] <= '5') { + // it is a ZAP status code, so emit an authentication failure event + session->get_socket ()->event_handshake_failed_auth ( + session->get_endpoint (), (error_reason[0] - '0') * 100); + } +} diff --git a/src/mechanism_base.hpp b/src/mechanism_base.hpp index e98a7f7e..620b6fa7 100644 --- a/src/mechanism_base.hpp +++ b/src/mechanism_base.hpp @@ -43,6 +43,8 @@ class mechanism_base_t : public mechanism_t session_base_t *const session; int check_basic_command_structure (msg_t *msg_); + + void handle_error_reason (const char *error_reason, int error_reason_len); }; } diff --git a/src/null_mechanism.cpp b/src/null_mechanism.cpp index 8f181920..c8192bf3 100644 --- a/src/null_mechanism.cpp +++ b/src/null_mechanism.cpp @@ -83,15 +83,21 @@ int zmq::null_mechanism_t::next_handshake_command (msg_t *msg_) } if (zap_reply_received && status_code != "200") { - const size_t status_code_len = 3; - const int rc = msg_->init_size (6 + 1 + status_code_len); - zmq_assert (rc == 0); - unsigned char *msg_data = static_cast (msg_->data ()); - memcpy (msg_data, "\5ERROR", 6); - msg_data [6] = status_code_len; - memcpy (msg_data + 7, status_code.c_str (), status_code_len); error_command_sent = true; - return 0; + if (status_code != "300") { + const size_t status_code_len = 3; + const int rc = msg_->init_size (6 + 1 + status_code_len); + zmq_assert (rc == 0); + unsigned char *msg_data = + static_cast (msg_->data ()); + memcpy (msg_data, "\5ERROR", 6); + msg_data[6] = status_code_len; + memcpy (msg_data + 7, status_code.c_str (), status_code_len); + return 0; + } else { + errno = EAGAIN; + return -1; + } } make_command_with_basic_properties (msg_, "\5READY", 6); @@ -165,6 +171,8 @@ int zmq::null_mechanism_t::process_error_command ( errno = EPROTO; return -1; } + const char *error_reason = reinterpret_cast (cmd_data) + 7; + handle_error_reason (error_reason, error_reason_len); error_command_received = true; return 0; } diff --git a/src/plain_client.cpp b/src/plain_client.cpp index 43918cab..c4fcf280 100644 --- a/src/plain_client.cpp +++ b/src/plain_client.cpp @@ -86,9 +86,8 @@ int zmq::plain_client_t::process_handshake_command (msg_t *msg_) if (data_size >= 6 && !memcmp (cmd_data, "\5ERROR", 6)) rc = process_error (cmd_data, data_size); else { - // TODO see comment in curve_server_t::process_handshake_command session->get_socket ()->event_handshake_failed_protocol ( - session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNSPECIFIED); + session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); errno = EPROTO; rc = -1; } @@ -215,6 +214,8 @@ int zmq::plain_client_t::process_error ( errno = EPROTO; return -1; } + const char *error_reason = reinterpret_cast (cmd_data) + 7; + handle_error_reason (error_reason, error_reason_len); state = error_command_received; return 0; } diff --git a/src/zap_client.cpp b/src/zap_client.cpp index 1d820d37..c1d719be 100644 --- a/src/zap_client.cpp +++ b/src/zap_client.cpp @@ -281,10 +281,19 @@ void zap_client_common_handshake_t::handle_zap_status_code () // we can assume here that status_code is a valid ZAP status code, // i.e. 200, 300, 400 or 500 - if (status_code[0] == '2') { - state = zap_reply_ok_state; - } else { - state = sending_error; + switch (status_code[0]) { + case '2': + state = zap_reply_ok_state; + break; + case '3': + // a 300 error code (temporary failure) + // should NOT result in an ERROR message, but instead the + // client should be silently disconnected (see CURVEZMQ RFC) + // therefore, go immediately to state error_sent + state = error_sent; + break; + default: + state = sending_error; } } diff --git a/tests/test_security_curve.cpp b/tests/test_security_curve.cpp index 57a2dc71..51dec8f6 100644 --- a/tests/test_security_curve.cpp +++ b/tests/test_security_curve.cpp @@ -53,24 +53,6 @@ const char large_identity[] = "0123456789012345678901234567890123456789" "0123456789012345678901234567890123456789" "012345678901234"; -#ifdef ZMQ_BUILD_DRAFT_API -// assert_* are macros rather than functions, to allow assertion failures be -// attributed to the causing source code line -#define assert_no_more_monitor_events_with_timeout(monitor, timeout) \ - { \ - int event_count = 0; \ - int event, err; \ - while ((event = get_monitor_event_with_timeout ((monitor), &err, NULL, \ - (timeout))) \ - != -1) { \ - ++event_count; \ - fprintf (stderr, "Unexpected event: %x (err = %i)\n", event, err); \ - } \ - assert (event_count == 0); \ - } - -#endif - static void zap_handler_large_identity (void *ctx) { zap_handler_generic (ctx, zap_ok, large_identity); @@ -81,21 +63,23 @@ void expect_new_client_curve_bounce_fail (void *ctx, char *client_public, char *client_secret, char *my_endpoint, - void *server) + void *server, + void **client_mon = NULL) { curve_client_data_t curve_client_data = {server_public, client_public, client_secret}; - expect_new_client_bounce_fail ( - ctx, my_endpoint, server, socket_config_curve_client, &curve_client_data); + expect_new_client_bounce_fail (ctx, my_endpoint, server, + socket_config_curve_client, + &curve_client_data, client_mon); } -void test_garbage_key(void *ctx, - void *server, - void *server_mon, - char *my_endpoint, - char *server_public, - char *client_public, - char *client_secret) +void test_null_key (void *ctx, + void *server, + void *server_mon, + char *my_endpoint, + char *server_public, + char *client_public, + char *client_secret) { expect_new_client_curve_bounce_fail (ctx, server_public, client_public, client_secret, my_endpoint, server); @@ -113,7 +97,9 @@ void test_garbage_key(void *ctx, // long) fprintf (stderr, - "count of ZMQ_EVENT_HANDSHAKE_FAILED_ENCRYPTION events: %i\n", + "count of " + "ZMQ_EVENT_HANDSHAKE_FAILED_PROTOCOL/" + "ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC events: %i\n", handshake_failed_encryption_event_count); #endif } @@ -123,8 +109,10 @@ void test_curve_security_with_valid_credentials ( { curve_client_data_t curve_client_data = { valid_server_public, valid_client_public, valid_client_secret}; - void *client = create_and_connect_client ( - ctx, my_endpoint, socket_config_curve_client, &curve_client_data); + void *client_mon; + void *client = + create_and_connect_client (ctx, my_endpoint, socket_config_curve_client, + &curve_client_data, &client_mon); bounce (server, client); int rc = zmq_close (client); assert (rc == 0); @@ -134,6 +122,14 @@ void test_curve_security_with_valid_credentials ( assert (event == ZMQ_EVENT_HANDSHAKE_SUCCEEDED); assert_no_more_monitor_events_with_timeout (server_mon, timeout); + + event = get_monitor_event_with_timeout (client_mon, NULL, NULL, -1); + assert (event == ZMQ_EVENT_HANDSHAKE_SUCCEEDED); + + assert_no_more_monitor_events_with_timeout (client_mon, timeout); + + rc = zmq_close (client_mon); + assert (rc == 0); #endif } @@ -147,18 +143,27 @@ void test_curve_security_with_bogus_client_credentials ( char bogus_secret [41]; zmq_curve_keypair (bogus_public, bogus_secret); + void *client_mon; expect_new_client_curve_bounce_fail (ctx, valid_server_public, bogus_public, - bogus_secret, my_endpoint, server); + bogus_secret, my_endpoint, server, + &client_mon); - int event_count = 0; + int server_event_count = 0; #ifdef ZMQ_BUILD_DRAFT_API - event_count = expect_monitor_event_multiple ( + server_event_count = expect_monitor_event_multiple ( server_mon, ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, 400); - assert (event_count <= 1); + assert (server_event_count <= 1); + + int client_event_count = expect_monitor_event_multiple ( + client_mon, ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, 400); + assert (client_event_count == 1); + + int rc = zmq_close (client_mon); + assert (rc == 0); #endif // there may be more than one ZAP request due to repeated attempts by the client - assert (0 == event_count + assert (0 == server_event_count || 1 <= zmq_atomic_counter_value (zap_requests_handled)); } @@ -445,7 +450,7 @@ void recv_greeting (int fd) } int connect_exchange_greeting_and_send_hello (char *my_endpoint, - zmq::curve_client_tools_t &tools) + zmq::curve_client_tools_t &tools) { int s = connect_vanilla_socket (my_endpoint); @@ -643,35 +648,35 @@ int main (void) shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, handler); - char garbage_key[] = "0000000000000000000000000000000000000000"; + char null_key[] = "0000000000000000000000000000000000000000"; - // Check CURVE security with a garbage server key + // Check CURVE security with a null server key // This will be caught by the curve_server class, not passed to ZAP - fprintf (stderr, "test_garbage_server_key\n"); + fprintf (stderr, "test_null_key (server)\n"); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); - test_garbage_key (ctx, server, server_mon, my_endpoint, garbage_key, + test_null_key (ctx, server, server_mon, my_endpoint, null_key, valid_client_public, valid_client_secret); shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, handler); - // Check CURVE security with a garbage client public key + // Check CURVE security with a null client public key // This will be caught by the curve_server class, not passed to ZAP - fprintf (stderr, "test_garbage_client_public_key\n"); + fprintf (stderr, "test_null_key (client public)\n"); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); - test_garbage_key (ctx, server, server_mon, my_endpoint, valid_server_public, - garbage_key, valid_client_secret); + test_null_key (ctx, server, server_mon, my_endpoint, valid_server_public, + null_key, valid_client_secret); shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, handler); - // Check CURVE security with a garbage client secret key + // Check CURVE security with a null client secret key // This will be caught by the curve_server class, not passed to ZAP - fprintf (stderr, "test_garbage_client_secret_key\n"); + fprintf (stderr, "test_null_key (client secret)\n"); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); - test_garbage_key (ctx, server, server_mon, my_endpoint, valid_server_public, - valid_client_public, garbage_key); + test_null_key (ctx, server, server_mon, my_endpoint, valid_server_public, + valid_client_public, null_key); shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, handler); @@ -682,6 +687,7 @@ int main (void) server_mon, timeout); shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, handler); + fprintf (stderr, "test_curve_security_with_null_client_credentials\n"); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); diff --git a/tests/test_security_zap.cpp b/tests/test_security_zap.cpp index b07b2492..c718f84e 100644 --- a/tests/test_security_zap.cpp +++ b/tests/test_security_zap.cpp @@ -66,10 +66,11 @@ void test_zap_unsuccessful (void *ctx, int expected_event, int expected_err, socket_config_fn socket_config_, - void *socket_config_data_) + void *socket_config_data_, + void **client_mon = NULL) { expect_new_client_bounce_fail (ctx, my_endpoint, server, socket_config_, - socket_config_data_); + socket_config_data_, client_mon); int events_received = 0; #ifdef ZMQ_BUILD_DRAFT_API @@ -77,7 +78,8 @@ void test_zap_unsuccessful (void *ctx, expect_monitor_event_multiple (server_mon, expected_event, expected_err); #endif - // there may be more than one ZAP request due to repeated attempts by the client + // there may be more than one ZAP request due to repeated attempts by the + // client (actually only in case if ZAP status code 300) assert (events_received == 0 || 1 <= zmq_atomic_counter_value (zap_requests_handled)); } @@ -99,6 +101,59 @@ void test_zap_protocol_error (void *ctx, socket_config_, socket_config_data_); } +void test_zap_unsuccessful_status_300 (void *ctx, + char *my_endpoint, + void *server, + void *server_mon, + socket_config_fn client_socket_config_, + void *client_socket_config_data_) +{ + void *client_mon; + test_zap_unsuccessful (ctx, my_endpoint, server, server_mon, +#ifdef ZMQ_BUILD_DRAFT_API + ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, 300, +#else + 0, 0, +#endif + client_socket_config_, client_socket_config_data_, + &client_mon); + +#ifdef ZMQ_BUILD_DRAFT_API + assert_no_more_monitor_events_with_timeout (client_mon, 250); + + int rc = zmq_close (client_mon); + assert (rc == 0); +#endif +} + +void test_zap_unsuccessful_status_500 (void *ctx, + char *my_endpoint, + void *server, + void *server_mon, + socket_config_fn client_socket_config_, + void *client_socket_config_data_) +{ + void *client_mon; + test_zap_unsuccessful (ctx, my_endpoint, server, server_mon, +#ifdef ZMQ_BUILD_DRAFT_API + ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, 500, +#else + 0, 0, +#endif + client_socket_config_, client_socket_config_data_, + &client_mon); + +#ifdef ZMQ_BUILD_DRAFT_API + int events_received = 0; + events_received = expect_monitor_event_multiple ( + client_mon, ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, 500); + assert(events_received == 1); + + int rc = zmq_close (client_mon); + assert (rc == 0); +#endif +} + void test_zap_errors (socket_config_fn server_socket_config_, void *server_socket_config_data_, socket_config_fn client_socket_config_, @@ -192,13 +247,9 @@ void test_zap_errors (socket_config_fn server_socket_config_, &ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint, &zap_handler_wrong_status_temporary_failure, server_socket_config_, server_socket_config_data_); - test_zap_unsuccessful (ctx, my_endpoint, server, server_mon, -#ifdef ZMQ_BUILD_DRAFT_API - ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, 300, -#else - 0, 0, -#endif - client_socket_config_, client_socket_config_data_); + test_zap_unsuccessful_status_300 (ctx, my_endpoint, server, server_mon, + client_socket_config_, + client_socket_config_data_); shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, handler); @@ -207,13 +258,9 @@ void test_zap_errors (socket_config_fn server_socket_config_, setup_context_and_server_side ( &ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint, &zap_handler_wrong_status_internal_error, server_socket_config_); - test_zap_unsuccessful (ctx, my_endpoint, server, server_mon, -#ifdef ZMQ_BUILD_DRAFT_API - ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, 500, -#else - 0, 0, -#endif - client_socket_config_, client_socket_config_data_); + test_zap_unsuccessful_status_500 (ctx, my_endpoint, server, server_mon, + client_socket_config_, + client_socket_config_data_); shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, handler); } diff --git a/tests/testutil_security.hpp b/tests/testutil_security.hpp index 61241d43..61c375ca 100644 --- a/tests/testutil_security.hpp +++ b/tests/testutil_security.hpp @@ -288,6 +288,30 @@ void zap_handler (void *ctx) zap_handler_generic (ctx, zap_ok); } +void setup_handshake_socket_monitor (void *ctx, + void *server, + void **server_mon, + const char *monitor_endpoint) +{ +#ifdef ZMQ_BUILD_DRAFT_API + // Monitor handshake events on the server + int rc = zmq_socket_monitor (server, monitor_endpoint, + ZMQ_EVENT_HANDSHAKE_SUCCEEDED + | ZMQ_EVENT_HANDSHAKE_FAILED_NO_DETAIL + | ZMQ_EVENT_HANDSHAKE_FAILED_AUTH + | ZMQ_EVENT_HANDSHAKE_FAILED_PROTOCOL); + assert (rc == 0); + + // Create socket for collecting monitor events + *server_mon = zmq_socket (ctx, ZMQ_PAIR); + assert (*server_mon); + + // Connect it to the inproc endpoints so they'll get events + rc = zmq_connect (*server_mon, monitor_endpoint); + assert (rc == 0); +#endif +} + void setup_context_and_server_side ( void **ctx, void **handler, @@ -335,25 +359,9 @@ void setup_context_and_server_side ( rc = zmq_getsockopt (*server, ZMQ_LAST_ENDPOINT, my_endpoint, &len); assert (rc == 0); -#ifdef ZMQ_BUILD_DRAFT_API - char monitor_endpoint [] = "inproc://monitor-server"; - - // Monitor handshake events on the server - rc = zmq_socket_monitor (*server, monitor_endpoint, - ZMQ_EVENT_HANDSHAKE_SUCCEEDED - | ZMQ_EVENT_HANDSHAKE_FAILED_NO_DETAIL - | ZMQ_EVENT_HANDSHAKE_FAILED_AUTH - | ZMQ_EVENT_HANDSHAKE_FAILED_PROTOCOL); - assert (rc == 0); - - // Create socket for collecting monitor events - *server_mon = zmq_socket (*ctx, ZMQ_PAIR); - assert (*server_mon); - - // Connect it to the inproc endpoints so they'll get events - rc = zmq_connect (*server_mon, monitor_endpoint); - assert (rc == 0); -#endif + const char server_monitor_endpoint [] = "inproc://monitor-server"; + setup_handshake_socket_monitor (*ctx, *server, server_mon, + server_monitor_endpoint); } void shutdown_context_and_server_side (void *ctx, @@ -389,7 +397,8 @@ void shutdown_context_and_server_side (void *ctx, void *create_and_connect_client (void *ctx, char *my_endpoint, socket_config_fn socket_config_, - void *socket_config_data_) + void *socket_config_data_, + void **client_mon = NULL) { void *client = zmq_socket (ctx, ZMQ_DEALER); assert (client); @@ -399,6 +408,12 @@ void *create_and_connect_client (void *ctx, int rc = zmq_connect (client, my_endpoint); assert (rc == 0); + if (client_mon) + { + setup_handshake_socket_monitor (ctx, client, client_mon, + "inproc://client-monitor"); + } + return client; } @@ -406,10 +421,11 @@ void expect_new_client_bounce_fail (void *ctx, char *my_endpoint, void *server, socket_config_fn socket_config_, - void *socket_config_data_) + void *socket_config_data_, + void **client_mon = NULL) { void *client = create_and_connect_client (ctx, my_endpoint, socket_config_, - socket_config_data_); + socket_config_data_, client_mon); expect_bounce_fail (server, client); close_zero_linger (client); } @@ -484,6 +500,19 @@ int get_monitor_event_with_timeout (void *monitor, } #ifdef ZMQ_BUILD_DRAFT_API + +void print_unexpected_event (int event, + int err, + int expected_event, + int expected_err) +{ + fprintf( + stderr, + "Unexpected event: 0x%x, value = %i/0x%x (expected: 0x%x, value " + "= %i/0x%x)\n", + event, err, err, expected_event, expected_err, expected_err); +} + // expects that one or more occurrences of the expected event are received // via the specified socket monitor // returns the number of occurrences of the expected event @@ -497,16 +526,22 @@ int expect_monitor_event_multiple (void *server_mon, { int count_of_expected_events = 0; int client_closed_connection = 0; - // infinite timeout at the start - int timeout = -1; + int timeout = 250; + int wait_time = 0; int event; int err; while ( (event = get_monitor_event_with_timeout (server_mon, &err, NULL, timeout)) - != -1) { - timeout = 250; - + != -1 || !count_of_expected_events) { + if (event == -1) { + wait_time += timeout; + fprintf (stderr, + "Still waiting for first event after %ims (expected event " + "%x (value %i/%x))\n", + wait_time, expected_event, expected_err, expected_err); + continue; + } // ignore errors with EPIPE/ECONNRESET/ECONNABORTED, which can happen // ECONNRESET can happen on very slow machines, when the engine writes // to the peer and then tries to read the socket before the peer reads @@ -522,11 +557,7 @@ int expect_monitor_event_multiple (void *server_mon, } if (event != expected_event || (-1 != expected_err && err != expected_err)) { - fprintf ( - stderr, - "Unexpected event: 0x%x, value = %i/0x%x (expected: 0x%x, value " - "= %i/0x%x)\n", - event, err, err, expected_event, expected_err, expected_err); + print_unexpected_event (event, err, expected_event, expected_err); assert (false); } ++count_of_expected_events; @@ -535,6 +566,31 @@ int expect_monitor_event_multiple (void *server_mon, return count_of_expected_events; } + +// assert_* are macros rather than functions, to allow assertion failures be +// attributed to the causing source code line +#define assert_no_more_monitor_events_with_timeout(monitor, timeout) \ + { \ + int event_count = 0; \ + int event, err; \ + while ((event = get_monitor_event_with_timeout ((monitor), &err, NULL, \ + (timeout))) \ + != -1) { \ + if (event == ZMQ_EVENT_HANDSHAKE_FAILED_NO_DETAIL \ + && (err == EPIPE || err == ECONNRESET \ + || err == ECONNABORTED)) { \ + fprintf (stderr, \ + "Ignored event (skipping any further events): %x " \ + "(err = %i)\n", \ + event, err); \ + continue; \ + } \ + ++event_count; \ + print_unexpected_event (event, err, 0, 0); \ + } \ + assert (event_count == 0); \ + } + #endif #endif