diff --git a/tests/test_security_curve.cpp b/tests/test_security_curve.cpp index 88578785..ed18f4fc 100644 --- a/tests/test_security_curve.cpp +++ b/tests/test_security_curve.cpp @@ -148,13 +148,42 @@ enum zap_protocol_t zap_too_many_parts }; -static void zap_handler_generic (void *handler, zap_protocol_t zap_protocol) +static void zap_handler_generic (void *ctx, zap_protocol_t zap_protocol) { + void *control = zmq_socket (ctx, ZMQ_REQ); + assert (control); + int rc = zmq_connect (control, "inproc://handler-control"); + assert (rc == 0); + + void *handler = zmq_socket (ctx, ZMQ_REP); + assert (handler); + rc = zmq_bind (handler, "inproc://zeromq.zap.01"); + assert (rc == 0); + + // Signal main thread that we are ready + rc = s_send (control, "GO"); + assert (rc == 2); + + zmq_pollitem_t items [] = { + { control, 0, ZMQ_POLLIN, 0 }, + { handler, 0, ZMQ_POLLIN, 0 }, + }; + // Process ZAP requests forever - while (true) { + while (zmq_poll (items, 2, -1) >= 0) { + if (items [0].revents & ZMQ_POLLIN) { + char *buf = s_recv (control); + assert (buf); + assert (streq (buf, "STOP")); + free (buf); + break; // Terminating - main thread signal + } + if (!(items [1].revents & ZMQ_POLLIN)) + continue; + char *version = s_recv (handler); if (!version) - break; // Terminating + break; // Terminating - peer's socket closed char *sequence = s_recv (handler); char *domain = s_recv (handler); @@ -206,37 +235,43 @@ static void zap_handler_generic (void *handler, zap_protocol_t zap_protocol) zmq_atomic_counter_inc (zap_requests_handled); } - zmq_close (handler); + rc = zmq_unbind (handler, "inproc://zeromq.zap.01"); + assert (rc == 0); + close_zero_linger (handler); + + rc = s_send (control, "STOPPED"); + assert (rc == 7); + close_zero_linger (control); } -static void zap_handler (void *handler) +static void zap_handler (void *ctx) { - zap_handler_generic (handler, zap_ok); + zap_handler_generic (ctx, zap_ok); } -static void zap_handler_wrong_version (void *handler) +static void zap_handler_wrong_version (void *ctx) { - zap_handler_generic (handler, zap_wrong_version); + zap_handler_generic (ctx, zap_wrong_version); } -static void zap_handler_wrong_request_id (void *handler) +static void zap_handler_wrong_request_id (void *ctx) { - zap_handler_generic (handler, zap_wrong_request_id); + zap_handler_generic (ctx, zap_wrong_request_id); } -static void zap_handler_wrong_status_invalid (void *handler) +static void zap_handler_wrong_status_invalid (void *ctx) { - zap_handler_generic (handler, zap_status_invalid); + zap_handler_generic (ctx, zap_status_invalid); } -static void zap_handler_wrong_status_internal_error (void *handler) +static void zap_handler_wrong_status_internal_error (void *ctx) { - zap_handler_generic (handler, zap_status_internal_error); + zap_handler_generic (ctx, zap_status_internal_error); } -static void zap_handler_too_many_parts (void *handler) +static void zap_handler_too_many_parts (void *ctx) { - zap_handler_generic (handler, zap_too_many_parts); + zap_handler_generic (ctx, zap_too_many_parts); } void *create_and_connect_curve_client (void *ctx, @@ -360,13 +395,17 @@ void setup_context_and_server_side (void **ctx, zap_requests_handled = zmq_atomic_counter_new (); assert (zap_requests_handled != NULL); - // We create and bind ZAP socket in main thread to avoid case - // where child thread does not start up fast enough. *handler = zmq_socket (*ctx, ZMQ_REP); assert (*handler); - int rc = zmq_bind (*handler, "inproc://zeromq.zap.01"); + int rc = zmq_bind (*handler, "inproc://handler-control"); assert (rc == 0); - *zap_thread = zmq_threadstart (zap_handler_, *handler); + + *zap_thread = zmq_threadstart (zap_handler_, *ctx); + + char *buf = s_recv (*handler); + assert (buf); + assert (streq (buf, "GO")); + free (buf); // Server socket will accept connections *server = zmq_socket (*ctx, ZMQ_DEALER); @@ -413,19 +452,30 @@ void setup_context_and_server_side (void **ctx, void shutdown_context_and_server_side (void *ctx, void *zap_thread, void *server, - void *server_mon) + void *server_mon, + void *handler) { + int rc = s_send (handler, "STOP"); + assert (rc == 4); + char *buf = s_recv (handler); + assert (buf); + assert (streq (buf, "STOPPED")); + free (buf); + rc = zmq_unbind (handler, "inproc://handler-control"); + assert (rc == 0); + close_zero_linger (handler); + #ifdef ZMQ_BUILD_DRAFT_API close_zero_linger (server_mon); #endif close_zero_linger (server); - int rc = zmq_ctx_term (ctx); - assert (rc == 0); - // Wait until ZAP handler terminates zmq_threadclose (zap_thread); + rc = zmq_ctx_term (ctx); + assert (rc == 0); + zmq_atomic_counter_destroy (&zap_requests_handled); } @@ -632,7 +682,8 @@ int main (void) &server_mon, my_endpoint); test_curve_security_with_valid_credentials (ctx, my_endpoint, server, server_mon, timeout); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); char garbage_key [] = "0000000000000000000000000000000000000000"; @@ -643,7 +694,8 @@ int main (void) &server_mon, my_endpoint); test_garbage_key (ctx, server, server_mon, my_endpoint, garbage_key, valid_client_public, valid_client_secret); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); // Check CURVE security with a garbage client public key // This will be caught by the curve_server class, not passed to ZAP @@ -652,7 +704,8 @@ int main (void) &server_mon, my_endpoint); test_garbage_key (ctx, server, server_mon, my_endpoint, valid_server_public, garbage_key, valid_client_secret); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); // Check CURVE security with a garbage client secret key // This will be caught by the curve_server class, not passed to ZAP @@ -661,30 +714,35 @@ int main (void) &server_mon, my_endpoint); test_garbage_key (ctx, server, server_mon, my_endpoint, valid_server_public, valid_client_public, garbage_key); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); test_curve_security_with_bogus_client_credentials (ctx, my_endpoint, server, server_mon, timeout); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); test_curve_security_with_null_client_credentials (ctx, my_endpoint, server, server_mon); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); test_curve_security_with_plain_client_credentials (ctx, my_endpoint, server, server_mon); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, &server_mon, my_endpoint); test_curve_security_unauthenticated_message (my_endpoint, server, timeout); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); // Invalid ZAP protocol tests @@ -694,7 +752,8 @@ int main (void) &zap_handler_wrong_version); test_curve_security_zap_protocol_error (ctx, my_endpoint, server, server_mon); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); // wrong request id setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, @@ -702,7 +761,8 @@ int main (void) &zap_handler_wrong_request_id); test_curve_security_zap_protocol_error (ctx, my_endpoint, server, server_mon); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); // status invalid (not a 3-digit number) setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, @@ -710,7 +770,8 @@ int main (void) &zap_handler_wrong_status_invalid); test_curve_security_zap_protocol_error (ctx, my_endpoint, server, server_mon); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); // too many parts setup_context_and_server_side (&ctx, &handler, &zap_thread, &server, @@ -718,7 +779,8 @@ int main (void) &zap_handler_too_many_parts); test_curve_security_zap_protocol_error (ctx, my_endpoint, server, server_mon); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); // ZAP non-standard cases @@ -737,7 +799,8 @@ int main (void) 0, 0 #endif ); - shutdown_context_and_server_side (ctx, zap_thread, server, server_mon); + shutdown_context_and_server_side (ctx, zap_thread, server, server_mon, + handler); ctx = zmq_ctx_new (); test_curve_security_invalid_keysize (ctx);