0
0
mirror of https://github.com/zeromq/libzmq.git synced 2025-01-21 15:12:03 +08:00
libzmq/builds/msvc/vs2015_xp/test_zmq/test_multithread.cpp

230 lines
7.1 KiB
C++

/*
Server thread listen ZMQ_SERVER socket and transfer incoming message
to worker threads by ZMQ_PUSH-ZMQ_PULL
Worker thread receive message and send back to ZMQ_SERVER socket
Each client thread open CLIENT_CONNECTION ZMQ_CLIENT sockets,
send random size message to each socket and check server answer
*/
#define ZMQ_BUILD_DRAFT_API
#include "../../../../include/zmq.h"
#pragma comment(lib,"libzmq.lib")
#include <assert.h>
#include <stdlib.h>
#include <thread>
#include <atomic>
#define SERVER_ADDR "tcp://127.0.0.1:12345"
#define SERVER_WORKER_COUNT 3 // worker threads count
#define CLIENT_COUNT 5 // client threads count
#define CLIENT_CONNECTION 100 // ZMQ_CLIENT sockets at each client
#define CLIENT_RECCONECT 1000 // reconnect one socket after messages
#define MESSAGE_MAX_SIZE 1024
//*******************************************************************
//****** MESSAGE ****************************************************
//*******************************************************************
void message_fill(zmq_msg_t* msg, int val) {
assert(val > 0);
int size = sizeof(int) * 2 + val;
int rc = zmq_msg_init_size(msg, size); assert(rc == 0);
uint8_t* data = (uint8_t*)zmq_msg_data(msg);
memcpy(data, &val, sizeof(int));
data += sizeof(int);
memset(data, val & 0xFF, val);
int check_sum = val + (val & 0xFF) * val;
data += val;
memcpy(data, &check_sum, sizeof(int));
}
int message_check(zmq_msg_t* msg) {
uint8_t* data = (uint8_t*)zmq_msg_data(msg);
int size = zmq_msg_size(msg);
assert(size > sizeof(int) * 2);
// check size
int val;
memcpy(&val, data, sizeof(int));
if(size != sizeof(int) * 2 + val) {
fprintf(stderr, "wrong message: val = %d size = %d\n", val, size);
return -1;
}
// check sum
data += sizeof(int);
int cs = val;
for(int i = 0; i < val; i++) {
cs += data[i];
}
data += val;
int check_sum;
memcpy(&check_sum, data, sizeof(int));
if(check_sum != cs) {
fprintf(stderr, "wrong message: cs = %d check_sum = %d\n", cs, check_sum);
return -1;
}
return val;
}
//*******************************************************************
//****** SERVER *****************************************************
//*******************************************************************
void *server_ctx = NULL;
void *server_sock = NULL;
std::atomic<int> worker_cnt[SERVER_WORKER_COUNT] = {0}; // statistic
// worker thread
void worker(int num) {
printf("worker %d start\n", num);
void* queue = zmq_socket(server_ctx, ZMQ_PULL); assert(queue);
int rc = zmq_connect(queue, "inproc://queue"); assert(rc == 0);
while (1) {
// receive messages from the queue
zmq_msg_t msg;
rc = zmq_msg_init(&msg); assert(rc == 0);
rc = zmq_msg_recv(&msg, queue, 0); assert(rc > 0);
// check message
//printf("worker %d recv %d bytes at %X from %X\n", num, zmq_msg_size(&msg), zmq_msg_data(&msg), zmq_msg_routing_id(&msg));
// send to client
rc = zmq_msg_send(&msg, server_sock, 0); assert(rc != -1);
worker_cnt[num]++;
}
zmq_close(queue);
}
// server thread
void server() {
server_ctx = zmq_ctx_new(); assert(server_ctx);
// create queue
void* queue = zmq_socket(server_ctx, ZMQ_PUSH); assert(queue);
int rc = zmq_bind(queue, "inproc://queue"); assert(rc == 0);
// start workers
std::thread w[SERVER_WORKER_COUNT];
for (int i = 0; i < SERVER_WORKER_COUNT; i++) w[i] = std::thread(worker, i);
// ZMQ_SERVER for client messages
server_sock = zmq_socket(server_ctx, ZMQ_SERVER); assert(server_sock);
rc = zmq_bind(server_sock, SERVER_ADDR); assert(rc == 0);
while (1) {
// wait client message
zmq_msg_t msg;
rc = zmq_msg_init(&msg); assert(rc == 0);
rc = zmq_msg_recv(&msg, server_sock, 0); assert(rc > 0);
//printf("recv %d bytes at %X from %X\n", zmq_msg_size(&msg), zmq_msg_data(&msg), zmq_msg_routing_id(&msg));
// send message to queue
rc = zmq_msg_send(&msg, queue, 0); assert(rc > 0);
}
}
//*******************************************************************
//****** CLIENT *****************************************************
//*******************************************************************
std::atomic<int> client_cnt[CLIENT_COUNT] = { 0 }; // statistic
std::atomic<int> client_ready = 0;
// client thread
void client(int num)
{
//printf("client %d start. Open %d connections\n", num, CLIENT_CONNECTION);
void *ctx = zmq_ctx_new(); assert(ctx);
void *sock[CLIENT_CONNECTION];
int rc;
// open ZMQ_CLIENT connections
for (int i = 0; i < CLIENT_CONNECTION; i++) {
sock[i] = zmq_socket(ctx, ZMQ_CLIENT); assert(sock[i]);
rc = zmq_connect(sock[i], SERVER_ADDR); assert(rc == 0);
// test connection
zmq_msg_t msg;
int v = rand() % 256 + 1;
message_fill(&msg, v);
rc = zmq_msg_send(&msg, sock[i], 0); assert(rc > 0);
rc = zmq_msg_init(&msg); assert(rc == 0);
rc = zmq_msg_recv(&msg, sock[i], 0); assert(rc > 0);
rc = message_check(&msg); assert(rc == v);
zmq_msg_close(&msg);
}
printf("client %d open %d connections\n", num, CLIENT_CONNECTION);
client_ready++;
while (client_ready < CLIENT_COUNT) Sleep(10); // wait while all clients open sockets
int recconect = 0;
while(1) {
int val[CLIENT_CONNECTION];
zmq_msg_t msg;
// send messages
for(int i = 0; i < CLIENT_CONNECTION; i++) {
val[i] = rand() % MESSAGE_MAX_SIZE + 1;
message_fill(&msg, val[i]);
rc = zmq_msg_send(&msg, sock[i], 0); assert(rc > 0);
}
// recv and check
for (int i = 0; i < CLIENT_CONNECTION; i++) {
rc = zmq_msg_init(&msg); assert(rc == 0);
rc = zmq_msg_recv(&msg, sock[i], 0); assert(rc > 0);
rc = message_check(&msg);
if(rc != val[i] && rc > 0) {
fprintf(stderr, "wrong message: send %d recv %d \n", val[i], rc);
}
zmq_msg_close(&msg);
client_cnt[num]++;
}
// reconnect one
recconect++;
if(recconect == CLIENT_RECCONECT) {
int n = rand() % CLIENT_CONNECTION;
zmq_close(sock[n]);
sock[n] = zmq_socket(ctx, ZMQ_CLIENT); assert(sock[n]);
int rc = zmq_connect(sock[n], SERVER_ADDR); assert(rc == 0);
}
}
}
//*******************************************************************
int main (void) {
int v1, v2, v3; zmq_version(&v1, &v2, &v3);
printf("ZMQ version %d.%d.%d. Compile %s %s\n", v1, v2, v3, __DATE__, __TIME__);
std::thread ct[CLIENT_COUNT];
for (int i = 0; i < CLIENT_COUNT; i++) ct[i] = std::thread(client, i);
std::thread st(server);
int w[SERVER_WORKER_COUNT] = { 0 };
int c[CLIENT_COUNT] = { 0 };
int total = 0;
while(1) {
Sleep(1000);
if (client_ready < CLIENT_COUNT) continue;
// check workers
for(int i = 0; i < SERVER_WORKER_COUNT; i++) {
if(w[i] == worker_cnt[i]) {
fprintf(stderr, "worker %d not work \n", i);
}
w[i] = worker_cnt[i];
}
// check clients
int t = 0;
for (int i = 0; i < CLIENT_COUNT; i++) {
if (c[i] == client_cnt[i]) {
fprintf(stderr, "client %d not work \n", i);
}
c[i] = client_cnt[i];
t += c[i];
}
printf("\rTotal %d messages. Speed %d per second ", t, t - total);
total = t;
}
return 0;
}