From 449d2b008fd6a2adb7ba88a983ae131bca86bc59 Mon Sep 17 00:00:00 2001 From: tqcq Date: Sat, 26 Oct 2024 12:54:02 +0000 Subject: [PATCH] feat: update --- drama/src/rev-mc.c | 56 +++++++++++++++++----------------------------- 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/drama/src/rev-mc.c b/drama/src/rev-mc.c index 27d7bb4..0823a3f 100644 --- a/drama/src/rev-mc.c +++ b/drama/src/rev-mc.c @@ -346,58 +346,44 @@ find_row_mask(std::vector &sets, for (int i = 0; i < thread_num; ++i) { workers.emplace_back([&] { while (!found) { - uint64_t cur_mask = 0; - uint64_t step_count = 0; - uint64_t my_end_pos = 0; + uint64_t cur_mask = 0; + uint64_t diff = 0; + uint64_t end_pos = 0; { std::lock_guard _(lock); - if (row_mask >= last_mask || found) { return; } - step_count = g_cur_pos - g_base_pos; - - cur_mask = row_mask; - my_end_pos = g_cur_pos.fetch_add(outer_step); + cur_mask = row_mask; + diff = g_cur_pos - g_base_pos; + end_pos = g_cur_pos.fetch_add(outer_step); } - while (cur_mask < last_mask && step_count > 0 && !found.load(std::memory_order_relaxed)) { - for (int i = std::min(step_count, inner_step); i > 0 && cur_mask < last_mask; --i) { + + while (diff > 0 && cur_mask < last_mask) { + for (int i = std::min(inner_step, diff); i > 0 && cur_mask < last_mask; --i) { cur_mask = next_bit_permutation(cur_mask); - if (cnt.fetch_add(1) % 10000000 == 0) { - fprintf(stderr, "cnt : %ld, step_count: %5ld, base_pos: %5ld progress: %ld\n", - cnt.load(std::memory_order_relaxed), step_count, g_base_pos.load(), - progress.load()); - } } + diff -= std::min(inner_step, diff); - if (found) { break; } - - step_count -= std::min(step_count, inner_step); - - // my is new - if (my_end_pos > g_base_pos.load(std::memory_order_relaxed) + step_count) { - std::lock_guard _(lock); - if (my_end_pos > g_base_pos + step_count) { - g_base_pos = my_end_pos - step_count; - row_mask = cur_mask; - } - } else if (g_base_pos.load(std::memory_order_relaxed) <= my_end_pos) { - std::lock_guard _(lock); - if (g_base_pos <= my_end_pos) { - cur_mask = row_mask; - step_count = my_end_pos - g_base_pos; - } + std::lock_guard _(lock); + if (end_pos - diff > g_base_pos) { + if (found) { break; } + g_base_pos = end_pos - diff; + row_mask = cur_mask; + } else if (end_pos >= g_base_pos) { + diff = end_pos - g_base_pos; + cur_mask = row_mask; } } - for (int i = outer_step; i > 0 && cur_mask < last_mask && !found.load(std::memory_order_relaxed); --i) { + for (int i = outer_step; i > 0 && cur_mask < last_mask; --i) { if (resolve(cur_mask)) { bool expected = false; if (found.compare_exchange_strong(expected, true)) { std::lock_guard _(lock); + found.store(true); row_mask = cur_mask; - break; } } + cur_mask = next_bit_permutation(cur_mask); } - progress.fetch_add(outer_step); } }); }