diff --git a/include/zmq.h b/include/zmq.h index 5dc11461..8329b37e 100644 --- a/include/zmq.h +++ b/include/zmq.h @@ -147,6 +147,7 @@ ZMQ_EXPORT int zmq_getmsgopt (zmq_msg_t *msg, int option, void *optval, /******************************************************************************/ ZMQ_EXPORT void *zmq_init (int io_threads); +ZMQ_EXPORT void *zmq_init_thread_safe (int io_threads); ZMQ_EXPORT int zmq_term (void *context); /******************************************************************************/ diff --git a/src/ctx.cpp b/src/ctx.cpp index d8783be8..44fd0465 100644 --- a/src/ctx.cpp +++ b/src/ctx.cpp @@ -81,6 +81,16 @@ zmq::ctx_t::ctx_t (uint32_t io_threads_) : zmq_assert (rc == 0); } +void zmq::ctx_t::set_thread_safe() +{ + thread_safe_flag = true; +} + +bool zmq::ctx_t::get_thread_safe() const +{ + return thread_safe_flag; +} + bool zmq::ctx_t::check_tag () { return tag == 0xbadcafe0; diff --git a/src/ctx.hpp b/src/ctx.hpp index d259f594..40b4acff 100644 --- a/src/ctx.hpp +++ b/src/ctx.hpp @@ -99,6 +99,10 @@ namespace zmq reaper_tid = 1 }; + // create thread safe sockets + void set_thread_safe(); + bool get_thread_safe() const; + ~ctx_t (); private: @@ -151,6 +155,8 @@ namespace zmq zmq::socket_base_t *log_socket; mutex_t log_sync; + bool thread_safe_flag; + ctx_t (const ctx_t&); const ctx_t &operator = (const ctx_t&); }; diff --git a/src/socket_base.cpp b/src/socket_base.cpp index 8167786d..7cd24c66 100644 --- a/src/socket_base.cpp +++ b/src/socket_base.cpp @@ -874,6 +874,11 @@ void zmq::socket_base_t::extract_flags (msg_t *msg_) rcvmore = msg_->flags () & msg_t::more ? true : false; } +void zmq::socket_base_t::set_thread_safe() +{ + thread_safe_flag = true; +} + void zmq::socket_base_t::lock() { sync.lock(); diff --git a/src/socket_base.hpp b/src/socket_base.hpp index 3df50802..f22e1ef0 100644 --- a/src/socket_base.hpp +++ b/src/socket_base.hpp @@ -96,6 +96,7 @@ namespace zmq void hiccuped (pipe_t *pipe_); void terminated (pipe_t *pipe_); bool thread_safe() const { return thread_safe_flag; } + void set_thread_safe(); // should be in constructor, here for compat void lock(); void unlock(); protected: diff --git a/src/zmq.cpp b/src/zmq.cpp index 9aca8f9d..777cbc3f 100644 --- a/src/zmq.cpp +++ b/src/zmq.cpp @@ -90,7 +90,7 @@ const char *zmq_strerror (int errnum_) return zmq::errno_to_string (errnum_); } -void *zmq_init (int io_threads_) +static zmq::ctx_t *inner_init (int io_threads_) { if (io_threads_ < 0) { errno = EINVAL; @@ -139,7 +139,19 @@ void *zmq_init (int io_threads_) // Create 0MQ context. zmq::ctx_t *ctx = new (std::nothrow) zmq::ctx_t ((uint32_t) io_threads_); alloc_assert (ctx); - return (void*) ctx; + return ctx; +} + +void *zmq_init (int io_threads_) +{ + return (void*) inner_init (io_threads_); +} + +void *zmq_init_thread_safe (int io_threads_) +{ + zmq::ctx_t *ctx = inner_init (io_threads_); + ctx->set_thread_safe(); + return (void*) ctx; } int zmq_term (void *ctx_) @@ -174,7 +186,10 @@ void *zmq_socket (void *ctx_, int type_) errno = EFAULT; return NULL; } - return (void*) (((zmq::ctx_t*) ctx_)->create_socket (type_)); + zmq::ctx_t *ctx = (zmq::ctx_t*) ctx_; + zmq::socket_base_t *s = ctx->create_socket (type_); + if (ctx->get_thread_safe ()) s->set_thread_safe (); + return (void*) s; } int zmq_close (void *s_)