From 8a74c8e78f20c6fce9dcf525337f17529ef53cae Mon Sep 17 00:00:00 2001 From: tqcq Date: Sat, 26 Oct 2024 13:12:24 +0000 Subject: [PATCH] feat: update --- drama/src/rev-mc.c | 55 +++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/drama/src/rev-mc.c b/drama/src/rev-mc.c index ed00035..9ac0afa 100644 --- a/drama/src/rev-mc.c +++ b/drama/src/rev-mc.c @@ -340,6 +340,35 @@ find_row_mask(std::vector &sets, // step /= thread_num; // if (!step) { ++step; } fprintf(stderr, "worker_step: %ld\n", outer_step); + auto try_update_row_mask = [&](uint64_t pos, uint64_t mask) { + if (pos > g_base_pos.load(std::memory_order_relaxed)) { + std::lock_guard _(lock); + if (pos > g_base_pos.load()) { + g_base_pos = row_mask; + row_mask = mask; + } + } + }; + auto try_get_row_mask = [&](uint64_t end_pos, uint64_t &diff, uint64_t &cur_mask) { + uint64_t v = g_base_pos.load(std::memory_order_relaxed); + if (end_pos - diff < v && end_pos >= v) { + std::lock_guard _(lock); + v = g_base_pos.load(); + if (end_pos - diff < v && end_pos >= v) { + diff = end_pos - v; + cur_mask = row_mask; + } + } + }; + + auto set_found = [&](uint64_t cur_mask) { + bool expected = false; + if (found.compare_exchange_strong(expected, true)) { + std::lock_guard _(lock); + found.store(true); + row_mask = cur_mask; + } + }; std::atomic cnt{0}; std::atomic progress{0}; @@ -361,32 +390,12 @@ find_row_mask(std::vector &sets, cur_mask = next_bit_permutation(cur_mask); } diff -= std::min(inner_step, diff); - - if (end_pos - diff > g_base_pos) { - std::lock_guard _(lock); - if (found) { break; } - if (end_pos - diff > g_base_pos) { - g_base_pos = end_pos - diff; - row_mask = cur_mask; - } - } else if (end_pos >= g_base_pos) { - std::lock_guard _(lock); - if (end_pos >= g_base_pos) { - diff = end_pos - g_base_pos; - cur_mask = row_mask; - } - } + try_update_row_mask(end_pos - diff, cur_mask); + try_get_row_mask(end_pos, diff, cur_mask); } for (int i = outer_step; i > 0 && cur_mask < last_mask; --i) { - if (resolve(cur_mask) && !found) { - bool expected = false; - if (found.compare_exchange_strong(expected, true)) { - std::lock_guard _(lock); - found.store(true); - row_mask = cur_mask; - } - } + if (resolve(cur_mask)) { set_found(cur_mask); } cur_mask = next_bit_permutation(cur_mask); } }