/* Copyright (c) 2016 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Gabriel Ebner */ #pragma once #include #include #include #include #include "util/thread.h" #include "util/optional.h" #include "util/rc.h" namespace lean { enum class task_result_state { QUEUED, EXECUTING, FINISHED, FAILED }; class generic_task; class generic_task_result_cell { MK_LEAN_RC() void dealloc() { delete this; } friend class task_queue; friend class st_task_queue; friend class mt_task_queue; template friend class task_result_cell; friend class generic_task_result; generic_task * m_task = nullptr; atomic m_state { task_result_state::QUEUED }; std::string m_desc; std::exception_ptr m_ex; virtual ~generic_task_result_cell() { clear_task(); } void clear_task(); generic_task_result_cell(generic_task * t); generic_task_result_cell(std::string const & desc) : m_rc(0), m_state(task_result_state::FINISHED), m_desc(desc) {} bool has_evaluated() const { auto state = m_state.load(); return state != task_result_state::QUEUED && state != task_result_state::EXECUTING; } virtual bool execute() = 0; }; class generic_task_result { friend class st_task_queue; friend class mt_task_queue; template friend class task_result; generic_task_result_cell * m_ptr = nullptr; generic_task_result_cell * operator->() const { return m_ptr; } generic_task_result_cell & operator*() const { return *m_ptr; } public: generic_task_result(generic_task_result_cell * t) : m_ptr(t) { if (t) t->inc_ref(); } generic_task_result() {} generic_task_result(generic_task_result && t) : m_ptr(t.m_ptr) { t.m_ptr = nullptr; } generic_task_result(generic_task_result const & t) : m_ptr(t.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } ~generic_task_result() { if (m_ptr) m_ptr->dec_ref(); m_ptr = nullptr; } generic_task_result & operator=(generic_task_result const & t) { LEAN_COPY_REF(t); } generic_task_result & operator=(generic_task_result && t) { LEAN_MOVE_REF(t); } bool operator==(generic_task_result const & t) const { return m_ptr == t.m_ptr; } operator bool() const { return m_ptr != nullptr; } struct hash { size_t operator()(generic_task_result const & t) const { return std::hash()(t.m_ptr); } }; std::string description() const { return m_ptr->m_desc; } void cancel() const; void reset() { *this = nullptr; } }; struct task_priority { unsigned m_prio = static_cast(-1); optional m_not_before; bool operator<(task_priority const & p) const { if (m_prio < p.m_prio) return true; if (m_not_before && p.m_not_before && *m_not_before < *p.m_not_before) return true; if (!m_not_before && p.m_not_before) return true; return false; } void bump(task_priority const & p) { if (p.m_prio < m_prio) m_prio = p.m_prio; if (m_not_before && p.m_not_before && *p.m_not_before < *m_not_before) *m_not_before = *p.m_not_before; } }; class generic_task { friend class task_queue; friend class st_task_queue; friend class mt_task_queue; task_priority m_prio; std::vector m_reverse_deps; condition_variable m_has_finished; public: virtual ~generic_task() {} virtual void description(std::ostream &) const = 0; std::string description() const; virtual std::vector get_dependencies() { return {}; } virtual bool is_tiny() const { return false; } }; template class task : public generic_task { public: typedef T result; virtual ~task() {} virtual T execute() = 0; }; template class task_result_cell : public generic_task_result_cell { friend class task_queue; template friend class task_result; optional m_result; task * get_ptr() { return static_cast *>(m_task); } virtual bool execute() { try { m_result = { get_ptr()->execute() }; return true; } catch (...) { m_ex = std::current_exception(); return false; } } public: task_result_cell(task * t) : generic_task_result_cell(t) {} task_result_cell(T const & t, std::string const & desc) : generic_task_result_cell(desc), m_result(t) {} }; template class task_result : public generic_task_result { friend class task_queue; task_result_cell * get_ptr() { return static_cast *>(m_ptr); } task_result_cell const * get_ptr() const { return static_cast *>(m_ptr); } public: task_result(task_result_cell * t) : generic_task_result(t) {} task_result() : generic_task_result() {} task_result(task_result const & t) : generic_task_result(t) {} task_result(task_result && t) : generic_task_result(t) {} task_result & operator=(task_result const & t) { LEAN_COPY_REF(t); } task_result & operator=(task_result && t) { LEAN_MOVE_REF(t); } T const & get() const; optional peek() const { if (m_ptr->m_state.load() == task_result_state::FINISHED) { return get_ptr()->m_result; } else { return optional(); } } }; template task_result mk_pure_task_result(T const & t, std::string const & desc) { return task_result(new task_result_cell(t, desc)); } class task_cancellation_exception : public std::exception { generic_task_result m_cancelled_task; std::string m_msg; public: task_cancellation_exception() : task_cancellation_exception(generic_task_result()) {} task_cancellation_exception(generic_task_result const & cancelled_task); char const * what() const noexcept override; generic_task_result get_cancelled_task() const { return m_cancelled_task; } }; class task_queue { virtual void submit(generic_task_result const &) = 0; protected: task_queue() {} public: virtual ~task_queue() {} virtual optional get_current_task() = 0; virtual bool empty() = 0; template task_result submit(As... args) { task_result task( new task_result_cell( new T(std::forward(args)...))); submit(task); return task; } template T const & get_result(task_result const & t) { wait(t); lean_assert(t.get_ptr()->m_result); return *t.get_ptr()->m_result; } virtual void wait(generic_task_result const & t) = 0; virtual void cancel(generic_task_result const & t) = 0; using progress_cb = std::function; // NOLINT // disabling lint because it this this is cast ^^^ virtual void set_progress_callback(progress_cb const &) = 0; }; class scope_global_task_queue { task_queue * m_old_tq; public: scope_global_task_queue(task_queue * tq); ~scope_global_task_queue(); }; task_queue & get_global_task_queue(); template T const & task_result::get() const { return get_global_task_queue().get_result(*this); } inline void generic_task_result::cancel() const { get_global_task_queue().cancel(*this); } }