diff --git a/src/library/mt_task_queue.cpp b/src/library/mt_task_queue.cpp index bf12f67484..41ffae93da 100644 --- a/src/library/mt_task_queue.cpp +++ b/src/library/mt_task_queue.cpp @@ -55,7 +55,7 @@ mt_task_queue::mt_task_queue(unsigned num_workers, mt_tq_prioritizer const & pri mt_task_queue::~mt_task_queue() { { unique_lock lock(m_mutex); - m_queue_removed.wait(lock, [=] { return m_queue.empty(); }); + m_queue_removed.wait(lock, [=] { return empty_core(); }); m_shutting_down = true; m_queue_added.notify_all(); m_wake_up_worker.notify_all(); @@ -112,16 +112,22 @@ void mt_task_queue::spawn_worker() { t->m_state = is_ok ? task_result_state::FINISHED : task_result_state::FAILED; t->m_task->m_has_finished.notify_all(); - if (t->m_state.load() == task_result_state::FINISHED) { + if (is_ok) { for (auto & rdep : t->m_task->m_reverse_deps) { if (rdep->has_evaluated()) { m_waiting.erase(rdep); } else { - if (m_waiting.count(rdep) && check_deps(rdep)) { - m_waiting.erase(rdep); - if (!rdep->has_evaluated()) { - enqueue(rdep); - } + switch (rdep->m_state.load()) { + case task_result_state::WAITING: + if (check_deps(rdep)) { + m_waiting.erase(rdep); + if (!rdep->has_evaluated()) + enqueue(rdep); + } + break; + case task_result_state::FAILED: break; + default: + lean_unreachable(); } } } @@ -130,6 +136,7 @@ void mt_task_queue::spawn_worker() { } t->clear_task(); + m_queue_removed.notify_all(); } }); m_workers.push_back(this_worker); @@ -137,16 +144,21 @@ void mt_task_queue::spawn_worker() { void mt_task_queue::propagate_failure(generic_task_result const & tr) { lean_assert(tr->m_state.load() == task_result_state::FAILED); + m_waiting.erase(tr); if (auto t = tr->m_task) { tr->m_task->m_has_finished.notify_all(); for (auto & rdep : t->m_reverse_deps) { - if (rdep->m_state.load() != task_result_state::QUEUED) continue; - rdep->m_ex = tr->m_ex; - rdep->m_state = task_result_state::FAILED; - m_waiting.erase(rdep); - propagate_failure(rdep); + switch (rdep->m_state.load()) { + case task_result_state::WAITING: + case task_result_state::QUEUED: + rdep->m_ex = tr->m_ex; + rdep->m_state = task_result_state::FAILED; + propagate_failure(rdep); + break; + default: break; + } } } @@ -158,28 +170,41 @@ void mt_task_queue::submit(generic_task_result const & t) { check_interrupted(); t->m_task->m_prio = m_prioritizer(t->m_task); if (check_deps(t)) { - if (!t->has_evaluated()) enqueue(t); + if (!t->has_evaluated()) { + enqueue(t); + } } else { + t->m_state = task_result_state::WAITING; m_waiting.insert(t); } } void mt_task_queue::bump_prio(generic_task_result const & t, task_priority const & new_prio) { - if (t->m_task && new_prio < t->m_task->m_prio && t->m_state.load() == task_result_state::QUEUED) { - if (!m_waiting.count(t)) { + if (t->m_task && new_prio < t->m_task->m_prio) { + switch (t->m_state.load()) { + case task_result_state::QUEUED: { auto prio = t->m_task->m_prio.m_prio; - auto & q = m_queue[prio]; + auto &q = m_queue[prio]; auto it = std::find(q.begin(), q.end(), t); lean_assert(it != q.end()); q.erase(it); if (q.empty()) m_queue.erase(prio); t->m_task->m_prio.bump(new_prio); + check_deps(t); enqueue(t); - } else { - t->m_task->m_prio.bump(new_prio); + break; + } + case task_result_state::WAITING: + t->m_task->m_prio.bump(new_prio); + check_deps(t); + break; + case task_result_state::EXECUTING: + case task_result_state::FINISHED: + case task_result_state::FAILED: + break; + default: lean_unreachable(); } - check_deps(t); } } @@ -189,14 +214,15 @@ bool mt_task_queue::check_deps(generic_task_result const & t) { deps = t->m_task->get_dependencies(); } catch (...) {} for (auto & dep : deps) { - if (dep && dep->m_state.load() == task_result_state::QUEUED) - bump_prio(dep, t->m_task->m_prio); + if (dep) bump_prio(dep, t->m_task->m_prio); } for (auto & dep : deps) { if (!dep) continue; switch (dep->m_state.load()) { + case task_result_state::WAITING: case task_result_state::QUEUED: case task_result_state::EXECUTING: + lean_assert(dep->m_task); dep->m_task->m_reverse_deps.push_back(t); return false; case task_result_state::FINISHED: @@ -206,6 +232,7 @@ bool mt_task_queue::check_deps(generic_task_result const & t) { t->m_state = task_result_state::FAILED; propagate_failure(t); return true; + default: lean_unreachable(); } } return true; @@ -260,10 +287,11 @@ void mt_task_queue::cancel_if(const std::function & pred) void mt_task_queue::cancel_core(generic_task_result const & t) { switch (t->m_state.load()) { + case task_result_state::WAITING: + m_waiting.erase(t); case task_result_state::QUEUED: t->m_ex = std::make_exception_ptr(task_cancellation_exception(t)); - t->m_state.store(task_result_state::FAILED); - if (t->m_task) t->m_task->m_has_finished.notify_all(); + t->m_state = task_result_state::FAILED; propagate_failure(t); t->clear_task(); return; @@ -283,13 +311,17 @@ void mt_task_queue::cancel(generic_task_result const & t) { cancel_core(t); } -bool mt_task_queue::empty() { - unique_lock lock(m_mutex); +bool mt_task_queue::empty_core() { for (auto & w : m_workers) { if (w->m_current_task) return false; } - return m_queue.empty(); + return m_queue.empty() && m_waiting.empty(); +} + +bool mt_task_queue::empty() { + unique_lock lock(m_mutex); + return empty_core(); } optional mt_task_queue::get_current_task() { @@ -304,23 +336,23 @@ optional mt_task_queue::get_current_task() { generic_task_result mt_task_queue::dequeue() { lean_assert(!m_queue.empty()); - auto & highest_prio = m_queue.begin()->second; + auto it = m_queue.begin(); + auto & highest_prio = it->second; lean_assert(!highest_prio.empty()); auto result = std::move(highest_prio.front()); highest_prio.pop_front(); if (highest_prio.empty()) { - m_queue.erase(m_queue.begin()); + m_queue.erase(it); } - m_queue_removed.notify_all(); return result; } void mt_task_queue::enqueue(generic_task_result const & t) { - if (t->m_state.load() == task_result_state::QUEUED) { - lean_assert(t->m_task); - m_queue[t->m_task->m_prio.m_prio].push_back(t); - m_queue_added.notify_one(); - } + lean_assert(t->m_state.load() < task_result_state::EXECUTING); + lean_assert(t->m_task); + t->m_state = task_result_state::QUEUED; + m_queue[t->m_task->m_prio.m_prio].push_back(t); + m_queue_added.notify_one(); } void mt_task_queue::reprioritize(mt_tq_prioritizer const & p) { diff --git a/src/library/mt_task_queue.h b/src/library/mt_task_queue.h index e7849fb198..fa936f6cbb 100644 --- a/src/library/mt_task_queue.h +++ b/src/library/mt_task_queue.h @@ -48,6 +48,8 @@ class mt_task_queue : public task_queue { io_state m_ios; message_buffer * m_msg_buf; + bool empty_core(); + generic_task_result dequeue(); void enqueue(generic_task_result const &); diff --git a/src/library/task_queue.h b/src/library/task_queue.h index 68ae372695..07dfa189f6 100644 --- a/src/library/task_queue.h +++ b/src/library/task_queue.h @@ -16,7 +16,12 @@ Author: Gabriel Ebner namespace lean { -enum class task_result_state { QUEUED, EXECUTING, FINISHED, FAILED }; +enum class task_result_state { + CREATED, + QUEUED, WAITING, + EXECUTING, + FINISHED, FAILED +}; class generic_task; class generic_task_result_cell { @@ -30,7 +35,7 @@ class generic_task_result_cell { friend class generic_task_result; generic_task * m_task = nullptr; - atomic m_state { task_result_state::QUEUED }; + atomic m_state { task_result_state::CREATED }; std::string m_desc; std::exception_ptr m_ex; @@ -42,8 +47,7 @@ class generic_task_result_cell { m_rc(0), m_state(task_result_state::FINISHED), m_desc(desc) {} bool has_evaluated() const { - auto state = m_state.load(); - return state != task_result_state::QUEUED && state != task_result_state::EXECUTING; + return m_state.load() > task_result_state::EXECUTING; } virtual void execute_and_store_result() = 0;