feat update pending_task_safety_flag

This commit is contained in:
tqcq
2024-03-19 15:12:26 +08:00
parent 7e1a443130
commit d716ac684e
4 changed files with 68 additions and 38 deletions

View File

@@ -10,11 +10,11 @@
#include "sled/ref_counted_base.h"
#include "sled/scoped_refptr.h"
#include <functional>
namespace sled {
class PendingTaskSafetyFlag final
: public sled::RefCountedNonVirtual<PendingTaskSafetyFlag> {
class PendingTaskSafetyFlag final : public sled::RefCountedNonVirtual<PendingTaskSafetyFlag> {
public:
static sled::scoped_refptr<PendingTaskSafetyFlag> Create();
static sled::scoped_refptr<PendingTaskSafetyFlag> CreateDetached();
@@ -29,11 +29,36 @@ protected:
explicit PendingTaskSafetyFlag(bool alive) : alive_(alive) {}
private:
static sled::scoped_refptr<PendingTaskSafetyFlag>
CreateInternal(bool alive);
static sled::scoped_refptr<PendingTaskSafetyFlag> CreateInternal(bool alive);
bool alive_ = true;
};
class ScopedTaskSafety final {
public:
ScopedTaskSafety() = default;
explicit ScopedTaskSafety(scoped_refptr<PendingTaskSafetyFlag> flag) : flag_(std::move(flag)) {}
~ScopedTaskSafety() { flag_->SetNotAlive(); }
void reset(scoped_refptr<PendingTaskSafetyFlag> new_flag = PendingTaskSafetyFlag::Create())
{
flag_->SetNotAlive();
flag_ = std::move(new_flag);
}
private:
scoped_refptr<PendingTaskSafetyFlag> flag_;
};
inline std::function<void()>
SafeTask(scoped_refptr<PendingTaskSafetyFlag> flag, std::function<void()> task)
{
return [task, flag]() mutable {
if (flag->alive()) { std::move(task)(); }
};
}
}// namespace sled
#endif// SLED_TASK_QUEUE_PENDING_TASK_SAFETY_FLAG_H

View File

@@ -1,7 +1,9 @@
#pragma once
#include "sled/scoped_refptr.h"
#ifndef SLED_TIMER_QUEUE_TIMEOUT_H
#define SLED_TIMER_QUEUE_TIMEOUT_H
#include "sled/task_queue/pending_task_safety_flag.h"
#include "sled/task_queue/task_queue_base.h"
#include "sled/timer/timeout.h"
#include <limits>
@@ -12,28 +14,24 @@ typedef uint64_t TimeMs;
class TaskQueueTimeoutFactory {
public:
TaskQueueTimeoutFactory(
sled::TaskQueueBase &task_queue,
std::function<TimeMs()> get_time,
std::function<void(TimeoutID timeout_id)> on_expired)
TaskQueueTimeoutFactory(sled::TaskQueueBase &task_queue,
std::function<TimeMs()> get_time,
std::function<void(TimeoutID timeout_id)> on_expired)
: task_queue_(task_queue),
get_time_(get_time),
on_expired_(on_expired)
{}
std::unique_ptr<Timeout>
CreateTimeout(sled::TaskQueueBase::DelayPrecision precision =
sled::TaskQueueBase::DelayPrecision::kHigh)
CreateTimeout(sled::TaskQueueBase::DelayPrecision precision = sled::TaskQueueBase::DelayPrecision::kHigh)
{
return std::unique_ptr<TaskQueueTimeout>(
new TaskQueueTimeout(*this, precision));
return std::unique_ptr<TaskQueueTimeout>(new TaskQueueTimeout(*this, precision));
}
private:
class TaskQueueTimeout : public Timeout {
public:
TaskQueueTimeout(TaskQueueTimeoutFactory &parent,
sled::TaskQueueBase::DelayPrecision precision);
TaskQueueTimeout(TaskQueueTimeoutFactory &parent, sled::TaskQueueBase::DelayPrecision precision);
~TaskQueueTimeout() override;
void Start(DurationMs duration, TimeoutID timeout_id) override;
void Stop() override;
@@ -44,6 +42,7 @@ private:
TimeMs posted_task_expiration_ = std::numeric_limits<TimeMs>::max();
TimeMs timeout_expiration_ = std::numeric_limits<TimeMs>::max();
TimeoutID timeout_id_ = TimeoutID(0);
scoped_refptr<PendingTaskSafetyFlag> safety_flag_;
};
sled::TaskQueueBase &task_queue_;