diff --git a/src/sled/system/thread_pool.cc b/src/sled/system/thread_pool.cc index 3c54427..e77612a 100644 --- a/src/sled/system/thread_pool.cc +++ b/src/sled/system/thread_pool.cc @@ -13,7 +13,7 @@ ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create() auto state = std::make_shared(); for (int i = 0; i < num_threads; i++) { threads_.emplace_back(std::thread([state] { - state->idle++; + state->idle.fetch_add(1, std::memory_order_relaxed); while (state->is_running) { std::function task; sled::Location loc = SLED_FROM_HERE; @@ -26,10 +26,16 @@ ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create() loc = state->task_queue.front().second; state->task_queue.pop(); } + // FIXME: can't exit if task add self, must check + if (!state->is_running) { + state->idle.fetch_sub(1, std::memory_order_relaxed); + break; + } + if (!state->task_queue.empty()) { state->cv.NotifyOne(); } } if (task) { - state->idle--; + state->idle.fetch_sub(1, std::memory_order_release); try { task(); } catch (const std::exception &e) { @@ -37,7 +43,7 @@ ThreadPool::ThreadPool(int num_threads) : delayed_thread_(sled::Thread::Create() } catch (...) { LOGE(kTag, "ThreadPool::ThreadPool() task unknown exception, from={}", loc.ToString()); } - state->idle++; + state->idle.fetch_add(1, std::memory_order_relaxed); } } }));