411 lines
13 KiB
C++
411 lines
13 KiB
C++
// Copyright 2020 The Marl Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// marl::DAG<> provides an ahead of time, declarative, directed acyclic
|
|
// task graph.
|
|
|
|
#ifndef marl_dag_h
|
|
#define marl_dag_h
|
|
|
|
#include "containers.h"
|
|
#include "export.h"
|
|
#include "memory.h"
|
|
#include "scheduler.h"
|
|
#include "waitgroup.h"
|
|
|
|
namespace marl {
|
|
namespace detail {
|
|
using DAGCounter = std::atomic<uint32_t>;
|
|
template <typename T>
|
|
struct DAGRunContext {
|
|
T data;
|
|
Allocator::unique_ptr<DAGCounter> counters;
|
|
|
|
template <typename F>
|
|
MARL_NO_EXPORT inline void invoke(F&& f) {
|
|
f(data);
|
|
}
|
|
};
|
|
template <>
|
|
struct DAGRunContext<void> {
|
|
Allocator::unique_ptr<DAGCounter> counters;
|
|
|
|
template <typename F>
|
|
MARL_NO_EXPORT inline void invoke(F&& f) {
|
|
f();
|
|
}
|
|
};
|
|
template <typename T>
|
|
struct DAGWork {
|
|
using type = std::function<void(T)>;
|
|
};
|
|
template <>
|
|
struct DAGWork<void> {
|
|
using type = std::function<void()>;
|
|
};
|
|
} // namespace detail
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Forward declarations
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
class DAG;
|
|
|
|
template <typename T>
|
|
class DAGBuilder;
|
|
|
|
template <typename T>
|
|
class DAGNodeBuilder;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// DAGBase<T>
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// DAGBase is derived by DAG<T> and DAG<void>. It has no public API.
|
|
template <typename T>
|
|
class DAGBase {
|
|
protected:
|
|
friend DAGBuilder<T>;
|
|
friend DAGNodeBuilder<T>;
|
|
|
|
using RunContext = detail::DAGRunContext<T>;
|
|
using Counter = detail::DAGCounter;
|
|
using NodeIndex = size_t;
|
|
using Work = typename detail::DAGWork<T>::type;
|
|
static const constexpr size_t NumReservedNodes = 32;
|
|
static const constexpr size_t NumReservedNumOuts = 4;
|
|
static const constexpr size_t InvalidCounterIndex = ~static_cast<size_t>(0);
|
|
static const constexpr NodeIndex RootIndex = 0;
|
|
static const constexpr NodeIndex InvalidNodeIndex =
|
|
~static_cast<NodeIndex>(0);
|
|
|
|
// DAG work node.
|
|
struct Node {
|
|
MARL_NO_EXPORT inline Node() = default;
|
|
MARL_NO_EXPORT inline Node(Work&& work);
|
|
MARL_NO_EXPORT inline Node(const Work& work);
|
|
|
|
// The work to perform for this node in the graph.
|
|
Work work;
|
|
|
|
// counterIndex if valid, is the index of the counter in the RunContext for
|
|
// this node. The counter is decremented for each completed dependency task
|
|
// (ins), and once it reaches 0, this node will be invoked.
|
|
size_t counterIndex = InvalidCounterIndex;
|
|
|
|
// Indices for all downstream nodes.
|
|
containers::vector<NodeIndex, NumReservedNumOuts> outs;
|
|
};
|
|
|
|
// initCounters() allocates and initializes the ctx->coutners from
|
|
// initialCounters.
|
|
MARL_NO_EXPORT inline void initCounters(RunContext* ctx,
|
|
Allocator* allocator);
|
|
|
|
// notify() is called each time a dependency task (ins) has completed for the
|
|
// node with the given index.
|
|
// If all dependency tasks have completed (or this is the root node) then
|
|
// notify() returns true and the caller should then call invoke().
|
|
MARL_NO_EXPORT inline bool notify(RunContext*, NodeIndex);
|
|
|
|
// invoke() calls the work function for the node with the given index, then
|
|
// calls notify() and possibly invoke() for all the dependee nodes.
|
|
MARL_NO_EXPORT inline void invoke(RunContext*, NodeIndex, WaitGroup*);
|
|
|
|
// nodes is the full list of the nodes in the graph.
|
|
// nodes[0] is always the root node, which has no dependencies (ins).
|
|
containers::vector<Node, NumReservedNodes> nodes;
|
|
|
|
// initialCounters is a list of initial counter values to be copied to
|
|
// RunContext::counters on DAG<>::run().
|
|
// initialCounters is indexed by Node::counterIndex, and only contains counts
|
|
// for nodes that have at least 2 dependencies (ins) - because of this the
|
|
// number of entries in initialCounters may be fewer than nodes.
|
|
containers::vector<uint32_t, NumReservedNodes> initialCounters;
|
|
};
|
|
|
|
template <typename T>
|
|
DAGBase<T>::Node::Node(Work&& work) : work(std::move(work)) {}
|
|
|
|
template <typename T>
|
|
DAGBase<T>::Node::Node(const Work& work) : work(work) {}
|
|
|
|
template <typename T>
|
|
void DAGBase<T>::initCounters(RunContext* ctx, Allocator* allocator) {
|
|
auto numCounters = initialCounters.size();
|
|
ctx->counters = allocator->make_unique_n<Counter>(numCounters);
|
|
for (size_t i = 0; i < numCounters; i++) {
|
|
ctx->counters.get()[i] = {initialCounters[i]};
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
bool DAGBase<T>::notify(RunContext* ctx, NodeIndex nodeIdx) {
|
|
Node* node = &nodes[nodeIdx];
|
|
|
|
// If we have multiple dependencies, decrement the counter and check whether
|
|
// we've reached 0.
|
|
if (node->counterIndex == InvalidCounterIndex) {
|
|
return true;
|
|
}
|
|
auto counters = ctx->counters.get();
|
|
auto counter = --counters[node->counterIndex];
|
|
return counter == 0;
|
|
}
|
|
|
|
template <typename T>
|
|
void DAGBase<T>::invoke(RunContext* ctx, NodeIndex nodeIdx, WaitGroup* wg) {
|
|
Node* node = &nodes[nodeIdx];
|
|
|
|
// Run this node's work.
|
|
if (node->work) {
|
|
ctx->invoke(node->work);
|
|
}
|
|
|
|
// Then call notify() on all dependees (outs), and invoke() those that
|
|
// returned true.
|
|
// We buffer the node to invoke (toInvoke) so we can schedule() all but the
|
|
// last node to invoke(), and directly call the last invoke() on this thread.
|
|
// This is done to avoid the overheads of scheduling when a direct call would
|
|
// suffice.
|
|
NodeIndex toInvoke = InvalidNodeIndex;
|
|
for (NodeIndex idx : node->outs) {
|
|
if (notify(ctx, idx)) {
|
|
if (toInvoke != InvalidNodeIndex) {
|
|
wg->add(1);
|
|
// Schedule while promoting the WaitGroup capture from a pointer
|
|
// reference to a value. This ensures that the WaitGroup isn't dropped
|
|
// while in use.
|
|
schedule(
|
|
[=](WaitGroup wg) {
|
|
invoke(ctx, toInvoke, &wg);
|
|
wg.done();
|
|
},
|
|
*wg);
|
|
}
|
|
toInvoke = idx;
|
|
}
|
|
}
|
|
if (toInvoke != InvalidNodeIndex) {
|
|
invoke(ctx, toInvoke, wg);
|
|
}
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// DAGNodeBuilder<T>
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// DAGNodeBuilder is the builder interface for a DAG node.
|
|
template <typename T>
|
|
class DAGNodeBuilder {
|
|
using NodeIndex = typename DAGBase<T>::NodeIndex;
|
|
|
|
public:
|
|
// then() builds and returns a new DAG node that will be invoked after this
|
|
// node has completed.
|
|
//
|
|
// F is a function that will be called when the new DAG node is invoked, with
|
|
// the signature:
|
|
// void(T) when T is not void
|
|
// or
|
|
// void() when T is void
|
|
template <typename F>
|
|
MARL_NO_EXPORT inline DAGNodeBuilder then(F&&);
|
|
|
|
private:
|
|
friend DAGBuilder<T>;
|
|
MARL_NO_EXPORT inline DAGNodeBuilder(DAGBuilder<T>*, NodeIndex);
|
|
DAGBuilder<T>* builder;
|
|
NodeIndex index;
|
|
};
|
|
|
|
template <typename T>
|
|
DAGNodeBuilder<T>::DAGNodeBuilder(DAGBuilder<T>* builder, NodeIndex index)
|
|
: builder(builder), index(index) {}
|
|
|
|
template <typename T>
|
|
template <typename F>
|
|
DAGNodeBuilder<T> DAGNodeBuilder<T>::then(F&& work) {
|
|
auto node = builder->node(std::forward<F>(work));
|
|
builder->addDependency(*this, node);
|
|
return node;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// DAGBuilder<T>
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
template <typename T>
|
|
class DAGBuilder {
|
|
public:
|
|
// DAGBuilder constructor
|
|
MARL_NO_EXPORT inline DAGBuilder(Allocator* allocator = Allocator::Default);
|
|
|
|
// root() returns the root DAG node.
|
|
MARL_NO_EXPORT inline DAGNodeBuilder<T> root();
|
|
|
|
// node() builds and returns a new DAG node with no initial dependencies.
|
|
// The returned node must be attached to the graph in order to invoke F or any
|
|
// of the dependees of this returned node.
|
|
//
|
|
// F is a function that will be called when the new DAG node is invoked, with
|
|
// the signature:
|
|
// void(T) when T is not void
|
|
// or
|
|
// void() when T is void
|
|
template <typename F>
|
|
MARL_NO_EXPORT inline DAGNodeBuilder<T> node(F&& work);
|
|
|
|
// node() builds and returns a new DAG node that depends on all the tasks in
|
|
// after to be completed before invoking F.
|
|
//
|
|
// F is a function that will be called when the new DAG node is invoked, with
|
|
// the signature:
|
|
// void(T) when T is not void
|
|
// or
|
|
// void() when T is void
|
|
template <typename F>
|
|
MARL_NO_EXPORT inline DAGNodeBuilder<T> node(
|
|
F&& work,
|
|
std::initializer_list<DAGNodeBuilder<T>> after);
|
|
|
|
// addDependency() adds parent as dependency on child. All dependencies of
|
|
// child must have completed before child is invoked.
|
|
MARL_NO_EXPORT inline void addDependency(DAGNodeBuilder<T> parent,
|
|
DAGNodeBuilder<T> child);
|
|
|
|
// build() constructs and returns the DAG. No other methods of this class may
|
|
// be called after calling build().
|
|
MARL_NO_EXPORT inline Allocator::unique_ptr<DAG<T>> build();
|
|
|
|
private:
|
|
static const constexpr size_t NumReservedNumIns = 4;
|
|
using Node = typename DAG<T>::Node;
|
|
|
|
// The DAG being built.
|
|
Allocator::unique_ptr<DAG<T>> dag;
|
|
|
|
// Number of dependencies (ins) for each node in dag->nodes.
|
|
containers::vector<uint32_t, NumReservedNumIns> numIns;
|
|
};
|
|
|
|
template <typename T>
|
|
DAGBuilder<T>::DAGBuilder(Allocator* allocator /* = Allocator::Default */)
|
|
: dag(allocator->make_unique<DAG<T>>()), numIns(allocator) {
|
|
// Add root
|
|
dag->nodes.emplace_back(Node{});
|
|
numIns.emplace_back(0);
|
|
}
|
|
|
|
template <typename T>
|
|
DAGNodeBuilder<T> DAGBuilder<T>::root() {
|
|
return DAGNodeBuilder<T>{this, DAGBase<T>::RootIndex};
|
|
}
|
|
|
|
template <typename T>
|
|
template <typename F>
|
|
DAGNodeBuilder<T> DAGBuilder<T>::node(F&& work) {
|
|
return node(std::forward<F>(work), {});
|
|
}
|
|
|
|
template <typename T>
|
|
template <typename F>
|
|
DAGNodeBuilder<T> DAGBuilder<T>::node(
|
|
F&& work,
|
|
std::initializer_list<DAGNodeBuilder<T>> after) {
|
|
MARL_ASSERT(numIns.size() == dag->nodes.size(),
|
|
"NodeBuilder vectors out of sync");
|
|
auto index = dag->nodes.size();
|
|
numIns.emplace_back(0);
|
|
dag->nodes.emplace_back(Node{std::forward<F>(work)});
|
|
auto node = DAGNodeBuilder<T>{this, index};
|
|
for (auto in : after) {
|
|
addDependency(in, node);
|
|
}
|
|
return node;
|
|
}
|
|
|
|
template <typename T>
|
|
void DAGBuilder<T>::addDependency(DAGNodeBuilder<T> parent,
|
|
DAGNodeBuilder<T> child) {
|
|
numIns[child.index]++;
|
|
dag->nodes[parent.index].outs.push_back(child.index);
|
|
}
|
|
|
|
template <typename T>
|
|
Allocator::unique_ptr<DAG<T>> DAGBuilder<T>::build() {
|
|
auto numNodes = dag->nodes.size();
|
|
MARL_ASSERT(numIns.size() == dag->nodes.size(),
|
|
"NodeBuilder vectors out of sync");
|
|
for (size_t i = 0; i < numNodes; i++) {
|
|
if (numIns[i] > 1) {
|
|
auto& node = dag->nodes[i];
|
|
node.counterIndex = dag->initialCounters.size();
|
|
dag->initialCounters.push_back(numIns[i]);
|
|
}
|
|
}
|
|
return std::move(dag);
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// DAG<T>
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
template <typename T = void>
|
|
class DAG : public DAGBase<T> {
|
|
public:
|
|
using Builder = DAGBuilder<T>;
|
|
using NodeBuilder = DAGNodeBuilder<T>;
|
|
|
|
// run() invokes the function of each node in the graph of the DAG, passing
|
|
// data to each, starting with the root node. All dependencies need to have
|
|
// completed their function before dependees will be invoked.
|
|
MARL_NO_EXPORT inline void run(T& data,
|
|
Allocator* allocator = Allocator::Default);
|
|
};
|
|
|
|
template <typename T>
|
|
void DAG<T>::run(T& arg, Allocator* allocator /* = Allocator::Default */) {
|
|
typename DAGBase<T>::RunContext ctx{arg};
|
|
this->initCounters(&ctx, allocator);
|
|
WaitGroup wg;
|
|
this->invoke(&ctx, this->RootIndex, &wg);
|
|
wg.wait();
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// DAG<void>
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
template <>
|
|
class DAG<void> : public DAGBase<void> {
|
|
public:
|
|
using Builder = DAGBuilder<void>;
|
|
using NodeBuilder = DAGNodeBuilder<void>;
|
|
|
|
// run() invokes the function of each node in the graph of the DAG, starting
|
|
// with the root node. All dependencies need to have completed their function
|
|
// before dependees will be invoked.
|
|
MARL_NO_EXPORT inline void run(Allocator* allocator = Allocator::Default);
|
|
};
|
|
|
|
void DAG<void>::run(Allocator* allocator /* = Allocator::Default */) {
|
|
typename DAGBase<void>::RunContext ctx{};
|
|
this->initCounters(&ctx, allocator);
|
|
WaitGroup wg;
|
|
this->invoke(&ctx, this->RootIndex, &wg);
|
|
wg.wait();
|
|
}
|
|
|
|
} // namespace marl
|
|
|
|
#endif // marl_dag_h
|