feat update pending_task_safety_flag
All checks were successful
linux-x64-gcc / linux-gcc (Release) (push) Successful in 1m37s
linux-x64-gcc / linux-gcc (Debug) (push) Successful in 2m0s

This commit is contained in:
tqcq 2024-03-19 15:12:26 +08:00
parent 7e1a443130
commit d716ac684e
4 changed files with 68 additions and 38 deletions

View File

@ -10,11 +10,11 @@
#include "sled/ref_counted_base.h" #include "sled/ref_counted_base.h"
#include "sled/scoped_refptr.h" #include "sled/scoped_refptr.h"
#include <functional>
namespace sled { namespace sled {
class PendingTaskSafetyFlag final class PendingTaskSafetyFlag final : public sled::RefCountedNonVirtual<PendingTaskSafetyFlag> {
: public sled::RefCountedNonVirtual<PendingTaskSafetyFlag> {
public: public:
static sled::scoped_refptr<PendingTaskSafetyFlag> Create(); static sled::scoped_refptr<PendingTaskSafetyFlag> Create();
static sled::scoped_refptr<PendingTaskSafetyFlag> CreateDetached(); static sled::scoped_refptr<PendingTaskSafetyFlag> CreateDetached();
@ -29,11 +29,36 @@ protected:
explicit PendingTaskSafetyFlag(bool alive) : alive_(alive) {} explicit PendingTaskSafetyFlag(bool alive) : alive_(alive) {}
private: private:
static sled::scoped_refptr<PendingTaskSafetyFlag> static sled::scoped_refptr<PendingTaskSafetyFlag> CreateInternal(bool alive);
CreateInternal(bool alive);
bool alive_ = true; bool alive_ = true;
}; };
class ScopedTaskSafety final {
public:
ScopedTaskSafety() = default;
explicit ScopedTaskSafety(scoped_refptr<PendingTaskSafetyFlag> flag) : flag_(std::move(flag)) {}
~ScopedTaskSafety() { flag_->SetNotAlive(); }
void reset(scoped_refptr<PendingTaskSafetyFlag> new_flag = PendingTaskSafetyFlag::Create())
{
flag_->SetNotAlive();
flag_ = std::move(new_flag);
}
private:
scoped_refptr<PendingTaskSafetyFlag> flag_;
};
inline std::function<void()>
SafeTask(scoped_refptr<PendingTaskSafetyFlag> flag, std::function<void()> task)
{
return [task, flag]() mutable {
if (flag->alive()) { std::move(task)(); }
};
}
}// namespace sled }// namespace sled
#endif// SLED_TASK_QUEUE_PENDING_TASK_SAFETY_FLAG_H #endif// SLED_TASK_QUEUE_PENDING_TASK_SAFETY_FLAG_H

View File

@ -1,7 +1,9 @@
#pragma once #pragma once
#include "sled/scoped_refptr.h"
#ifndef SLED_TIMER_QUEUE_TIMEOUT_H #ifndef SLED_TIMER_QUEUE_TIMEOUT_H
#define SLED_TIMER_QUEUE_TIMEOUT_H #define SLED_TIMER_QUEUE_TIMEOUT_H
#include "sled/task_queue/pending_task_safety_flag.h"
#include "sled/task_queue/task_queue_base.h" #include "sled/task_queue/task_queue_base.h"
#include "sled/timer/timeout.h" #include "sled/timer/timeout.h"
#include <limits> #include <limits>
@ -12,8 +14,7 @@ typedef uint64_t TimeMs;
class TaskQueueTimeoutFactory { class TaskQueueTimeoutFactory {
public: public:
TaskQueueTimeoutFactory( TaskQueueTimeoutFactory(sled::TaskQueueBase &task_queue,
sled::TaskQueueBase &task_queue,
std::function<TimeMs()> get_time, std::function<TimeMs()> get_time,
std::function<void(TimeoutID timeout_id)> on_expired) std::function<void(TimeoutID timeout_id)> on_expired)
: task_queue_(task_queue), : task_queue_(task_queue),
@ -22,18 +23,15 @@ public:
{} {}
std::unique_ptr<Timeout> std::unique_ptr<Timeout>
CreateTimeout(sled::TaskQueueBase::DelayPrecision precision = CreateTimeout(sled::TaskQueueBase::DelayPrecision precision = sled::TaskQueueBase::DelayPrecision::kHigh)
sled::TaskQueueBase::DelayPrecision::kHigh)
{ {
return std::unique_ptr<TaskQueueTimeout>( return std::unique_ptr<TaskQueueTimeout>(new TaskQueueTimeout(*this, precision));
new TaskQueueTimeout(*this, precision));
} }
private: private:
class TaskQueueTimeout : public Timeout { class TaskQueueTimeout : public Timeout {
public: public:
TaskQueueTimeout(TaskQueueTimeoutFactory &parent, TaskQueueTimeout(TaskQueueTimeoutFactory &parent, sled::TaskQueueBase::DelayPrecision precision);
sled::TaskQueueBase::DelayPrecision precision);
~TaskQueueTimeout() override; ~TaskQueueTimeout() override;
void Start(DurationMs duration, TimeoutID timeout_id) override; void Start(DurationMs duration, TimeoutID timeout_id) override;
void Stop() override; void Stop() override;
@ -44,6 +42,7 @@ private:
TimeMs posted_task_expiration_ = std::numeric_limits<TimeMs>::max(); TimeMs posted_task_expiration_ = std::numeric_limits<TimeMs>::max();
TimeMs timeout_expiration_ = std::numeric_limits<TimeMs>::max(); TimeMs timeout_expiration_ = std::numeric_limits<TimeMs>::max();
TimeoutID timeout_id_ = TimeoutID(0); TimeoutID timeout_id_ = TimeoutID(0);
scoped_refptr<PendingTaskSafetyFlag> safety_flag_;
}; };
sled::TaskQueueBase &task_queue_; sled::TaskQueueBase &task_queue_;

View File

@ -6,8 +6,7 @@ sled::scoped_refptr<PendingTaskSafetyFlag>
PendingTaskSafetyFlag::CreateInternal(bool alive) PendingTaskSafetyFlag::CreateInternal(bool alive)
{ {
// Explicit new, to access private constructor. // Explicit new, to access private constructor.
return sled::scoped_refptr<PendingTaskSafetyFlag>( return sled::scoped_refptr<PendingTaskSafetyFlag>(new PendingTaskSafetyFlag(alive));
new PendingTaskSafetyFlag(alive));
} }
sled::scoped_refptr<PendingTaskSafetyFlag> sled::scoped_refptr<PendingTaskSafetyFlag>
@ -31,7 +30,7 @@ PendingTaskSafetyFlag::CreateDetachedInactive()
void void
PendingTaskSafetyFlag::SetNotAlive() PendingTaskSafetyFlag::SetNotAlive()
{ {
alive_ = true; alive_ = false;
} }
void void

View File

@ -1,5 +1,6 @@
#include "sled/timer/task_queue_timeout.h" #include "sled/timer/task_queue_timeout.h"
#include "sled/log/log.h" #include "sled/log/log.h"
#include "sled/task_queue/pending_task_safety_flag.h"
#include "sled/units/time_delta.h" #include "sled/units/time_delta.h"
namespace sled { namespace sled {
@ -23,15 +24,18 @@ TaskQueueTimeoutFactory::TaskQueueTimeout::Start(DurationMs duration_ms, Timeout
LOGV("timer", LOGV("timer",
"New timeout duration is less than scheduled - " "New timeout duration is less than scheduled - "
"ghosting old delayed task"); "ghosting old delayed task");
safety_flag_->SetNotAlive();
safety_flag_ = PendingTaskSafetyFlag::Create();
} }
posted_task_expiration_ = timeout_expiration_; posted_task_expiration_ = timeout_expiration_;
auto safety_flag = safety_flag_;
parent_.task_queue_.PostDelayedTaskWithPrecision( parent_.task_queue_.PostDelayedTaskWithPrecision(
precision_, precision_,
SafeTask(safety_flag_,
[timeout_id, this]() { [timeout_id, this]() {
if (timeout_id != this->timeout_id_) { return; } // if (timeout_id != this->timeout_id_) { return; }
LOGV("timer", "Timeout expired: {}", timeout_id); LOGV("timer", "Timeout expired: {}", timeout_id);
ASSERT(posted_task_expiration_ != std::numeric_limits<TimeMs>::max(), ""); ASSERT(posted_task_expiration_ != std::numeric_limits<TimeMs>::max(), "");
posted_task_expiration_ = std::numeric_limits<TimeMs>::max(); posted_task_expiration_ = std::numeric_limits<TimeMs>::max();
@ -40,17 +44,20 @@ TaskQueueTimeoutFactory::TaskQueueTimeout::Start(DurationMs duration_ms, Timeout
// do nothing // do nothing
} else { } else {
const TimeMs now = parent_.get_time_(); const TimeMs now = parent_.get_time_();
if (timeout_expiration_ <= now) { const DurationMs remaining = timeout_expiration_ - now;
bool is_expired = timeout_expiration_ <= now;
timeout_expiration_ = std::numeric_limits<TimeMs>::max(); timeout_expiration_ = std::numeric_limits<TimeMs>::max();
if (!is_expired) {
// continue wait
Start(remaining, timeout_id);
} else {
LOGV("timer", "Timeout Triggered: {}", timeout_id); LOGV("timer", "Timeout Triggered: {}", timeout_id);
parent_.on_expired_(timeout_id_); parent_.on_expired_(timeout_id_);
} else {
const DurationMs remaining = timeout_expiration_ - now;
timeout_expiration_ = std::numeric_limits<TimeMs>::max();
Start(remaining, timeout_id);
} }
} }
}, }),
sled::TimeDelta::Millis(duration_ms)); sled::TimeDelta::Millis(duration_ms));
} }