/* Copyright (c) 2016 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Gabriel Ebner */ #pragma once #include #include #include #include #include "util/thread.h" #include "util/optional.h" #include "util/rc.h" #include "util/message_definitions.h" namespace lean { enum class task_result_state { CREATED, QUEUED, WAITING, EXECUTING, FINISHED, FAILED }; class generic_task; struct generic_task_result_cell { MK_LEAN_RC() void dealloc() { delete this; } generic_task * m_task = nullptr; atomic m_state { task_result_state::CREATED }; std::string m_desc; std::exception_ptr m_ex; virtual ~generic_task_result_cell() { clear_task(); } void clear_task(); generic_task_result_cell(generic_task * t); generic_task_result_cell(std::string const & desc) : m_rc(0), m_state(task_result_state::FINISHED), m_desc(desc) {} bool has_evaluated() const { return m_state.load() > task_result_state::EXECUTING; } virtual void execute_and_store_result() = 0; }; class generic_task_result { friend class task_queue; template friend class task_result; generic_task_result_cell * m_ptr = nullptr; public: generic_task_result(generic_task_result_cell * t) : m_ptr(t) { if (t) t->inc_ref(); } generic_task_result() {} generic_task_result(generic_task_result && t) : m_ptr(t.m_ptr) { t.m_ptr = nullptr; } generic_task_result(generic_task_result const & t) : m_ptr(t.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } ~generic_task_result() { if (m_ptr) m_ptr->dec_ref(); m_ptr = nullptr; } generic_task_result & operator=(generic_task_result const & t) { LEAN_COPY_REF(t); } generic_task_result & operator=(generic_task_result && t) { LEAN_MOVE_REF(t); } bool operator==(generic_task_result const & t) const { return m_ptr == t.m_ptr; } bool operator!=(generic_task_result const & t) const { return !(*this == t); } operator bool() const { return m_ptr != nullptr; } struct hash { size_t operator()(generic_task_result const & t) const { return std::hash()(t.m_ptr); } }; std::string description() const { return m_ptr->m_desc; } void cancel() const; void reset() { *this = nullptr; } }; struct task_priority { unsigned m_prio = static_cast(-1); optional m_not_before; bool operator<(task_priority const & p) const { if (m_prio < p.m_prio) return true; if (m_not_before && p.m_not_before && *m_not_before < *p.m_not_before) return true; if (!m_not_before && p.m_not_before) return true; return false; } void bump(task_priority const & p) { if (p.m_prio < m_prio) m_prio = p.m_prio; if (m_not_before && p.m_not_before && *p.m_not_before < *m_not_before) *m_not_before = *p.m_not_before; } }; typedef std::string module_id; enum class task_kind { parse, elab, print }; module_id get_current_module(); pos_info get_current_task_pos(); class scoped_task_context { module_id * m_old_id; pos_info * m_old_pos; module_id m_id; pos_info m_pos; public: scoped_task_context(module_id const & mod, pos_info const & pos); ~scoped_task_context(); }; struct task_scheduling_data { task_priority m_prio; std::vector m_reverse_deps; condition_variable m_has_finished; }; class generic_task { friend class task_queue; task_scheduling_data m_data; // metadata message_bucket_id m_bucket; module_id m_mod; pos_info m_pos; public: generic_task(); virtual ~generic_task() {} virtual void description(std::ostream &) const = 0; std::string description() const; virtual std::vector get_dependencies() { return {}; } virtual bool is_tiny() const { return false; } virtual bool do_priority_inversion() const { return true; } virtual task_kind get_kind() const { return task_kind::elab; } virtual pos_info get_pos() const { return get_task_pos(); } virtual pos_info get_end_pos() const { return get_pos(); } message_bucket_id const & get_bucket() const { return m_bucket; } period get_version() const { return m_bucket.m_version; } module_id const & get_module_id() const { return m_mod; } pos_info const & get_task_pos() const { return m_pos; } }; template class task : public generic_task { public: typedef T result; virtual ~task() {} virtual T execute() = 0; }; template struct task_result_cell : public generic_task_result_cell { optional m_result; task * get_ptr() { return static_cast *>(m_task); } virtual void execute_and_store_result() override { m_result = { get_ptr()->execute() }; } task_result_cell(task * t) : generic_task_result_cell(t) {} task_result_cell(T const & t, std::string const & desc) : generic_task_result_cell(desc), m_result(t) {} }; template class task_result : public generic_task_result { friend class task_queue; optional const & get_current_result() const { return static_cast *>(m_ptr)->m_result; } public: task_result(task_result_cell * t) : generic_task_result(t) {} task_result() : generic_task_result() {} task_result(task_result const & t) : generic_task_result(t) {} task_result(task_result && t) : generic_task_result(t) {} task_result & operator=(task_result const & t) { LEAN_COPY_REF(t); } task_result & operator=(task_result && t) { LEAN_MOVE_REF(t); } T const & get() const; optional peek() const { if (m_ptr->m_state.load() == task_result_state::FINISHED) { return get_current_result(); } else { return optional(); } } }; template task_result mk_pure_task_result(T const & t, std::string const & desc) { return task_result(new task_result_cell(t, desc)); } class task_cancellation_exception : public std::exception { std::string m_msg; public: task_cancellation_exception() : task_cancellation_exception(generic_task_result()) {} task_cancellation_exception(generic_task_result const & cancelled_task); char const * what() const noexcept override; }; class task_queue { virtual void prepare_task(generic_task_result const &) = 0; protected: task_queue() {} // Friendship forwarding. generic_task_result_cell * unwrap(generic_task_result const & tr) const { return tr.m_ptr; } task_scheduling_data & get_data(generic_task * t) const { return t->m_data; } task_scheduling_data & get_data(generic_task_result const & tr) const { return get_data(unwrap(tr)->m_task); } void set_bucket(generic_task_result const & tr, message_bucket_id const & id) const { tr.m_ptr->m_task->m_bucket = id; } public: virtual ~task_queue() {} virtual optional get_current_task() = 0; virtual bool empty() = 0; virtual void join() = 0; virtual void submit(generic_task_result const &) = 0; template task_result mk_lazy_task(As... args) { task_result task( new task_result_cell( new T(std::forward(args)...))); prepare_task(task); return task; } template task_result submit(As... args) { task_result task = mk_lazy_task(std::forward(args)...); submit(task); return task; } template T const & get_result(task_result const & t) { while (true) { switch (unwrap(t)->m_state.load()) { case task_result_state::FINISHED: return *t.get_current_result(); case task_result_state::FAILED: std::rethrow_exception(unwrap(t)->m_ex); default: wait(t); } } } virtual void wait(generic_task_result const & t) = 0; virtual void cancel(generic_task_result const & t) = 0; virtual void cancel_if(std::function const & pred) = 0; // NOLINT using progress_cb = std::function; // NOLINT // disabling lint because it this this is cast ^^^ virtual void set_progress_callback(progress_cb const &) = 0; }; class scope_global_task_queue { task_queue * m_old_tq; public: scope_global_task_queue(task_queue * tq); ~scope_global_task_queue(); }; task_queue * get_global_task_queue(); template T const & task_result::get() const { return get_global_task_queue()->get_result(*this); } inline void generic_task_result::cancel() const { get_global_task_queue()->cancel(*this); } }