feat(util/task_queue,library/versioned_msg_buf): rudimentary support for task interruption

This commit is contained in:
Gabriel Ebner 2016-11-21 08:01:49 -05:00 committed by Leonardo de Moura
parent f69164d621
commit e1cb1a8cd2
13 changed files with 106 additions and 40 deletions

View file

@ -118,10 +118,12 @@ elaborator::elaborator(environment const & env, options const & opts, metavar_co
}
elaborator::~elaborator() {
if (m_uses_infom && get_global_info_manager()) {
m_info.instantiate_mvars(m_ctx.mctx());
get_global_info_manager()->merge(m_info);
}
try {
if (m_uses_infom && get_global_info_manager()) {
m_info.instantiate_mvars(m_ctx.mctx());
get_global_info_manager()->merge(m_info);
}
} catch (...) {}
}
auto elaborator::mk_pp_ctx() -> pp_fn {

View file

@ -306,15 +306,9 @@ void parser::protected_call(std::function<void()> && f, std::function<void()> &&
} catch (parser_error & ex) {
CATCH((mk_message(ex.m_pos, ERROR) << ex.get_msg()).report(),
throw_parser_exception(ex.what(), ex.m_pos));
} catch (interrupted & ex) {
reset_interrupt();
if (m_verbose)
(mk_message(m_last_cmd_pos, ERROR) << "!!!Interrupted!!!").report();
sync();
if (m_use_exceptions)
throw;
} catch (interrupted) {
throw;
} catch (throwable & ex) {
reset_interrupt();
CATCH(mk_message(m_last_cmd_pos, ERROR).set_exception(ex).report(),
throw_nested_exception(ex, m_last_cmd_pos));
}

View file

@ -12,6 +12,7 @@ Author: Gabriel Ebner
#include "util/lean_path.h"
#include "frontends/lean/parser.h"
#include "library/module.h"
#include "versioned_msg_buf.h"
#include <sys/stat.h>
#include <frontends/lean/pp.h>
#include <util/file_lock.h>
@ -400,4 +401,12 @@ module_id const & get_global_module_id() {
return *g_scoped_module_id;
}
void generic_module_task::set_result(generic_task_result const & self) {
if (m_auto_cancel) {
if (auto vmb = dynamic_cast<versioned_msg_buf *>(m_msg_buf))
vmb->cancel_when_invalidated(m_bucket, self);
}
generic_task::set_result(self);
}
}

View file

@ -146,6 +146,8 @@ public:
m_bucket(get_scope_message_context().new_sub_bucket()),
m_pos(pos), m_auto_cancel(auto_cancel), m_kind(kind) {}
void set_result(generic_task_result const & self) override;
task_kind get_kind() const { return m_kind; }
module_id get_module() const { return m_mod; }
@ -158,6 +160,10 @@ public:
module_task(optional<pos_info> const & pos, task_kind kind, bool auto_cancel = true) :
generic_module_task(pos, kind, auto_cancel) {}
void set_result(generic_task_result const & self) override {
generic_module_task::set_result(self);
}
virtual T execute_core() = 0;
T execute() final override;
@ -170,13 +176,15 @@ T module_task<T>::execute() {
scoped_message_buffer scoped_msg_buf(m_msg_buf);
scope_message_context scope_msg_ctx(m_bucket);
if (m_auto_cancel && !m_msg_buf->is_bucket_valid(m_bucket)) {
throw task_cancellation_exception();
throw interrupted();
}
try {
scope_traces_as_messages scope_traces(get_module(), get_pos_or_something());
return execute_core();
} catch (task_cancellation_exception) {
throw;
} catch (interrupted) {
throw;
} catch (throwable & ex) {
environment env;
message_builder builder(env, m_ios, get_module(), get_pos_or_something(), ERROR);

View file

@ -20,6 +20,7 @@ void versioned_msg_buf::start_bucket(message_bucket_id const & bucket) {
auto & buf = m_buf[bucket.m_bucket];
if (buf.m_version < bucket.m_version) {
buf.m_version = bucket.m_version;
buf.m_cancel_on_invalidation.reset();
buf.m_msgs.clear();
buf.m_infom.reset();
}
@ -30,9 +31,7 @@ void versioned_msg_buf::report(message_bucket_id const & bucket, message const &
unique_lock<mutex> lock(m_mutex);
auto & buf = m_buf[bucket.m_bucket];
if (buf.m_version < bucket.m_version) {
throw exception("missing call to start_bucket");
} else if (buf.m_version == bucket.m_version) {
if (buf.m_version == bucket.m_version) {
buf.m_msgs.push_back(msg);
}
}
@ -56,9 +55,16 @@ void versioned_msg_buf::finish_bucket(message_bucket_id const & bucket, name_set
});
}
void versioned_msg_buf::cancel_bucket(name const & bucket) {
auto & bck_buf = m_buf[bucket];
bck_buf.m_children.for_each([&] (name const & c) { cancel_bucket(c); });
if (auto & t = bck_buf.m_cancel_on_invalidation) { t.cancel(); t.reset(); }
}
void versioned_msg_buf::erase_bucket(name const & bucket) {
m_buf[bucket].m_children.for_each(
[&] (name const & c) { erase_bucket(c); });
auto & bck_buf = m_buf[bucket];
bck_buf.m_children.for_each([&] (name const & c) { erase_bucket(c); });
if (auto & t = bck_buf.m_cancel_on_invalidation) t.cancel();
m_buf.erase(bucket);
}
@ -83,9 +89,7 @@ void versioned_msg_buf::report_info_manager(message_bucket_id const & bucket, in
unique_lock<mutex> lock(m_mutex);
auto & buf = m_buf[bucket.m_bucket];
if (buf.m_version < bucket.m_version) {
throw exception("missing call to start_bucket");
} else if (buf.m_version == bucket.m_version) {
if (buf.m_version == bucket.m_version) {
buf.m_infom = std::unique_ptr<info_manager>(new info_manager(infom));
}
}
@ -110,4 +114,14 @@ std::vector<info_manager> versioned_msg_buf::get_info_managers() {
return result;
}
void versioned_msg_buf::cancel_when_invalidated(message_bucket_id const & bucket, generic_task_result const & t) {
unique_lock<mutex> lock(m_mutex);
auto & buf = m_buf[bucket.m_bucket];
if (buf.m_version < bucket.m_version) {
if (auto & t_old = buf.m_cancel_on_invalidation) t_old.cancel();
buf.m_cancel_on_invalidation = t;
}
}
}

View file

@ -20,12 +20,15 @@ class versioned_msg_buf : public message_buffer {
std::unique_ptr<info_manager> m_infom;
period m_version = 0;
generic_task_result m_cancel_on_invalidation;
name_set m_children;
};
mutex m_mutex;
std::unordered_map<name, msg_bucket, name_hash> m_buf;
void cancel_bucket(name const & bucket);
void erase_bucket(name const & bucket);
bool is_bucket_valid_core(message_bucket_id const & bucket);
@ -38,6 +41,8 @@ public:
bool is_bucket_valid(message_bucket_id const & bucket) override;
void report_info_manager(message_bucket_id const & bucket, info_manager const & infom) override;
void cancel_when_invalidated(message_bucket_id const & bucket, generic_task_result const & t);
std::vector<message> get_messages();
std::vector<info_manager> get_info_managers();
};

View file

@ -61,13 +61,11 @@ public:
};
/** \brief Exception used to sign that a computation was interrupted */
class interrupted : public throwable {
class interrupted {
public:
interrupted() {}
virtual ~interrupted() noexcept {}
virtual char const * what() const noexcept { return "interrupted"; }
virtual throwable * clone() const { return new interrupted(); }
virtual void rethrow() const { throw *this; }
};
class stack_space_exception : public throwable {

View file

@ -25,7 +25,7 @@ bool interrupt_requested() {
}
void check_interrupted() {
if (interrupt_requested()) {
if (interrupt_requested() && !std::uncaught_exception()) {
reset_interrupt();
throw interrupted();
}
@ -51,6 +51,10 @@ void sleep_for(unsigned ms, unsigned step_ms) {
check_interrupted();
}
atomic<bool> *get_interrupt_flag() {
return &get_g_interrupt();
}
atomic_bool * interruptible_thread::get_flag_addr() {
return &get_g_interrupt();
}

View file

@ -11,6 +11,9 @@ Author: Leonardo de Moura
#include "util/exception.h"
namespace lean {
atomic<bool> * get_interrupt_flag();
/**
\brief Mark flag for interrupting current thread.
*/

View file

@ -7,6 +7,8 @@ Author: Gabriel Ebner
#include <string>
#include <algorithm>
#include "util/mt_task_queue.h"
#include "interrupt.h"
#include "flet.h"
#if defined(LEAN_MULTI_THREAD)
namespace lean {
@ -63,6 +65,8 @@ void mt_task_queue::spawn_worker() {
lean_assert(!m_shutting_down);
auto this_worker = std::make_shared<worker_info>();
this_worker->m_thread = thread([=] {
this_worker->m_interrupt_flag = get_interrupt_flag();
scope_global_task_queue scope(this);
unique_lock<mutex> lock(m_mutex);
scoped_add<int> dec_required(m_required_workers, -1);
@ -79,22 +83,22 @@ void mt_task_queue::spawn_worker() {
continue;
}
this_worker->m_current_task = dequeue();
auto & t = this_worker->m_current_task;
auto t = dequeue();
if (t->m_state.load() != task_result_state::QUEUED) continue;
t->m_state = task_result_state::EXECUTING;
bool is_ok;
auto cb = m_progress_cb;
reset_interrupt();
{
flet<generic_task_result> _(this_worker->m_current_task, t);
scoped_current_task scope_cur_task(&t);
lock.unlock();
if (cb) cb(t->m_task);
is_ok = t->execute();
lock.lock();
}
reset_interrupt();
t->m_state = is_ok ? task_result_state::FINISHED : task_result_state::FAILED;
t->m_task->m_has_finished.notify_all();
@ -117,7 +121,6 @@ void mt_task_queue::spawn_worker() {
}
t->clear_task();
this_worker->m_current_task.reset();
}
});
m_workers.push_back(this_worker);
@ -171,7 +174,10 @@ void mt_task_queue::bump_prio(generic_task_result const & t, task_priority const
}
bool mt_task_queue::check_deps(generic_task_result const & t) {
auto deps = t->m_task->get_dependencies();
std::vector<generic_task_result> deps;
try {
deps = t->m_task->get_dependencies();
} catch (...) {}
for (auto & dep : deps) {
if (dep && dep->m_state.load() == task_result_state::QUEUED)
bump_prio(dep, t->m_task->m_prio);
@ -232,6 +238,13 @@ void mt_task_queue::cancel(generic_task_result const & t) {
propagate_failure(t);
t->clear_task();
return;
case task_result_state::EXECUTING:
for (auto & w : m_workers) {
if (w->m_current_task == t) {
w->m_interrupt_flag->store(true);
}
}
return;
default: return;
}
}

View file

@ -31,6 +31,7 @@ class mt_task_queue : public task_queue {
struct worker_info {
thread m_thread;
generic_task_result m_current_task;
atomic<bool> * m_interrupt_flag = nullptr;
};
std::vector<std::shared_ptr<worker_info>> m_workers;
bool m_shutting_down = false;

View file

@ -15,6 +15,8 @@ std::string generic_task::description() const {
return out.str();
}
void generic_task::set_result(generic_task_result const &) {}
generic_task_result_cell::generic_task_result_cell(generic_task * t) :
m_rc(0), m_task(t), m_desc(t->description()) {}
@ -25,6 +27,21 @@ void generic_task_result_cell::clear_task() {
}
}
bool generic_task_result_cell::execute() {
lean_assert(!has_evaluated());
try {
execute_and_store_result();
return true;
} catch (interrupted) {
m_ex = std::make_exception_ptr(
task_cancellation_exception(generic_task_result(this)));
return false;
} catch (...) {
m_ex = std::current_exception();
return false;
}
}
LEAN_THREAD_PTR(task_queue, g_tq);
scope_global_task_queue::scope_global_task_queue(task_queue * tq) {
m_old_tq = g_tq;

View file

@ -45,7 +45,8 @@ class generic_task_result_cell {
return state != task_result_state::QUEUED && state != task_result_state::EXECUTING;
}
virtual bool execute() = 0;
virtual void execute_and_store_result() = 0;
bool execute();
};
class generic_task_result {
@ -117,6 +118,8 @@ public:
std::string description() const;
virtual std::vector<generic_task_result> get_dependencies() { return {}; }
virtual void set_result(generic_task_result const & self);
virtual bool is_tiny() const { return false; }
};
@ -139,14 +142,8 @@ class task_result_cell : public generic_task_result_cell {
task<T> * get_ptr() { return static_cast<task<T> *>(m_task); }
virtual bool execute() {
try {
m_result = { get_ptr()->execute() };
return true;
} catch (...) {
m_ex = std::current_exception();
return false;
}
virtual void execute_and_store_result() override {
m_result = { get_ptr()->execute() };
}
public:
@ -214,6 +211,7 @@ public:
task_result<typename T::result> task(
new task_result_cell<typename T::result>(
new T(std::forward<As>(args)...)));
task->m_task->set_result(task);
submit(task);
return task;
}