From e1cb1a8cd247c5d8cbacbef5b4cdf637644ec425 Mon Sep 17 00:00:00 2001 From: Gabriel Ebner Date: Mon, 21 Nov 2016 08:01:49 -0500 Subject: [PATCH] feat(util/task_queue,library/versioned_msg_buf): rudimentary support for task interruption --- src/frontends/lean/elaborator.cpp | 10 ++++++---- src/frontends/lean/parser.cpp | 10 ++-------- src/library/module_mgr.cpp | 9 +++++++++ src/library/module_mgr.h | 10 +++++++++- src/library/versioned_msg_buf.cpp | 30 ++++++++++++++++++++++-------- src/library/versioned_msg_buf.h | 5 +++++ src/util/exception.h | 4 +--- src/util/interrupt.cpp | 6 +++++- src/util/interrupt.h | 3 +++ src/util/mt_task_queue.cpp | 25 +++++++++++++++++++------ src/util/mt_task_queue.h | 1 + src/util/task_queue.cpp | 17 +++++++++++++++++ src/util/task_queue.h | 16 +++++++--------- 13 files changed, 106 insertions(+), 40 deletions(-) diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 8b6760c9a2..94199509fb 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -118,10 +118,12 @@ elaborator::elaborator(environment const & env, options const & opts, metavar_co } elaborator::~elaborator() { - if (m_uses_infom && get_global_info_manager()) { - m_info.instantiate_mvars(m_ctx.mctx()); - get_global_info_manager()->merge(m_info); - } + try { + if (m_uses_infom && get_global_info_manager()) { + m_info.instantiate_mvars(m_ctx.mctx()); + get_global_info_manager()->merge(m_info); + } + } catch (...) {} } auto elaborator::mk_pp_ctx() -> pp_fn { diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 2a189828fe..918aabcecc 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -306,15 +306,9 @@ void parser::protected_call(std::function && f, std::function && } catch (parser_error & ex) { CATCH((mk_message(ex.m_pos, ERROR) << ex.get_msg()).report(), throw_parser_exception(ex.what(), ex.m_pos)); - } catch (interrupted & ex) { - reset_interrupt(); - if (m_verbose) - (mk_message(m_last_cmd_pos, ERROR) << "!!!Interrupted!!!").report(); - sync(); - if (m_use_exceptions) - throw; + } catch (interrupted) { + throw; } catch (throwable & ex) { - reset_interrupt(); CATCH(mk_message(m_last_cmd_pos, ERROR).set_exception(ex).report(), throw_nested_exception(ex, m_last_cmd_pos)); } diff --git a/src/library/module_mgr.cpp b/src/library/module_mgr.cpp index c8d0ef9aef..2501956cb4 100644 --- a/src/library/module_mgr.cpp +++ b/src/library/module_mgr.cpp @@ -12,6 +12,7 @@ Author: Gabriel Ebner #include "util/lean_path.h" #include "frontends/lean/parser.h" #include "library/module.h" +#include "versioned_msg_buf.h" #include #include #include @@ -400,4 +401,12 @@ module_id const & get_global_module_id() { return *g_scoped_module_id; } +void generic_module_task::set_result(generic_task_result const & self) { + if (m_auto_cancel) { + if (auto vmb = dynamic_cast(m_msg_buf)) + vmb->cancel_when_invalidated(m_bucket, self); + } + generic_task::set_result(self); +} + } diff --git a/src/library/module_mgr.h b/src/library/module_mgr.h index ac716c0521..e3c9255624 100644 --- a/src/library/module_mgr.h +++ b/src/library/module_mgr.h @@ -146,6 +146,8 @@ public: m_bucket(get_scope_message_context().new_sub_bucket()), m_pos(pos), m_auto_cancel(auto_cancel), m_kind(kind) {} + void set_result(generic_task_result const & self) override; + task_kind get_kind() const { return m_kind; } module_id get_module() const { return m_mod; } @@ -158,6 +160,10 @@ public: module_task(optional const & pos, task_kind kind, bool auto_cancel = true) : generic_module_task(pos, kind, auto_cancel) {} + void set_result(generic_task_result const & self) override { + generic_module_task::set_result(self); + } + virtual T execute_core() = 0; T execute() final override; @@ -170,13 +176,15 @@ T module_task::execute() { scoped_message_buffer scoped_msg_buf(m_msg_buf); scope_message_context scope_msg_ctx(m_bucket); if (m_auto_cancel && !m_msg_buf->is_bucket_valid(m_bucket)) { - throw task_cancellation_exception(); + throw interrupted(); } try { scope_traces_as_messages scope_traces(get_module(), get_pos_or_something()); return execute_core(); } catch (task_cancellation_exception) { throw; + } catch (interrupted) { + throw; } catch (throwable & ex) { environment env; message_builder builder(env, m_ios, get_module(), get_pos_or_something(), ERROR); diff --git a/src/library/versioned_msg_buf.cpp b/src/library/versioned_msg_buf.cpp index cc00fed140..2717f3e67d 100644 --- a/src/library/versioned_msg_buf.cpp +++ b/src/library/versioned_msg_buf.cpp @@ -20,6 +20,7 @@ void versioned_msg_buf::start_bucket(message_bucket_id const & bucket) { auto & buf = m_buf[bucket.m_bucket]; if (buf.m_version < bucket.m_version) { buf.m_version = bucket.m_version; + buf.m_cancel_on_invalidation.reset(); buf.m_msgs.clear(); buf.m_infom.reset(); } @@ -30,9 +31,7 @@ void versioned_msg_buf::report(message_bucket_id const & bucket, message const & unique_lock lock(m_mutex); auto & buf = m_buf[bucket.m_bucket]; - if (buf.m_version < bucket.m_version) { - throw exception("missing call to start_bucket"); - } else if (buf.m_version == bucket.m_version) { + if (buf.m_version == bucket.m_version) { buf.m_msgs.push_back(msg); } } @@ -56,9 +55,16 @@ void versioned_msg_buf::finish_bucket(message_bucket_id const & bucket, name_set }); } +void versioned_msg_buf::cancel_bucket(name const & bucket) { + auto & bck_buf = m_buf[bucket]; + bck_buf.m_children.for_each([&] (name const & c) { cancel_bucket(c); }); + if (auto & t = bck_buf.m_cancel_on_invalidation) { t.cancel(); t.reset(); } +} + void versioned_msg_buf::erase_bucket(name const & bucket) { - m_buf[bucket].m_children.for_each( - [&] (name const & c) { erase_bucket(c); }); + auto & bck_buf = m_buf[bucket]; + bck_buf.m_children.for_each([&] (name const & c) { erase_bucket(c); }); + if (auto & t = bck_buf.m_cancel_on_invalidation) t.cancel(); m_buf.erase(bucket); } @@ -83,9 +89,7 @@ void versioned_msg_buf::report_info_manager(message_bucket_id const & bucket, in unique_lock lock(m_mutex); auto & buf = m_buf[bucket.m_bucket]; - if (buf.m_version < bucket.m_version) { - throw exception("missing call to start_bucket"); - } else if (buf.m_version == bucket.m_version) { + if (buf.m_version == bucket.m_version) { buf.m_infom = std::unique_ptr(new info_manager(infom)); } } @@ -110,4 +114,14 @@ std::vector versioned_msg_buf::get_info_managers() { return result; } +void versioned_msg_buf::cancel_when_invalidated(message_bucket_id const & bucket, generic_task_result const & t) { + unique_lock lock(m_mutex); + + auto & buf = m_buf[bucket.m_bucket]; + if (buf.m_version < bucket.m_version) { + if (auto & t_old = buf.m_cancel_on_invalidation) t_old.cancel(); + buf.m_cancel_on_invalidation = t; + } +} + } diff --git a/src/library/versioned_msg_buf.h b/src/library/versioned_msg_buf.h index c5e46cb82a..7149ddbb71 100644 --- a/src/library/versioned_msg_buf.h +++ b/src/library/versioned_msg_buf.h @@ -20,12 +20,15 @@ class versioned_msg_buf : public message_buffer { std::unique_ptr m_infom; period m_version = 0; + generic_task_result m_cancel_on_invalidation; + name_set m_children; }; mutex m_mutex; std::unordered_map m_buf; + void cancel_bucket(name const & bucket); void erase_bucket(name const & bucket); bool is_bucket_valid_core(message_bucket_id const & bucket); @@ -38,6 +41,8 @@ public: bool is_bucket_valid(message_bucket_id const & bucket) override; void report_info_manager(message_bucket_id const & bucket, info_manager const & infom) override; + void cancel_when_invalidated(message_bucket_id const & bucket, generic_task_result const & t); + std::vector get_messages(); std::vector get_info_managers(); }; diff --git a/src/util/exception.h b/src/util/exception.h index 98d768e4ed..6ef0d8fd93 100644 --- a/src/util/exception.h +++ b/src/util/exception.h @@ -61,13 +61,11 @@ public: }; /** \brief Exception used to sign that a computation was interrupted */ -class interrupted : public throwable { +class interrupted { public: interrupted() {} virtual ~interrupted() noexcept {} virtual char const * what() const noexcept { return "interrupted"; } - virtual throwable * clone() const { return new interrupted(); } - virtual void rethrow() const { throw *this; } }; class stack_space_exception : public throwable { diff --git a/src/util/interrupt.cpp b/src/util/interrupt.cpp index 4d92153d1c..18b1a87c5e 100644 --- a/src/util/interrupt.cpp +++ b/src/util/interrupt.cpp @@ -25,7 +25,7 @@ bool interrupt_requested() { } void check_interrupted() { - if (interrupt_requested()) { + if (interrupt_requested() && !std::uncaught_exception()) { reset_interrupt(); throw interrupted(); } @@ -51,6 +51,10 @@ void sleep_for(unsigned ms, unsigned step_ms) { check_interrupted(); } +atomic *get_interrupt_flag() { + return &get_g_interrupt(); +} + atomic_bool * interruptible_thread::get_flag_addr() { return &get_g_interrupt(); } diff --git a/src/util/interrupt.h b/src/util/interrupt.h index bca8028945..4cbc7fa53d 100644 --- a/src/util/interrupt.h +++ b/src/util/interrupt.h @@ -11,6 +11,9 @@ Author: Leonardo de Moura #include "util/exception.h" namespace lean { + +atomic * get_interrupt_flag(); + /** \brief Mark flag for interrupting current thread. */ diff --git a/src/util/mt_task_queue.cpp b/src/util/mt_task_queue.cpp index 40a5ac8c2e..e733d2d422 100644 --- a/src/util/mt_task_queue.cpp +++ b/src/util/mt_task_queue.cpp @@ -7,6 +7,8 @@ Author: Gabriel Ebner #include #include #include "util/mt_task_queue.h" +#include "interrupt.h" +#include "flet.h" #if defined(LEAN_MULTI_THREAD) namespace lean { @@ -63,6 +65,8 @@ void mt_task_queue::spawn_worker() { lean_assert(!m_shutting_down); auto this_worker = std::make_shared(); this_worker->m_thread = thread([=] { + this_worker->m_interrupt_flag = get_interrupt_flag(); + scope_global_task_queue scope(this); unique_lock lock(m_mutex); scoped_add dec_required(m_required_workers, -1); @@ -79,22 +83,22 @@ void mt_task_queue::spawn_worker() { continue; } - this_worker->m_current_task = dequeue(); - - auto & t = this_worker->m_current_task; - + auto t = dequeue(); if (t->m_state.load() != task_result_state::QUEUED) continue; t->m_state = task_result_state::EXECUTING; bool is_ok; auto cb = m_progress_cb; + reset_interrupt(); { + flet _(this_worker->m_current_task, t); scoped_current_task scope_cur_task(&t); lock.unlock(); if (cb) cb(t->m_task); is_ok = t->execute(); lock.lock(); } + reset_interrupt(); t->m_state = is_ok ? task_result_state::FINISHED : task_result_state::FAILED; t->m_task->m_has_finished.notify_all(); @@ -117,7 +121,6 @@ void mt_task_queue::spawn_worker() { } t->clear_task(); - this_worker->m_current_task.reset(); } }); m_workers.push_back(this_worker); @@ -171,7 +174,10 @@ void mt_task_queue::bump_prio(generic_task_result const & t, task_priority const } bool mt_task_queue::check_deps(generic_task_result const & t) { - auto deps = t->m_task->get_dependencies(); + std::vector deps; + try { + deps = t->m_task->get_dependencies(); + } catch (...) {} for (auto & dep : deps) { if (dep && dep->m_state.load() == task_result_state::QUEUED) bump_prio(dep, t->m_task->m_prio); @@ -232,6 +238,13 @@ void mt_task_queue::cancel(generic_task_result const & t) { propagate_failure(t); t->clear_task(); return; + case task_result_state::EXECUTING: + for (auto & w : m_workers) { + if (w->m_current_task == t) { + w->m_interrupt_flag->store(true); + } + } + return; default: return; } } diff --git a/src/util/mt_task_queue.h b/src/util/mt_task_queue.h index b4b68e17d1..959e7f2963 100644 --- a/src/util/mt_task_queue.h +++ b/src/util/mt_task_queue.h @@ -31,6 +31,7 @@ class mt_task_queue : public task_queue { struct worker_info { thread m_thread; generic_task_result m_current_task; + atomic * m_interrupt_flag = nullptr; }; std::vector> m_workers; bool m_shutting_down = false; diff --git a/src/util/task_queue.cpp b/src/util/task_queue.cpp index 5086a80c6e..f5dbac5389 100644 --- a/src/util/task_queue.cpp +++ b/src/util/task_queue.cpp @@ -15,6 +15,8 @@ std::string generic_task::description() const { return out.str(); } +void generic_task::set_result(generic_task_result const &) {} + generic_task_result_cell::generic_task_result_cell(generic_task * t) : m_rc(0), m_task(t), m_desc(t->description()) {} @@ -25,6 +27,21 @@ void generic_task_result_cell::clear_task() { } } +bool generic_task_result_cell::execute() { + lean_assert(!has_evaluated()); + try { + execute_and_store_result(); + return true; + } catch (interrupted) { + m_ex = std::make_exception_ptr( + task_cancellation_exception(generic_task_result(this))); + return false; + } catch (...) { + m_ex = std::current_exception(); + return false; + } +} + LEAN_THREAD_PTR(task_queue, g_tq); scope_global_task_queue::scope_global_task_queue(task_queue * tq) { m_old_tq = g_tq; diff --git a/src/util/task_queue.h b/src/util/task_queue.h index 90a77fe306..db95d22e4c 100644 --- a/src/util/task_queue.h +++ b/src/util/task_queue.h @@ -45,7 +45,8 @@ class generic_task_result_cell { return state != task_result_state::QUEUED && state != task_result_state::EXECUTING; } - virtual bool execute() = 0; + virtual void execute_and_store_result() = 0; + bool execute(); }; class generic_task_result { @@ -117,6 +118,8 @@ public: std::string description() const; virtual std::vector get_dependencies() { return {}; } + virtual void set_result(generic_task_result const & self); + virtual bool is_tiny() const { return false; } }; @@ -139,14 +142,8 @@ class task_result_cell : public generic_task_result_cell { task * get_ptr() { return static_cast *>(m_task); } - virtual bool execute() { - try { - m_result = { get_ptr()->execute() }; - return true; - } catch (...) { - m_ex = std::current_exception(); - return false; - } + virtual void execute_and_store_result() override { + m_result = { get_ptr()->execute() }; } public: @@ -214,6 +211,7 @@ public: task_result task( new task_result_cell( new T(std::forward(args)...))); + task->m_task->set_result(task); submit(task); return task; }