From a9fe57bf3830f92fda83e474c41d0c2a67f71bc5 Mon Sep 17 00:00:00 2001 From: tqcq Date: Sat, 26 Oct 2024 09:34:46 +0000 Subject: [PATCH] feat: fast update --- drama/src/rev-mc.c | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/drama/src/rev-mc.c b/drama/src/rev-mc.c index e568076..a15a491 100644 --- a/drama/src/rev-mc.c +++ b/drama/src/rev-mc.c @@ -307,9 +307,8 @@ find_row_mask(std::vector &sets, uint64_t row_mask = init_row_mask; const uint64_t last_mask = init_last_mask; row_mask <<= CL_SHIFT;// skip the lowest 6 bits since they're used for CL addressing - bool need_update_row_mask = false; - auto resolve = [=](uint64_t row_mask) -> bool { + auto resolve = [=](uint64_t row_mask) -> bool { if (row_mask & LS_BITMASK(CL_SHIFT)) { return false; } for (auto addr_pool : same_row_sets) { addr_tuple base_addr = addr_pool[0]; @@ -334,6 +333,8 @@ find_row_mask(std::vector &sets, fprintf(stderr, "thread_num: %d\n", thread_num); uint64_t step = 1000000; + std::atomic cur_pos{0}; + std::atomic base_pos{0}; // for (uint64_t i = row_mask; i < last_mask; ++step) { i = next_bit_permutation(i); } // fprintf(stderr, "total_step: %ld\n", step); // step /= thread_num; @@ -343,17 +344,29 @@ find_row_mask(std::vector &sets, for (int i = 0; i < thread_num; ++i) { workers.emplace_back([&] { while (!found) { - uint64_t cur_mask = last_mask; + uint64_t cur_mask = last_mask; + uint64_t step_count = 0; + uint64_t my_pos = 0; { std::lock_guard _(lock); if (row_mask >= last_mask || found) { break; } - if (need_update_row_mask) { - for (int i = 0; i < step && row_mask < last_mask; ++i) { - row_mask = next_bit_permutation(row_mask); - } + cur_mask = row_mask; + step_count = cur_pos - base_pos; + my_pos = cur_pos.fetch_add(1); + } + while (row_mask < last_mask && step_count > 0) { + --step_count; + for (int i = 0; i < step && row_mask < last_mask; ++i) { + row_mask = next_bit_permutation(row_mask); + } + } + // update pos + { + std::lock_guard _(lock); + if (my_pos > base_pos) { + base_pos = my_pos; + row_mask = cur_mask; } - cur_mask = row_mask; - need_update_row_mask = true; } for (int i = 0; i < step && cur_mask < last_mask; ++i) {