diff --git a/drama/src/rev-mc.c b/drama/src/rev-mc.c index 2e04d8c..96b8f1f 100644 --- a/drama/src/rev-mc.c +++ b/drama/src/rev-mc.c @@ -320,9 +320,6 @@ find_row_mask(std::vector &sets, return true; }; - std::atomic_bool found{false}; - std::mutex lock; - // while (row_mask < last_mask) { // row_mask = next_bit_permutation(row_mask); // } @@ -332,14 +329,17 @@ find_row_mask(std::vector &sets, if (cpu == 0 && thread_num > 4) { thread_num -= 4; } fprintf(stderr, "thread_num: %d\n", thread_num); - uint64_t step = 1000 * 10; + std::atomic_bool found{false}; + std::mutex lock; + uint64_t inner_step = 1000; + uint64_t outer_step = 1000 * 10; std::atomic g_cur_pos{0}; std::atomic g_base_pos{0}; // for (uint64_t i = row_mask; i < last_mask; ++step) { i = next_bit_permutation(i); } // fprintf(stderr, "total_step: %ld\n", step); // step /= thread_num; // if (!step) { ++step; } - fprintf(stderr, "worker_step: %ld\n", step); + fprintf(stderr, "worker_step: %ld\n", outer_step); std::atomic cnt{0}; std::atomic progress{0}; @@ -348,17 +348,17 @@ find_row_mask(std::vector &sets, while (!found) { uint64_t cur_mask = 0; uint64_t step_count = 0; - uint64_t my_pos = 0; + uint64_t my_end_pos = 0; { std::lock_guard _(lock); if (row_mask >= last_mask || found) { return; } step_count = g_cur_pos - g_base_pos; cur_mask = row_mask; - my_pos = g_cur_pos.fetch_add(step); + my_end_pos = g_cur_pos.fetch_add(outer_step); } while (cur_mask < last_mask && step_count > 0 && !found.load(std::memory_order_relaxed)) { - for (int i = std::min(step_count, step / 10); i > 0 && cur_mask < last_mask; --i) { + for (int i = std::min(step_count, inner_step); i > 0 && cur_mask < last_mask; --i) { cur_mask = next_bit_permutation(cur_mask); if (cnt.fetch_add(1) % 10000000 == 0) { fprintf(stderr, "cnt : %ld, step_count: %5ld, base_pos: %5ld progress: %ld\n", @@ -369,25 +369,25 @@ find_row_mask(std::vector &sets, if (found) { break; } - step_count -= std::min(step_count, step / 10); + step_count -= std::min(step_count, inner_step); // my is new - if (my_pos > g_base_pos.load(std::memory_order_relaxed) + step_count) { + if (my_end_pos > g_base_pos.load(std::memory_order_relaxed) + step_count) { std::lock_guard _(lock); - if (my_pos > g_base_pos + step_count) { - g_base_pos = my_pos; + if (my_end_pos > g_base_pos + step_count) { + g_base_pos = my_end_pos; row_mask = cur_mask; } - } else if (g_base_pos.load(std::memory_order_relaxed) <= my_pos) { + } else if (g_base_pos.load(std::memory_order_relaxed) <= my_end_pos) { std::lock_guard _(lock); - if (g_base_pos <= my_pos) { + if (g_base_pos <= my_end_pos) { cur_mask = row_mask; - step_count = my_pos - g_base_pos; + step_count = my_end_pos - g_base_pos; } } } - for (int i = step; i > 0 && cur_mask < last_mask && !found.load(std::memory_order_relaxed); --i) { + for (int i = outer_step; i > 0 && cur_mask < last_mask && !found.load(std::memory_order_relaxed); --i) { if (resolve(cur_mask)) { bool expected = false; if (found.compare_exchange_strong(expected, true)) { @@ -397,7 +397,7 @@ find_row_mask(std::vector &sets, } } } - progress.fetch_add(step); + progress.fetch_add(outer_step); } }); }