764 lines
23 KiB
C++
764 lines
23 KiB
C++
|
// Copyright 2019 The Marl Authors.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
#include "osfiber.h" // Must come first. See osfiber_ucontext.h.
|
||
|
|
||
|
#include "marl/scheduler.h"
|
||
|
|
||
|
#include "marl/debug.h"
|
||
|
#include "marl/thread.h"
|
||
|
#include "marl/trace.h"
|
||
|
|
||
|
#if defined(_WIN32)
|
||
|
#include <intrin.h> // __nop()
|
||
|
#endif
|
||
|
|
||
|
// Enable to trace scheduler events.
|
||
|
#define ENABLE_TRACE_EVENTS 0
|
||
|
|
||
|
// Enable to print verbose debug logging.
|
||
|
#define ENABLE_DEBUG_LOGGING 0
|
||
|
|
||
|
#if ENABLE_TRACE_EVENTS
|
||
|
#define TRACE(...) MARL_SCOPED_EVENT(__VA_ARGS__)
|
||
|
#else
|
||
|
#define TRACE(...)
|
||
|
#endif
|
||
|
|
||
|
#if ENABLE_DEBUG_LOGGING
|
||
|
#define DBG_LOG(msg, ...) \
|
||
|
printf("%.3x " msg "\n", (int)threadID() & 0xfff, __VA_ARGS__)
|
||
|
#else
|
||
|
#define DBG_LOG(msg, ...)
|
||
|
#endif
|
||
|
|
||
|
#define ASSERT_FIBER_STATE(FIBER, STATE) \
|
||
|
MARL_ASSERT(FIBER->state == STATE, \
|
||
|
"fiber %d was in state %s, but expected %s", (int)FIBER->id, \
|
||
|
Fiber::toString(FIBER->state), Fiber::toString(STATE))
|
||
|
|
||
|
namespace {
|
||
|
|
||
|
#if ENABLE_DEBUG_LOGGING
|
||
|
// threadID() returns a uint64_t representing the currently executing thread.
|
||
|
// threadID() is only intended to be used for debugging purposes.
|
||
|
inline uint64_t threadID() {
|
||
|
auto id = std::this_thread::get_id();
|
||
|
return std::hash<std::thread::id>()(id);
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
inline void nop() {
|
||
|
#if defined(_WIN32)
|
||
|
__nop();
|
||
|
#else
|
||
|
__asm__ __volatile__("nop");
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
inline marl::Scheduler::Config setConfigDefaults(
|
||
|
const marl::Scheduler::Config& cfgIn) {
|
||
|
marl::Scheduler::Config cfg{cfgIn};
|
||
|
if (cfg.workerThread.count > 0 && !cfg.workerThread.affinityPolicy) {
|
||
|
cfg.workerThread.affinityPolicy = marl::Thread::Affinity::Policy::anyOf(
|
||
|
marl::Thread::Affinity::all(cfg.allocator), cfg.allocator);
|
||
|
}
|
||
|
return cfg;
|
||
|
}
|
||
|
|
||
|
} // anonymous namespace
|
||
|
|
||
|
namespace marl {
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
// Scheduler
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
MARL_INSTANTIATE_THREAD_LOCAL(Scheduler*, Scheduler::bound, nullptr);
|
||
|
|
||
|
Scheduler* Scheduler::get() {
|
||
|
return bound;
|
||
|
}
|
||
|
|
||
|
void Scheduler::setBound(Scheduler* scheduler) {
|
||
|
bound = scheduler;
|
||
|
}
|
||
|
|
||
|
void Scheduler::bind() {
|
||
|
MARL_ASSERT(get() == nullptr, "Scheduler already bound");
|
||
|
setBound(this);
|
||
|
{
|
||
|
marl::lock lock(singleThreadedWorkers.mutex);
|
||
|
auto worker = cfg.allocator->make_unique<Worker>(
|
||
|
this, Worker::Mode::SingleThreaded, -1);
|
||
|
worker->start();
|
||
|
auto tid = std::this_thread::get_id();
|
||
|
singleThreadedWorkers.byTid.emplace(tid, std::move(worker));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Scheduler::unbind() {
|
||
|
MARL_ASSERT(get() != nullptr, "No scheduler bound");
|
||
|
auto worker = Worker::getCurrent();
|
||
|
worker->stop();
|
||
|
{
|
||
|
marl::lock lock(get()->singleThreadedWorkers.mutex);
|
||
|
auto tid = std::this_thread::get_id();
|
||
|
auto& workers = get()->singleThreadedWorkers.byTid;
|
||
|
auto it = workers.find(tid);
|
||
|
MARL_ASSERT(it != workers.end(), "singleThreadedWorker not found");
|
||
|
MARL_ASSERT(it->second.get() == worker, "worker is not bound?");
|
||
|
workers.erase(it);
|
||
|
if (workers.empty()) {
|
||
|
get()->singleThreadedWorkers.unbind.notify_one();
|
||
|
}
|
||
|
}
|
||
|
setBound(nullptr);
|
||
|
}
|
||
|
|
||
|
Scheduler::Scheduler(const Config& config)
|
||
|
: cfg(setConfigDefaults(config)),
|
||
|
workerThreads{},
|
||
|
singleThreadedWorkers(config.allocator) {
|
||
|
for (int i = 0; i < cfg.workerThread.count; i++) {
|
||
|
spinningWorkers[i] = -1;
|
||
|
workerThreads[i] =
|
||
|
cfg.allocator->create<Worker>(this, Worker::Mode::MultiThreaded, i);
|
||
|
}
|
||
|
for (int i = 0; i < cfg.workerThread.count; i++) {
|
||
|
workerThreads[i]->start();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Scheduler::~Scheduler() {
|
||
|
{
|
||
|
// Wait until all the single threaded workers have been unbound.
|
||
|
marl::lock lock(singleThreadedWorkers.mutex);
|
||
|
lock.wait(singleThreadedWorkers.unbind,
|
||
|
[this]() REQUIRES(singleThreadedWorkers.mutex) {
|
||
|
return singleThreadedWorkers.byTid.empty();
|
||
|
});
|
||
|
}
|
||
|
|
||
|
// Release all worker threads.
|
||
|
// This will wait for all in-flight tasks to complete before returning.
|
||
|
for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
|
||
|
workerThreads[i]->stop();
|
||
|
}
|
||
|
for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
|
||
|
cfg.allocator->destroy(workerThreads[i]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Scheduler::enqueue(Task&& task) {
|
||
|
if (task.is(Task::Flags::SameThread)) {
|
||
|
Worker::getCurrent()->enqueue(std::move(task));
|
||
|
return;
|
||
|
}
|
||
|
if (cfg.workerThread.count > 0) {
|
||
|
while (true) {
|
||
|
// Prioritize workers that have recently started spinning.
|
||
|
auto i = --nextSpinningWorkerIdx % cfg.workerThread.count;
|
||
|
auto idx = spinningWorkers[i].exchange(-1);
|
||
|
if (idx < 0) {
|
||
|
// If a spinning worker couldn't be found, round-robin the
|
||
|
// workers.
|
||
|
idx = nextEnqueueIndex++ % cfg.workerThread.count;
|
||
|
}
|
||
|
|
||
|
auto worker = workerThreads[idx];
|
||
|
if (worker->tryLock()) {
|
||
|
worker->enqueueAndUnlock(std::move(task));
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
if (auto worker = Worker::getCurrent()) {
|
||
|
worker->enqueue(std::move(task));
|
||
|
} else {
|
||
|
MARL_FATAL(
|
||
|
"singleThreadedWorker not found. Did you forget to call "
|
||
|
"marl::Scheduler::bind()?");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const Scheduler::Config& Scheduler::config() const {
|
||
|
return cfg;
|
||
|
}
|
||
|
|
||
|
bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) {
|
||
|
if (cfg.workerThread.count > 0) {
|
||
|
auto thread = workerThreads[from % cfg.workerThread.count];
|
||
|
if (thread != thief) {
|
||
|
if (thread->steal(out)) {
|
||
|
return true;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
void Scheduler::onBeginSpinning(int workerId) {
|
||
|
auto idx = nextSpinningWorkerIdx++ % cfg.workerThread.count;
|
||
|
spinningWorkers[idx] = workerId;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
// Scheduler::Config
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
Scheduler::Config Scheduler::Config::allCores() {
|
||
|
return Config().setWorkerThreadCount(Thread::numLogicalCPUs());
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
// Scheduler::Fiber
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
Scheduler::Fiber::Fiber(Allocator::unique_ptr<OSFiber>&& impl, uint32_t id)
|
||
|
: id(id), impl(std::move(impl)), worker(Worker::getCurrent()) {
|
||
|
MARL_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
|
||
|
}
|
||
|
|
||
|
Scheduler::Fiber* Scheduler::Fiber::current() {
|
||
|
auto worker = Worker::getCurrent();
|
||
|
return worker != nullptr ? worker->getCurrentFiber() : nullptr;
|
||
|
}
|
||
|
|
||
|
void Scheduler::Fiber::notify() {
|
||
|
worker->enqueue(this);
|
||
|
}
|
||
|
|
||
|
void Scheduler::Fiber::wait(marl::lock& lock, const Predicate& pred) {
|
||
|
MARL_ASSERT(worker == Worker::getCurrent(),
|
||
|
"Scheduler::Fiber::wait() must only be called on the currently "
|
||
|
"executing fiber");
|
||
|
worker->wait(lock, nullptr, pred);
|
||
|
}
|
||
|
|
||
|
void Scheduler::Fiber::switchTo(Fiber* to) {
|
||
|
MARL_ASSERT(worker == Worker::getCurrent(),
|
||
|
"Scheduler::Fiber::switchTo() must only be called on the "
|
||
|
"currently executing fiber");
|
||
|
if (to != this) {
|
||
|
impl->switchTo(to->impl.get());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Allocator::unique_ptr<Scheduler::Fiber> Scheduler::Fiber::create(
|
||
|
Allocator* allocator,
|
||
|
uint32_t id,
|
||
|
size_t stackSize,
|
||
|
const std::function<void()>& func) {
|
||
|
return allocator->make_unique<Fiber>(
|
||
|
OSFiber::createFiber(allocator, stackSize, func), id);
|
||
|
}
|
||
|
|
||
|
Allocator::unique_ptr<Scheduler::Fiber>
|
||
|
Scheduler::Fiber::createFromCurrentThread(Allocator* allocator, uint32_t id) {
|
||
|
return allocator->make_unique<Fiber>(
|
||
|
OSFiber::createFiberFromCurrentThread(allocator), id);
|
||
|
}
|
||
|
|
||
|
const char* Scheduler::Fiber::toString(State state) {
|
||
|
switch (state) {
|
||
|
case State::Idle:
|
||
|
return "Idle";
|
||
|
case State::Yielded:
|
||
|
return "Yielded";
|
||
|
case State::Queued:
|
||
|
return "Queued";
|
||
|
case State::Running:
|
||
|
return "Running";
|
||
|
case State::Waiting:
|
||
|
return "Waiting";
|
||
|
}
|
||
|
MARL_ASSERT(false, "bad fiber state");
|
||
|
return "<unknown>";
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
// Scheduler::WaitingFibers
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
Scheduler::WaitingFibers::WaitingFibers(Allocator* allocator)
|
||
|
: timeouts(allocator), fibers(allocator) {}
|
||
|
|
||
|
Scheduler::WaitingFibers::operator bool() const {
|
||
|
return !fibers.empty();
|
||
|
}
|
||
|
|
||
|
Scheduler::Fiber* Scheduler::WaitingFibers::take(const TimePoint& timeout) {
|
||
|
if (!*this) {
|
||
|
return nullptr;
|
||
|
}
|
||
|
auto it = timeouts.begin();
|
||
|
if (timeout < it->timepoint) {
|
||
|
return nullptr;
|
||
|
}
|
||
|
auto fiber = it->fiber;
|
||
|
timeouts.erase(it);
|
||
|
auto deleted = fibers.erase(fiber) != 0;
|
||
|
(void)deleted;
|
||
|
MARL_ASSERT(deleted, "WaitingFibers::take() maps out of sync");
|
||
|
return fiber;
|
||
|
}
|
||
|
|
||
|
Scheduler::TimePoint Scheduler::WaitingFibers::next() const {
|
||
|
MARL_ASSERT(*this,
|
||
|
"WaitingFibers::next() called when there' no waiting fibers");
|
||
|
return timeouts.begin()->timepoint;
|
||
|
}
|
||
|
|
||
|
void Scheduler::WaitingFibers::add(const TimePoint& timeout, Fiber* fiber) {
|
||
|
timeouts.emplace(Timeout{timeout, fiber});
|
||
|
bool added = fibers.emplace(fiber, timeout).second;
|
||
|
(void)added;
|
||
|
MARL_ASSERT(added, "WaitingFibers::add() fiber already waiting");
|
||
|
}
|
||
|
|
||
|
void Scheduler::WaitingFibers::erase(Fiber* fiber) {
|
||
|
auto it = fibers.find(fiber);
|
||
|
if (it != fibers.end()) {
|
||
|
auto timeout = it->second;
|
||
|
auto erased = timeouts.erase(Timeout{timeout, fiber}) != 0;
|
||
|
(void)erased;
|
||
|
MARL_ASSERT(erased, "WaitingFibers::erase() maps out of sync");
|
||
|
fibers.erase(it);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool Scheduler::WaitingFibers::contains(Fiber* fiber) const {
|
||
|
return fibers.count(fiber) != 0;
|
||
|
}
|
||
|
|
||
|
bool Scheduler::WaitingFibers::Timeout::operator<(const Timeout& o) const {
|
||
|
if (timepoint != o.timepoint) {
|
||
|
return timepoint < o.timepoint;
|
||
|
}
|
||
|
return fiber < o.fiber;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
// Scheduler::Worker
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
MARL_INSTANTIATE_THREAD_LOCAL(Scheduler::Worker*,
|
||
|
Scheduler::Worker::current,
|
||
|
nullptr);
|
||
|
|
||
|
Scheduler::Worker::Worker(Scheduler* scheduler, Mode mode, uint32_t id)
|
||
|
: id(id),
|
||
|
mode(mode),
|
||
|
scheduler(scheduler),
|
||
|
work(scheduler->cfg.allocator),
|
||
|
idleFibers(scheduler->cfg.allocator) {}
|
||
|
|
||
|
void Scheduler::Worker::start() {
|
||
|
switch (mode) {
|
||
|
case Mode::MultiThreaded: {
|
||
|
auto allocator = scheduler->cfg.allocator;
|
||
|
auto& affinityPolicy = scheduler->cfg.workerThread.affinityPolicy;
|
||
|
auto affinity = affinityPolicy->get(id, allocator);
|
||
|
thread = Thread(std::move(affinity), [=] {
|
||
|
Thread::setName("Thread<%.2d>", int(id));
|
||
|
|
||
|
if (auto const& initFunc = scheduler->cfg.workerThread.initializer) {
|
||
|
initFunc(id);
|
||
|
}
|
||
|
|
||
|
Scheduler::setBound(scheduler);
|
||
|
Worker::current = this;
|
||
|
mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
|
||
|
currentFiber = mainFiber.get();
|
||
|
{
|
||
|
marl::lock lock(work.mutex);
|
||
|
run();
|
||
|
}
|
||
|
mainFiber.reset();
|
||
|
Worker::current = nullptr;
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case Mode::SingleThreaded: {
|
||
|
Worker::current = this;
|
||
|
mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
|
||
|
currentFiber = mainFiber.get();
|
||
|
break;
|
||
|
}
|
||
|
default:
|
||
|
MARL_ASSERT(false, "Unknown mode: %d", int(mode));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::stop() {
|
||
|
switch (mode) {
|
||
|
case Mode::MultiThreaded: {
|
||
|
enqueue(Task([this] { shutdown = true; }, Task::Flags::SameThread));
|
||
|
thread.join();
|
||
|
break;
|
||
|
}
|
||
|
case Mode::SingleThreaded: {
|
||
|
marl::lock lock(work.mutex);
|
||
|
shutdown = true;
|
||
|
runUntilShutdown();
|
||
|
Worker::current = nullptr;
|
||
|
break;
|
||
|
}
|
||
|
default:
|
||
|
MARL_ASSERT(false, "Unknown mode: %d", int(mode));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool Scheduler::Worker::wait(const TimePoint* timeout) {
|
||
|
DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
|
||
|
{
|
||
|
marl::lock lock(work.mutex);
|
||
|
suspend(timeout);
|
||
|
}
|
||
|
return timeout == nullptr || std::chrono::system_clock::now() < *timeout;
|
||
|
}
|
||
|
|
||
|
bool Scheduler::Worker::wait(lock& waitLock,
|
||
|
const TimePoint* timeout,
|
||
|
const Predicate& pred) {
|
||
|
DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
|
||
|
while (!pred()) {
|
||
|
// Lock the work mutex to call suspend().
|
||
|
work.mutex.lock();
|
||
|
|
||
|
// Unlock the wait mutex with the work mutex lock held.
|
||
|
// Order is important here as we need to ensure that the fiber is not
|
||
|
// enqueued (via Fiber::notify()) between the waitLock.unlock() and fiber
|
||
|
// switch, otherwise the Fiber::notify() call may be ignored and the fiber
|
||
|
// is never woken.
|
||
|
waitLock.unlock_no_tsa();
|
||
|
|
||
|
// suspend the fiber.
|
||
|
suspend(timeout);
|
||
|
|
||
|
// Fiber resumed. We don't need the work mutex locked any more.
|
||
|
work.mutex.unlock();
|
||
|
|
||
|
// Re-lock to either return due to timeout, or call pred().
|
||
|
waitLock.lock_no_tsa();
|
||
|
|
||
|
// Check timeout.
|
||
|
if (timeout != nullptr && std::chrono::system_clock::now() >= *timeout) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
// Spurious wake up. Spin again.
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::suspend(
|
||
|
const std::chrono::system_clock::time_point* timeout) {
|
||
|
// Current fiber is yielding as it is blocked.
|
||
|
if (timeout != nullptr) {
|
||
|
changeFiberState(currentFiber, Fiber::State::Running,
|
||
|
Fiber::State::Waiting);
|
||
|
work.waiting.add(*timeout, currentFiber);
|
||
|
} else {
|
||
|
changeFiberState(currentFiber, Fiber::State::Running,
|
||
|
Fiber::State::Yielded);
|
||
|
}
|
||
|
|
||
|
// First wait until there's something else this worker can do.
|
||
|
waitForWork();
|
||
|
|
||
|
work.numBlockedFibers++;
|
||
|
|
||
|
if (!work.fibers.empty()) {
|
||
|
// There's another fiber that has become unblocked, resume that.
|
||
|
work.num--;
|
||
|
auto to = containers::take(work.fibers);
|
||
|
ASSERT_FIBER_STATE(to, Fiber::State::Queued);
|
||
|
switchToFiber(to);
|
||
|
} else if (!idleFibers.empty()) {
|
||
|
// There's an old fiber we can reuse, resume that.
|
||
|
auto to = containers::take(idleFibers);
|
||
|
ASSERT_FIBER_STATE(to, Fiber::State::Idle);
|
||
|
switchToFiber(to);
|
||
|
} else {
|
||
|
// Tasks to process and no existing fibers to resume.
|
||
|
// Spawn a new fiber.
|
||
|
switchToFiber(createWorkerFiber());
|
||
|
}
|
||
|
|
||
|
work.numBlockedFibers--;
|
||
|
|
||
|
setFiberState(currentFiber, Fiber::State::Running);
|
||
|
}
|
||
|
|
||
|
bool Scheduler::Worker::tryLock() {
|
||
|
return work.mutex.try_lock();
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::enqueue(Fiber* fiber) {
|
||
|
bool notify = false;
|
||
|
{
|
||
|
marl::lock lock(work.mutex);
|
||
|
DBG_LOG("%d: ENQUEUE(%d %s)", (int)id, (int)fiber->id,
|
||
|
Fiber::toString(fiber->state));
|
||
|
switch (fiber->state) {
|
||
|
case Fiber::State::Running:
|
||
|
case Fiber::State::Queued:
|
||
|
return; // Nothing to do here - task is already queued or running.
|
||
|
case Fiber::State::Waiting:
|
||
|
work.waiting.erase(fiber);
|
||
|
break;
|
||
|
case Fiber::State::Idle:
|
||
|
case Fiber::State::Yielded:
|
||
|
break;
|
||
|
}
|
||
|
notify = work.notifyAdded;
|
||
|
work.fibers.push_back(fiber);
|
||
|
MARL_ASSERT(!work.waiting.contains(fiber),
|
||
|
"fiber is unexpectedly in the waiting list");
|
||
|
setFiberState(fiber, Fiber::State::Queued);
|
||
|
work.num++;
|
||
|
}
|
||
|
|
||
|
if (notify) {
|
||
|
work.added.notify_one();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::enqueue(Task&& task) {
|
||
|
work.mutex.lock();
|
||
|
enqueueAndUnlock(std::move(task));
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::enqueueAndUnlock(Task&& task) {
|
||
|
auto notify = work.notifyAdded;
|
||
|
work.tasks.push_back(std::move(task));
|
||
|
work.num++;
|
||
|
work.mutex.unlock();
|
||
|
if (notify) {
|
||
|
work.added.notify_one();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool Scheduler::Worker::steal(Task& out) {
|
||
|
if (work.num.load() == 0) {
|
||
|
return false;
|
||
|
}
|
||
|
if (!work.mutex.try_lock()) {
|
||
|
return false;
|
||
|
}
|
||
|
if (work.tasks.empty() || work.tasks.front().is(Task::Flags::SameThread)) {
|
||
|
work.mutex.unlock();
|
||
|
return false;
|
||
|
}
|
||
|
work.num--;
|
||
|
out = containers::take(work.tasks);
|
||
|
work.mutex.unlock();
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::run() {
|
||
|
if (mode == Mode::MultiThreaded) {
|
||
|
MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), Fiber::current()->id);
|
||
|
// This is the entry point for a multi-threaded worker.
|
||
|
// Start with a regular condition-variable wait for work. This avoids
|
||
|
// starting the thread with a spinForWorkAndLock().
|
||
|
work.wait([this]() REQUIRES(work.mutex) {
|
||
|
return work.num > 0 || work.waiting || shutdown;
|
||
|
});
|
||
|
}
|
||
|
ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
|
||
|
runUntilShutdown();
|
||
|
switchToFiber(mainFiber.get());
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::runUntilShutdown() {
|
||
|
while (!shutdown || work.num > 0 || work.numBlockedFibers > 0U) {
|
||
|
waitForWork();
|
||
|
runUntilIdle();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::waitForWork() {
|
||
|
MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
|
||
|
"work.num out of sync");
|
||
|
if (work.num > 0) {
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
if (mode == Mode::MultiThreaded) {
|
||
|
scheduler->onBeginSpinning(id);
|
||
|
work.mutex.unlock();
|
||
|
spinForWorkAndLock();
|
||
|
}
|
||
|
|
||
|
work.wait([this]() REQUIRES(work.mutex) {
|
||
|
return work.num > 0 || (shutdown && work.numBlockedFibers == 0U);
|
||
|
});
|
||
|
if (work.waiting) {
|
||
|
enqueueFiberTimeouts();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::enqueueFiberTimeouts() {
|
||
|
auto now = std::chrono::system_clock::now();
|
||
|
while (auto fiber = work.waiting.take(now)) {
|
||
|
changeFiberState(fiber, Fiber::State::Waiting, Fiber::State::Queued);
|
||
|
DBG_LOG("%d: TIMEOUT(%d)", (int)id, (int)fiber->id);
|
||
|
work.fibers.push_back(fiber);
|
||
|
work.num++;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::changeFiberState(Fiber* fiber,
|
||
|
Fiber::State from,
|
||
|
Fiber::State to) const {
|
||
|
(void)from; // Unusued parameter when ENABLE_DEBUG_LOGGING is disabled.
|
||
|
DBG_LOG("%d: CHANGE_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
|
||
|
Fiber::toString(from), Fiber::toString(to));
|
||
|
ASSERT_FIBER_STATE(fiber, from);
|
||
|
fiber->state = to;
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::setFiberState(Fiber* fiber, Fiber::State to) const {
|
||
|
DBG_LOG("%d: SET_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
|
||
|
Fiber::toString(fiber->state), Fiber::toString(to));
|
||
|
fiber->state = to;
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::spinForWorkAndLock() {
|
||
|
TRACE("SPIN");
|
||
|
Task stolen;
|
||
|
|
||
|
constexpr auto duration = std::chrono::milliseconds(1);
|
||
|
auto start = std::chrono::high_resolution_clock::now();
|
||
|
while (std::chrono::high_resolution_clock::now() - start < duration) {
|
||
|
for (int i = 0; i < 256; i++) // Empirically picked magic number!
|
||
|
{
|
||
|
// clang-format off
|
||
|
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
|
||
|
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
|
||
|
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
|
||
|
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
|
||
|
// clang-format on
|
||
|
|
||
|
if (work.num > 0) {
|
||
|
work.mutex.lock();
|
||
|
if (work.num > 0) {
|
||
|
return;
|
||
|
}
|
||
|
else {
|
||
|
// Our new task was stolen by another worker. Keep spinning.
|
||
|
work.mutex.unlock();
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (scheduler->stealWork(this, rng(), stolen)) {
|
||
|
work.mutex.lock();
|
||
|
work.tasks.emplace_back(std::move(stolen));
|
||
|
work.num++;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
std::this_thread::yield();
|
||
|
}
|
||
|
work.mutex.lock();
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::runUntilIdle() {
|
||
|
ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
|
||
|
MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
|
||
|
"work.num out of sync");
|
||
|
while (!work.fibers.empty() || !work.tasks.empty()) {
|
||
|
// Note: we cannot take and store on the stack more than a single fiber
|
||
|
// or task at a time, as the Fiber may yield and these items may get
|
||
|
// held on suspended fiber stack.
|
||
|
|
||
|
while (!work.fibers.empty()) {
|
||
|
work.num--;
|
||
|
auto fiber = containers::take(work.fibers);
|
||
|
// Sanity checks,
|
||
|
MARL_ASSERT(idleFibers.count(fiber) == 0, "dequeued fiber is idle");
|
||
|
MARL_ASSERT(fiber != currentFiber, "dequeued fiber is currently running");
|
||
|
ASSERT_FIBER_STATE(fiber, Fiber::State::Queued);
|
||
|
|
||
|
changeFiberState(currentFiber, Fiber::State::Running, Fiber::State::Idle);
|
||
|
auto added = idleFibers.emplace(currentFiber).second;
|
||
|
(void)added;
|
||
|
MARL_ASSERT(added, "fiber already idle");
|
||
|
|
||
|
switchToFiber(fiber);
|
||
|
changeFiberState(currentFiber, Fiber::State::Idle, Fiber::State::Running);
|
||
|
}
|
||
|
|
||
|
if (!work.tasks.empty()) {
|
||
|
work.num--;
|
||
|
auto task = containers::take(work.tasks);
|
||
|
work.mutex.unlock();
|
||
|
|
||
|
// Run the task.
|
||
|
task();
|
||
|
|
||
|
// std::function<> can carry arguments with complex destructors.
|
||
|
// Ensure these are destructed outside of the lock.
|
||
|
task = Task();
|
||
|
|
||
|
work.mutex.lock();
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
|
||
|
auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1);
|
||
|
DBG_LOG("%d: CREATE(%d)", (int)id, (int)fiberId);
|
||
|
auto fiber = Fiber::create(scheduler->cfg.allocator, fiberId,
|
||
|
scheduler->cfg.fiberStackSize,
|
||
|
[&]() REQUIRES(work.mutex) { run(); });
|
||
|
auto ptr = fiber.get();
|
||
|
workerFibers.emplace_back(std::move(fiber));
|
||
|
return ptr;
|
||
|
}
|
||
|
|
||
|
void Scheduler::Worker::switchToFiber(Fiber* to) {
|
||
|
DBG_LOG("%d: SWITCH(%d -> %d)", (int)id, (int)currentFiber->id, (int)to->id);
|
||
|
MARL_ASSERT(to == mainFiber.get() || idleFibers.count(to) == 0,
|
||
|
"switching to idle fiber");
|
||
|
auto from = currentFiber;
|
||
|
currentFiber = to;
|
||
|
from->switchTo(to);
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
// Scheduler::Worker::Work
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
Scheduler::Worker::Work::Work(Allocator* allocator)
|
||
|
: tasks(allocator), fibers(allocator), waiting(allocator) {}
|
||
|
|
||
|
template <typename F>
|
||
|
void Scheduler::Worker::Work::wait(F&& f) {
|
||
|
notifyAdded = true;
|
||
|
if (waiting) {
|
||
|
mutex.wait_until_locked(added, waiting.next(), f);
|
||
|
} else {
|
||
|
mutex.wait_locked(added, f);
|
||
|
}
|
||
|
notifyAdded = false;
|
||
|
}
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
// Scheduler::Worker::Work
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
Scheduler::SingleThreadedWorkers::SingleThreadedWorkers(Allocator* allocator)
|
||
|
: byTid(allocator) {}
|
||
|
|
||
|
} // namespace marl
|